ollamadiffuser 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,436 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Union
5
+ import logging
6
+ import hashlib
7
+ from huggingface_hub import login
8
+ from ..config.settings import settings, ModelConfig
9
+ from ..utils.download_utils import robust_snapshot_download, robust_file_download
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class ModelManager:
14
+ """Model manager"""
15
+
16
+ def __init__(self):
17
+ self.loaded_model: Optional[object] = None
18
+ self.current_model_name: Optional[str] = None
19
+
20
+ # Predefined model registry
21
+ self.model_registry = {
22
+ "flux.1-dev": {
23
+ "repo_id": "black-forest-labs/FLUX.1-dev",
24
+ "model_type": "flux",
25
+ "variant": "bf16",
26
+ "parameters": {
27
+ "num_inference_steps": 16,
28
+ "guidance_scale": 2.0,
29
+ "max_sequence_length": 512
30
+ },
31
+ "hardware_requirements": {
32
+ "min_vram_gb": 12,
33
+ "recommended_vram_gb": 16,
34
+ "min_ram_gb": 24,
35
+ "recommended_ram_gb": 32,
36
+ "disk_space_gb": 15,
37
+ "supported_devices": ["CUDA", "MPS"],
38
+ "performance_notes": "Requires NVIDIA RTX 4070+ or Apple M2 Pro+. Needs HuggingFace token. Use 'lora pull' to add LoRA styles."
39
+ },
40
+ "license_info": {
41
+ "type": "FLUX.1-dev Non-Commercial License",
42
+ "requires_agreement": True,
43
+ "commercial_use": False
44
+ }
45
+ },
46
+
47
+ "flux.1-schnell": {
48
+ "repo_id": "black-forest-labs/FLUX.1-schnell",
49
+ "model_type": "flux",
50
+ "variant": "bf16",
51
+ "parameters": {
52
+ "num_inference_steps": 4,
53
+ "guidance_scale": 0.0,
54
+ "max_sequence_length": 256
55
+ },
56
+ "hardware_requirements": {
57
+ "min_vram_gb": 12,
58
+ "recommended_vram_gb": 16,
59
+ "min_ram_gb": 24,
60
+ "recommended_ram_gb": 32,
61
+ "disk_space_gb": 15,
62
+ "supported_devices": ["CUDA", "MPS"],
63
+ "performance_notes": "Fast distilled version of FLUX.1-dev. Generates images in ~4 steps. Requires NVIDIA RTX 4070+ or Apple M2 Pro+."
64
+ },
65
+ "license_info": {
66
+ "type": "Apache 2.0",
67
+ "requires_agreement": False,
68
+ "commercial_use": True
69
+ }
70
+ },
71
+
72
+ "stable-diffusion-3.5-medium": {
73
+ "repo_id": "stabilityai/stable-diffusion-3.5-medium",
74
+ "model_type": "sd3",
75
+ "variant": "fp16",
76
+ "parameters": {
77
+ "num_inference_steps": 28,
78
+ "guidance_scale": 3.5
79
+ },
80
+ "hardware_requirements": {
81
+ "min_vram_gb": 8,
82
+ "recommended_vram_gb": 12,
83
+ "min_ram_gb": 16,
84
+ "recommended_ram_gb": 32,
85
+ "disk_space_gb": 10,
86
+ "supported_devices": ["CUDA", "MPS", "CPU"],
87
+ "performance_notes": "Best on NVIDIA RTX 3080+ or Apple M2 Pro+"
88
+ }
89
+ },
90
+ "stable-diffusion-xl-base": {
91
+ "repo_id": "stabilityai/stable-diffusion-xl-base-1.0",
92
+ "model_type": "sdxl",
93
+ "variant": "fp16",
94
+ "parameters": {
95
+ "num_inference_steps": 50,
96
+ "guidance_scale": 7.5
97
+ },
98
+ "hardware_requirements": {
99
+ "min_vram_gb": 6,
100
+ "recommended_vram_gb": 10,
101
+ "min_ram_gb": 12,
102
+ "recommended_ram_gb": 24,
103
+ "disk_space_gb": 7,
104
+ "supported_devices": ["CUDA", "MPS", "CPU"],
105
+ "performance_notes": "Good on NVIDIA RTX 3070+ or Apple M1 Pro+"
106
+ }
107
+ },
108
+ "stable-diffusion-1.5": {
109
+ "repo_id": "runwayml/stable-diffusion-v1-5",
110
+ "model_type": "sd15",
111
+ "variant": "fp16",
112
+ "parameters": {
113
+ "num_inference_steps": 20,
114
+ "guidance_scale": 7.0
115
+ },
116
+ "hardware_requirements": {
117
+ "min_vram_gb": 4,
118
+ "recommended_vram_gb": 6,
119
+ "min_ram_gb": 8,
120
+ "recommended_ram_gb": 16,
121
+ "disk_space_gb": 5,
122
+ "supported_devices": ["CUDA", "MPS", "CPU"],
123
+ "performance_notes": "Runs well on most modern GPUs, including GTX 1060+"
124
+ }
125
+ }
126
+ }
127
+
128
+ def list_available_models(self) -> List[str]:
129
+ """List all available models"""
130
+ return list(self.model_registry.keys())
131
+
132
+ def list_installed_models(self) -> List[str]:
133
+ """List installed models"""
134
+ return list(settings.models.keys())
135
+
136
+ def is_model_installed(self, model_name: str) -> bool:
137
+ """Check if model is installed"""
138
+ return model_name in settings.models
139
+
140
+ def get_model_info(self, model_name: str) -> Optional[Dict]:
141
+ """Get model information"""
142
+ if model_name in self.model_registry:
143
+ info = self.model_registry[model_name].copy()
144
+ info['installed'] = self.is_model_installed(model_name)
145
+ if info['installed']:
146
+ config = settings.models[model_name]
147
+ info['local_path'] = config.path
148
+ info['size'] = self._get_model_size(config.path)
149
+ return info
150
+ return None
151
+
152
+ def _get_model_size(self, model_path: str) -> str:
153
+ """Get model size"""
154
+ try:
155
+ path = Path(model_path)
156
+ if path.is_file():
157
+ size = path.stat().st_size
158
+ else:
159
+ size = sum(f.stat().st_size for f in path.rglob('*') if f.is_file())
160
+
161
+ # Convert to human readable format
162
+ for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
163
+ if size < 1024.0:
164
+ return f"{size:.1f} {unit}"
165
+ size /= 1024.0
166
+ return f"{size:.1f} PB"
167
+ except Exception:
168
+ return "Unknown"
169
+
170
+ def pull_model(self, model_name: str, force: bool = False, progress_callback=None) -> bool:
171
+ """Download model using robust download utilities with detailed progress tracking"""
172
+ if not force and self.is_model_installed(model_name):
173
+ logger.info(f"Model {model_name} already exists")
174
+ if progress_callback:
175
+ progress_callback(f"✅ Model {model_name} already installed")
176
+ return True
177
+
178
+ if model_name not in self.model_registry:
179
+ logger.error(f"Unknown model: {model_name}")
180
+ if progress_callback:
181
+ progress_callback(f"❌ Error: Unknown model {model_name}")
182
+ return False
183
+
184
+ model_info = self.model_registry[model_name]
185
+ model_path = settings.get_model_path(model_name)
186
+
187
+ # Show model information before download
188
+ if progress_callback:
189
+ license_info = model_info.get("license_info", {})
190
+ progress_callback(f"📦 Model: {model_name}")
191
+ progress_callback(f"🔗 Repository: {model_info['repo_id']}")
192
+ if license_info:
193
+ progress_callback(f"📄 License: {license_info.get('type', 'Unknown')}")
194
+ if license_info.get('requires_agreement', False):
195
+ progress_callback(f"🔑 HuggingFace token required - ensure HF_TOKEN is set")
196
+ else:
197
+ progress_callback(f"✅ No HuggingFace token required")
198
+
199
+ # Check if partial download exists and is valid
200
+ if not force and model_path.exists():
201
+ if progress_callback:
202
+ progress_callback(f"🔍 Checking existing download...")
203
+
204
+ from ..utils.download_utils import check_download_integrity
205
+ if check_download_integrity(str(model_path), model_info["repo_id"]):
206
+ if progress_callback:
207
+ progress_callback(f"✅ Found complete download, adding to configuration...")
208
+
209
+ # Add to configuration
210
+ model_config = ModelConfig(
211
+ name=model_name,
212
+ path=str(model_path),
213
+ model_type=model_info["model_type"],
214
+ variant=model_info.get("variant"),
215
+ components=model_info.get("components"),
216
+ parameters=model_info.get("parameters")
217
+ )
218
+
219
+ settings.add_model(model_config)
220
+ logger.info(f"Model {model_name} configuration updated")
221
+ if progress_callback:
222
+ progress_callback(f"✅ {model_name} ready to use!")
223
+ return True
224
+ else:
225
+ if progress_callback:
226
+ progress_callback(f"⚠️ Incomplete download detected, will resume...")
227
+
228
+ try:
229
+ # Ensure HuggingFace token is set
230
+ if settings.hf_token:
231
+ login(token=settings.hf_token)
232
+ if progress_callback:
233
+ progress_callback(f"🔑 Authenticated with HuggingFace")
234
+ else:
235
+ if progress_callback:
236
+ progress_callback(f"⚠️ No HuggingFace token found - some models may not be accessible")
237
+
238
+ logger.info(f"Downloading model: {model_name}")
239
+ if progress_callback:
240
+ progress_callback(f"🚀 Starting download of {model_name}")
241
+
242
+ # Download main model using robust downloader with enhanced progress
243
+ from ..utils.download_utils import robust_snapshot_download
244
+ robust_snapshot_download(
245
+ repo_id=model_info["repo_id"],
246
+ local_dir=str(model_path),
247
+ cache_dir=str(settings.cache_dir),
248
+ max_retries=5, # Increased retries for large models
249
+ initial_workers=4, # More workers for faster download
250
+ force_download=force,
251
+ progress_callback=progress_callback
252
+ )
253
+
254
+ # Download components (such as LoRA)
255
+ if "components" in model_info:
256
+ components_path = model_path / "components"
257
+ components_path.mkdir(exist_ok=True)
258
+
259
+ for comp_name, comp_info in model_info["components"].items():
260
+ comp_path = components_path / comp_name
261
+ comp_path.mkdir(exist_ok=True)
262
+
263
+ if progress_callback:
264
+ progress_callback(f"📦 Downloading component: {comp_name}")
265
+
266
+ if "filename" in comp_info:
267
+ # Download single file using robust downloader
268
+ from ..utils.download_utils import robust_file_download
269
+ robust_file_download(
270
+ repo_id=comp_info["repo_id"],
271
+ filename=comp_info["filename"],
272
+ local_dir=str(comp_path),
273
+ cache_dir=str(settings.cache_dir),
274
+ max_retries=3,
275
+ progress_callback=progress_callback
276
+ )
277
+ else:
278
+ # Download entire repository using robust downloader
279
+ robust_snapshot_download(
280
+ repo_id=comp_info["repo_id"],
281
+ local_dir=str(comp_path),
282
+ cache_dir=str(settings.cache_dir),
283
+ max_retries=3,
284
+ initial_workers=2, # Use fewer workers for components
285
+ force_download=force,
286
+ progress_callback=progress_callback
287
+ )
288
+
289
+ # Verify download integrity
290
+ if progress_callback:
291
+ progress_callback(f"🔍 Verifying download integrity...")
292
+
293
+ from ..utils.download_utils import check_download_integrity
294
+ if not check_download_integrity(str(model_path), model_info["repo_id"]):
295
+ raise Exception("Download integrity check failed - some files may be missing or corrupted")
296
+
297
+ # Add to configuration
298
+ model_config = ModelConfig(
299
+ name=model_name,
300
+ path=str(model_path),
301
+ model_type=model_info["model_type"],
302
+ variant=model_info.get("variant"),
303
+ components=model_info.get("components"),
304
+ parameters=model_info.get("parameters")
305
+ )
306
+
307
+ settings.add_model(model_config)
308
+ logger.info(f"Model {model_name} download completed")
309
+ if progress_callback:
310
+ progress_callback(f"✅ {model_name} download completed successfully and verified!")
311
+ return True
312
+
313
+ except Exception as e:
314
+ logger.error(f"Model download failed: {e}")
315
+ if progress_callback:
316
+ progress_callback(f"❌ Download failed: {str(e)}")
317
+
318
+ # Clean up failed download only if it's a fresh download
319
+ if force and model_path.exists():
320
+ try:
321
+ shutil.rmtree(model_path)
322
+ logger.info(f"Cleaned up failed download directory: {model_path}")
323
+ if progress_callback:
324
+ progress_callback(f"🧹 Cleaned up incomplete download")
325
+ except Exception as cleanup_error:
326
+ logger.warning(f"Failed to clean up directory {model_path}: {cleanup_error}")
327
+ return False
328
+
329
+ def remove_model(self, model_name: str) -> bool:
330
+ """Remove model"""
331
+ if not self.is_model_installed(model_name):
332
+ logger.error(f"Model {model_name} is not installed")
333
+ return False
334
+
335
+ try:
336
+ # If currently using this model, unload it first
337
+ if self.current_model_name == model_name:
338
+ self.unload_model()
339
+
340
+ # Delete model files
341
+ model_config = settings.models[model_name]
342
+ model_path = Path(model_config.path)
343
+ if model_path.exists():
344
+ shutil.rmtree(model_path)
345
+
346
+ # Remove from configuration
347
+ settings.remove_model(model_name)
348
+
349
+ logger.info(f"Model {model_name} has been removed")
350
+ return True
351
+
352
+ except Exception as e:
353
+ logger.error(f"Failed to remove model: {e}")
354
+ return False
355
+
356
+ def load_model(self, model_name: str) -> bool:
357
+ """Load model into memory"""
358
+ if not self.is_model_installed(model_name):
359
+ logger.error(f"Model {model_name} is not installed")
360
+ return False
361
+
362
+ # If the same model is already loaded, return directly
363
+ if self.current_model_name == model_name:
364
+ logger.info(f"Model {model_name} is already loaded")
365
+ return True
366
+
367
+ # Unload current model
368
+ if self.loaded_model is not None:
369
+ self.unload_model()
370
+
371
+ try:
372
+ from ..inference.engine import InferenceEngine
373
+
374
+ model_config = settings.models[model_name]
375
+ engine = InferenceEngine()
376
+
377
+ if engine.load_model(model_config):
378
+ self.loaded_model = engine
379
+ self.current_model_name = model_name
380
+ settings.set_current_model(model_name)
381
+ logger.info(f"Model {model_name} loaded successfully")
382
+ return True
383
+ else:
384
+ logger.error(f"Model {model_name} failed to load")
385
+ return False
386
+
387
+ except Exception as e:
388
+ logger.error(f"Failed to load model: {e}")
389
+ return False
390
+
391
+ def unload_model(self):
392
+ """Unload current model"""
393
+ if self.loaded_model is not None:
394
+ try:
395
+ self.loaded_model.unload()
396
+ logger.info(f"Model {self.current_model_name} unloaded")
397
+ except Exception as e:
398
+ logger.error(f"Failed to unload model: {e}")
399
+ finally:
400
+ self.loaded_model = None
401
+ self.current_model_name = None
402
+
403
+ # Also clear the persisted state
404
+ settings.current_model = None
405
+ settings.save_config()
406
+
407
+ def get_current_model(self) -> Optional[str]:
408
+ """Get current loaded model name"""
409
+ # First check in-memory state
410
+ if self.current_model_name:
411
+ return self.current_model_name
412
+ # Then check persisted state
413
+ return settings.current_model
414
+
415
+ def is_model_loaded(self) -> bool:
416
+ """Check if a model is loaded in memory"""
417
+ # Only check in-memory state - a model is truly loaded only if it's in memory
418
+ return self.loaded_model is not None
419
+
420
+ def has_current_model(self) -> bool:
421
+ """Check if there's a current model set (may not be loaded in memory)"""
422
+ return settings.current_model is not None
423
+
424
+ def is_server_running(self) -> bool:
425
+ """Check if the server is actually running"""
426
+ try:
427
+ import requests
428
+ host = settings.server.host
429
+ port = settings.server.port
430
+ response = requests.get(f"http://{host}:{port}/api/health", timeout=2)
431
+ return response.status_code == 200
432
+ except:
433
+ return False
434
+
435
+ # Global model manager instance
436
+ model_manager = ModelManager()
@@ -0,0 +1,3 @@
1
+ """
2
+ Core utilities package
3
+ """