ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""SD 1.5 inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SD15Strategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for Stable Diffusion 1.5 models"""
|
|
17
|
+
|
|
18
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
from diffusers import StableDiffusionPipeline
|
|
21
|
+
|
|
22
|
+
self.device = device
|
|
23
|
+
self.model_config = model_config
|
|
24
|
+
|
|
25
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
26
|
+
dtype = self._get_dtype(device)
|
|
27
|
+
|
|
28
|
+
if model_config.variant == "fp16" and device not in ("cpu", "mps"):
|
|
29
|
+
load_kwargs["torch_dtype"] = dtype
|
|
30
|
+
load_kwargs["variant"] = "fp16"
|
|
31
|
+
else:
|
|
32
|
+
load_kwargs["torch_dtype"] = torch.float32
|
|
33
|
+
|
|
34
|
+
self.pipeline = StableDiffusionPipeline.from_pretrained(
|
|
35
|
+
model_config.path, **load_kwargs
|
|
36
|
+
)
|
|
37
|
+
self._move_to_device(device)
|
|
38
|
+
self._apply_memory_optimizations()
|
|
39
|
+
|
|
40
|
+
logger.info(f"SD 1.5 model {model_config.name} loaded on {self.device}")
|
|
41
|
+
return True
|
|
42
|
+
except Exception as e:
|
|
43
|
+
logger.error(f"Failed to load SD 1.5 model: {e}")
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
def generate(
|
|
47
|
+
self,
|
|
48
|
+
prompt: str,
|
|
49
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
50
|
+
num_inference_steps: Optional[int] = None,
|
|
51
|
+
guidance_scale: Optional[float] = None,
|
|
52
|
+
width: int = 512,
|
|
53
|
+
height: int = 512,
|
|
54
|
+
seed: Optional[int] = None,
|
|
55
|
+
image: Optional[Image.Image] = None,
|
|
56
|
+
mask_image: Optional[Image.Image] = None,
|
|
57
|
+
strength: float = 0.75,
|
|
58
|
+
**kwargs,
|
|
59
|
+
) -> Image.Image:
|
|
60
|
+
if not self.pipeline:
|
|
61
|
+
raise RuntimeError("Model not loaded")
|
|
62
|
+
|
|
63
|
+
params = self.model_config.parameters or {}
|
|
64
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 50)
|
|
65
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
|
|
66
|
+
|
|
67
|
+
# Clamp guidance on MPS (known hardware limitation)
|
|
68
|
+
if self.device == "mps" and guidance > 6.0:
|
|
69
|
+
logger.info(f"Clamping guidance_scale from {guidance} to 6.0 for MPS stability")
|
|
70
|
+
guidance = 6.0
|
|
71
|
+
|
|
72
|
+
# Warn about non-native resolutions for SD 1.5 (trained on 512x512)
|
|
73
|
+
if width > 768 or height > 768:
|
|
74
|
+
logger.warning(
|
|
75
|
+
f"SD 1.5 was trained on 512x512. Using {width}x{height} may produce artifacts. "
|
|
76
|
+
"Consider using 512x512 or 768x768 for best results."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
80
|
+
|
|
81
|
+
gen_kwargs = {
|
|
82
|
+
"prompt": prompt,
|
|
83
|
+
"negative_prompt": negative_prompt,
|
|
84
|
+
"num_inference_steps": steps,
|
|
85
|
+
"guidance_scale": guidance,
|
|
86
|
+
"width": width,
|
|
87
|
+
"height": height,
|
|
88
|
+
"generator": generator,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
# Handle img2img mode
|
|
93
|
+
if image is not None and mask_image is not None:
|
|
94
|
+
return self._inpaint(gen_kwargs, image, mask_image, strength)
|
|
95
|
+
elif image is not None:
|
|
96
|
+
return self._img2img(gen_kwargs, image, strength)
|
|
97
|
+
|
|
98
|
+
logger.info(f"Generating SD 1.5 image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
99
|
+
output = self.pipeline(**gen_kwargs)
|
|
100
|
+
return output.images[0]
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"SD 1.5 generation failed: {e}")
|
|
103
|
+
return self._create_error_image(str(e), prompt)
|
|
104
|
+
|
|
105
|
+
def _img2img(self, gen_kwargs: dict, image: Image.Image, strength: float) -> Image.Image:
|
|
106
|
+
"""Run img2img generation"""
|
|
107
|
+
from diffusers import StableDiffusionImg2ImgPipeline
|
|
108
|
+
|
|
109
|
+
img2img_pipe = StableDiffusionImg2ImgPipeline(**self.pipeline.components)
|
|
110
|
+
img2img_pipe = img2img_pipe.to(self.device)
|
|
111
|
+
|
|
112
|
+
gen_kwargs.pop("width", None)
|
|
113
|
+
gen_kwargs.pop("height", None)
|
|
114
|
+
gen_kwargs["image"] = image
|
|
115
|
+
gen_kwargs["strength"] = strength
|
|
116
|
+
|
|
117
|
+
output = img2img_pipe(**gen_kwargs)
|
|
118
|
+
return output.images[0]
|
|
119
|
+
|
|
120
|
+
def _inpaint(self, gen_kwargs: dict, image: Image.Image, mask_image: Image.Image, strength: float) -> Image.Image:
|
|
121
|
+
"""Run inpainting generation"""
|
|
122
|
+
from diffusers import StableDiffusionInpaintPipeline
|
|
123
|
+
|
|
124
|
+
inpaint_pipe = StableDiffusionInpaintPipeline(**self.pipeline.components)
|
|
125
|
+
inpaint_pipe = inpaint_pipe.to(self.device)
|
|
126
|
+
|
|
127
|
+
gen_kwargs.pop("width", None)
|
|
128
|
+
gen_kwargs.pop("height", None)
|
|
129
|
+
gen_kwargs["image"] = image
|
|
130
|
+
gen_kwargs["mask_image"] = mask_image
|
|
131
|
+
gen_kwargs["strength"] = strength
|
|
132
|
+
|
|
133
|
+
output = inpaint_pipe(**gen_kwargs)
|
|
134
|
+
return output.images[0]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""SD 3.x inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SD3Strategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for Stable Diffusion 3.x models"""
|
|
17
|
+
|
|
18
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
from diffusers import StableDiffusion3Pipeline
|
|
21
|
+
|
|
22
|
+
self.device = device
|
|
23
|
+
self.model_config = model_config
|
|
24
|
+
|
|
25
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
26
|
+
if model_config.variant == "fp16" and device not in ("cpu", "mps"):
|
|
27
|
+
load_kwargs["torch_dtype"] = torch.float16
|
|
28
|
+
load_kwargs["variant"] = "fp16"
|
|
29
|
+
else:
|
|
30
|
+
load_kwargs["torch_dtype"] = self._get_dtype(device)
|
|
31
|
+
|
|
32
|
+
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
|
33
|
+
model_config.path, **load_kwargs
|
|
34
|
+
)
|
|
35
|
+
self._move_to_device(device)
|
|
36
|
+
self._apply_memory_optimizations()
|
|
37
|
+
|
|
38
|
+
logger.info(f"SD3 model {model_config.name} loaded on {self.device}")
|
|
39
|
+
return True
|
|
40
|
+
except Exception as e:
|
|
41
|
+
logger.error(f"Failed to load SD3 model: {e}")
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
def generate(
|
|
45
|
+
self,
|
|
46
|
+
prompt: str,
|
|
47
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
48
|
+
num_inference_steps: Optional[int] = None,
|
|
49
|
+
guidance_scale: Optional[float] = None,
|
|
50
|
+
width: int = 1024,
|
|
51
|
+
height: int = 1024,
|
|
52
|
+
seed: Optional[int] = None,
|
|
53
|
+
**kwargs,
|
|
54
|
+
) -> Image.Image:
|
|
55
|
+
if not self.pipeline:
|
|
56
|
+
raise RuntimeError("Model not loaded")
|
|
57
|
+
|
|
58
|
+
params = self.model_config.parameters or {}
|
|
59
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 28)
|
|
60
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 3.5)
|
|
61
|
+
|
|
62
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
63
|
+
|
|
64
|
+
gen_kwargs = {
|
|
65
|
+
"prompt": prompt,
|
|
66
|
+
"negative_prompt": negative_prompt,
|
|
67
|
+
"num_inference_steps": steps,
|
|
68
|
+
"guidance_scale": guidance,
|
|
69
|
+
"width": width,
|
|
70
|
+
"height": height,
|
|
71
|
+
"generator": generator,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
logger.info(f"Generating SD3 image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
76
|
+
output = self.pipeline(**gen_kwargs)
|
|
77
|
+
return output.images[0]
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"SD3 generation failed: {e}")
|
|
80
|
+
return self._create_error_image(str(e), prompt)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""SDXL inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SDXLStrategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for Stable Diffusion XL models"""
|
|
17
|
+
|
|
18
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
from diffusers import StableDiffusionXLPipeline
|
|
21
|
+
|
|
22
|
+
self.device = device
|
|
23
|
+
self.model_config = model_config
|
|
24
|
+
|
|
25
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
26
|
+
if model_config.variant == "fp16" and device not in ("cpu", "mps"):
|
|
27
|
+
load_kwargs["torch_dtype"] = torch.float16
|
|
28
|
+
load_kwargs["variant"] = "fp16"
|
|
29
|
+
elif device == "mps":
|
|
30
|
+
load_kwargs["torch_dtype"] = torch.float32
|
|
31
|
+
else:
|
|
32
|
+
load_kwargs["torch_dtype"] = self._get_dtype(device)
|
|
33
|
+
|
|
34
|
+
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
35
|
+
model_config.path, **load_kwargs
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Apply scheduler override if specified in parameters
|
|
39
|
+
params = model_config.parameters or {}
|
|
40
|
+
scheduler_class_name = params.get("scheduler_class")
|
|
41
|
+
if scheduler_class_name:
|
|
42
|
+
import diffusers
|
|
43
|
+
scheduler_cls = getattr(diffusers, scheduler_class_name, None)
|
|
44
|
+
if scheduler_cls is None:
|
|
45
|
+
logger.warning(f"Scheduler '{scheduler_class_name}' not found in diffusers, using default")
|
|
46
|
+
else:
|
|
47
|
+
scheduler_kwargs = params.get("scheduler_kwargs", {})
|
|
48
|
+
self.pipeline.scheduler = scheduler_cls.from_config(
|
|
49
|
+
self.pipeline.scheduler.config, **scheduler_kwargs
|
|
50
|
+
)
|
|
51
|
+
logger.info(f"Applied scheduler override: {scheduler_class_name}")
|
|
52
|
+
|
|
53
|
+
self._move_to_device(device)
|
|
54
|
+
self._apply_memory_optimizations()
|
|
55
|
+
|
|
56
|
+
logger.info(f"SDXL model {model_config.name} loaded on {self.device}")
|
|
57
|
+
return True
|
|
58
|
+
except Exception as e:
|
|
59
|
+
logger.error(f"Failed to load SDXL model: {e}")
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
def generate(
|
|
63
|
+
self,
|
|
64
|
+
prompt: str,
|
|
65
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
66
|
+
num_inference_steps: Optional[int] = None,
|
|
67
|
+
guidance_scale: Optional[float] = None,
|
|
68
|
+
width: int = 1024,
|
|
69
|
+
height: int = 1024,
|
|
70
|
+
seed: Optional[int] = None,
|
|
71
|
+
image: Optional[Image.Image] = None,
|
|
72
|
+
mask_image: Optional[Image.Image] = None,
|
|
73
|
+
strength: float = 0.75,
|
|
74
|
+
**kwargs,
|
|
75
|
+
) -> Image.Image:
|
|
76
|
+
if not self.pipeline:
|
|
77
|
+
raise RuntimeError("Model not loaded")
|
|
78
|
+
|
|
79
|
+
params = self.model_config.parameters or {}
|
|
80
|
+
steps = num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 50)
|
|
81
|
+
guidance = guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5)
|
|
82
|
+
|
|
83
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
84
|
+
|
|
85
|
+
gen_kwargs = {
|
|
86
|
+
"prompt": prompt,
|
|
87
|
+
"negative_prompt": negative_prompt,
|
|
88
|
+
"num_inference_steps": steps,
|
|
89
|
+
"guidance_scale": guidance,
|
|
90
|
+
"width": width,
|
|
91
|
+
"height": height,
|
|
92
|
+
"generator": generator,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
if image is not None and mask_image is not None:
|
|
97
|
+
return self._inpaint(gen_kwargs, image, mask_image, strength)
|
|
98
|
+
elif image is not None:
|
|
99
|
+
return self._img2img(gen_kwargs, image, strength)
|
|
100
|
+
|
|
101
|
+
logger.info(f"Generating SDXL image: steps={steps}, guidance={guidance}, seed={used_seed}")
|
|
102
|
+
output = self.pipeline(**gen_kwargs)
|
|
103
|
+
return output.images[0]
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.error(f"SDXL generation failed: {e}")
|
|
106
|
+
return self._create_error_image(str(e), prompt)
|
|
107
|
+
|
|
108
|
+
def _img2img(self, gen_kwargs: dict, image: Image.Image, strength: float) -> Image.Image:
|
|
109
|
+
from diffusers import StableDiffusionXLImg2ImgPipeline
|
|
110
|
+
|
|
111
|
+
pipe = StableDiffusionXLImg2ImgPipeline(**self.pipeline.components)
|
|
112
|
+
pipe = pipe.to(self.device)
|
|
113
|
+
gen_kwargs.pop("width", None)
|
|
114
|
+
gen_kwargs.pop("height", None)
|
|
115
|
+
gen_kwargs["image"] = image
|
|
116
|
+
gen_kwargs["strength"] = strength
|
|
117
|
+
output = pipe(**gen_kwargs)
|
|
118
|
+
return output.images[0]
|
|
119
|
+
|
|
120
|
+
def _inpaint(self, gen_kwargs: dict, image: Image.Image, mask_image: Image.Image, strength: float) -> Image.Image:
|
|
121
|
+
from diffusers import StableDiffusionXLInpaintPipeline
|
|
122
|
+
|
|
123
|
+
pipe = StableDiffusionXLInpaintPipeline(**self.pipeline.components)
|
|
124
|
+
pipe = pipe.to(self.device)
|
|
125
|
+
gen_kwargs.pop("width", None)
|
|
126
|
+
gen_kwargs.pop("height", None)
|
|
127
|
+
gen_kwargs["image"] = image
|
|
128
|
+
gen_kwargs["mask_image"] = mask_image
|
|
129
|
+
gen_kwargs["strength"] = strength
|
|
130
|
+
output = pipe(**gen_kwargs)
|
|
131
|
+
return output.images[0]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Video (AnimateDiff) inference strategy"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from ..base import InferenceStrategy, SAFETY_DISABLED_KWARGS
|
|
10
|
+
from ...config.settings import ModelConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VideoStrategy(InferenceStrategy):
|
|
16
|
+
"""Strategy for AnimateDiff video models"""
|
|
17
|
+
|
|
18
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
19
|
+
try:
|
|
20
|
+
from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
|
|
21
|
+
|
|
22
|
+
self.device = device
|
|
23
|
+
self.model_config = model_config
|
|
24
|
+
|
|
25
|
+
load_kwargs = {**SAFETY_DISABLED_KWARGS}
|
|
26
|
+
dtype = torch.float16 if device != "cpu" else torch.float32
|
|
27
|
+
load_kwargs["torch_dtype"] = dtype
|
|
28
|
+
|
|
29
|
+
# Load motion adapter
|
|
30
|
+
adapter_path = getattr(model_config, "motion_adapter_path", None)
|
|
31
|
+
if not adapter_path:
|
|
32
|
+
adapter_path = "guoyww/animatediff-motion-adapter-v1-5-2"
|
|
33
|
+
motion_adapter = MotionAdapter.from_pretrained(adapter_path, torch_dtype=dtype)
|
|
34
|
+
load_kwargs["motion_adapter"] = motion_adapter
|
|
35
|
+
|
|
36
|
+
self.pipeline = AnimateDiffPipeline.from_pretrained(
|
|
37
|
+
model_config.path, **load_kwargs
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Configure DDIM scheduler for AnimateDiff
|
|
41
|
+
self.pipeline.scheduler = DDIMScheduler.from_config(
|
|
42
|
+
self.pipeline.scheduler.config,
|
|
43
|
+
clip_sample=False,
|
|
44
|
+
timestep_spacing="linspace",
|
|
45
|
+
beta_schedule="linear",
|
|
46
|
+
steps_offset=1,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if device in ("cuda", "mps") and hasattr(self.pipeline, "enable_model_cpu_offload"):
|
|
50
|
+
# CPU offloading manages device placement itself — don't call _move_to_device
|
|
51
|
+
self.pipeline.enable_model_cpu_offload(device=device)
|
|
52
|
+
else:
|
|
53
|
+
self._move_to_device(device)
|
|
54
|
+
|
|
55
|
+
if hasattr(self.pipeline, "enable_vae_slicing"):
|
|
56
|
+
self.pipeline.enable_vae_slicing()
|
|
57
|
+
|
|
58
|
+
logger.info(f"Video model {model_config.name} loaded on {self.device}")
|
|
59
|
+
return True
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"Failed to load video model: {e}")
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
def generate(
|
|
65
|
+
self,
|
|
66
|
+
prompt: str,
|
|
67
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
68
|
+
num_inference_steps: Optional[int] = None,
|
|
69
|
+
guidance_scale: Optional[float] = None,
|
|
70
|
+
width: int = 512,
|
|
71
|
+
height: int = 512,
|
|
72
|
+
seed: Optional[int] = None,
|
|
73
|
+
num_frames: int = 16,
|
|
74
|
+
**kwargs,
|
|
75
|
+
) -> Image.Image:
|
|
76
|
+
if not self.pipeline:
|
|
77
|
+
raise RuntimeError("Model not loaded")
|
|
78
|
+
|
|
79
|
+
params = self.model_config.parameters or {}
|
|
80
|
+
steps = min(num_inference_steps if num_inference_steps is not None else params.get("num_inference_steps", 25), 25)
|
|
81
|
+
guidance = min(
|
|
82
|
+
guidance_scale if guidance_scale is not None else params.get("guidance_scale", 7.5),
|
|
83
|
+
7.5,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
generator, used_seed = self._make_generator(seed, self.device)
|
|
87
|
+
|
|
88
|
+
gen_kwargs = {
|
|
89
|
+
"prompt": prompt,
|
|
90
|
+
"negative_prompt": negative_prompt,
|
|
91
|
+
"num_inference_steps": steps,
|
|
92
|
+
"guidance_scale": guidance,
|
|
93
|
+
"width": width,
|
|
94
|
+
"height": height,
|
|
95
|
+
"num_frames": num_frames,
|
|
96
|
+
"generator": generator,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
logger.info(f"Generating video: {num_frames} frames, steps={steps}, seed={used_seed}")
|
|
101
|
+
output = self.pipeline(**gen_kwargs)
|
|
102
|
+
|
|
103
|
+
if hasattr(output, "frames") and len(output.frames) > 0:
|
|
104
|
+
return output.frames[0]
|
|
105
|
+
return output.images[0]
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.error(f"Video generation failed: {e}")
|
|
108
|
+
return self._create_error_image(str(e), prompt)
|
|
File without changes
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""OllamaDiffuser MCP Server - Model Context Protocol integration."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import io
|
|
5
|
+
import logging
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from ..core.models.manager import model_manager
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_mcp():
|
|
15
|
+
"""Check that the mcp package is available."""
|
|
16
|
+
try:
|
|
17
|
+
import mcp # noqa: F401
|
|
18
|
+
|
|
19
|
+
return True
|
|
20
|
+
except ImportError:
|
|
21
|
+
logger.error(
|
|
22
|
+
"MCP package not installed. Install with: pip install 'ollamadiffuser[mcp]'"
|
|
23
|
+
)
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_mcp_server():
|
|
28
|
+
"""Create and configure the MCP server with all tools."""
|
|
29
|
+
from mcp.server.fastmcp import FastMCP, Image
|
|
30
|
+
|
|
31
|
+
mcp_server = FastMCP(
|
|
32
|
+
"OllamaDiffuser",
|
|
33
|
+
instructions=(
|
|
34
|
+
"Local AI image generation via Stable Diffusion, FLUX, and 30+ models"
|
|
35
|
+
),
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
@mcp_server.tool()
|
|
39
|
+
async def generate_image(
|
|
40
|
+
prompt: str,
|
|
41
|
+
model: Optional[str] = None,
|
|
42
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
43
|
+
width: int = 1024,
|
|
44
|
+
height: int = 1024,
|
|
45
|
+
steps: Optional[int] = None,
|
|
46
|
+
guidance_scale: Optional[float] = None,
|
|
47
|
+
seed: Optional[int] = None,
|
|
48
|
+
) -> Image:
|
|
49
|
+
"""Generate an image from a text prompt using a local diffusion model.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
prompt: Text description of the desired image
|
|
53
|
+
model: Model to use (auto-loads if needed). Leave empty to use current model.
|
|
54
|
+
negative_prompt: What to avoid in the image
|
|
55
|
+
width: Image width in pixels
|
|
56
|
+
height: Image height in pixels
|
|
57
|
+
steps: Number of inference steps (model-specific default if omitted)
|
|
58
|
+
guidance_scale: Guidance scale (model-specific default if omitted)
|
|
59
|
+
seed: Random seed for reproducibility
|
|
60
|
+
"""
|
|
61
|
+
if model and model_manager.get_current_model() != model:
|
|
62
|
+
if not model_manager.is_model_installed(model):
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Model '{model}' is not installed. "
|
|
65
|
+
f"Install it first: ollamadiffuser pull {model}"
|
|
66
|
+
)
|
|
67
|
+
logger.info(f"Loading model: {model}")
|
|
68
|
+
success = await asyncio.to_thread(model_manager.load_model, model)
|
|
69
|
+
if not success:
|
|
70
|
+
raise RuntimeError(f"Failed to load model '{model}'")
|
|
71
|
+
|
|
72
|
+
if not model_manager.is_model_loaded():
|
|
73
|
+
raise RuntimeError(
|
|
74
|
+
"No model loaded. Load one with: load_model('model-name') "
|
|
75
|
+
"or pass model= parameter"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
engine = model_manager.loaded_model
|
|
79
|
+
result = await asyncio.to_thread(
|
|
80
|
+
engine.generate_image,
|
|
81
|
+
prompt=prompt,
|
|
82
|
+
negative_prompt=negative_prompt,
|
|
83
|
+
num_inference_steps=steps,
|
|
84
|
+
guidance_scale=guidance_scale,
|
|
85
|
+
width=width,
|
|
86
|
+
height=height,
|
|
87
|
+
seed=seed,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
buf = io.BytesIO()
|
|
91
|
+
result.save(buf, format="PNG")
|
|
92
|
+
return Image(data=buf.getvalue(), format="png")
|
|
93
|
+
|
|
94
|
+
@mcp_server.tool()
|
|
95
|
+
async def list_models() -> str:
|
|
96
|
+
"""List all available and installed image generation models.
|
|
97
|
+
|
|
98
|
+
Returns a formatted list showing which models are available to download
|
|
99
|
+
and which are already installed locally.
|
|
100
|
+
"""
|
|
101
|
+
available = model_manager.list_available_models()
|
|
102
|
+
installed = model_manager.list_installed_models()
|
|
103
|
+
current = model_manager.get_current_model()
|
|
104
|
+
|
|
105
|
+
lines = ["Available models:"]
|
|
106
|
+
for name in sorted(available):
|
|
107
|
+
status_parts = []
|
|
108
|
+
if name in installed:
|
|
109
|
+
status_parts.append("installed")
|
|
110
|
+
if name == current:
|
|
111
|
+
status_parts.append("loaded")
|
|
112
|
+
suffix = f" ({', '.join(status_parts)})" if status_parts else ""
|
|
113
|
+
lines.append(f" - {name}{suffix}")
|
|
114
|
+
|
|
115
|
+
lines.append(f"\nInstalled: {len(installed)}/{len(available)}")
|
|
116
|
+
if current:
|
|
117
|
+
lines.append(f"Currently loaded: {current}")
|
|
118
|
+
|
|
119
|
+
return "\n".join(lines)
|
|
120
|
+
|
|
121
|
+
@mcp_server.tool()
|
|
122
|
+
async def load_model(model_name: str) -> str:
|
|
123
|
+
"""Load a specific image generation model into memory.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model_name: Name of the model to load (must be installed first)
|
|
127
|
+
"""
|
|
128
|
+
if not model_manager.is_model_installed(model_name):
|
|
129
|
+
installed = model_manager.list_installed_models()
|
|
130
|
+
return (
|
|
131
|
+
f"Model '{model_name}' is not installed. "
|
|
132
|
+
f"Installed models: {', '.join(installed) if installed else 'none'}. "
|
|
133
|
+
f"Use 'ollamadiffuser pull {model_name}' to install it first."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
success = await asyncio.to_thread(model_manager.load_model, model_name)
|
|
137
|
+
if success:
|
|
138
|
+
return f"Model '{model_name}' loaded successfully"
|
|
139
|
+
return f"Failed to load model '{model_name}'"
|
|
140
|
+
|
|
141
|
+
@mcp_server.tool()
|
|
142
|
+
async def get_status() -> str:
|
|
143
|
+
"""Get the current status of OllamaDiffuser.
|
|
144
|
+
|
|
145
|
+
Returns device info, loaded model, and installed model count.
|
|
146
|
+
"""
|
|
147
|
+
is_loaded = model_manager.is_model_loaded()
|
|
148
|
+
current = model_manager.get_current_model()
|
|
149
|
+
installed = model_manager.list_installed_models()
|
|
150
|
+
|
|
151
|
+
lines = ["OllamaDiffuser Status:"]
|
|
152
|
+
lines.append(f" Model loaded: {'yes' if is_loaded else 'no'}")
|
|
153
|
+
if current:
|
|
154
|
+
lines.append(f" Current model: {current}")
|
|
155
|
+
lines.append(f" Installed models: {len(installed)}")
|
|
156
|
+
|
|
157
|
+
if is_loaded and model_manager.loaded_model:
|
|
158
|
+
engine = model_manager.loaded_model
|
|
159
|
+
info = engine.get_model_info()
|
|
160
|
+
if info:
|
|
161
|
+
lines.append(f" Device: {info.get('device', 'unknown')}")
|
|
162
|
+
lines.append(f" Model type: {info.get('type', 'unknown')}")
|
|
163
|
+
|
|
164
|
+
return "\n".join(lines)
|
|
165
|
+
|
|
166
|
+
return mcp_server
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def main():
|
|
170
|
+
"""Entry point for the MCP server (stdio transport)."""
|
|
171
|
+
if not _ensure_mcp():
|
|
172
|
+
sys.exit(1)
|
|
173
|
+
|
|
174
|
+
logging.basicConfig(
|
|
175
|
+
level=logging.INFO,
|
|
176
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
server = create_mcp_server()
|
|
180
|
+
server.run(transport="stdio")
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
if __name__ == "__main__":
|
|
184
|
+
main()
|