ollamadiffuser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +0 -0
- ollamadiffuser/__main__.py +50 -0
- ollamadiffuser/api/__init__.py +0 -0
- ollamadiffuser/api/server.py +297 -0
- ollamadiffuser/cli/__init__.py +0 -0
- ollamadiffuser/cli/main.py +597 -0
- ollamadiffuser/core/__init__.py +0 -0
- ollamadiffuser/core/config/__init__.py +0 -0
- ollamadiffuser/core/config/settings.py +137 -0
- ollamadiffuser/core/inference/__init__.py +0 -0
- ollamadiffuser/core/inference/engine.py +926 -0
- ollamadiffuser/core/models/__init__.py +0 -0
- ollamadiffuser/core/models/manager.py +436 -0
- ollamadiffuser/core/utils/__init__.py +3 -0
- ollamadiffuser/core/utils/download_utils.py +356 -0
- ollamadiffuser/core/utils/lora_manager.py +390 -0
- ollamadiffuser/ui/__init__.py +0 -0
- ollamadiffuser/ui/templates/index.html +496 -0
- ollamadiffuser/ui/web.py +278 -0
- ollamadiffuser/utils/__init__.py +0 -0
- ollamadiffuser-1.0.0.dist-info/METADATA +493 -0
- ollamadiffuser-1.0.0.dist-info/RECORD +26 -0
- ollamadiffuser-1.0.0.dist-info/WHEEL +5 -0
- ollamadiffuser-1.0.0.dist-info/entry_points.txt +2 -0
- ollamadiffuser-1.0.0.dist-info/licenses/LICENSE +21 -0
- ollamadiffuser-1.0.0.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Optional, Union
|
|
5
|
+
import logging
|
|
6
|
+
import hashlib
|
|
7
|
+
from huggingface_hub import login
|
|
8
|
+
from ..config.settings import settings, ModelConfig
|
|
9
|
+
from ..utils.download_utils import robust_snapshot_download, robust_file_download
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class ModelManager:
|
|
14
|
+
"""Model manager"""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.loaded_model: Optional[object] = None
|
|
18
|
+
self.current_model_name: Optional[str] = None
|
|
19
|
+
|
|
20
|
+
# Predefined model registry
|
|
21
|
+
self.model_registry = {
|
|
22
|
+
"flux.1-dev": {
|
|
23
|
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
|
24
|
+
"model_type": "flux",
|
|
25
|
+
"variant": "bf16",
|
|
26
|
+
"parameters": {
|
|
27
|
+
"num_inference_steps": 16,
|
|
28
|
+
"guidance_scale": 2.0,
|
|
29
|
+
"max_sequence_length": 512
|
|
30
|
+
},
|
|
31
|
+
"hardware_requirements": {
|
|
32
|
+
"min_vram_gb": 12,
|
|
33
|
+
"recommended_vram_gb": 16,
|
|
34
|
+
"min_ram_gb": 24,
|
|
35
|
+
"recommended_ram_gb": 32,
|
|
36
|
+
"disk_space_gb": 15,
|
|
37
|
+
"supported_devices": ["CUDA", "MPS"],
|
|
38
|
+
"performance_notes": "Requires NVIDIA RTX 4070+ or Apple M2 Pro+. Needs HuggingFace token. Use 'lora pull' to add LoRA styles."
|
|
39
|
+
},
|
|
40
|
+
"license_info": {
|
|
41
|
+
"type": "FLUX.1-dev Non-Commercial License",
|
|
42
|
+
"requires_agreement": True,
|
|
43
|
+
"commercial_use": False
|
|
44
|
+
}
|
|
45
|
+
},
|
|
46
|
+
|
|
47
|
+
"flux.1-schnell": {
|
|
48
|
+
"repo_id": "black-forest-labs/FLUX.1-schnell",
|
|
49
|
+
"model_type": "flux",
|
|
50
|
+
"variant": "bf16",
|
|
51
|
+
"parameters": {
|
|
52
|
+
"num_inference_steps": 4,
|
|
53
|
+
"guidance_scale": 0.0,
|
|
54
|
+
"max_sequence_length": 256
|
|
55
|
+
},
|
|
56
|
+
"hardware_requirements": {
|
|
57
|
+
"min_vram_gb": 12,
|
|
58
|
+
"recommended_vram_gb": 16,
|
|
59
|
+
"min_ram_gb": 24,
|
|
60
|
+
"recommended_ram_gb": 32,
|
|
61
|
+
"disk_space_gb": 15,
|
|
62
|
+
"supported_devices": ["CUDA", "MPS"],
|
|
63
|
+
"performance_notes": "Fast distilled version of FLUX.1-dev. Generates images in ~4 steps. Requires NVIDIA RTX 4070+ or Apple M2 Pro+."
|
|
64
|
+
},
|
|
65
|
+
"license_info": {
|
|
66
|
+
"type": "Apache 2.0",
|
|
67
|
+
"requires_agreement": False,
|
|
68
|
+
"commercial_use": True
|
|
69
|
+
}
|
|
70
|
+
},
|
|
71
|
+
|
|
72
|
+
"stable-diffusion-3.5-medium": {
|
|
73
|
+
"repo_id": "stabilityai/stable-diffusion-3.5-medium",
|
|
74
|
+
"model_type": "sd3",
|
|
75
|
+
"variant": "fp16",
|
|
76
|
+
"parameters": {
|
|
77
|
+
"num_inference_steps": 28,
|
|
78
|
+
"guidance_scale": 3.5
|
|
79
|
+
},
|
|
80
|
+
"hardware_requirements": {
|
|
81
|
+
"min_vram_gb": 8,
|
|
82
|
+
"recommended_vram_gb": 12,
|
|
83
|
+
"min_ram_gb": 16,
|
|
84
|
+
"recommended_ram_gb": 32,
|
|
85
|
+
"disk_space_gb": 10,
|
|
86
|
+
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
87
|
+
"performance_notes": "Best on NVIDIA RTX 3080+ or Apple M2 Pro+"
|
|
88
|
+
}
|
|
89
|
+
},
|
|
90
|
+
"stable-diffusion-xl-base": {
|
|
91
|
+
"repo_id": "stabilityai/stable-diffusion-xl-base-1.0",
|
|
92
|
+
"model_type": "sdxl",
|
|
93
|
+
"variant": "fp16",
|
|
94
|
+
"parameters": {
|
|
95
|
+
"num_inference_steps": 50,
|
|
96
|
+
"guidance_scale": 7.5
|
|
97
|
+
},
|
|
98
|
+
"hardware_requirements": {
|
|
99
|
+
"min_vram_gb": 6,
|
|
100
|
+
"recommended_vram_gb": 10,
|
|
101
|
+
"min_ram_gb": 12,
|
|
102
|
+
"recommended_ram_gb": 24,
|
|
103
|
+
"disk_space_gb": 7,
|
|
104
|
+
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
105
|
+
"performance_notes": "Good on NVIDIA RTX 3070+ or Apple M1 Pro+"
|
|
106
|
+
}
|
|
107
|
+
},
|
|
108
|
+
"stable-diffusion-1.5": {
|
|
109
|
+
"repo_id": "runwayml/stable-diffusion-v1-5",
|
|
110
|
+
"model_type": "sd15",
|
|
111
|
+
"variant": "fp16",
|
|
112
|
+
"parameters": {
|
|
113
|
+
"num_inference_steps": 20,
|
|
114
|
+
"guidance_scale": 7.0
|
|
115
|
+
},
|
|
116
|
+
"hardware_requirements": {
|
|
117
|
+
"min_vram_gb": 4,
|
|
118
|
+
"recommended_vram_gb": 6,
|
|
119
|
+
"min_ram_gb": 8,
|
|
120
|
+
"recommended_ram_gb": 16,
|
|
121
|
+
"disk_space_gb": 5,
|
|
122
|
+
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
123
|
+
"performance_notes": "Runs well on most modern GPUs, including GTX 1060+"
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
def list_available_models(self) -> List[str]:
|
|
129
|
+
"""List all available models"""
|
|
130
|
+
return list(self.model_registry.keys())
|
|
131
|
+
|
|
132
|
+
def list_installed_models(self) -> List[str]:
|
|
133
|
+
"""List installed models"""
|
|
134
|
+
return list(settings.models.keys())
|
|
135
|
+
|
|
136
|
+
def is_model_installed(self, model_name: str) -> bool:
|
|
137
|
+
"""Check if model is installed"""
|
|
138
|
+
return model_name in settings.models
|
|
139
|
+
|
|
140
|
+
def get_model_info(self, model_name: str) -> Optional[Dict]:
|
|
141
|
+
"""Get model information"""
|
|
142
|
+
if model_name in self.model_registry:
|
|
143
|
+
info = self.model_registry[model_name].copy()
|
|
144
|
+
info['installed'] = self.is_model_installed(model_name)
|
|
145
|
+
if info['installed']:
|
|
146
|
+
config = settings.models[model_name]
|
|
147
|
+
info['local_path'] = config.path
|
|
148
|
+
info['size'] = self._get_model_size(config.path)
|
|
149
|
+
return info
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def _get_model_size(self, model_path: str) -> str:
|
|
153
|
+
"""Get model size"""
|
|
154
|
+
try:
|
|
155
|
+
path = Path(model_path)
|
|
156
|
+
if path.is_file():
|
|
157
|
+
size = path.stat().st_size
|
|
158
|
+
else:
|
|
159
|
+
size = sum(f.stat().st_size for f in path.rglob('*') if f.is_file())
|
|
160
|
+
|
|
161
|
+
# Convert to human readable format
|
|
162
|
+
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
|
163
|
+
if size < 1024.0:
|
|
164
|
+
return f"{size:.1f} {unit}"
|
|
165
|
+
size /= 1024.0
|
|
166
|
+
return f"{size:.1f} PB"
|
|
167
|
+
except Exception:
|
|
168
|
+
return "Unknown"
|
|
169
|
+
|
|
170
|
+
def pull_model(self, model_name: str, force: bool = False, progress_callback=None) -> bool:
|
|
171
|
+
"""Download model using robust download utilities with detailed progress tracking"""
|
|
172
|
+
if not force and self.is_model_installed(model_name):
|
|
173
|
+
logger.info(f"Model {model_name} already exists")
|
|
174
|
+
if progress_callback:
|
|
175
|
+
progress_callback(f"✅ Model {model_name} already installed")
|
|
176
|
+
return True
|
|
177
|
+
|
|
178
|
+
if model_name not in self.model_registry:
|
|
179
|
+
logger.error(f"Unknown model: {model_name}")
|
|
180
|
+
if progress_callback:
|
|
181
|
+
progress_callback(f"❌ Error: Unknown model {model_name}")
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
model_info = self.model_registry[model_name]
|
|
185
|
+
model_path = settings.get_model_path(model_name)
|
|
186
|
+
|
|
187
|
+
# Show model information before download
|
|
188
|
+
if progress_callback:
|
|
189
|
+
license_info = model_info.get("license_info", {})
|
|
190
|
+
progress_callback(f"📦 Model: {model_name}")
|
|
191
|
+
progress_callback(f"🔗 Repository: {model_info['repo_id']}")
|
|
192
|
+
if license_info:
|
|
193
|
+
progress_callback(f"📄 License: {license_info.get('type', 'Unknown')}")
|
|
194
|
+
if license_info.get('requires_agreement', False):
|
|
195
|
+
progress_callback(f"🔑 HuggingFace token required - ensure HF_TOKEN is set")
|
|
196
|
+
else:
|
|
197
|
+
progress_callback(f"✅ No HuggingFace token required")
|
|
198
|
+
|
|
199
|
+
# Check if partial download exists and is valid
|
|
200
|
+
if not force and model_path.exists():
|
|
201
|
+
if progress_callback:
|
|
202
|
+
progress_callback(f"🔍 Checking existing download...")
|
|
203
|
+
|
|
204
|
+
from ..utils.download_utils import check_download_integrity
|
|
205
|
+
if check_download_integrity(str(model_path), model_info["repo_id"]):
|
|
206
|
+
if progress_callback:
|
|
207
|
+
progress_callback(f"✅ Found complete download, adding to configuration...")
|
|
208
|
+
|
|
209
|
+
# Add to configuration
|
|
210
|
+
model_config = ModelConfig(
|
|
211
|
+
name=model_name,
|
|
212
|
+
path=str(model_path),
|
|
213
|
+
model_type=model_info["model_type"],
|
|
214
|
+
variant=model_info.get("variant"),
|
|
215
|
+
components=model_info.get("components"),
|
|
216
|
+
parameters=model_info.get("parameters")
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
settings.add_model(model_config)
|
|
220
|
+
logger.info(f"Model {model_name} configuration updated")
|
|
221
|
+
if progress_callback:
|
|
222
|
+
progress_callback(f"✅ {model_name} ready to use!")
|
|
223
|
+
return True
|
|
224
|
+
else:
|
|
225
|
+
if progress_callback:
|
|
226
|
+
progress_callback(f"⚠️ Incomplete download detected, will resume...")
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
# Ensure HuggingFace token is set
|
|
230
|
+
if settings.hf_token:
|
|
231
|
+
login(token=settings.hf_token)
|
|
232
|
+
if progress_callback:
|
|
233
|
+
progress_callback(f"🔑 Authenticated with HuggingFace")
|
|
234
|
+
else:
|
|
235
|
+
if progress_callback:
|
|
236
|
+
progress_callback(f"⚠️ No HuggingFace token found - some models may not be accessible")
|
|
237
|
+
|
|
238
|
+
logger.info(f"Downloading model: {model_name}")
|
|
239
|
+
if progress_callback:
|
|
240
|
+
progress_callback(f"🚀 Starting download of {model_name}")
|
|
241
|
+
|
|
242
|
+
# Download main model using robust downloader with enhanced progress
|
|
243
|
+
from ..utils.download_utils import robust_snapshot_download
|
|
244
|
+
robust_snapshot_download(
|
|
245
|
+
repo_id=model_info["repo_id"],
|
|
246
|
+
local_dir=str(model_path),
|
|
247
|
+
cache_dir=str(settings.cache_dir),
|
|
248
|
+
max_retries=5, # Increased retries for large models
|
|
249
|
+
initial_workers=4, # More workers for faster download
|
|
250
|
+
force_download=force,
|
|
251
|
+
progress_callback=progress_callback
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Download components (such as LoRA)
|
|
255
|
+
if "components" in model_info:
|
|
256
|
+
components_path = model_path / "components"
|
|
257
|
+
components_path.mkdir(exist_ok=True)
|
|
258
|
+
|
|
259
|
+
for comp_name, comp_info in model_info["components"].items():
|
|
260
|
+
comp_path = components_path / comp_name
|
|
261
|
+
comp_path.mkdir(exist_ok=True)
|
|
262
|
+
|
|
263
|
+
if progress_callback:
|
|
264
|
+
progress_callback(f"📦 Downloading component: {comp_name}")
|
|
265
|
+
|
|
266
|
+
if "filename" in comp_info:
|
|
267
|
+
# Download single file using robust downloader
|
|
268
|
+
from ..utils.download_utils import robust_file_download
|
|
269
|
+
robust_file_download(
|
|
270
|
+
repo_id=comp_info["repo_id"],
|
|
271
|
+
filename=comp_info["filename"],
|
|
272
|
+
local_dir=str(comp_path),
|
|
273
|
+
cache_dir=str(settings.cache_dir),
|
|
274
|
+
max_retries=3,
|
|
275
|
+
progress_callback=progress_callback
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
# Download entire repository using robust downloader
|
|
279
|
+
robust_snapshot_download(
|
|
280
|
+
repo_id=comp_info["repo_id"],
|
|
281
|
+
local_dir=str(comp_path),
|
|
282
|
+
cache_dir=str(settings.cache_dir),
|
|
283
|
+
max_retries=3,
|
|
284
|
+
initial_workers=2, # Use fewer workers for components
|
|
285
|
+
force_download=force,
|
|
286
|
+
progress_callback=progress_callback
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Verify download integrity
|
|
290
|
+
if progress_callback:
|
|
291
|
+
progress_callback(f"🔍 Verifying download integrity...")
|
|
292
|
+
|
|
293
|
+
from ..utils.download_utils import check_download_integrity
|
|
294
|
+
if not check_download_integrity(str(model_path), model_info["repo_id"]):
|
|
295
|
+
raise Exception("Download integrity check failed - some files may be missing or corrupted")
|
|
296
|
+
|
|
297
|
+
# Add to configuration
|
|
298
|
+
model_config = ModelConfig(
|
|
299
|
+
name=model_name,
|
|
300
|
+
path=str(model_path),
|
|
301
|
+
model_type=model_info["model_type"],
|
|
302
|
+
variant=model_info.get("variant"),
|
|
303
|
+
components=model_info.get("components"),
|
|
304
|
+
parameters=model_info.get("parameters")
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
settings.add_model(model_config)
|
|
308
|
+
logger.info(f"Model {model_name} download completed")
|
|
309
|
+
if progress_callback:
|
|
310
|
+
progress_callback(f"✅ {model_name} download completed successfully and verified!")
|
|
311
|
+
return True
|
|
312
|
+
|
|
313
|
+
except Exception as e:
|
|
314
|
+
logger.error(f"Model download failed: {e}")
|
|
315
|
+
if progress_callback:
|
|
316
|
+
progress_callback(f"❌ Download failed: {str(e)}")
|
|
317
|
+
|
|
318
|
+
# Clean up failed download only if it's a fresh download
|
|
319
|
+
if force and model_path.exists():
|
|
320
|
+
try:
|
|
321
|
+
shutil.rmtree(model_path)
|
|
322
|
+
logger.info(f"Cleaned up failed download directory: {model_path}")
|
|
323
|
+
if progress_callback:
|
|
324
|
+
progress_callback(f"🧹 Cleaned up incomplete download")
|
|
325
|
+
except Exception as cleanup_error:
|
|
326
|
+
logger.warning(f"Failed to clean up directory {model_path}: {cleanup_error}")
|
|
327
|
+
return False
|
|
328
|
+
|
|
329
|
+
def remove_model(self, model_name: str) -> bool:
|
|
330
|
+
"""Remove model"""
|
|
331
|
+
if not self.is_model_installed(model_name):
|
|
332
|
+
logger.error(f"Model {model_name} is not installed")
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
try:
|
|
336
|
+
# If currently using this model, unload it first
|
|
337
|
+
if self.current_model_name == model_name:
|
|
338
|
+
self.unload_model()
|
|
339
|
+
|
|
340
|
+
# Delete model files
|
|
341
|
+
model_config = settings.models[model_name]
|
|
342
|
+
model_path = Path(model_config.path)
|
|
343
|
+
if model_path.exists():
|
|
344
|
+
shutil.rmtree(model_path)
|
|
345
|
+
|
|
346
|
+
# Remove from configuration
|
|
347
|
+
settings.remove_model(model_name)
|
|
348
|
+
|
|
349
|
+
logger.info(f"Model {model_name} has been removed")
|
|
350
|
+
return True
|
|
351
|
+
|
|
352
|
+
except Exception as e:
|
|
353
|
+
logger.error(f"Failed to remove model: {e}")
|
|
354
|
+
return False
|
|
355
|
+
|
|
356
|
+
def load_model(self, model_name: str) -> bool:
|
|
357
|
+
"""Load model into memory"""
|
|
358
|
+
if not self.is_model_installed(model_name):
|
|
359
|
+
logger.error(f"Model {model_name} is not installed")
|
|
360
|
+
return False
|
|
361
|
+
|
|
362
|
+
# If the same model is already loaded, return directly
|
|
363
|
+
if self.current_model_name == model_name:
|
|
364
|
+
logger.info(f"Model {model_name} is already loaded")
|
|
365
|
+
return True
|
|
366
|
+
|
|
367
|
+
# Unload current model
|
|
368
|
+
if self.loaded_model is not None:
|
|
369
|
+
self.unload_model()
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
from ..inference.engine import InferenceEngine
|
|
373
|
+
|
|
374
|
+
model_config = settings.models[model_name]
|
|
375
|
+
engine = InferenceEngine()
|
|
376
|
+
|
|
377
|
+
if engine.load_model(model_config):
|
|
378
|
+
self.loaded_model = engine
|
|
379
|
+
self.current_model_name = model_name
|
|
380
|
+
settings.set_current_model(model_name)
|
|
381
|
+
logger.info(f"Model {model_name} loaded successfully")
|
|
382
|
+
return True
|
|
383
|
+
else:
|
|
384
|
+
logger.error(f"Model {model_name} failed to load")
|
|
385
|
+
return False
|
|
386
|
+
|
|
387
|
+
except Exception as e:
|
|
388
|
+
logger.error(f"Failed to load model: {e}")
|
|
389
|
+
return False
|
|
390
|
+
|
|
391
|
+
def unload_model(self):
|
|
392
|
+
"""Unload current model"""
|
|
393
|
+
if self.loaded_model is not None:
|
|
394
|
+
try:
|
|
395
|
+
self.loaded_model.unload()
|
|
396
|
+
logger.info(f"Model {self.current_model_name} unloaded")
|
|
397
|
+
except Exception as e:
|
|
398
|
+
logger.error(f"Failed to unload model: {e}")
|
|
399
|
+
finally:
|
|
400
|
+
self.loaded_model = None
|
|
401
|
+
self.current_model_name = None
|
|
402
|
+
|
|
403
|
+
# Also clear the persisted state
|
|
404
|
+
settings.current_model = None
|
|
405
|
+
settings.save_config()
|
|
406
|
+
|
|
407
|
+
def get_current_model(self) -> Optional[str]:
|
|
408
|
+
"""Get current loaded model name"""
|
|
409
|
+
# First check in-memory state
|
|
410
|
+
if self.current_model_name:
|
|
411
|
+
return self.current_model_name
|
|
412
|
+
# Then check persisted state
|
|
413
|
+
return settings.current_model
|
|
414
|
+
|
|
415
|
+
def is_model_loaded(self) -> bool:
|
|
416
|
+
"""Check if a model is loaded in memory"""
|
|
417
|
+
# Only check in-memory state - a model is truly loaded only if it's in memory
|
|
418
|
+
return self.loaded_model is not None
|
|
419
|
+
|
|
420
|
+
def has_current_model(self) -> bool:
|
|
421
|
+
"""Check if there's a current model set (may not be loaded in memory)"""
|
|
422
|
+
return settings.current_model is not None
|
|
423
|
+
|
|
424
|
+
def is_server_running(self) -> bool:
|
|
425
|
+
"""Check if the server is actually running"""
|
|
426
|
+
try:
|
|
427
|
+
import requests
|
|
428
|
+
host = settings.server.host
|
|
429
|
+
port = settings.server.port
|
|
430
|
+
response = requests.get(f"http://{host}:{port}/api/health", timeout=2)
|
|
431
|
+
return response.status_code == 200
|
|
432
|
+
except:
|
|
433
|
+
return False
|
|
434
|
+
|
|
435
|
+
# Global model manager instance
|
|
436
|
+
model_manager = ModelManager()
|