gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,763 @@
1
+ import os
2
+ import shutil
3
+ import asyncio
4
+ import aiohttp
5
+ from typing import Optional, List
6
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
7
+ from .paths import get_models_dir
8
+ from huggingface_hub import HfApi, hf_hub_download, scan_cache_dir, snapshot_download
9
+ from huggingface_hub.file_download import repo_folder_name
10
+ import torch
11
+ import logging
12
+ from huggingface_hub.constants import HF_HUB_CACHE
13
+ import json
14
+ from tqdm import tqdm
15
+ from .config import get_config
16
+ from .utils import serialize_config
17
+ import hashlib
18
+ import time
19
+ import re
20
+ from urllib.parse import urlparse, parse_qs, unquote
21
+ import backoff
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class ModelSource:
28
+ """Represents a model source with its type and details"""
29
+
30
+ def __init__(self, source_str: str):
31
+ self.original_string = source_str
32
+ if source_str.startswith("hf:"):
33
+ self.type = "huggingface"
34
+ self.location = source_str[3:]
35
+ elif "civitai.com" in source_str:
36
+ self.type = "civitai"
37
+ self.location = source_str
38
+ elif source_str.startswith("file:"):
39
+ self.type = "file"
40
+ self.location = source_str[5:]
41
+ elif source_str.startswith(("http://", "https://")):
42
+ self.type = "direct"
43
+ self.location = source_str
44
+ else:
45
+ raise ValueError(f"Unsupported model source: {source_str}")
46
+
47
+
48
+ class ModelManager:
49
+ def __init__(self, cache_dir: Optional[str] = None):
50
+ self.hf_api = HfApi()
51
+ self.cache_dir = cache_dir or HF_HUB_CACHE
52
+ self.base_cache_dir = cache_dir or os.path.expanduser("~/.cache")
53
+ self.cozy_cache_dir = get_models_dir()
54
+ self.session: Optional[aiohttp.ClientSession] = None
55
+
56
+ # config = serialize_config(get_config())
57
+ # self.civitai_api_key = config["civitai_api_key"]
58
+
59
+ # check env for civitai api key
60
+ self.civitai_api_key = os.getenv("CIVITAI_API_KEY")
61
+
62
+ async def __aenter__(self):
63
+ self.session = aiohttp.ClientSession()
64
+ return self
65
+
66
+ async def __aexit__(self, exc_type, exc, tb):
67
+ if self.session:
68
+ await self.session.close()
69
+ self.session = None
70
+
71
+ def parse_hf_string(
72
+ self, hf_string: str
73
+ ) -> tuple[str, Optional[str], Optional[str]]:
74
+ """
75
+ Parses an HuggingFace string into its components.
76
+ Returns a tuple of (repo_id, subfolder, filename)
77
+ """
78
+ # Remove 'hf:' prefix if present
79
+ if hf_string.startswith("hf:"):
80
+ hf_string = hf_string[3:]
81
+
82
+ parts = hf_string.split("/")
83
+ if len(parts) < 2:
84
+ raise ValueError("Invalid HuggingFace string: repo_id is required")
85
+
86
+ repo_id = "/".join(parts[:2])
87
+ subfolder = None
88
+ filename = None
89
+
90
+ if len(parts) > 2:
91
+ if not parts[-1].endswith("/"):
92
+ filename = parts[-1]
93
+ subfolder = "/".join(parts[2:-1]) if len(parts) > 3 else None
94
+ else:
95
+ subfolder = "/".join(parts[2:])
96
+
97
+ return repo_id, subfolder, filename
98
+
99
+ async def _get_civitai_filename(self, url: str) -> Optional[str]:
100
+ """Extract original filename from Civitai redirect response"""
101
+ try:
102
+ headers = {}
103
+ if self.civitai_api_key:
104
+ headers["Authorization"] = f"Bearer {self.civitai_api_key}"
105
+
106
+ need_cleanup = False
107
+ if not self.session:
108
+ self.session = aiohttp.ClientSession()
109
+ need_cleanup = True
110
+
111
+ try:
112
+ async with self.session.get(
113
+ url, headers=headers, allow_redirects=False
114
+ ) as response:
115
+ if response.status in (301, 302, 307):
116
+ location = response.headers.get("location")
117
+ if location:
118
+ # Parse the query parameters from the redirect URL
119
+ parsed = urlparse(location)
120
+ query_params = parse_qs(parsed.query)
121
+
122
+ # Look for response-content-disposition parameter
123
+ content_disp = query_params.get(
124
+ "response-content-disposition", [None]
125
+ )[0]
126
+ if content_disp:
127
+ # Extract filename from content disposition
128
+ match = re.search(r'filename="([^"]+)"', content_disp)
129
+ if match:
130
+ return unquote(match.group(1))
131
+
132
+ # Fallback to path if no content disposition
133
+ path = parsed.path
134
+ if path:
135
+ return os.path.basename(path)
136
+
137
+ return None
138
+ finally:
139
+ # Clean up the session if we created it
140
+ if need_cleanup and self.session:
141
+ await self.session.close()
142
+ self.session = None
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error getting Civitai filename: {e}")
146
+ # Make sure to clean up session on error if we created it
147
+ if "need_cleanup" in locals() and need_cleanup and self.session:
148
+ await self.session.close()
149
+ self.session = None
150
+ return None
151
+
152
+ async def is_downloaded(self, model_id: str, model_config: Optional[dict] = None) -> tuple[bool, Optional[str]]:
153
+ """Check if a model is downloaded, handling all source types including Civitai filename variants"""
154
+ try:
155
+ config = serialize_config(get_config())
156
+ model_info = config["pipeline_defs"].get(model_id)
157
+ if not model_info:
158
+ logger.error(f"Model {model_id} not found in configuration.")
159
+ return False, None
160
+
161
+ source = ModelSource(model_info["source"])
162
+
163
+ # Get components from model_config
164
+ if model_config:
165
+ components = model_config.get("components", [])
166
+ # add the component names to an array
167
+ if isinstance(components, list):
168
+ component_names = [component for component in components]
169
+ print(f"Component names: {component_names}")
170
+ else:
171
+ component_names = None
172
+
173
+ # Handle local files - just check if they exist
174
+ if source.type == "file":
175
+ exists = os.path.exists(source.location)
176
+ if not exists:
177
+ logger.error(f"Local file not found: {source.location}")
178
+ return exists, None
179
+
180
+ # Handle HuggingFace models as before
181
+ if source.type == "huggingface":
182
+ is_downloaded, variant = self._check_repo_downloaded(source.location, component_names)
183
+ print(
184
+ f"Repo {source.location} is downloaded: {is_downloaded}, variant: {variant}"
185
+ )
186
+ return is_downloaded, variant
187
+
188
+ # Special handling for Civitai models
189
+ elif source.type == "civitai":
190
+ # First check the default numeric ID path
191
+ default_path = await self._get_cache_path(model_id, source)
192
+ if self._check_file_downloaded(default_path):
193
+ return True, None
194
+
195
+ # If not found, try to get the original filename
196
+ if not self.session:
197
+ self.session = aiohttp.ClientSession()
198
+ need_cleanup = True
199
+ else:
200
+ need_cleanup = False
201
+
202
+ try:
203
+ original_filename = await self._get_civitai_filename(
204
+ source.location
205
+ )
206
+ if original_filename:
207
+ dir_path = os.path.dirname(default_path)
208
+ alternate_path = os.path.join(dir_path, original_filename)
209
+ if self._check_file_downloaded(alternate_path):
210
+ return True, None
211
+ finally:
212
+ if need_cleanup and self.session:
213
+ await self.session.close()
214
+ self.session = None
215
+
216
+ return False, None
217
+
218
+ # Handle direct downloads
219
+ else:
220
+ cache_path = await self._get_cache_path(model_id, source)
221
+ return self._check_file_downloaded(cache_path), None
222
+
223
+ except Exception as e:
224
+ logger.error(f"Error checking download status for {model_id}: {e}")
225
+ return False, None
226
+
227
+ def _get_model_directory(self, model_id: str, url_hash: str) -> str:
228
+ """Get the directory path for a model"""
229
+ safe_name = model_id.replace("/", "-")
230
+ return os.path.join(self.cozy_cache_dir, f"{safe_name}--{url_hash}")
231
+
232
+ async def download_model(self, model_id: str, source: ModelSource):
233
+ """Download a model from any source"""
234
+ if not self.session:
235
+ raise RuntimeError("Session not initialized. Use async with context.")
236
+
237
+ if source.type == "huggingface":
238
+ repo_id, subfolder, filename = self.parse_hf_string(source.location)
239
+ return await self.download(repo_id, subfolder, filename)
240
+ elif source.type == "civitai":
241
+ return await self._download_civitai(model_id, source.location)
242
+ else:
243
+ return await self._download_direct(model_id, source.location)
244
+
245
+ async def _download_civitai(self, model_id: str, url: str):
246
+ """Handle Civitai-specific download logic with proper filename handling"""
247
+ if not self.session:
248
+ raise RuntimeError("Session not initialized. Use async with context.")
249
+
250
+ # Convert to API URL if needed
251
+ if "/api/download/" not in url:
252
+ model_path = urlparse(url).path
253
+ model_number = model_path.split("/models/")[1].split("/")[0]
254
+ api_url = f"https://civitai.com/api/v1/models/{model_number}"
255
+
256
+ headers = {}
257
+ if self.civitai_api_key:
258
+ headers["Authorization"] = f"Bearer {self.civitai_api_key}"
259
+
260
+ async with self.session.get(api_url, headers=headers) as response:
261
+ if response.status != 200:
262
+ raise Exception(
263
+ f"Failed to get Civitai model info: {response.status}"
264
+ )
265
+ data = await response.json()
266
+ # Extract download URL from the first version
267
+ if "modelVersions" in data and len(data["modelVersions"]) > 0:
268
+ download_url = data["modelVersions"][0]["downloadUrl"]
269
+ else:
270
+ raise Exception("No model versions found in Civitai response")
271
+ else:
272
+ download_url = url
273
+
274
+ # Get original filename from redirect
275
+ dest_path = await self._get_cache_path(model_id, ModelSource(download_url)) # Use download_url for consistent hashing if filename changes
276
+
277
+ original_filename = await self._get_civitai_filename(download_url) # download_url is the one that might redirect
278
+ if original_filename:
279
+ # If we got an original filename, update the dest_path to use it.
280
+ # This ensures the filename in the cache matches what Civitai intends.
281
+ dir_path = os.path.dirname(dest_path) # Keep the hashed directory structure
282
+ dest_path = os.path.join(dir_path, original_filename)
283
+ logger.info(f"Using original filename from Civitai for destination: {dest_path}")
284
+ else:
285
+ logger.warning(f"Could not determine original filename from Civitai for {download_url}. Using default: {dest_path}")
286
+
287
+ # Download with the correct filename
288
+ await self._download_direct(model_id, download_url, dest_path)
289
+
290
+ @backoff.on_exception(
291
+ backoff.expo, (aiohttp.ClientError, asyncio.TimeoutError), max_tries=3
292
+ )
293
+ async def _download_direct(
294
+ self, model_id: str, url: str, dest_path: Optional[str] = None
295
+ ):
296
+ """Download from direct URL with progress bar, retry logic, and resume capability"""
297
+ if dest_path is None:
298
+ dest_path = await self._get_cache_path(model_id, ModelSource(url))
299
+
300
+ temp_path = dest_path + ".tmp"
301
+
302
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
303
+
304
+ headers = {}
305
+ if self.civitai_api_key:
306
+ headers["Authorization"] = f"Bearer {self.civitai_api_key}"
307
+
308
+ # Check if we have a partial download
309
+ initial_size = 0
310
+ if os.path.exists(temp_path):
311
+ initial_size = os.path.getsize(temp_path)
312
+ if initial_size > 0:
313
+ headers["Range"] = f"bytes={initial_size}-"
314
+ logger.info(f"Resuming download from byte {initial_size}")
315
+
316
+ timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
317
+
318
+ try:
319
+ async with self.session.get(
320
+ url, headers=headers, timeout=timeout
321
+ ) as response:
322
+ # Handle resume responses
323
+ if initial_size > 0:
324
+ if response.status == 206: # Partial Content, resume successful
325
+ total_size = initial_size + int(
326
+ response.headers.get("content-length", 0)
327
+ )
328
+ elif response.status == 200: # Server doesn't support resume
329
+ logger.warning(
330
+ "Server doesn't support resume, starting from beginning"
331
+ )
332
+ total_size = int(response.headers.get("content-length", 0))
333
+ initial_size = 0
334
+ else:
335
+ raise Exception(f"Resume failed with status {response.status}")
336
+ else:
337
+ if response.status != 200:
338
+ raise Exception(
339
+ f"Download failed with status {response.status}"
340
+ )
341
+ total_size = int(response.headers.get("content-length", 0))
342
+
343
+ # Open file in append mode if resuming, write mode if starting fresh
344
+ mode = "ab" if initial_size > 0 else "wb"
345
+ downloaded_size = initial_size
346
+ last_progress_update = time.time()
347
+ stall_timer = 0
348
+
349
+ with tqdm(
350
+ total=total_size, initial=initial_size, unit="iB", unit_scale=True
351
+ ) as pbar:
352
+ try:
353
+ with open(temp_path, mode) as f:
354
+ async for chunk in response.content.iter_chunked(8192):
355
+ if chunk: # filter out keep-alive chunks
356
+ f.write(chunk)
357
+ downloaded_size += len(chunk)
358
+ pbar.update(len(chunk))
359
+
360
+ # Check for download stalls
361
+ current_time = time.time()
362
+ if (
363
+ current_time - last_progress_update > 30
364
+ ): # 30 seconds without progress
365
+ stall_timer += (
366
+ current_time - last_progress_update
367
+ )
368
+ if (
369
+ stall_timer > 120
370
+ ): # 2 minutes total stall time
371
+ raise Exception(
372
+ "Download stalled for too long"
373
+ )
374
+ else:
375
+ stall_timer = 0
376
+ last_progress_update = current_time
377
+
378
+ # Verify downloaded size
379
+ if total_size > 0 and downloaded_size != total_size:
380
+ raise Exception(
381
+ f"Download incomplete. Expected {total_size} bytes, got {downloaded_size} bytes"
382
+ )
383
+
384
+ # Verify file integrity
385
+ if await self._verify_file(temp_path):
386
+ os.rename(temp_path, dest_path)
387
+ logger.info(
388
+ f"Downloaded and saved as: {os.path.basename(dest_path)}"
389
+ )
390
+ else:
391
+ raise Exception(
392
+ "File verification failed - will attempt resume on next try"
393
+ )
394
+
395
+ except Exception as e:
396
+ logger.error(
397
+ f"Download error (temporary file kept for resume): {str(e)}"
398
+ )
399
+ raise
400
+
401
+ except Exception as e:
402
+ logger.error(f"Error downloading {url}: {str(e)}")
403
+ raise
404
+
405
+ async def _verify_file(self, path: str) -> bool:
406
+ """Verify downloaded file integrity with more thorough checks"""
407
+ try:
408
+ if not os.path.exists(path):
409
+ logger.error(f"File {path} does not exist")
410
+ return False
411
+
412
+ # Size check
413
+ file_size = os.path.getsize(path)
414
+ if file_size < 1024 * 1024: # 1MB minimum
415
+ logger.error(f"File {path} is too small: {file_size} bytes")
416
+ return False
417
+
418
+ # Extension check - we check the final intended path, not the .tmp path
419
+ check_path = path[:-4] if path.endswith(".tmp") else path
420
+ valid_extensions = {".safetensors", ".ckpt", ".pt", ".bin"}
421
+ if not any(check_path.endswith(ext) for ext in valid_extensions):
422
+ logger.error(f"File {check_path} has invalid extension")
423
+ return False
424
+
425
+ # Try to open the file to ensure it's not corrupted
426
+ with open(path, "rb") as f:
427
+ # Read first and last 1MB to check file accessibility
428
+ f.read(1024 * 1024)
429
+ f.seek(-1024 * 1024, 2)
430
+ f.read(1024 * 1024)
431
+
432
+ logger.info(f"File {path} passed all verification checks")
433
+ return True
434
+
435
+ except Exception as e:
436
+ logger.error(f"File verification failed: {str(e)}")
437
+ return False
438
+
439
+ async def _get_cache_path(self, model_id: str, source: ModelSource) -> str:
440
+ """Get the cache path for a model"""
441
+ if source.type == "huggingface":
442
+ return os.path.join(
443
+ HF_HUB_CACHE, repo_folder_name(source.location, "model")
444
+ )
445
+
446
+ # For non-HF models
447
+ safe_name = model_id.replace("/", "-")
448
+ url_hash = hashlib.sha256(source.location.encode()).hexdigest()[:8]
449
+
450
+ # Create model directory with hash
451
+ model_dir = os.path.join(self.cozy_cache_dir, f"{safe_name}--{url_hash}")
452
+ os.makedirs(model_dir, exist_ok=True)
453
+
454
+ if source.type == "civitai":
455
+ # Try to get original filename from Civitai
456
+ print(f"Getting Civitai filename for {source.location}")
457
+ original_filename = await self._get_civitai_filename(source.location)
458
+ if original_filename:
459
+ return os.path.join(model_dir, original_filename)
460
+
461
+ # Fallback for direct downloads or if couldn't get Civitai filename
462
+ url_path = urlparse(source.location).path
463
+ filename = (
464
+ os.path.basename(url_path) if url_path else f"{safe_name}.safetensors"
465
+ )
466
+
467
+ # Use hash for the final filename to avoid duplicates
468
+ base, ext = os.path.splitext(filename)
469
+ final_filename = f"{base}_{url_hash}{ext}"
470
+
471
+ return os.path.join(model_dir, final_filename)
472
+
473
+ def _check_repo_downloaded(self, repo_id: str, component_names: Optional[List[str]] = None) -> bool:
474
+ storage_folder = os.path.join(
475
+ self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
476
+ )
477
+
478
+ if not os.path.exists(storage_folder):
479
+ return False, None
480
+
481
+ # Get the latest commit hash
482
+ refs_path = os.path.join(storage_folder, "refs", "main")
483
+ if not os.path.exists(refs_path):
484
+ return False, None
485
+
486
+ with open(refs_path, "r") as f:
487
+ commit_hash = f.read().strip()
488
+
489
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
490
+ if not os.path.exists(snapshot_folder):
491
+ return False, None
492
+
493
+ # Check model_index.json for required folders
494
+ model_index_path = os.path.join(snapshot_folder, "model_index.json")
495
+
496
+ if os.path.exists(model_index_path):
497
+ with open(model_index_path, "r") as f:
498
+ model_index = json.load(f)
499
+ required_folders = {
500
+ k
501
+ for k, v in model_index.items()
502
+ if isinstance(v, list)
503
+ and len(v) == 2
504
+ and v[0] is not None
505
+ and v[1] is not None
506
+ }
507
+
508
+ # Remove known non-folder keys and ignored folders
509
+ ignored_folders = {
510
+ "_class_name",
511
+ "_diffusers_version",
512
+ "scheduler",
513
+ "feature_extractor",
514
+ "tokenizer",
515
+ "tokenizer_2",
516
+ "tokenizer_3",
517
+ "safety_checker",
518
+ }
519
+
520
+ required_folders -= ignored_folders
521
+ if component_names:
522
+ required_folders -= set(component_names)
523
+
524
+ print(f"Required folders: {required_folders}")
525
+
526
+ # Define variant hierarchy
527
+ variants = [
528
+ "bf16",
529
+ "fp8",
530
+ "fp16",
531
+ "",
532
+ ] # empty string for normal variant
533
+
534
+ def check_folder_completeness(folder_path: str, variant: str) -> bool:
535
+ if not os.path.exists(folder_path):
536
+ return False
537
+
538
+ for _, _, files in os.walk(folder_path):
539
+ for file in files:
540
+ if file.endswith(".incomplete"):
541
+ print(f"Incomplete File: {file}")
542
+ return False
543
+
544
+ if (
545
+ file.endswith(f"{variant}.safetensors")
546
+ or file.endswith(f"{variant}.bin")
547
+ or (
548
+ variant == ""
549
+ and (
550
+ file.endswith(".safetensors")
551
+ or file.endswith(".bin")
552
+ or file.endswith(".ckpt")
553
+ )
554
+ )
555
+ ):
556
+ return True
557
+
558
+ return False
559
+
560
+ def check_variant_completeness(variant: str) -> bool:
561
+ for folder in required_folders:
562
+ folder_path = os.path.join(snapshot_folder, folder)
563
+
564
+ if not check_folder_completeness(folder_path, variant):
565
+ return False
566
+
567
+ return True
568
+
569
+ # Check variants in hierarchy
570
+ for variant in variants:
571
+ print(f"Checking variant: {variant}")
572
+ if check_variant_completeness(variant):
573
+ print(f"Variant {variant} is complete")
574
+ return True, variant
575
+
576
+ else:
577
+ # For repos without model_index.json, check the blob folder
578
+ blob_folder = os.path.join(storage_folder, "blobs")
579
+ if os.path.exists(blob_folder):
580
+ for _root, _, files in os.walk(blob_folder):
581
+ if any(file.endswith(".incomplete") for file in files):
582
+ return False, None
583
+
584
+ return True, None
585
+
586
+ return False, None
587
+
588
+ def _check_component_downloaded(self, repo_id: str, component_name: str) -> bool:
589
+ storage_folder = os.path.join(
590
+ self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
591
+ )
592
+
593
+ if not os.path.exists(storage_folder):
594
+ return False
595
+
596
+ refs_path = os.path.join(storage_folder, "refs", "main")
597
+ if not os.path.exists(refs_path):
598
+ return False
599
+
600
+ with open(refs_path, "r") as f:
601
+ commit_hash = f.read().strip()
602
+
603
+ component_folder = os.path.join(
604
+ storage_folder, "snapshots", commit_hash, component_name
605
+ )
606
+
607
+ if not os.path.exists(component_folder):
608
+ return False
609
+
610
+ # Check for any .bin, .safetensors, or .ckpt file in the component folder
611
+ for _, _, files in os.walk(component_folder):
612
+ for file in files:
613
+ if file.endswith(
614
+ (".bin", ".safetensors", ".ckpt")
615
+ ) and not file.endswith(".incomplete"):
616
+ return True
617
+
618
+ return False
619
+
620
+
621
+ def _check_file_downloaded(self, path: str) -> bool:
622
+ """Check if a file exists and is complete in the cache"""
623
+ # First check if the exact path exists
624
+ if os.path.exists(path):
625
+ # Check for temporary or incomplete markers
626
+ if os.path.exists(f"{path}.tmp") or os.path.exists(f"{path}.incomplete"):
627
+ print(f"Found incomplete markers for {path}")
628
+ return False
629
+ print(f"Found complete file at {path}")
630
+ return True
631
+
632
+ # If path doesn't exist, check the directory for any valid model files
633
+ dir_path = os.path.dirname(path)
634
+ if os.path.exists(dir_path):
635
+ for file in os.listdir(dir_path):
636
+ file_path = os.path.join(dir_path, file)
637
+ if file.endswith((".safetensors", ".ckpt", ".pt", ".bin")):
638
+ if not os.path.exists(f"{file_path}.tmp") and not os.path.exists(
639
+ f"{file_path}.incomplete"
640
+ ):
641
+ print(f"Found alternative model file at {file_path}")
642
+ return True
643
+
644
+ print(f"No valid model files found in {dir_path}")
645
+ return False
646
+
647
+ def list(self) -> List[str]:
648
+ cache_info = scan_cache_dir()
649
+ return [
650
+ repo.repo_id
651
+ for repo in cache_info.repos
652
+ if self.is_downloaded(repo.repo_id)[0]
653
+ ]
654
+
655
+ async def download(
656
+ self,
657
+ repo_id: str,
658
+ file_name: Optional[str] = None,
659
+ sub_folder: Optional[str] = None,
660
+ ) -> None:
661
+ if file_name or sub_folder:
662
+ try:
663
+ if sub_folder and not file_name:
664
+ await asyncio.to_thread(
665
+ snapshot_download,
666
+ repo_id,
667
+ allow_patterns=f"{sub_folder}/*",
668
+ )
669
+ logger.info(
670
+ f"{sub_folder} subfolder from {repo_id} downloaded successfully."
671
+ )
672
+ else:
673
+ await asyncio.to_thread(
674
+ hf_hub_download,
675
+ repo_id,
676
+ file_name,
677
+ cache_dir=self.cache_dir,
678
+ subfolder=sub_folder,
679
+ )
680
+ logger.info(
681
+ f"File {file_name} from {repo_id} downloaded successfully."
682
+ )
683
+ # self.list() # Refresh the cached list
684
+ return True
685
+ except Exception as e:
686
+ logger.error(f"Failed to download file {file_name} from {repo_id}: {e}")
687
+ return False
688
+
689
+ variants = ["bf16", "fp8", "fp16", None] # None represents no variant
690
+ for var in variants:
691
+ try:
692
+ if var:
693
+ logger.info(
694
+ f"Attempting to download {repo_id} with {var} variant..."
695
+ )
696
+ else:
697
+ logger.info(f"Attempting to download {repo_id} without variant...")
698
+
699
+ await asyncio.to_thread(
700
+ DiffusionPipeline.download,
701
+ repo_id,
702
+ variant=var,
703
+ cache_dir=self.cache_dir,
704
+ torch_dtype=torch.float16,
705
+ )
706
+
707
+ logger.info(
708
+ f"Model {repo_id} downloaded successfully with variant: {var if var else 'default'}"
709
+ )
710
+ # self.list() # Refresh the cached list
711
+ return True
712
+
713
+ except Exception as e:
714
+ if var:
715
+ logger.error(
716
+ f"Failed to download {var} variant for {repo_id}. Trying next variant..."
717
+ )
718
+ else:
719
+ logger.error(
720
+ f"Failed to download default variant for {repo_id}: {e}"
721
+ )
722
+
723
+ logger.error(f"Failed to download model {repo_id} with any variant.")
724
+ return False
725
+
726
+ async def delete(self, repo_id: str) -> None:
727
+ model_path = os.path.join(
728
+ self.cache_dir, "models--" + repo_id.replace("/", "--")
729
+ )
730
+ if os.path.exists(model_path):
731
+ await asyncio.to_thread(shutil.rmtree, model_path)
732
+ logger.info(f"Model {repo_id} deleted successfully.")
733
+ else:
734
+ logger.warning(f"Model {repo_id} not found in cache.")
735
+
736
+ async def get_diffusers_multifolder_components(
737
+ self, repo_id: str
738
+ ) -> dict[str, str | tuple[str, str]] | None:
739
+ """
740
+ This is only meaningful if the repo is in diffusers-multifolder layout.
741
+ This retrieves and parses the model_index.json file, and None otherwise.
742
+ """
743
+ try:
744
+ model_index_path = await asyncio.to_thread(
745
+ hf_hub_download,
746
+ repo_id=repo_id,
747
+ filename="model_index.json",
748
+ cache_dir=self.cache_dir,
749
+ )
750
+
751
+ if model_index_path:
752
+ with open(model_index_path, "r") as f:
753
+ data = json.load(f)
754
+ return {
755
+ k: tuple(v) if isinstance(v, list) else v
756
+ for k, v in data.items()
757
+ }
758
+ else:
759
+ return None
760
+ except Exception as e:
761
+ logger.error(f"Error retrieving model_index.json for {repo_id}: {e}")
762
+ return None
763
+