ollamadiffuser 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/cli/main.py +366 -28
- ollamadiffuser/core/config/model_registry.py +757 -0
- ollamadiffuser/core/inference/engine.py +334 -4
- ollamadiffuser/core/models/gguf_loader.py +437 -0
- ollamadiffuser/core/models/manager.py +139 -312
- ollamadiffuser/core/models/registry.py +384 -0
- ollamadiffuser/core/utils/download_utils.py +35 -2
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/METADATA +89 -10
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/RECORD +14 -11
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/WHEEL +0 -0
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/entry_points.txt +0 -0
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.1.6.dist-info → ollamadiffuser-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -6,262 +6,28 @@ import logging
|
|
|
6
6
|
import hashlib
|
|
7
7
|
from huggingface_hub import login
|
|
8
8
|
from ..config.settings import settings, ModelConfig
|
|
9
|
+
from ..config.model_registry import model_registry
|
|
9
10
|
from ..utils.download_utils import robust_snapshot_download, robust_file_download
|
|
11
|
+
from .gguf_loader import gguf_loader, GGUF_AVAILABLE
|
|
10
12
|
|
|
11
13
|
logger = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
class ModelManager:
|
|
14
|
-
"""Model manager"""
|
|
16
|
+
"""Model manager with dynamic registry support and GGUF compatibility"""
|
|
15
17
|
|
|
16
18
|
def __init__(self):
|
|
17
19
|
self.loaded_model: Optional[object] = None
|
|
18
20
|
self.current_model_name: Optional[str] = None
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"variant": "bf16",
|
|
26
|
-
"parameters": {
|
|
27
|
-
"num_inference_steps": 16,
|
|
28
|
-
"guidance_scale": 2.0,
|
|
29
|
-
"max_sequence_length": 512
|
|
30
|
-
},
|
|
31
|
-
"hardware_requirements": {
|
|
32
|
-
"min_vram_gb": 12,
|
|
33
|
-
"recommended_vram_gb": 16,
|
|
34
|
-
"min_ram_gb": 24,
|
|
35
|
-
"recommended_ram_gb": 32,
|
|
36
|
-
"disk_space_gb": 15,
|
|
37
|
-
"supported_devices": ["CUDA", "MPS"],
|
|
38
|
-
"performance_notes": "Requires NVIDIA RTX 4070+ or Apple M2 Pro+. Needs HuggingFace token. Use 'lora pull' to add LoRA styles."
|
|
39
|
-
},
|
|
40
|
-
"license_info": {
|
|
41
|
-
"type": "FLUX.1-dev Non-Commercial License",
|
|
42
|
-
"requires_agreement": True,
|
|
43
|
-
"commercial_use": False
|
|
44
|
-
}
|
|
45
|
-
},
|
|
46
|
-
|
|
47
|
-
"flux.1-schnell": {
|
|
48
|
-
"repo_id": "black-forest-labs/FLUX.1-schnell",
|
|
49
|
-
"model_type": "flux",
|
|
50
|
-
"variant": "bf16",
|
|
51
|
-
"parameters": {
|
|
52
|
-
"num_inference_steps": 4,
|
|
53
|
-
"guidance_scale": 0.0,
|
|
54
|
-
"max_sequence_length": 256
|
|
55
|
-
},
|
|
56
|
-
"hardware_requirements": {
|
|
57
|
-
"min_vram_gb": 12,
|
|
58
|
-
"recommended_vram_gb": 16,
|
|
59
|
-
"min_ram_gb": 24,
|
|
60
|
-
"recommended_ram_gb": 32,
|
|
61
|
-
"disk_space_gb": 15,
|
|
62
|
-
"supported_devices": ["CUDA", "MPS"],
|
|
63
|
-
"performance_notes": "Fast distilled version of FLUX.1-dev. Generates images in ~4 steps. Requires NVIDIA RTX 4070+ or Apple M2 Pro+."
|
|
64
|
-
},
|
|
65
|
-
"license_info": {
|
|
66
|
-
"type": "Apache 2.0",
|
|
67
|
-
"requires_agreement": False,
|
|
68
|
-
"commercial_use": True
|
|
69
|
-
}
|
|
70
|
-
},
|
|
71
|
-
|
|
72
|
-
"stable-diffusion-3.5-medium": {
|
|
73
|
-
"repo_id": "stabilityai/stable-diffusion-3.5-medium",
|
|
74
|
-
"model_type": "sd3",
|
|
75
|
-
"variant": "fp16",
|
|
76
|
-
"parameters": {
|
|
77
|
-
"num_inference_steps": 28,
|
|
78
|
-
"guidance_scale": 3.5
|
|
79
|
-
},
|
|
80
|
-
"hardware_requirements": {
|
|
81
|
-
"min_vram_gb": 8,
|
|
82
|
-
"recommended_vram_gb": 12,
|
|
83
|
-
"min_ram_gb": 16,
|
|
84
|
-
"recommended_ram_gb": 32,
|
|
85
|
-
"disk_space_gb": 10,
|
|
86
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
87
|
-
"performance_notes": "Best on NVIDIA RTX 3080+ or Apple M2 Pro+"
|
|
88
|
-
}
|
|
89
|
-
},
|
|
90
|
-
"stable-diffusion-xl-base": {
|
|
91
|
-
"repo_id": "stabilityai/stable-diffusion-xl-base-1.0",
|
|
92
|
-
"model_type": "sdxl",
|
|
93
|
-
"variant": "fp16",
|
|
94
|
-
"parameters": {
|
|
95
|
-
"num_inference_steps": 50,
|
|
96
|
-
"guidance_scale": 7.5
|
|
97
|
-
},
|
|
98
|
-
"hardware_requirements": {
|
|
99
|
-
"min_vram_gb": 6,
|
|
100
|
-
"recommended_vram_gb": 10,
|
|
101
|
-
"min_ram_gb": 12,
|
|
102
|
-
"recommended_ram_gb": 24,
|
|
103
|
-
"disk_space_gb": 7,
|
|
104
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
105
|
-
"performance_notes": "Good on NVIDIA RTX 3070+ or Apple M1 Pro+"
|
|
106
|
-
}
|
|
107
|
-
},
|
|
108
|
-
"stable-diffusion-1.5": {
|
|
109
|
-
"repo_id": "runwayml/stable-diffusion-v1-5",
|
|
110
|
-
"model_type": "sd15",
|
|
111
|
-
"variant": "fp16",
|
|
112
|
-
"parameters": {
|
|
113
|
-
"num_inference_steps": 50,
|
|
114
|
-
"guidance_scale": 7.5
|
|
115
|
-
},
|
|
116
|
-
"hardware_requirements": {
|
|
117
|
-
"min_vram_gb": 4,
|
|
118
|
-
"recommended_vram_gb": 6,
|
|
119
|
-
"min_ram_gb": 8,
|
|
120
|
-
"recommended_ram_gb": 16,
|
|
121
|
-
"disk_space_gb": 5,
|
|
122
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
123
|
-
"performance_notes": "Runs well on most modern GPUs, including GTX 1060+"
|
|
124
|
-
}
|
|
125
|
-
},
|
|
126
|
-
|
|
127
|
-
# ControlNet models for SD 1.5
|
|
128
|
-
"controlnet-canny-sd15": {
|
|
129
|
-
"repo_id": "lllyasviel/sd-controlnet-canny",
|
|
130
|
-
"model_type": "controlnet_sd15",
|
|
131
|
-
"base_model": "stable-diffusion-1.5",
|
|
132
|
-
"controlnet_type": "canny",
|
|
133
|
-
"variant": "fp16",
|
|
134
|
-
"parameters": {
|
|
135
|
-
"num_inference_steps": 50,
|
|
136
|
-
"guidance_scale": 7.5,
|
|
137
|
-
"controlnet_conditioning_scale": 1.0
|
|
138
|
-
},
|
|
139
|
-
"hardware_requirements": {
|
|
140
|
-
"min_vram_gb": 6,
|
|
141
|
-
"recommended_vram_gb": 8,
|
|
142
|
-
"min_ram_gb": 12,
|
|
143
|
-
"recommended_ram_gb": 20,
|
|
144
|
-
"disk_space_gb": 7,
|
|
145
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
146
|
-
"performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for edge detection."
|
|
147
|
-
}
|
|
148
|
-
},
|
|
149
|
-
|
|
150
|
-
"controlnet-depth-sd15": {
|
|
151
|
-
"repo_id": "lllyasviel/sd-controlnet-depth",
|
|
152
|
-
"model_type": "controlnet_sd15",
|
|
153
|
-
"base_model": "stable-diffusion-1.5",
|
|
154
|
-
"controlnet_type": "depth",
|
|
155
|
-
"variant": "fp16",
|
|
156
|
-
"parameters": {
|
|
157
|
-
"num_inference_steps": 50,
|
|
158
|
-
"guidance_scale": 7.5,
|
|
159
|
-
"controlnet_conditioning_scale": 1.0
|
|
160
|
-
},
|
|
161
|
-
"hardware_requirements": {
|
|
162
|
-
"min_vram_gb": 6,
|
|
163
|
-
"recommended_vram_gb": 8,
|
|
164
|
-
"min_ram_gb": 12,
|
|
165
|
-
"recommended_ram_gb": 20,
|
|
166
|
-
"disk_space_gb": 7,
|
|
167
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
168
|
-
"performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for depth-based control."
|
|
169
|
-
}
|
|
170
|
-
},
|
|
171
|
-
|
|
172
|
-
"controlnet-openpose-sd15": {
|
|
173
|
-
"repo_id": "lllyasviel/sd-controlnet-openpose",
|
|
174
|
-
"model_type": "controlnet_sd15",
|
|
175
|
-
"base_model": "stable-diffusion-1.5",
|
|
176
|
-
"controlnet_type": "openpose",
|
|
177
|
-
"variant": "fp16",
|
|
178
|
-
"parameters": {
|
|
179
|
-
"num_inference_steps": 50,
|
|
180
|
-
"guidance_scale": 7.5,
|
|
181
|
-
"controlnet_conditioning_scale": 1.0
|
|
182
|
-
},
|
|
183
|
-
"hardware_requirements": {
|
|
184
|
-
"min_vram_gb": 6,
|
|
185
|
-
"recommended_vram_gb": 8,
|
|
186
|
-
"min_ram_gb": 12,
|
|
187
|
-
"recommended_ram_gb": 20,
|
|
188
|
-
"disk_space_gb": 7,
|
|
189
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
190
|
-
"performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for pose control."
|
|
191
|
-
}
|
|
192
|
-
},
|
|
193
|
-
|
|
194
|
-
"controlnet-scribble-sd15": {
|
|
195
|
-
"repo_id": "lllyasviel/sd-controlnet-scribble",
|
|
196
|
-
"model_type": "controlnet_sd15",
|
|
197
|
-
"base_model": "stable-diffusion-1.5",
|
|
198
|
-
"controlnet_type": "scribble",
|
|
199
|
-
"variant": "fp16",
|
|
200
|
-
"parameters": {
|
|
201
|
-
"num_inference_steps": 50,
|
|
202
|
-
"guidance_scale": 7.5,
|
|
203
|
-
"controlnet_conditioning_scale": 1.0
|
|
204
|
-
},
|
|
205
|
-
"hardware_requirements": {
|
|
206
|
-
"min_vram_gb": 6,
|
|
207
|
-
"recommended_vram_gb": 8,
|
|
208
|
-
"min_ram_gb": 12,
|
|
209
|
-
"recommended_ram_gb": 20,
|
|
210
|
-
"disk_space_gb": 7,
|
|
211
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
212
|
-
"performance_notes": "Requires base SD 1.5 model + ControlNet model. Good for sketch-based control."
|
|
213
|
-
}
|
|
214
|
-
},
|
|
215
|
-
|
|
216
|
-
# ControlNet models for SDXL
|
|
217
|
-
"controlnet-canny-sdxl": {
|
|
218
|
-
"repo_id": "diffusers/controlnet-canny-sdxl-1.0",
|
|
219
|
-
"model_type": "controlnet_sdxl",
|
|
220
|
-
"base_model": "stable-diffusion-xl-base",
|
|
221
|
-
"controlnet_type": "canny",
|
|
222
|
-
"variant": "fp16",
|
|
223
|
-
"parameters": {
|
|
224
|
-
"num_inference_steps": 50,
|
|
225
|
-
"guidance_scale": 7.5,
|
|
226
|
-
"controlnet_conditioning_scale": 1.0
|
|
227
|
-
},
|
|
228
|
-
"hardware_requirements": {
|
|
229
|
-
"min_vram_gb": 8,
|
|
230
|
-
"recommended_vram_gb": 12,
|
|
231
|
-
"min_ram_gb": 16,
|
|
232
|
-
"recommended_ram_gb": 28,
|
|
233
|
-
"disk_space_gb": 10,
|
|
234
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
235
|
-
"performance_notes": "Requires base SDXL model + ControlNet model. Good for edge detection with SDXL quality."
|
|
236
|
-
}
|
|
237
|
-
},
|
|
238
|
-
|
|
239
|
-
"controlnet-depth-sdxl": {
|
|
240
|
-
"repo_id": "diffusers/controlnet-depth-sdxl-1.0",
|
|
241
|
-
"model_type": "controlnet_sdxl",
|
|
242
|
-
"base_model": "stable-diffusion-xl-base",
|
|
243
|
-
"controlnet_type": "depth",
|
|
244
|
-
"variant": "fp16",
|
|
245
|
-
"parameters": {
|
|
246
|
-
"num_inference_steps": 50,
|
|
247
|
-
"guidance_scale": 7.5,
|
|
248
|
-
"controlnet_conditioning_scale": 1.0
|
|
249
|
-
},
|
|
250
|
-
"hardware_requirements": {
|
|
251
|
-
"min_vram_gb": 8,
|
|
252
|
-
"recommended_vram_gb": 12,
|
|
253
|
-
"min_ram_gb": 16,
|
|
254
|
-
"recommended_ram_gb": 28,
|
|
255
|
-
"disk_space_gb": 10,
|
|
256
|
-
"supported_devices": ["CUDA", "MPS", "CPU"],
|
|
257
|
-
"performance_notes": "Requires base SDXL model + ControlNet model. Good for depth-based control with SDXL quality."
|
|
258
|
-
}
|
|
259
|
-
}
|
|
260
|
-
}
|
|
21
|
+
self.current_model_type: Optional[str] = None # Track model type
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def model_registry(self):
|
|
25
|
+
"""Get the current model registry (for backward compatibility)"""
|
|
26
|
+
return model_registry.get_all_models()
|
|
261
27
|
|
|
262
28
|
def list_available_models(self) -> List[str]:
|
|
263
29
|
"""List all available models"""
|
|
264
|
-
return
|
|
30
|
+
return model_registry.get_model_names()
|
|
265
31
|
|
|
266
32
|
def list_installed_models(self) -> List[str]:
|
|
267
33
|
"""List installed models"""
|
|
@@ -271,11 +37,24 @@ class ModelManager:
|
|
|
271
37
|
"""Check if model is installed"""
|
|
272
38
|
return model_name in settings.models
|
|
273
39
|
|
|
40
|
+
def is_gguf_model(self, model_name: str) -> bool:
|
|
41
|
+
"""Check if a model is a GGUF model"""
|
|
42
|
+
if not model_name:
|
|
43
|
+
return False
|
|
44
|
+
model_info = model_registry.get_model(model_name)
|
|
45
|
+
if model_info:
|
|
46
|
+
return gguf_loader.is_gguf_model(model_name, model_info)
|
|
47
|
+
return False
|
|
48
|
+
|
|
274
49
|
def get_model_info(self, model_name: str) -> Optional[Dict]:
|
|
275
50
|
"""Get model information"""
|
|
276
|
-
|
|
277
|
-
|
|
51
|
+
info = model_registry.get_model(model_name)
|
|
52
|
+
if info:
|
|
53
|
+
# Create a copy to avoid modifying the original
|
|
54
|
+
info = info.copy()
|
|
278
55
|
info['installed'] = self.is_model_installed(model_name)
|
|
56
|
+
info['is_gguf'] = self.is_gguf_model(model_name)
|
|
57
|
+
info['gguf_supported'] = GGUF_AVAILABLE
|
|
279
58
|
if info['installed']:
|
|
280
59
|
config = settings.models[model_name]
|
|
281
60
|
info['local_path'] = config.path
|
|
@@ -309,13 +88,13 @@ class ModelManager:
|
|
|
309
88
|
progress_callback(f"✅ Model {model_name} already installed")
|
|
310
89
|
return True
|
|
311
90
|
|
|
312
|
-
|
|
91
|
+
model_info = model_registry.get_model(model_name)
|
|
92
|
+
if not model_info:
|
|
313
93
|
logger.error(f"Unknown model: {model_name}")
|
|
314
94
|
if progress_callback:
|
|
315
95
|
progress_callback(f"❌ Error: Unknown model {model_name}")
|
|
316
96
|
return False
|
|
317
97
|
|
|
318
|
-
model_info = self.model_registry[model_name]
|
|
319
98
|
model_path = settings.get_model_path(model_name)
|
|
320
99
|
|
|
321
100
|
# Show model information before download
|
|
@@ -373,17 +152,32 @@ class ModelManager:
|
|
|
373
152
|
if progress_callback:
|
|
374
153
|
progress_callback(f"🚀 Starting download of {model_name}")
|
|
375
154
|
|
|
155
|
+
# Determine download patterns for GGUF models
|
|
156
|
+
download_kwargs = {
|
|
157
|
+
"repo_id": model_info["repo_id"],
|
|
158
|
+
"local_dir": str(model_path),
|
|
159
|
+
"cache_dir": str(settings.cache_dir),
|
|
160
|
+
"max_retries": 5, # Increased retries for large models
|
|
161
|
+
"initial_workers": 4, # More workers for faster download
|
|
162
|
+
"force_download": force,
|
|
163
|
+
"progress_callback": progress_callback
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# Add GGUF-specific file filtering
|
|
167
|
+
if self.is_gguf_model(model_name):
|
|
168
|
+
variant = model_info.get("variant", "gguf")
|
|
169
|
+
patterns = gguf_loader.get_gguf_download_patterns(variant)
|
|
170
|
+
download_kwargs["allow_patterns"] = patterns["allow_patterns"]
|
|
171
|
+
download_kwargs["ignore_patterns"] = patterns["ignore_patterns"]
|
|
172
|
+
|
|
173
|
+
if progress_callback:
|
|
174
|
+
progress_callback(f"🔍 GGUF model detected - downloading only required files for {variant}")
|
|
175
|
+
progress_callback(f"📦 Required files: {len(patterns['allow_patterns'])} files")
|
|
176
|
+
progress_callback(f"🚫 Ignoring: {len(patterns['ignore_patterns'])} other GGUF variants")
|
|
177
|
+
|
|
376
178
|
# Download main model using robust downloader with enhanced progress
|
|
377
179
|
from ..utils.download_utils import robust_snapshot_download
|
|
378
|
-
robust_snapshot_download(
|
|
379
|
-
repo_id=model_info["repo_id"],
|
|
380
|
-
local_dir=str(model_path),
|
|
381
|
-
cache_dir=str(settings.cache_dir),
|
|
382
|
-
max_retries=5, # Increased retries for large models
|
|
383
|
-
initial_workers=4, # More workers for faster download
|
|
384
|
-
force_download=force,
|
|
385
|
-
progress_callback=progress_callback
|
|
386
|
-
)
|
|
180
|
+
robust_snapshot_download(**download_kwargs)
|
|
387
181
|
|
|
388
182
|
# Download components (such as LoRA)
|
|
389
183
|
if "components" in model_info:
|
|
@@ -397,38 +191,18 @@ class ModelManager:
|
|
|
397
191
|
if progress_callback:
|
|
398
192
|
progress_callback(f"📦 Downloading component: {comp_name}")
|
|
399
193
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
)
|
|
411
|
-
else:
|
|
412
|
-
# Download entire repository using robust downloader
|
|
413
|
-
robust_snapshot_download(
|
|
414
|
-
repo_id=comp_info["repo_id"],
|
|
415
|
-
local_dir=str(comp_path),
|
|
416
|
-
cache_dir=str(settings.cache_dir),
|
|
417
|
-
max_retries=3,
|
|
418
|
-
initial_workers=2, # Use fewer workers for components
|
|
419
|
-
force_download=force,
|
|
420
|
-
progress_callback=progress_callback
|
|
421
|
-
)
|
|
422
|
-
|
|
423
|
-
# Verify download integrity
|
|
424
|
-
if progress_callback:
|
|
425
|
-
progress_callback(f"🔍 Verifying download integrity...")
|
|
426
|
-
|
|
427
|
-
from ..utils.download_utils import check_download_integrity
|
|
428
|
-
if not check_download_integrity(str(model_path), model_info["repo_id"]):
|
|
429
|
-
raise Exception("Download integrity check failed - some files may be missing or corrupted")
|
|
194
|
+
robust_snapshot_download(
|
|
195
|
+
repo_id=comp_info["repo_id"],
|
|
196
|
+
local_dir=str(comp_path),
|
|
197
|
+
cache_dir=str(settings.cache_dir),
|
|
198
|
+
allow_patterns=comp_info.get("allow_patterns"),
|
|
199
|
+
ignore_patterns=comp_info.get("ignore_patterns", ["*.git*", "README.md", "*.txt"]),
|
|
200
|
+
max_retries=3,
|
|
201
|
+
initial_workers=2,
|
|
202
|
+
progress_callback=progress_callback
|
|
203
|
+
)
|
|
430
204
|
|
|
431
|
-
#
|
|
205
|
+
# Create model configuration
|
|
432
206
|
model_config = ModelConfig(
|
|
433
207
|
name=model_name,
|
|
434
208
|
path=str(model_path),
|
|
@@ -438,26 +212,27 @@ class ModelManager:
|
|
|
438
212
|
parameters=model_info.get("parameters")
|
|
439
213
|
)
|
|
440
214
|
|
|
215
|
+
# Add to settings
|
|
441
216
|
settings.add_model(model_config)
|
|
442
|
-
|
|
217
|
+
|
|
218
|
+
logger.info(f"Model {model_name} downloaded successfully")
|
|
443
219
|
if progress_callback:
|
|
444
|
-
progress_callback(f"✅ {model_name}
|
|
220
|
+
progress_callback(f"✅ {model_name} downloaded and configured successfully!")
|
|
221
|
+
|
|
445
222
|
return True
|
|
446
223
|
|
|
447
224
|
except Exception as e:
|
|
448
|
-
logger.error(f"
|
|
225
|
+
logger.error(f"Download failed: {str(e)}")
|
|
449
226
|
if progress_callback:
|
|
450
227
|
progress_callback(f"❌ Download failed: {str(e)}")
|
|
451
228
|
|
|
452
|
-
# Clean up
|
|
453
|
-
if
|
|
229
|
+
# Clean up partial download
|
|
230
|
+
if model_path.exists():
|
|
454
231
|
try:
|
|
455
232
|
shutil.rmtree(model_path)
|
|
456
|
-
logger.info(f"Cleaned up failed download directory: {model_path}")
|
|
457
|
-
if progress_callback:
|
|
458
|
-
progress_callback(f"🧹 Cleaned up incomplete download")
|
|
459
233
|
except Exception as cleanup_error:
|
|
460
|
-
logger.warning(f"Failed to clean up
|
|
234
|
+
logger.warning(f"Failed to clean up partial download: {cleanup_error}")
|
|
235
|
+
|
|
461
236
|
return False
|
|
462
237
|
|
|
463
238
|
def remove_model(self, model_name: str) -> bool:
|
|
@@ -488,7 +263,7 @@ class ModelManager:
|
|
|
488
263
|
return False
|
|
489
264
|
|
|
490
265
|
def load_model(self, model_name: str) -> bool:
|
|
491
|
-
"""Load model into memory"""
|
|
266
|
+
"""Load model into memory (supports both regular and GGUF models)"""
|
|
492
267
|
if not self.is_model_installed(model_name):
|
|
493
268
|
logger.error(f"Model {model_name} is not installed")
|
|
494
269
|
return False
|
|
@@ -503,36 +278,72 @@ class ModelManager:
|
|
|
503
278
|
self.unload_model()
|
|
504
279
|
|
|
505
280
|
try:
|
|
506
|
-
from ..inference.engine import InferenceEngine
|
|
507
|
-
|
|
508
281
|
model_config = settings.models[model_name]
|
|
509
|
-
engine = InferenceEngine()
|
|
510
282
|
|
|
511
|
-
if
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
283
|
+
# Check if this is a GGUF model
|
|
284
|
+
if self.is_gguf_model(model_name):
|
|
285
|
+
if not GGUF_AVAILABLE:
|
|
286
|
+
logger.error("GGUF support not available. Install with: pip install llama-cpp-python gguf")
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
# Load GGUF model
|
|
290
|
+
model_config_dict = {
|
|
291
|
+
'name': model_name,
|
|
292
|
+
'path': model_config.path,
|
|
293
|
+
'variant': model_config.variant,
|
|
294
|
+
'model_type': model_config.model_type,
|
|
295
|
+
'parameters': model_config.parameters
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
if gguf_loader.load_model(model_config_dict):
|
|
299
|
+
self.loaded_model = gguf_loader
|
|
300
|
+
self.current_model_name = model_name
|
|
301
|
+
self.current_model_type = 'gguf'
|
|
302
|
+
settings.set_current_model(model_name)
|
|
303
|
+
logger.info(f"GGUF model {model_name} loaded successfully")
|
|
304
|
+
return True
|
|
305
|
+
else:
|
|
306
|
+
logger.error(f"GGUF model {model_name} failed to load")
|
|
307
|
+
return False
|
|
517
308
|
else:
|
|
518
|
-
|
|
519
|
-
|
|
309
|
+
# Load regular diffusion model
|
|
310
|
+
from ..inference.engine import InferenceEngine
|
|
311
|
+
|
|
312
|
+
engine = InferenceEngine()
|
|
313
|
+
|
|
314
|
+
if engine.load_model(model_config):
|
|
315
|
+
self.loaded_model = engine
|
|
316
|
+
self.current_model_name = model_name
|
|
317
|
+
self.current_model_type = 'diffusion'
|
|
318
|
+
settings.set_current_model(model_name)
|
|
319
|
+
logger.info(f"Model {model_name} loaded successfully")
|
|
320
|
+
return True
|
|
321
|
+
else:
|
|
322
|
+
logger.error(f"Model {model_name} failed to load")
|
|
323
|
+
return False
|
|
520
324
|
|
|
521
325
|
except Exception as e:
|
|
522
326
|
logger.error(f"Failed to load model: {e}")
|
|
523
327
|
return False
|
|
524
328
|
|
|
525
329
|
def unload_model(self):
|
|
526
|
-
"""Unload current model"""
|
|
330
|
+
"""Unload current model (supports both regular and GGUF models)"""
|
|
527
331
|
if self.loaded_model is not None:
|
|
528
332
|
try:
|
|
529
|
-
self.
|
|
530
|
-
|
|
333
|
+
if self.current_model_type == 'gguf':
|
|
334
|
+
# Unload GGUF model
|
|
335
|
+
gguf_loader.unload_model()
|
|
336
|
+
logger.info(f"GGUF model {self.current_model_name} unloaded")
|
|
337
|
+
else:
|
|
338
|
+
# Unload regular model
|
|
339
|
+
self.loaded_model.unload()
|
|
340
|
+
logger.info(f"Model {self.current_model_name} unloaded")
|
|
531
341
|
except Exception as e:
|
|
532
342
|
logger.error(f"Failed to unload model: {e}")
|
|
533
343
|
finally:
|
|
534
344
|
self.loaded_model = None
|
|
535
345
|
self.current_model_name = None
|
|
346
|
+
self.current_model_type = None
|
|
536
347
|
|
|
537
348
|
# Also clear the persisted state
|
|
538
349
|
settings.current_model = None
|
|
@@ -565,6 +376,22 @@ class ModelManager:
|
|
|
565
376
|
return response.status_code == 200
|
|
566
377
|
except:
|
|
567
378
|
return False
|
|
379
|
+
|
|
380
|
+
def get_current_model_info(self) -> Optional[Dict]:
|
|
381
|
+
"""Get information about the currently loaded model"""
|
|
382
|
+
if not self.loaded_model or not self.current_model_name:
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
model_info = self.get_model_info(self.current_model_name)
|
|
386
|
+
if model_info:
|
|
387
|
+
model_info['loaded'] = True
|
|
388
|
+
model_info['type'] = self.current_model_type
|
|
389
|
+
|
|
390
|
+
# Add GGUF-specific info if applicable
|
|
391
|
+
if self.current_model_type == 'gguf':
|
|
392
|
+
model_info.update(gguf_loader.get_model_info())
|
|
393
|
+
|
|
394
|
+
return model_info
|
|
568
395
|
|
|
569
396
|
# Global model manager instance
|
|
570
397
|
model_manager = ModelManager()
|