ollamadiffuser 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,390 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ LoRA (Low-Rank Adaptation) manager for downloading and managing LoRA weights
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Callable
11
+ import logging
12
+ from datetime import datetime
13
+ from huggingface_hub import hf_hub_download, login
14
+
15
+ from ..config.settings import settings
16
+ from .download_utils import robust_file_download
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class LoRAManager:
21
+ """Manager for LoRA weights"""
22
+
23
+ def __init__(self):
24
+ self.lora_dir = settings.config_dir / "loras"
25
+ self.lora_dir.mkdir(exist_ok=True)
26
+ self.config_file = self.lora_dir / "loras.json"
27
+ self.current_lora = None
28
+ self._load_config()
29
+
30
+ def _load_config(self):
31
+ """Load LoRA configuration"""
32
+ if self.config_file.exists():
33
+ try:
34
+ with open(self.config_file, 'r') as f:
35
+ self.config = json.load(f)
36
+ except Exception as e:
37
+ logger.warning(f"Failed to load LoRA config: {e}")
38
+ self.config = {}
39
+ else:
40
+ self.config = {}
41
+
42
+ def _save_config(self):
43
+ """Save LoRA configuration"""
44
+ try:
45
+ with open(self.config_file, 'w') as f:
46
+ json.dump(self.config, f, indent=2)
47
+ except Exception as e:
48
+ logger.error(f"Failed to save LoRA config: {e}")
49
+
50
+ def _get_lora_path(self, lora_name: str) -> Path:
51
+ """Get path for LoRA storage"""
52
+ return self.lora_dir / lora_name
53
+
54
+ def _format_size(self, size_bytes: int) -> str:
55
+ """Format size in human readable format"""
56
+ for unit in ['B', 'KB', 'MB', 'GB']:
57
+ if size_bytes < 1024.0:
58
+ return f"{size_bytes:.1f} {unit}"
59
+ size_bytes /= 1024.0
60
+ return f"{size_bytes:.1f} TB"
61
+
62
+ def _get_directory_size(self, path: Path) -> int:
63
+ """Get total size of directory"""
64
+ total_size = 0
65
+ try:
66
+ for file_path in path.rglob('*'):
67
+ if file_path.is_file():
68
+ total_size += file_path.stat().st_size
69
+ except Exception as e:
70
+ logger.warning(f"Failed to calculate directory size: {e}")
71
+ return total_size
72
+
73
+ def _is_server_running(self) -> bool:
74
+ """Check if the API server is running"""
75
+ try:
76
+ import requests
77
+ response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/health", timeout=2)
78
+ return response.status_code == 200
79
+ except:
80
+ return False
81
+
82
+ def _try_load_lora_via_api(self, lora_name: str, scale: float = 1.0) -> bool:
83
+ """Try to load LoRA via API server"""
84
+ try:
85
+ if not self._is_server_running():
86
+ return False
87
+
88
+ # Check if LoRA exists
89
+ if lora_name not in self.config:
90
+ logger.error(f"LoRA {lora_name} not found")
91
+ return False
92
+
93
+ lora_info = self.config[lora_name]
94
+
95
+ import requests
96
+
97
+ # Prepare API request
98
+ api_data = {
99
+ "lora_name": lora_name,
100
+ "repo_id": lora_info["repo_id"],
101
+ "scale": scale
102
+ }
103
+
104
+ if "weight_name" in lora_info:
105
+ api_data["weight_name"] = lora_info["weight_name"]
106
+
107
+ # Make API request to load LoRA
108
+ response = requests.post(
109
+ f"http://{settings.server.host}:{settings.server.port}/api/lora/load",
110
+ json=api_data,
111
+ timeout=30
112
+ )
113
+
114
+ if response.status_code == 200:
115
+ self.current_lora = lora_name
116
+ logger.info(f"LoRA {lora_name} loaded successfully via API with scale {scale}")
117
+ return True
118
+ else:
119
+ logger.error(f"API request failed: {response.status_code} - {response.text}")
120
+ return False
121
+
122
+ except Exception as e:
123
+ logger.error(f"Failed to load LoRA via API: {e}")
124
+ return False
125
+
126
+ def _try_unload_lora_via_api(self) -> bool:
127
+ """Try to unload LoRA via API server"""
128
+ try:
129
+ if not self._is_server_running():
130
+ return False
131
+
132
+ import requests
133
+
134
+ # Make API request to unload LoRA
135
+ response = requests.post(
136
+ f"http://{settings.server.host}:{settings.server.port}/api/lora/unload",
137
+ timeout=30
138
+ )
139
+
140
+ if response.status_code == 200:
141
+ self.current_lora = None
142
+ logger.info("LoRA unloaded successfully via API")
143
+ return True
144
+ else:
145
+ logger.error(f"API request failed: {response.status_code} - {response.text}")
146
+ return False
147
+
148
+ except Exception as e:
149
+ logger.error(f"Failed to unload LoRA via API: {e}")
150
+ return False
151
+
152
+ def pull_lora(self, repo_id: str, weight_name: Optional[str] = None,
153
+ alias: Optional[str] = None, progress_callback: Optional[Callable] = None) -> bool:
154
+ """Download LoRA weights from Hugging Face Hub"""
155
+ try:
156
+ # Determine local name
157
+ lora_name = alias if alias else repo_id.replace('/', '_')
158
+ lora_path = self._get_lora_path(lora_name)
159
+
160
+ # Check if already exists
161
+ if lora_name in self.config and lora_path.exists():
162
+ if progress_callback:
163
+ progress_callback(f"✅ LoRA {lora_name} already exists")
164
+ logger.info(f"LoRA {lora_name} already exists")
165
+ return True
166
+
167
+ # Create directory
168
+ lora_path.mkdir(exist_ok=True)
169
+
170
+ # Ensure HuggingFace token is set
171
+ if settings.hf_token:
172
+ login(token=settings.hf_token)
173
+ if progress_callback:
174
+ progress_callback(f"🔑 Authenticated with HuggingFace")
175
+
176
+ if progress_callback:
177
+ progress_callback(f"📥 Downloading LoRA from {repo_id}")
178
+
179
+ # Download specific weight file or all files
180
+ if weight_name:
181
+ # Download specific file
182
+ downloaded_file = robust_file_download(
183
+ repo_id=repo_id,
184
+ filename=weight_name,
185
+ local_dir=str(lora_path),
186
+ cache_dir=str(settings.cache_dir),
187
+ progress_callback=progress_callback
188
+ )
189
+
190
+ # Store metadata
191
+ lora_info = {
192
+ "repo_id": repo_id,
193
+ "weight_name": weight_name,
194
+ "path": str(lora_path),
195
+ "downloaded_at": datetime.now().isoformat(),
196
+ "size": self._format_size(self._get_directory_size(lora_path))
197
+ }
198
+ else:
199
+ # Download all files (fallback)
200
+ from .download_utils import robust_snapshot_download
201
+ robust_snapshot_download(
202
+ repo_id=repo_id,
203
+ local_dir=str(lora_path),
204
+ cache_dir=str(settings.cache_dir),
205
+ progress_callback=progress_callback
206
+ )
207
+
208
+ # Store metadata
209
+ lora_info = {
210
+ "repo_id": repo_id,
211
+ "path": str(lora_path),
212
+ "downloaded_at": datetime.now().isoformat(),
213
+ "size": self._format_size(self._get_directory_size(lora_path))
214
+ }
215
+
216
+ # Update configuration
217
+ self.config[lora_name] = lora_info
218
+ self._save_config()
219
+
220
+ logger.info(f"LoRA {lora_name} downloaded successfully")
221
+ if progress_callback:
222
+ progress_callback(f"✅ LoRA {lora_name} downloaded successfully")
223
+
224
+ return True
225
+
226
+ except Exception as e:
227
+ logger.error(f"Failed to download LoRA: {e}")
228
+ if progress_callback:
229
+ progress_callback(f"❌ Failed to download LoRA: {e}")
230
+
231
+ # Clean up failed download
232
+ if lora_path.exists():
233
+ try:
234
+ shutil.rmtree(lora_path)
235
+ except Exception as cleanup_error:
236
+ logger.warning(f"Failed to clean up failed download: {cleanup_error}")
237
+
238
+ return False
239
+
240
+ def load_lora(self, lora_name: str, scale: float = 1.0) -> bool:
241
+ """Load LoRA weights into the current model"""
242
+ try:
243
+ from ..models.manager import model_manager
244
+
245
+ # Check if model is loaded locally
246
+ if not model_manager.is_model_loaded():
247
+ # Try to load via API if server is running
248
+ if self._try_load_lora_via_api(lora_name, scale):
249
+ return True
250
+ logger.error("No model is currently loaded")
251
+ return False
252
+
253
+ # Check if LoRA exists
254
+ if lora_name not in self.config:
255
+ logger.error(f"LoRA {lora_name} not found")
256
+ return False
257
+
258
+ lora_info = self.config[lora_name]
259
+ lora_path = Path(lora_info["path"])
260
+
261
+ if not lora_path.exists():
262
+ logger.error(f"LoRA path does not exist: {lora_path}")
263
+ return False
264
+
265
+ # Get the inference engine
266
+ engine = model_manager.loaded_model
267
+ if not engine:
268
+ logger.error("No inference engine available")
269
+ return False
270
+
271
+ # Load LoRA weights
272
+ if "weight_name" in lora_info:
273
+ # Load specific weight file
274
+ success = engine.load_lora_runtime(
275
+ repo_id=lora_info["repo_id"],
276
+ weight_name=lora_info["weight_name"],
277
+ scale=scale
278
+ )
279
+ else:
280
+ # Load from local directory
281
+ weight_files = list(lora_path.glob("*.safetensors"))
282
+ if not weight_files:
283
+ weight_files = list(lora_path.glob("*.bin"))
284
+
285
+ if not weight_files:
286
+ logger.error(f"No weight files found in {lora_path}")
287
+ return False
288
+
289
+ # Use the first weight file found
290
+ weight_file = weight_files[0]
291
+ success = engine.load_lora_runtime(
292
+ repo_id=str(lora_path),
293
+ weight_name=weight_file.name,
294
+ scale=scale
295
+ )
296
+
297
+ if success:
298
+ self.current_lora = lora_name
299
+ logger.info(f"LoRA {lora_name} loaded successfully with scale {scale}")
300
+ return True
301
+ else:
302
+ logger.error(f"Failed to load LoRA {lora_name}")
303
+ return False
304
+
305
+ except Exception as e:
306
+ logger.error(f"Failed to load LoRA: {e}")
307
+ return False
308
+
309
+ def unload_lora(self) -> bool:
310
+ """Unload current LoRA weights"""
311
+ try:
312
+ from ..models.manager import model_manager
313
+
314
+ # Check if model is loaded locally
315
+ if not model_manager.is_model_loaded():
316
+ # Try to unload via API if server is running
317
+ if self._try_unload_lora_via_api():
318
+ return True
319
+ logger.error("No model is currently loaded")
320
+ return False
321
+
322
+ # Get the inference engine
323
+ engine = model_manager.loaded_model
324
+ if not engine:
325
+ logger.error("No inference engine available")
326
+ return False
327
+
328
+ # Unload LoRA weights
329
+ success = engine.unload_lora()
330
+
331
+ if success:
332
+ self.current_lora = None
333
+ logger.info("LoRA weights unloaded successfully")
334
+ return True
335
+ else:
336
+ logger.error("Failed to unload LoRA weights")
337
+ return False
338
+
339
+ except Exception as e:
340
+ logger.error(f"Failed to unload LoRA: {e}")
341
+ return False
342
+
343
+ def remove_lora(self, lora_name: str) -> bool:
344
+ """Remove LoRA weights"""
345
+ try:
346
+ # Check if LoRA exists
347
+ if lora_name not in self.config:
348
+ logger.error(f"LoRA {lora_name} not found")
349
+ return False
350
+
351
+ # Unload if currently loaded
352
+ if self.current_lora == lora_name:
353
+ self.unload_lora()
354
+
355
+ # Remove files
356
+ lora_info = self.config[lora_name]
357
+ lora_path = Path(lora_info["path"])
358
+
359
+ if lora_path.exists():
360
+ shutil.rmtree(lora_path)
361
+
362
+ # Remove from configuration
363
+ del self.config[lora_name]
364
+ self._save_config()
365
+
366
+ logger.info(f"LoRA {lora_name} removed successfully")
367
+ return True
368
+
369
+ except Exception as e:
370
+ logger.error(f"Failed to remove LoRA: {e}")
371
+ return False
372
+
373
+ def list_installed_loras(self) -> Dict[str, Dict]:
374
+ """List all installed LoRA weights"""
375
+ return self.config.copy()
376
+
377
+ def get_current_lora(self) -> Optional[str]:
378
+ """Get currently loaded LoRA name"""
379
+ return self.current_lora
380
+
381
+ def get_lora_info(self, lora_name: str) -> Optional[Dict]:
382
+ """Get information about a specific LoRA"""
383
+ return self.config.get(lora_name)
384
+
385
+ def is_lora_installed(self, lora_name: str) -> bool:
386
+ """Check if LoRA is installed"""
387
+ return lora_name in self.config
388
+
389
+ # Global LoRA manager instance
390
+ lora_manager = LoRAManager()
File without changes