ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -69,7 +69,18 @@ class Settings:
|
|
|
69
69
|
}
|
|
70
70
|
|
|
71
71
|
self.current_model = config_data.get('current_model')
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
# Load custom path overrides
|
|
74
|
+
if 'paths' in config_data:
|
|
75
|
+
paths_data = config_data['paths']
|
|
76
|
+
if 'models_dir' in paths_data:
|
|
77
|
+
self.models_dir = Path(paths_data['models_dir'])
|
|
78
|
+
if 'cache_dir' in paths_data:
|
|
79
|
+
self.cache_dir = Path(paths_data['cache_dir'])
|
|
80
|
+
# Ensure custom directories exist
|
|
81
|
+
self.models_dir.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
|
|
73
84
|
logger.info(f"Configuration file loaded: {self.config_file}")
|
|
74
85
|
|
|
75
86
|
except Exception as e:
|
|
@@ -99,7 +110,18 @@ class Settings:
|
|
|
99
110
|
},
|
|
100
111
|
'current_model': self.current_model
|
|
101
112
|
}
|
|
102
|
-
|
|
113
|
+
|
|
114
|
+
# Only persist path overrides when they differ from defaults
|
|
115
|
+
default_models_dir = self.config_dir / "models"
|
|
116
|
+
default_cache_dir = self.config_dir / "cache"
|
|
117
|
+
paths = {}
|
|
118
|
+
if self.models_dir != default_models_dir:
|
|
119
|
+
paths['models_dir'] = str(self.models_dir)
|
|
120
|
+
if self.cache_dir != default_cache_dir:
|
|
121
|
+
paths['cache_dir'] = str(self.cache_dir)
|
|
122
|
+
if paths:
|
|
123
|
+
config_data['paths'] = paths
|
|
124
|
+
|
|
103
125
|
with open(self.config_file, 'w', encoding='utf-8') as f:
|
|
104
126
|
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
|
105
127
|
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Base inference strategy for OllamaDiffuser"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import random
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Dict, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from PIL import Image
|
|
10
|
+
|
|
11
|
+
from ..config.settings import ModelConfig
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Unified safety checker kwargs - use these in all from_pretrained calls
|
|
17
|
+
SAFETY_DISABLED_KWARGS = {
|
|
18
|
+
"safety_checker": None,
|
|
19
|
+
"requires_safety_checker": False,
|
|
20
|
+
"feature_extractor": None,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class InferenceStrategy(ABC):
|
|
25
|
+
"""Abstract base class for all model inference strategies"""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.pipeline = None
|
|
29
|
+
self.model_config: Optional[ModelConfig] = None
|
|
30
|
+
self.device: Optional[str] = None
|
|
31
|
+
self.current_lora = None
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def load(self, model_config: ModelConfig, device: str) -> bool:
|
|
35
|
+
"""Load the model pipeline"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def generate(
|
|
40
|
+
self,
|
|
41
|
+
prompt: str,
|
|
42
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution",
|
|
43
|
+
num_inference_steps: Optional[int] = None,
|
|
44
|
+
guidance_scale: Optional[float] = None,
|
|
45
|
+
width: int = 1024,
|
|
46
|
+
height: int = 1024,
|
|
47
|
+
seed: Optional[int] = None,
|
|
48
|
+
**kwargs,
|
|
49
|
+
) -> Image.Image:
|
|
50
|
+
"""Generate an image"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
def unload(self) -> None:
|
|
54
|
+
"""Unload model and free memory"""
|
|
55
|
+
if self.pipeline:
|
|
56
|
+
self.pipeline = self.pipeline.to("cpu")
|
|
57
|
+
if torch.cuda.is_available():
|
|
58
|
+
torch.cuda.empty_cache()
|
|
59
|
+
del self.pipeline
|
|
60
|
+
self.pipeline = None
|
|
61
|
+
self.model_config = None
|
|
62
|
+
self.current_lora = None
|
|
63
|
+
logger.info("Model unloaded")
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def is_loaded(self) -> bool:
|
|
67
|
+
return self.pipeline is not None
|
|
68
|
+
|
|
69
|
+
def get_info(self) -> Dict[str, Any]:
|
|
70
|
+
"""Return model information"""
|
|
71
|
+
if not self.model_config:
|
|
72
|
+
return {}
|
|
73
|
+
return {
|
|
74
|
+
"name": self.model_config.name,
|
|
75
|
+
"type": self.model_config.model_type,
|
|
76
|
+
"device": self.device,
|
|
77
|
+
"variant": self.model_config.variant,
|
|
78
|
+
"parameters": self.model_config.parameters,
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
def _make_generator(self, seed: Optional[int], device: str) -> tuple:
|
|
82
|
+
"""Create a torch Generator with the given or random seed. Returns (generator, seed)."""
|
|
83
|
+
if seed is None:
|
|
84
|
+
seed = random.randint(0, 2**32 - 1)
|
|
85
|
+
if device == "cpu":
|
|
86
|
+
generator = torch.Generator().manual_seed(seed)
|
|
87
|
+
else:
|
|
88
|
+
generator = torch.Generator(device=device).manual_seed(seed)
|
|
89
|
+
logger.info(f"Using seed: {seed}")
|
|
90
|
+
return generator, seed
|
|
91
|
+
|
|
92
|
+
def _get_dtype(self, device: str, prefer_bf16: bool = False) -> torch.dtype:
|
|
93
|
+
"""Get appropriate dtype for the given device"""
|
|
94
|
+
if device == "cpu":
|
|
95
|
+
return torch.float32
|
|
96
|
+
if prefer_bf16:
|
|
97
|
+
return torch.bfloat16
|
|
98
|
+
return torch.float16
|
|
99
|
+
|
|
100
|
+
def _move_to_device(self, device: str) -> bool:
|
|
101
|
+
"""Move pipeline to device with fallback to CPU"""
|
|
102
|
+
try:
|
|
103
|
+
self.pipeline = self.pipeline.to(device)
|
|
104
|
+
logger.info(f"Pipeline moved to {device}")
|
|
105
|
+
return True
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.warning(f"Failed to move pipeline to {device}: {e}")
|
|
108
|
+
if device != "cpu":
|
|
109
|
+
logger.info("Falling back to CPU")
|
|
110
|
+
self.device = "cpu"
|
|
111
|
+
self.pipeline = self.pipeline.to("cpu")
|
|
112
|
+
return True
|
|
113
|
+
raise
|
|
114
|
+
|
|
115
|
+
def _apply_memory_optimizations(self):
|
|
116
|
+
"""Apply common memory optimizations"""
|
|
117
|
+
if hasattr(self.pipeline, "enable_attention_slicing"):
|
|
118
|
+
self.pipeline.enable_attention_slicing()
|
|
119
|
+
logger.info("Enabled attention slicing")
|
|
120
|
+
if hasattr(self.pipeline, "enable_vae_tiling"):
|
|
121
|
+
self.pipeline.enable_vae_tiling()
|
|
122
|
+
logger.info("Enabled VAE tiling")
|
|
123
|
+
if hasattr(self.pipeline, "enable_vae_slicing"):
|
|
124
|
+
self.pipeline.enable_vae_slicing()
|
|
125
|
+
logger.info("Enabled VAE slicing")
|
|
126
|
+
|
|
127
|
+
def load_lora_runtime(
|
|
128
|
+
self, repo_id: str, weight_name: str = None, scale: float = 1.0
|
|
129
|
+
) -> bool:
|
|
130
|
+
"""Load LoRA weights at runtime"""
|
|
131
|
+
if not self.pipeline:
|
|
132
|
+
raise RuntimeError("Model not loaded")
|
|
133
|
+
try:
|
|
134
|
+
if weight_name:
|
|
135
|
+
self.pipeline.load_lora_weights(repo_id, weight_name=weight_name)
|
|
136
|
+
else:
|
|
137
|
+
self.pipeline.load_lora_weights(repo_id)
|
|
138
|
+
if hasattr(self.pipeline, "set_adapters") and scale != 1.0:
|
|
139
|
+
self.pipeline.set_adapters(["default"], adapter_weights=[scale])
|
|
140
|
+
self.current_lora = {
|
|
141
|
+
"repo_id": repo_id,
|
|
142
|
+
"weight_name": weight_name,
|
|
143
|
+
"scale": scale,
|
|
144
|
+
"loaded": True,
|
|
145
|
+
}
|
|
146
|
+
logger.info(f"LoRA loaded from {repo_id}")
|
|
147
|
+
return True
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"Failed to load LoRA: {e}")
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
def unload_lora(self) -> bool:
|
|
153
|
+
"""Unload LoRA weights"""
|
|
154
|
+
if not self.pipeline:
|
|
155
|
+
return False
|
|
156
|
+
try:
|
|
157
|
+
if hasattr(self.pipeline, "unload_lora_weights"):
|
|
158
|
+
self.pipeline.unload_lora_weights()
|
|
159
|
+
self.current_lora = None
|
|
160
|
+
logger.info("LoRA unloaded")
|
|
161
|
+
return True
|
|
162
|
+
return False
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(f"Failed to unload LoRA: {e}")
|
|
165
|
+
return False
|
|
166
|
+
|
|
167
|
+
def _create_error_image(self, error_msg: str, prompt: str) -> Image.Image:
|
|
168
|
+
"""Create an error placeholder image"""
|
|
169
|
+
from PIL import ImageDraw, ImageFont
|
|
170
|
+
|
|
171
|
+
img = Image.new("RGB", (512, 512), color=(255, 255, 255))
|
|
172
|
+
draw = ImageDraw.Draw(img)
|
|
173
|
+
try:
|
|
174
|
+
font = ImageFont.load_default()
|
|
175
|
+
except Exception:
|
|
176
|
+
font = None
|
|
177
|
+
draw.text((10, 10), f"Error: {error_msg}", fill=(255, 0, 0), font=font)
|
|
178
|
+
prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
|
179
|
+
draw.text(
|
|
180
|
+
(10, 30), f"Prompt: {prompt_display}", fill=(0, 0, 0), font=font
|
|
181
|
+
)
|
|
182
|
+
return img
|