ollamadiffuser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +0 -0
- ollamadiffuser/__main__.py +50 -0
- ollamadiffuser/api/__init__.py +0 -0
- ollamadiffuser/api/server.py +297 -0
- ollamadiffuser/cli/__init__.py +0 -0
- ollamadiffuser/cli/main.py +597 -0
- ollamadiffuser/core/__init__.py +0 -0
- ollamadiffuser/core/config/__init__.py +0 -0
- ollamadiffuser/core/config/settings.py +137 -0
- ollamadiffuser/core/inference/__init__.py +0 -0
- ollamadiffuser/core/inference/engine.py +926 -0
- ollamadiffuser/core/models/__init__.py +0 -0
- ollamadiffuser/core/models/manager.py +436 -0
- ollamadiffuser/core/utils/__init__.py +3 -0
- ollamadiffuser/core/utils/download_utils.py +356 -0
- ollamadiffuser/core/utils/lora_manager.py +390 -0
- ollamadiffuser/ui/__init__.py +0 -0
- ollamadiffuser/ui/templates/index.html +496 -0
- ollamadiffuser/ui/web.py +278 -0
- ollamadiffuser/utils/__init__.py +0 -0
- ollamadiffuser-1.0.0.dist-info/METADATA +493 -0
- ollamadiffuser-1.0.0.dist-info/RECORD +26 -0
- ollamadiffuser-1.0.0.dist-info/WHEEL +5 -0
- ollamadiffuser-1.0.0.dist-info/entry_points.txt +2 -0
- ollamadiffuser-1.0.0.dist-info/licenses/LICENSE +21 -0
- ollamadiffuser-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
LoRA (Low-Rank Adaptation) manager for downloading and managing LoRA weights
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
import shutil
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Optional, Callable
|
|
11
|
+
import logging
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from huggingface_hub import hf_hub_download, login
|
|
14
|
+
|
|
15
|
+
from ..config.settings import settings
|
|
16
|
+
from .download_utils import robust_file_download
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
class LoRAManager:
|
|
21
|
+
"""Manager for LoRA weights"""
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self.lora_dir = settings.config_dir / "loras"
|
|
25
|
+
self.lora_dir.mkdir(exist_ok=True)
|
|
26
|
+
self.config_file = self.lora_dir / "loras.json"
|
|
27
|
+
self.current_lora = None
|
|
28
|
+
self._load_config()
|
|
29
|
+
|
|
30
|
+
def _load_config(self):
|
|
31
|
+
"""Load LoRA configuration"""
|
|
32
|
+
if self.config_file.exists():
|
|
33
|
+
try:
|
|
34
|
+
with open(self.config_file, 'r') as f:
|
|
35
|
+
self.config = json.load(f)
|
|
36
|
+
except Exception as e:
|
|
37
|
+
logger.warning(f"Failed to load LoRA config: {e}")
|
|
38
|
+
self.config = {}
|
|
39
|
+
else:
|
|
40
|
+
self.config = {}
|
|
41
|
+
|
|
42
|
+
def _save_config(self):
|
|
43
|
+
"""Save LoRA configuration"""
|
|
44
|
+
try:
|
|
45
|
+
with open(self.config_file, 'w') as f:
|
|
46
|
+
json.dump(self.config, f, indent=2)
|
|
47
|
+
except Exception as e:
|
|
48
|
+
logger.error(f"Failed to save LoRA config: {e}")
|
|
49
|
+
|
|
50
|
+
def _get_lora_path(self, lora_name: str) -> Path:
|
|
51
|
+
"""Get path for LoRA storage"""
|
|
52
|
+
return self.lora_dir / lora_name
|
|
53
|
+
|
|
54
|
+
def _format_size(self, size_bytes: int) -> str:
|
|
55
|
+
"""Format size in human readable format"""
|
|
56
|
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
57
|
+
if size_bytes < 1024.0:
|
|
58
|
+
return f"{size_bytes:.1f} {unit}"
|
|
59
|
+
size_bytes /= 1024.0
|
|
60
|
+
return f"{size_bytes:.1f} TB"
|
|
61
|
+
|
|
62
|
+
def _get_directory_size(self, path: Path) -> int:
|
|
63
|
+
"""Get total size of directory"""
|
|
64
|
+
total_size = 0
|
|
65
|
+
try:
|
|
66
|
+
for file_path in path.rglob('*'):
|
|
67
|
+
if file_path.is_file():
|
|
68
|
+
total_size += file_path.stat().st_size
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.warning(f"Failed to calculate directory size: {e}")
|
|
71
|
+
return total_size
|
|
72
|
+
|
|
73
|
+
def _is_server_running(self) -> bool:
|
|
74
|
+
"""Check if the API server is running"""
|
|
75
|
+
try:
|
|
76
|
+
import requests
|
|
77
|
+
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/health", timeout=2)
|
|
78
|
+
return response.status_code == 200
|
|
79
|
+
except:
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
def _try_load_lora_via_api(self, lora_name: str, scale: float = 1.0) -> bool:
|
|
83
|
+
"""Try to load LoRA via API server"""
|
|
84
|
+
try:
|
|
85
|
+
if not self._is_server_running():
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
# Check if LoRA exists
|
|
89
|
+
if lora_name not in self.config:
|
|
90
|
+
logger.error(f"LoRA {lora_name} not found")
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
lora_info = self.config[lora_name]
|
|
94
|
+
|
|
95
|
+
import requests
|
|
96
|
+
|
|
97
|
+
# Prepare API request
|
|
98
|
+
api_data = {
|
|
99
|
+
"lora_name": lora_name,
|
|
100
|
+
"repo_id": lora_info["repo_id"],
|
|
101
|
+
"scale": scale
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
if "weight_name" in lora_info:
|
|
105
|
+
api_data["weight_name"] = lora_info["weight_name"]
|
|
106
|
+
|
|
107
|
+
# Make API request to load LoRA
|
|
108
|
+
response = requests.post(
|
|
109
|
+
f"http://{settings.server.host}:{settings.server.port}/api/lora/load",
|
|
110
|
+
json=api_data,
|
|
111
|
+
timeout=30
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if response.status_code == 200:
|
|
115
|
+
self.current_lora = lora_name
|
|
116
|
+
logger.info(f"LoRA {lora_name} loaded successfully via API with scale {scale}")
|
|
117
|
+
return True
|
|
118
|
+
else:
|
|
119
|
+
logger.error(f"API request failed: {response.status_code} - {response.text}")
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error(f"Failed to load LoRA via API: {e}")
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
def _try_unload_lora_via_api(self) -> bool:
|
|
127
|
+
"""Try to unload LoRA via API server"""
|
|
128
|
+
try:
|
|
129
|
+
if not self._is_server_running():
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
import requests
|
|
133
|
+
|
|
134
|
+
# Make API request to unload LoRA
|
|
135
|
+
response = requests.post(
|
|
136
|
+
f"http://{settings.server.host}:{settings.server.port}/api/lora/unload",
|
|
137
|
+
timeout=30
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if response.status_code == 200:
|
|
141
|
+
self.current_lora = None
|
|
142
|
+
logger.info("LoRA unloaded successfully via API")
|
|
143
|
+
return True
|
|
144
|
+
else:
|
|
145
|
+
logger.error(f"API request failed: {response.status_code} - {response.text}")
|
|
146
|
+
return False
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.error(f"Failed to unload LoRA via API: {e}")
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
def pull_lora(self, repo_id: str, weight_name: Optional[str] = None,
|
|
153
|
+
alias: Optional[str] = None, progress_callback: Optional[Callable] = None) -> bool:
|
|
154
|
+
"""Download LoRA weights from Hugging Face Hub"""
|
|
155
|
+
try:
|
|
156
|
+
# Determine local name
|
|
157
|
+
lora_name = alias if alias else repo_id.replace('/', '_')
|
|
158
|
+
lora_path = self._get_lora_path(lora_name)
|
|
159
|
+
|
|
160
|
+
# Check if already exists
|
|
161
|
+
if lora_name in self.config and lora_path.exists():
|
|
162
|
+
if progress_callback:
|
|
163
|
+
progress_callback(f"✅ LoRA {lora_name} already exists")
|
|
164
|
+
logger.info(f"LoRA {lora_name} already exists")
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
# Create directory
|
|
168
|
+
lora_path.mkdir(exist_ok=True)
|
|
169
|
+
|
|
170
|
+
# Ensure HuggingFace token is set
|
|
171
|
+
if settings.hf_token:
|
|
172
|
+
login(token=settings.hf_token)
|
|
173
|
+
if progress_callback:
|
|
174
|
+
progress_callback(f"🔑 Authenticated with HuggingFace")
|
|
175
|
+
|
|
176
|
+
if progress_callback:
|
|
177
|
+
progress_callback(f"📥 Downloading LoRA from {repo_id}")
|
|
178
|
+
|
|
179
|
+
# Download specific weight file or all files
|
|
180
|
+
if weight_name:
|
|
181
|
+
# Download specific file
|
|
182
|
+
downloaded_file = robust_file_download(
|
|
183
|
+
repo_id=repo_id,
|
|
184
|
+
filename=weight_name,
|
|
185
|
+
local_dir=str(lora_path),
|
|
186
|
+
cache_dir=str(settings.cache_dir),
|
|
187
|
+
progress_callback=progress_callback
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Store metadata
|
|
191
|
+
lora_info = {
|
|
192
|
+
"repo_id": repo_id,
|
|
193
|
+
"weight_name": weight_name,
|
|
194
|
+
"path": str(lora_path),
|
|
195
|
+
"downloaded_at": datetime.now().isoformat(),
|
|
196
|
+
"size": self._format_size(self._get_directory_size(lora_path))
|
|
197
|
+
}
|
|
198
|
+
else:
|
|
199
|
+
# Download all files (fallback)
|
|
200
|
+
from .download_utils import robust_snapshot_download
|
|
201
|
+
robust_snapshot_download(
|
|
202
|
+
repo_id=repo_id,
|
|
203
|
+
local_dir=str(lora_path),
|
|
204
|
+
cache_dir=str(settings.cache_dir),
|
|
205
|
+
progress_callback=progress_callback
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Store metadata
|
|
209
|
+
lora_info = {
|
|
210
|
+
"repo_id": repo_id,
|
|
211
|
+
"path": str(lora_path),
|
|
212
|
+
"downloaded_at": datetime.now().isoformat(),
|
|
213
|
+
"size": self._format_size(self._get_directory_size(lora_path))
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Update configuration
|
|
217
|
+
self.config[lora_name] = lora_info
|
|
218
|
+
self._save_config()
|
|
219
|
+
|
|
220
|
+
logger.info(f"LoRA {lora_name} downloaded successfully")
|
|
221
|
+
if progress_callback:
|
|
222
|
+
progress_callback(f"✅ LoRA {lora_name} downloaded successfully")
|
|
223
|
+
|
|
224
|
+
return True
|
|
225
|
+
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logger.error(f"Failed to download LoRA: {e}")
|
|
228
|
+
if progress_callback:
|
|
229
|
+
progress_callback(f"❌ Failed to download LoRA: {e}")
|
|
230
|
+
|
|
231
|
+
# Clean up failed download
|
|
232
|
+
if lora_path.exists():
|
|
233
|
+
try:
|
|
234
|
+
shutil.rmtree(lora_path)
|
|
235
|
+
except Exception as cleanup_error:
|
|
236
|
+
logger.warning(f"Failed to clean up failed download: {cleanup_error}")
|
|
237
|
+
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
def load_lora(self, lora_name: str, scale: float = 1.0) -> bool:
|
|
241
|
+
"""Load LoRA weights into the current model"""
|
|
242
|
+
try:
|
|
243
|
+
from ..models.manager import model_manager
|
|
244
|
+
|
|
245
|
+
# Check if model is loaded locally
|
|
246
|
+
if not model_manager.is_model_loaded():
|
|
247
|
+
# Try to load via API if server is running
|
|
248
|
+
if self._try_load_lora_via_api(lora_name, scale):
|
|
249
|
+
return True
|
|
250
|
+
logger.error("No model is currently loaded")
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
# Check if LoRA exists
|
|
254
|
+
if lora_name not in self.config:
|
|
255
|
+
logger.error(f"LoRA {lora_name} not found")
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
lora_info = self.config[lora_name]
|
|
259
|
+
lora_path = Path(lora_info["path"])
|
|
260
|
+
|
|
261
|
+
if not lora_path.exists():
|
|
262
|
+
logger.error(f"LoRA path does not exist: {lora_path}")
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
# Get the inference engine
|
|
266
|
+
engine = model_manager.loaded_model
|
|
267
|
+
if not engine:
|
|
268
|
+
logger.error("No inference engine available")
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
# Load LoRA weights
|
|
272
|
+
if "weight_name" in lora_info:
|
|
273
|
+
# Load specific weight file
|
|
274
|
+
success = engine.load_lora_runtime(
|
|
275
|
+
repo_id=lora_info["repo_id"],
|
|
276
|
+
weight_name=lora_info["weight_name"],
|
|
277
|
+
scale=scale
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
# Load from local directory
|
|
281
|
+
weight_files = list(lora_path.glob("*.safetensors"))
|
|
282
|
+
if not weight_files:
|
|
283
|
+
weight_files = list(lora_path.glob("*.bin"))
|
|
284
|
+
|
|
285
|
+
if not weight_files:
|
|
286
|
+
logger.error(f"No weight files found in {lora_path}")
|
|
287
|
+
return False
|
|
288
|
+
|
|
289
|
+
# Use the first weight file found
|
|
290
|
+
weight_file = weight_files[0]
|
|
291
|
+
success = engine.load_lora_runtime(
|
|
292
|
+
repo_id=str(lora_path),
|
|
293
|
+
weight_name=weight_file.name,
|
|
294
|
+
scale=scale
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if success:
|
|
298
|
+
self.current_lora = lora_name
|
|
299
|
+
logger.info(f"LoRA {lora_name} loaded successfully with scale {scale}")
|
|
300
|
+
return True
|
|
301
|
+
else:
|
|
302
|
+
logger.error(f"Failed to load LoRA {lora_name}")
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
except Exception as e:
|
|
306
|
+
logger.error(f"Failed to load LoRA: {e}")
|
|
307
|
+
return False
|
|
308
|
+
|
|
309
|
+
def unload_lora(self) -> bool:
|
|
310
|
+
"""Unload current LoRA weights"""
|
|
311
|
+
try:
|
|
312
|
+
from ..models.manager import model_manager
|
|
313
|
+
|
|
314
|
+
# Check if model is loaded locally
|
|
315
|
+
if not model_manager.is_model_loaded():
|
|
316
|
+
# Try to unload via API if server is running
|
|
317
|
+
if self._try_unload_lora_via_api():
|
|
318
|
+
return True
|
|
319
|
+
logger.error("No model is currently loaded")
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
# Get the inference engine
|
|
323
|
+
engine = model_manager.loaded_model
|
|
324
|
+
if not engine:
|
|
325
|
+
logger.error("No inference engine available")
|
|
326
|
+
return False
|
|
327
|
+
|
|
328
|
+
# Unload LoRA weights
|
|
329
|
+
success = engine.unload_lora()
|
|
330
|
+
|
|
331
|
+
if success:
|
|
332
|
+
self.current_lora = None
|
|
333
|
+
logger.info("LoRA weights unloaded successfully")
|
|
334
|
+
return True
|
|
335
|
+
else:
|
|
336
|
+
logger.error("Failed to unload LoRA weights")
|
|
337
|
+
return False
|
|
338
|
+
|
|
339
|
+
except Exception as e:
|
|
340
|
+
logger.error(f"Failed to unload LoRA: {e}")
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
def remove_lora(self, lora_name: str) -> bool:
|
|
344
|
+
"""Remove LoRA weights"""
|
|
345
|
+
try:
|
|
346
|
+
# Check if LoRA exists
|
|
347
|
+
if lora_name not in self.config:
|
|
348
|
+
logger.error(f"LoRA {lora_name} not found")
|
|
349
|
+
return False
|
|
350
|
+
|
|
351
|
+
# Unload if currently loaded
|
|
352
|
+
if self.current_lora == lora_name:
|
|
353
|
+
self.unload_lora()
|
|
354
|
+
|
|
355
|
+
# Remove files
|
|
356
|
+
lora_info = self.config[lora_name]
|
|
357
|
+
lora_path = Path(lora_info["path"])
|
|
358
|
+
|
|
359
|
+
if lora_path.exists():
|
|
360
|
+
shutil.rmtree(lora_path)
|
|
361
|
+
|
|
362
|
+
# Remove from configuration
|
|
363
|
+
del self.config[lora_name]
|
|
364
|
+
self._save_config()
|
|
365
|
+
|
|
366
|
+
logger.info(f"LoRA {lora_name} removed successfully")
|
|
367
|
+
return True
|
|
368
|
+
|
|
369
|
+
except Exception as e:
|
|
370
|
+
logger.error(f"Failed to remove LoRA: {e}")
|
|
371
|
+
return False
|
|
372
|
+
|
|
373
|
+
def list_installed_loras(self) -> Dict[str, Dict]:
|
|
374
|
+
"""List all installed LoRA weights"""
|
|
375
|
+
return self.config.copy()
|
|
376
|
+
|
|
377
|
+
def get_current_lora(self) -> Optional[str]:
|
|
378
|
+
"""Get currently loaded LoRA name"""
|
|
379
|
+
return self.current_lora
|
|
380
|
+
|
|
381
|
+
def get_lora_info(self, lora_name: str) -> Optional[Dict]:
|
|
382
|
+
"""Get information about a specific LoRA"""
|
|
383
|
+
return self.config.get(lora_name)
|
|
384
|
+
|
|
385
|
+
def is_lora_installed(self, lora_name: str) -> bool:
|
|
386
|
+
"""Check if LoRA is installed"""
|
|
387
|
+
return lora_name in self.config
|
|
388
|
+
|
|
389
|
+
# Global LoRA manager instance
|
|
390
|
+
lora_manager = LoRAManager()
|
|
File without changes
|