ollamadiffuser 1.2.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ollamadiffuser/__init__.py +1 -1
  2. ollamadiffuser/api/server.py +312 -312
  3. ollamadiffuser/cli/config_commands.py +119 -0
  4. ollamadiffuser/cli/lora_commands.py +169 -0
  5. ollamadiffuser/cli/main.py +85 -1233
  6. ollamadiffuser/cli/model_commands.py +664 -0
  7. ollamadiffuser/cli/recommend_command.py +205 -0
  8. ollamadiffuser/cli/registry_commands.py +197 -0
  9. ollamadiffuser/core/config/model_registry.py +562 -11
  10. ollamadiffuser/core/config/settings.py +24 -2
  11. ollamadiffuser/core/inference/__init__.py +5 -0
  12. ollamadiffuser/core/inference/base.py +182 -0
  13. ollamadiffuser/core/inference/engine.py +204 -1405
  14. ollamadiffuser/core/inference/strategies/__init__.py +1 -0
  15. ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
  16. ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
  17. ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
  18. ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
  19. ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
  20. ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
  21. ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
  22. ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
  23. ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
  24. ollamadiffuser/mcp/__init__.py +0 -0
  25. ollamadiffuser/mcp/server.py +184 -0
  26. ollamadiffuser/ui/templates/index.html +62 -1
  27. ollamadiffuser/ui/web.py +116 -54
  28. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +337 -108
  29. ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
  30. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
  31. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
  32. ollamadiffuser/core/models/registry.py +0 -384
  33. ollamadiffuser/ui/samples/.DS_Store +0 -0
  34. ollamadiffuser-1.2.2.dist-info/RECORD +0 -45
  35. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
  36. {ollamadiffuser-1.2.2.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ """Inference strategies for different model types"""
@@ -0,0 +1,170 @@
1
+ """ControlNet inference strategy"""
2
+
3
+ import logging
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
10
+ from ...config.settings import ModelConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ControlNetStrategy(InferenceStrategy):
16
+ """Strategy for ControlNet models (SD 1.5 and SDXL based)"""
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.controlnet = None
21
+ self.is_controlnet_pipeline = True
22
+
23
+ def load(self, model_config: ModelConfig, device: str) -> bool:
24
+ try:
25
+ from diffusers import (
26
+ ControlNetModel,
27
+ StableDiffusionControlNetPipeline,
28
+ StableDiffusionXLControlNetPipeline,
29
+ )
30
+
31
+ self.device = device
32
+ self.model_config = model_config
33
+
34
+ # Determine if SD15 or SDXL based
35
+ is_sdxl = model_config.model_type == "controlnet_sdxl"
36
+ pipeline_class = StableDiffusionXLControlNetPipeline if is_sdxl else StableDiffusionControlNetPipeline
37
+
38
+ load_kwargs = {**SAFETY_DISABLED_KWARGS}
39
+ if device in ("cpu", "mps"):
40
+ load_kwargs["torch_dtype"] = torch.float32
41
+ elif model_config.variant == "fp16":
42
+ load_kwargs["torch_dtype"] = torch.float16
43
+ load_kwargs["variant"] = "fp16"
44
+ else:
45
+ load_kwargs["torch_dtype"] = self._get_dtype(device)
46
+
47
+ # Load ControlNet model
48
+ logger.info(f"Loading ControlNet model from: {model_config.path}")
49
+ self.controlnet = ControlNetModel.from_pretrained(
50
+ model_config.path,
51
+ torch_dtype=load_kwargs.get("torch_dtype", torch.float32),
52
+ )
53
+
54
+ # Get base model path
55
+ base_model_name = getattr(model_config, "base_model", None)
56
+ if not base_model_name:
57
+ from ...models.manager import model_manager
58
+ model_info = model_manager.get_model_info(model_config.name)
59
+ if model_info and "base_model" in model_info:
60
+ base_model_name = model_info["base_model"]
61
+ else:
62
+ raise ValueError(f"No base model specified for ControlNet: {model_config.name}")
63
+
64
+ from ...models.manager import model_manager
65
+ if not model_manager.is_model_installed(base_model_name):
66
+ raise ValueError(f"Base model '{base_model_name}' not installed")
67
+
68
+ from ...config.settings import settings
69
+ base_config = settings.models[base_model_name]
70
+
71
+ # Load pipeline with controlnet
72
+ self.pipeline = pipeline_class.from_pretrained(
73
+ base_config.path,
74
+ controlnet=self.controlnet,
75
+ **load_kwargs,
76
+ )
77
+
78
+ self._move_to_device(device)
79
+ self.controlnet = self.controlnet.to(self.device)
80
+ self._apply_memory_optimizations()
81
+
82
+ logger.info(f"ControlNet model {model_config.name} loaded on {self.device}")
83
+ return True
84
+ except Exception as e:
85
+ logger.error(f"Failed to load ControlNet model: {e}")
86
+ return False
87
+
88
+ def generate(
89
+ self,
90
+ prompt: str,
91
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
92
+ num_inference_steps: Optional[int] = None,
93
+ guidance_scale: Optional[float] = None,
94
+ width: int = 512,
95
+ height: int = 512,
96
+ seed: Optional[int] = None,
97
+ control_image: Optional[Union[Image.Image, str]] = None,
98
+ controlnet_conditioning_scale: float = 1.0,
99
+ control_guidance_start: float = 0.0,
100
+ control_guidance_end: float = 1.0,
101
+ **kwargs,
102
+ ) -> Image.Image:
103
+ if not self.pipeline:
104
+ raise RuntimeError("Model not loaded")
105
+ if control_image is None:
106
+ raise ValueError("ControlNet requires a control_image")
107
+
108
+ params = self.model_config.parameters or {}
109
+ steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 50)
110
+ guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
111
+
112
+ # Prepare control image
113
+ control_image = self._prepare_control_image(control_image, width, height)
114
+
115
+ generator, used_seed = self._make_generator(seed, self.device)
116
+
117
+ gen_kwargs = {
118
+ "prompt": prompt,
119
+ "negative_prompt": negative_prompt,
120
+ "num_inference_steps": steps,
121
+ "guidance_scale": guidance,
122
+ "width": width,
123
+ "height": height,
124
+ "image": control_image,
125
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
126
+ "control_guidance_start": control_guidance_start,
127
+ "control_guidance_end": control_guidance_end,
128
+ "generator": generator,
129
+ }
130
+
131
+ try:
132
+ logger.info(f"Generating ControlNet image: steps={steps}, guidance={guidance}, seed={used_seed}")
133
+ output = self.pipeline(**gen_kwargs)
134
+ return output.images[0]
135
+ except Exception as e:
136
+ logger.error(f"ControlNet generation failed: {e}")
137
+ return self._create_error_image(str(e), prompt)
138
+
139
+ def _prepare_control_image(
140
+ self, control_image: Union[Image.Image, str], width: int, height: int
141
+ ) -> Image.Image:
142
+ """Prepare and preprocess control image"""
143
+ from ...utils.controlnet_preprocessors import controlnet_preprocessor
144
+
145
+ if isinstance(control_image, str):
146
+ control_image = Image.open(control_image).convert("RGB")
147
+ elif not isinstance(control_image, Image.Image):
148
+ raise ValueError("control_image must be PIL Image or file path")
149
+
150
+ if control_image.mode != "RGB":
151
+ control_image = control_image.convert("RGB")
152
+
153
+ # Get controlnet type from model info
154
+ from ...models.manager import model_manager
155
+ model_info = model_manager.get_model_info(self.model_config.name)
156
+ cn_type = model_info.get("controlnet_type", "canny") if model_info else "canny"
157
+
158
+ # Initialize preprocessor if needed
159
+ if not controlnet_preprocessor.is_initialized():
160
+ controlnet_preprocessor.initialize()
161
+
162
+ processed = controlnet_preprocessor.preprocess(control_image, cn_type)
163
+ processed = controlnet_preprocessor.resize_for_controlnet(processed, width, height)
164
+ return processed
165
+
166
+ def unload(self) -> None:
167
+ if self.controlnet:
168
+ del self.controlnet
169
+ self.controlnet = None
170
+ super().unload()
@@ -0,0 +1,136 @@
1
+ """FLUX inference strategy"""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
10
+ from ...config.settings import ModelConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FluxStrategy(InferenceStrategy):
16
+ """Strategy for FLUX models (FLUX.1-dev, FLUX.1-schnell)"""
17
+
18
+ def load(self, model_config: ModelConfig, device: str) -> bool:
19
+ try:
20
+ import diffusers
21
+
22
+ self.device = device
23
+ self.model_config = model_config
24
+
25
+ # Resolve pipeline class from parameters, defaulting to FluxPipeline
26
+ params = model_config.parameters or {}
27
+ pipeline_class_name = params.get("pipeline_class", "FluxPipeline")
28
+ pipeline_cls = getattr(diffusers, pipeline_class_name, None)
29
+ if pipeline_cls is None:
30
+ logger.error(
31
+ f"Pipeline class '{pipeline_class_name}' not found in diffusers. "
32
+ "You may need to upgrade: pip install --upgrade diffusers"
33
+ )
34
+ return False
35
+
36
+ load_kwargs = {**SAFETY_DISABLED_KWARGS}
37
+
38
+ if device == "cpu":
39
+ load_kwargs["torch_dtype"] = torch.float32
40
+ logger.warning("FLUX on CPU will be very slow for this 12B parameter model")
41
+ else:
42
+ load_kwargs["torch_dtype"] = torch.bfloat16
43
+ load_kwargs["use_safetensors"] = True
44
+
45
+ self.pipeline = pipeline_cls.from_pretrained(
46
+ model_config.path, **load_kwargs
47
+ )
48
+
49
+ if device in ("cuda", "mps") and hasattr(self.pipeline, "enable_model_cpu_offload"):
50
+ # CPU offloading manages device placement itself — don't call _move_to_device
51
+ self.pipeline.enable_model_cpu_offload(device=device)
52
+ logger.info(f"Enabled CPU offloading for FLUX on {device}")
53
+ else:
54
+ self._move_to_device(device)
55
+ self._apply_memory_optimizations()
56
+
57
+ logger.info(f"FLUX model {model_config.name} loaded on {self.device}")
58
+ return True
59
+ except Exception as e:
60
+ logger.error(f"Failed to load FLUX model: {e}")
61
+ return False
62
+
63
+ @property
64
+ def _is_schnell(self) -> bool:
65
+ return self.model_config is not None and "schnell" in self.model_config.name.lower()
66
+
67
+ def generate(
68
+ self,
69
+ prompt: str,
70
+ negative_prompt: str = "",
71
+ num_inference_steps: Optional[int] = None,
72
+ guidance_scale: Optional[float] = None,
73
+ width: int = 1024,
74
+ height: int = 1024,
75
+ seed: Optional[int] = None,
76
+ **kwargs,
77
+ ) -> Image.Image:
78
+ if not self.pipeline:
79
+ raise RuntimeError("Model not loaded")
80
+
81
+ params = self.model_config.parameters or {}
82
+
83
+ # Schnell-specific defaults
84
+ if self._is_schnell:
85
+ steps = num_inference_steps if num_inference_steps is not None else 4
86
+ if steps > 4:
87
+ logger.info(f"FLUX.1-schnell: reducing steps from {steps} to 4")
88
+ steps = 4
89
+ guidance = 0.0
90
+ logger.info("FLUX.1-schnell: using 0.0 guidance (distilled model)")
91
+ else:
92
+ steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 20)
93
+ guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 3.5)
94
+
95
+ # CPU-specific caps
96
+ if self.device == "cpu":
97
+ if not self._is_schnell and steps > 20:
98
+ steps = 20
99
+ logger.info(f"Reduced steps to {steps} for CPU performance")
100
+ if not self._is_schnell and guidance > 5.0:
101
+ guidance = 5.0
102
+
103
+ max_seq_len = kwargs.get("max_sequence_length", params.get("max_sequence_length", 512))
104
+
105
+ generator, used_seed = self._make_generator(seed, self.device)
106
+
107
+ gen_kwargs = {
108
+ "prompt": prompt,
109
+ "num_inference_steps": steps,
110
+ "guidance_scale": guidance,
111
+ "width": width,
112
+ "height": height,
113
+ "max_sequence_length": max_seq_len,
114
+ "generator": generator,
115
+ }
116
+
117
+ # Pass through params for Fill/Control pipeline variants
118
+ for key in ("image", "mask_image", "strength", "control_image",
119
+ "controlnet_conditioning_scale"):
120
+ if key in kwargs:
121
+ gen_kwargs[key] = kwargs[key]
122
+
123
+ try:
124
+ logger.info(f"Generating FLUX image: steps={steps}, guidance={guidance}, seed={used_seed}")
125
+ output = self.pipeline(**gen_kwargs)
126
+ return output.images[0]
127
+ except RuntimeError as e:
128
+ if "CUDA" in str(e) and self.device == "cpu":
129
+ logger.warning("Device mismatch, retrying without generator")
130
+ gen_kwargs.pop("generator", None)
131
+ output = self.pipeline(**gen_kwargs)
132
+ return output.images[0]
133
+ raise
134
+ except Exception as e:
135
+ logger.error(f"FLUX generation failed: {e}")
136
+ return self._create_error_image(str(e), prompt)
@@ -0,0 +1,164 @@
1
+ """Generic pipeline inference strategy for any diffusers-compatible model."""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from ..base import InferenceStrategy
10
+ from ...config.settings import ModelConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Dtype name -> torch dtype mapping
15
+ _DTYPE_MAP = {
16
+ "float32": torch.float32,
17
+ "float16": torch.float16,
18
+ "bfloat16": torch.bfloat16,
19
+ }
20
+
21
+
22
+ class GenericPipelineStrategy(InferenceStrategy):
23
+ """Strategy that dynamically loads any diffusers pipeline class.
24
+
25
+ The pipeline class name is read from model_config.parameters["pipeline_class"].
26
+ This allows adding new model types to the registry without writing new strategy code.
27
+ """
28
+
29
+ def load(self, model_config: ModelConfig, device: str) -> bool:
30
+ try:
31
+ import diffusers
32
+
33
+ self.device = device
34
+ self.model_config = model_config
35
+
36
+ params = model_config.parameters or {}
37
+ pipeline_class_name = params.get("pipeline_class")
38
+ if not pipeline_class_name:
39
+ logger.error("GenericPipelineStrategy requires 'pipeline_class' in model parameters")
40
+ return False
41
+
42
+ pipeline_cls = getattr(diffusers, pipeline_class_name, None)
43
+ if pipeline_cls is None:
44
+ logger.error(
45
+ f"Pipeline class '{pipeline_class_name}' not found in diffusers. "
46
+ "You may need to upgrade: pip install --upgrade diffusers"
47
+ )
48
+ return False
49
+
50
+ # Resolve dtype from parameters or auto-detect
51
+ dtype_name = params.get("torch_dtype")
52
+ if dtype_name and dtype_name in _DTYPE_MAP:
53
+ dtype = _DTYPE_MAP[dtype_name]
54
+ elif device == "cpu":
55
+ dtype = torch.float32
56
+ else:
57
+ dtype = torch.bfloat16
58
+
59
+ # MPS safety: bfloat16 is not reliably supported on Metal
60
+ if device == "mps" and dtype == torch.bfloat16:
61
+ logger.info("Falling back from bfloat16 to float16 for MPS device compatibility")
62
+ dtype = torch.float16
63
+
64
+ load_kwargs = {"torch_dtype": dtype, "low_cpu_mem_usage": True}
65
+ if dtype in (torch.float16, torch.bfloat16):
66
+ load_kwargs["use_safetensors"] = True
67
+
68
+ logger.info(f"Loading {pipeline_class_name} from {model_config.path} (dtype={dtype})")
69
+ self.pipeline = pipeline_cls.from_pretrained(
70
+ model_config.path, **load_kwargs
71
+ )
72
+
73
+ # Device placement
74
+ enable_offload = params.get("enable_cpu_offload", False)
75
+ # Auto-enable CPU offload on MPS to avoid OOM on unified memory
76
+ if device == "mps":
77
+ enable_offload = True
78
+
79
+ if enable_offload and device in ("cuda", "mps"):
80
+ if device == "mps" and hasattr(self.pipeline, "enable_model_cpu_offload"):
81
+ # MPS/unified memory: model-level offload is more effective than
82
+ # sequential offload because it fully deallocates entire components
83
+ # (T5 encoder, transformer, VAE) between stages, reducing peak
84
+ # memory pressure on the MPS allocator.
85
+ self.pipeline.enable_model_cpu_offload(device=device)
86
+ logger.info(f"Enabled model CPU offloading on {device}")
87
+ elif hasattr(self.pipeline, "enable_sequential_cpu_offload"):
88
+ # CUDA: sequential offload moves individual layers, lowest VRAM usage
89
+ self.pipeline.enable_sequential_cpu_offload(device=device)
90
+ logger.info(f"Enabled sequential CPU offloading on {device}")
91
+ elif hasattr(self.pipeline, "enable_model_cpu_offload"):
92
+ self.pipeline.enable_model_cpu_offload(device=device)
93
+ logger.info(f"Enabled model CPU offloading on {device}")
94
+ else:
95
+ self._move_to_device(device)
96
+ else:
97
+ self._move_to_device(device)
98
+
99
+ self._apply_memory_optimizations()
100
+
101
+ logger.info(f"{pipeline_class_name} model {model_config.name} loaded on {self.device}")
102
+ return True
103
+ except Exception as e:
104
+ logger.error(f"Failed to load {model_config.name}: {e}")
105
+ return False
106
+
107
+ def generate(
108
+ self,
109
+ prompt: str,
110
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
111
+ num_inference_steps: Optional[int] = None,
112
+ guidance_scale: Optional[float] = None,
113
+ width: int = 1024,
114
+ height: int = 1024,
115
+ seed: Optional[int] = None,
116
+ **kwargs,
117
+ ) -> Image.Image:
118
+ if not self.pipeline:
119
+ raise RuntimeError("Model not loaded")
120
+
121
+ params = self.model_config.parameters or {}
122
+ steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 28)
123
+ guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.0)
124
+
125
+ generator, used_seed = self._make_generator(seed, self.device)
126
+
127
+ gen_kwargs = {
128
+ "prompt": prompt,
129
+ "num_inference_steps": steps,
130
+ "guidance_scale": guidance,
131
+ "width": width,
132
+ "height": height,
133
+ "generator": generator,
134
+ }
135
+
136
+ # Include negative_prompt only if the pipeline supports it
137
+ if params.get("supports_negative_prompt", True):
138
+ gen_kwargs["negative_prompt"] = negative_prompt
139
+
140
+ # Pass through image/mask/control params from kwargs
141
+ for key in ("image", "mask_image", "strength", "control_image",
142
+ "controlnet_conditioning_scale", "control_guidance_start",
143
+ "control_guidance_end"):
144
+ if key in kwargs:
145
+ gen_kwargs[key] = kwargs[key]
146
+
147
+ try:
148
+ logger.info(
149
+ f"Generating with {type(self.pipeline).__name__}: "
150
+ f"steps={steps}, guidance={guidance}, seed={used_seed}"
151
+ )
152
+ output = self.pipeline(**gen_kwargs)
153
+ return output.images[0]
154
+ except TypeError as e:
155
+ # Some pipelines don't accept all standard params (e.g., width/height)
156
+ # Retry without optional params
157
+ logger.warning(f"Pipeline call failed: {e}. Retrying with minimal params.")
158
+ for key in ("width", "height", "negative_prompt"):
159
+ gen_kwargs.pop(key, None)
160
+ output = self.pipeline(**gen_kwargs)
161
+ return output.images[0]
162
+ except Exception as e:
163
+ logger.error(f"Generation failed: {e}")
164
+ return self._create_error_image(str(e), prompt)
@@ -0,0 +1,113 @@
1
+ """GGUF quantized model inference strategy"""
2
+
3
+ import logging
4
+ import random
5
+ from typing import Any, Dict, Optional
6
+
7
+ from PIL import Image
8
+
9
+ from ..base import InferenceStrategy
10
+ from ...config.settings import ModelConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ try:
15
+ from ...models.gguf_loader import gguf_loader, GGUF_AVAILABLE
16
+ except ImportError:
17
+ GGUF_AVAILABLE = False
18
+ gguf_loader = None
19
+
20
+
21
+ class GGUFStrategy(InferenceStrategy):
22
+ """Strategy for GGUF quantized models"""
23
+
24
+ def load(self, model_config: ModelConfig, device: str) -> bool:
25
+ if not GGUF_AVAILABLE:
26
+ logger.error("GGUF support not available. Install with: pip install stable-diffusion-cpp-python gguf")
27
+ return False
28
+
29
+ try:
30
+ self.device = device
31
+ self.model_config = model_config
32
+
33
+ config_dict = {
34
+ "name": model_config.name,
35
+ "path": model_config.path,
36
+ "variant": model_config.variant,
37
+ "model_type": model_config.model_type,
38
+ "parameters": model_config.parameters,
39
+ }
40
+
41
+ if gguf_loader.load_model(config_dict):
42
+ self.pipeline = None # GGUF uses its own loader
43
+ logger.info(f"GGUF model {model_config.name} loaded")
44
+ return True
45
+
46
+ logger.error(f"Failed to load GGUF model: {model_config.name}")
47
+ return False
48
+ except Exception as e:
49
+ logger.error(f"GGUF load error: {e}")
50
+ return False
51
+
52
+ @property
53
+ def is_loaded(self) -> bool:
54
+ return GGUF_AVAILABLE and gguf_loader is not None and gguf_loader.is_loaded()
55
+
56
+ def generate(
57
+ self,
58
+ prompt: str,
59
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
60
+ num_inference_steps: Optional[int] = None,
61
+ guidance_scale: Optional[float] = None,
62
+ width: int = 1024,
63
+ height: int = 1024,
64
+ seed: Optional[int] = None,
65
+ **kwargs,
66
+ ) -> Image.Image:
67
+ if not self.is_loaded:
68
+ raise RuntimeError("GGUF model not loaded")
69
+
70
+ params = self.model_config.parameters or {}
71
+ steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 20)
72
+ guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
73
+
74
+ # Resolve seed (GGUF uses integer seed, not torch.Generator)
75
+ if seed is None:
76
+ seed = random.randint(0, 2**32 - 1)
77
+ logger.info(f"Using seed: {seed}")
78
+
79
+ gen_kwargs = {
80
+ "prompt": prompt,
81
+ "negative_prompt": negative_prompt,
82
+ "num_inference_steps": steps,
83
+ "guidance_scale": guidance,
84
+ "width": width,
85
+ "height": height,
86
+ "seed": seed,
87
+ **kwargs,
88
+ }
89
+
90
+ try:
91
+ logger.info(f"Generating GGUF image: steps={steps}, guidance={guidance}, seed={seed}")
92
+ image = gguf_loader.generate_image(**gen_kwargs)
93
+ if image is None:
94
+ return self._create_error_image("GGUF generation returned None", prompt)
95
+ return image
96
+ except Exception as e:
97
+ logger.error(f"GGUF generation failed: {e}")
98
+ return self._create_error_image(str(e), prompt)
99
+
100
+ def unload(self) -> None:
101
+ if GGUF_AVAILABLE and gguf_loader is not None and gguf_loader.is_loaded():
102
+ gguf_loader.unload_model()
103
+ self.model_config = None
104
+ self.current_lora = None
105
+ logger.info("GGUF model unloaded")
106
+
107
+ def get_info(self) -> Dict[str, Any]:
108
+ info = super().get_info()
109
+ if GGUF_AVAILABLE and gguf_loader is not None:
110
+ info.update(gguf_loader.get_model_info())
111
+ info["is_gguf"] = True
112
+ info["gguf_available"] = GGUF_AVAILABLE
113
+ return info
@@ -0,0 +1,104 @@
1
+ """HiDream inference strategy"""
2
+
3
+ import logging
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
10
+ from ...config.settings import ModelConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Check availability
15
+ try:
16
+ from diffusers import HiDreamImagePipeline
17
+ HIDREAM_AVAILABLE = True
18
+ except ImportError:
19
+ HIDREAM_AVAILABLE = False
20
+
21
+
22
+ class HiDreamStrategy(InferenceStrategy):
23
+ """Strategy for HiDream models"""
24
+
25
+ def load(self, model_config: ModelConfig, device: str) -> bool:
26
+ if not HIDREAM_AVAILABLE:
27
+ logger.error("HiDreamImagePipeline not available. Install diffusers from source.")
28
+ return False
29
+
30
+ try:
31
+ self.device = device
32
+ self.model_config = model_config
33
+
34
+ load_kwargs = {**SAFETY_DISABLED_KWARGS}
35
+ if device == "cpu":
36
+ load_kwargs["torch_dtype"] = torch.float32
37
+ else:
38
+ load_kwargs["torch_dtype"] = torch.bfloat16
39
+
40
+ self.pipeline = HiDreamImagePipeline.from_pretrained(
41
+ model_config.path, **load_kwargs
42
+ )
43
+
44
+ if device in ("cuda", "mps") and hasattr(self.pipeline, "enable_model_cpu_offload"):
45
+ # CPU offloading manages device placement itself — don't call _move_to_device
46
+ self.pipeline.enable_model_cpu_offload(device=device)
47
+ else:
48
+ self._move_to_device(device)
49
+
50
+ if hasattr(self.pipeline, "enable_vae_slicing"):
51
+ self.pipeline.enable_vae_slicing()
52
+ if hasattr(self.pipeline, "enable_vae_tiling"):
53
+ self.pipeline.enable_vae_tiling()
54
+
55
+ logger.info(f"HiDream model {model_config.name} loaded on {self.device}")
56
+ return True
57
+ except Exception as e:
58
+ logger.error(f"Failed to load HiDream model: {e}")
59
+ return False
60
+
61
+ def generate(
62
+ self,
63
+ prompt: str,
64
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
65
+ num_inference_steps: Optional[int] = None,
66
+ guidance_scale: Optional[float] = None,
67
+ width: int = 1024,
68
+ height: int = 1024,
69
+ seed: Optional[int] = None,
70
+ **kwargs,
71
+ ) -> Image.Image:
72
+ if not self.pipeline:
73
+ raise RuntimeError("Model not loaded")
74
+
75
+ params = self.model_config.parameters or {}
76
+ steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 28)
77
+ guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 5.0)
78
+ max_seq_len = kwargs.get("max_sequence_length", params.get("max_sequence_length", 128))
79
+
80
+ generator, used_seed = self._make_generator(seed, self.device)
81
+
82
+ gen_kwargs = {
83
+ "prompt": prompt,
84
+ "negative_prompt": negative_prompt,
85
+ "num_inference_steps": steps,
86
+ "guidance_scale": guidance,
87
+ "width": width,
88
+ "height": height,
89
+ "max_sequence_length": max_seq_len,
90
+ "generator": generator,
91
+ }
92
+
93
+ # Support multiple text encoder prompts
94
+ for key in ("prompt_2", "prompt_3", "prompt_4"):
95
+ if key in kwargs:
96
+ gen_kwargs[key] = kwargs[key]
97
+
98
+ try:
99
+ logger.info(f"Generating HiDream image: steps={steps}, guidance={guidance}, seed={used_seed}")
100
+ output = self.pipeline(**gen_kwargs)
101
+ return output.images[0]
102
+ except Exception as e:
103
+ logger.error(f"HiDream generation failed: {e}")
104
+ return self._create_error_image(str(e), prompt)