gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,763 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import asyncio
|
|
4
|
+
import aiohttp
|
|
5
|
+
from typing import Optional, List
|
|
6
|
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
7
|
+
from .paths import get_models_dir
|
|
8
|
+
from huggingface_hub import HfApi, hf_hub_download, scan_cache_dir, snapshot_download
|
|
9
|
+
from huggingface_hub.file_download import repo_folder_name
|
|
10
|
+
import torch
|
|
11
|
+
import logging
|
|
12
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
13
|
+
import json
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from .config import get_config
|
|
16
|
+
from .utils import serialize_config
|
|
17
|
+
import hashlib
|
|
18
|
+
import time
|
|
19
|
+
import re
|
|
20
|
+
from urllib.parse import urlparse, parse_qs, unquote
|
|
21
|
+
import backoff
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ModelSource:
|
|
28
|
+
"""Represents a model source with its type and details"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, source_str: str):
|
|
31
|
+
self.original_string = source_str
|
|
32
|
+
if source_str.startswith("hf:"):
|
|
33
|
+
self.type = "huggingface"
|
|
34
|
+
self.location = source_str[3:]
|
|
35
|
+
elif "civitai.com" in source_str:
|
|
36
|
+
self.type = "civitai"
|
|
37
|
+
self.location = source_str
|
|
38
|
+
elif source_str.startswith("file:"):
|
|
39
|
+
self.type = "file"
|
|
40
|
+
self.location = source_str[5:]
|
|
41
|
+
elif source_str.startswith(("http://", "https://")):
|
|
42
|
+
self.type = "direct"
|
|
43
|
+
self.location = source_str
|
|
44
|
+
else:
|
|
45
|
+
raise ValueError(f"Unsupported model source: {source_str}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ModelManager:
|
|
49
|
+
def __init__(self, cache_dir: Optional[str] = None):
|
|
50
|
+
self.hf_api = HfApi()
|
|
51
|
+
self.cache_dir = cache_dir or HF_HUB_CACHE
|
|
52
|
+
self.base_cache_dir = cache_dir or os.path.expanduser("~/.cache")
|
|
53
|
+
self.cozy_cache_dir = get_models_dir()
|
|
54
|
+
self.session: Optional[aiohttp.ClientSession] = None
|
|
55
|
+
|
|
56
|
+
# config = serialize_config(get_config())
|
|
57
|
+
# self.civitai_api_key = config["civitai_api_key"]
|
|
58
|
+
|
|
59
|
+
# check env for civitai api key
|
|
60
|
+
self.civitai_api_key = os.getenv("CIVITAI_API_KEY")
|
|
61
|
+
|
|
62
|
+
async def __aenter__(self):
|
|
63
|
+
self.session = aiohttp.ClientSession()
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
async def __aexit__(self, exc_type, exc, tb):
|
|
67
|
+
if self.session:
|
|
68
|
+
await self.session.close()
|
|
69
|
+
self.session = None
|
|
70
|
+
|
|
71
|
+
def parse_hf_string(
|
|
72
|
+
self, hf_string: str
|
|
73
|
+
) -> tuple[str, Optional[str], Optional[str]]:
|
|
74
|
+
"""
|
|
75
|
+
Parses an HuggingFace string into its components.
|
|
76
|
+
Returns a tuple of (repo_id, subfolder, filename)
|
|
77
|
+
"""
|
|
78
|
+
# Remove 'hf:' prefix if present
|
|
79
|
+
if hf_string.startswith("hf:"):
|
|
80
|
+
hf_string = hf_string[3:]
|
|
81
|
+
|
|
82
|
+
parts = hf_string.split("/")
|
|
83
|
+
if len(parts) < 2:
|
|
84
|
+
raise ValueError("Invalid HuggingFace string: repo_id is required")
|
|
85
|
+
|
|
86
|
+
repo_id = "/".join(parts[:2])
|
|
87
|
+
subfolder = None
|
|
88
|
+
filename = None
|
|
89
|
+
|
|
90
|
+
if len(parts) > 2:
|
|
91
|
+
if not parts[-1].endswith("/"):
|
|
92
|
+
filename = parts[-1]
|
|
93
|
+
subfolder = "/".join(parts[2:-1]) if len(parts) > 3 else None
|
|
94
|
+
else:
|
|
95
|
+
subfolder = "/".join(parts[2:])
|
|
96
|
+
|
|
97
|
+
return repo_id, subfolder, filename
|
|
98
|
+
|
|
99
|
+
async def _get_civitai_filename(self, url: str) -> Optional[str]:
|
|
100
|
+
"""Extract original filename from Civitai redirect response"""
|
|
101
|
+
try:
|
|
102
|
+
headers = {}
|
|
103
|
+
if self.civitai_api_key:
|
|
104
|
+
headers["Authorization"] = f"Bearer {self.civitai_api_key}"
|
|
105
|
+
|
|
106
|
+
need_cleanup = False
|
|
107
|
+
if not self.session:
|
|
108
|
+
self.session = aiohttp.ClientSession()
|
|
109
|
+
need_cleanup = True
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
async with self.session.get(
|
|
113
|
+
url, headers=headers, allow_redirects=False
|
|
114
|
+
) as response:
|
|
115
|
+
if response.status in (301, 302, 307):
|
|
116
|
+
location = response.headers.get("location")
|
|
117
|
+
if location:
|
|
118
|
+
# Parse the query parameters from the redirect URL
|
|
119
|
+
parsed = urlparse(location)
|
|
120
|
+
query_params = parse_qs(parsed.query)
|
|
121
|
+
|
|
122
|
+
# Look for response-content-disposition parameter
|
|
123
|
+
content_disp = query_params.get(
|
|
124
|
+
"response-content-disposition", [None]
|
|
125
|
+
)[0]
|
|
126
|
+
if content_disp:
|
|
127
|
+
# Extract filename from content disposition
|
|
128
|
+
match = re.search(r'filename="([^"]+)"', content_disp)
|
|
129
|
+
if match:
|
|
130
|
+
return unquote(match.group(1))
|
|
131
|
+
|
|
132
|
+
# Fallback to path if no content disposition
|
|
133
|
+
path = parsed.path
|
|
134
|
+
if path:
|
|
135
|
+
return os.path.basename(path)
|
|
136
|
+
|
|
137
|
+
return None
|
|
138
|
+
finally:
|
|
139
|
+
# Clean up the session if we created it
|
|
140
|
+
if need_cleanup and self.session:
|
|
141
|
+
await self.session.close()
|
|
142
|
+
self.session = None
|
|
143
|
+
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"Error getting Civitai filename: {e}")
|
|
146
|
+
# Make sure to clean up session on error if we created it
|
|
147
|
+
if "need_cleanup" in locals() and need_cleanup and self.session:
|
|
148
|
+
await self.session.close()
|
|
149
|
+
self.session = None
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
async def is_downloaded(self, model_id: str, model_config: Optional[dict] = None) -> tuple[bool, Optional[str]]:
|
|
153
|
+
"""Check if a model is downloaded, handling all source types including Civitai filename variants"""
|
|
154
|
+
try:
|
|
155
|
+
config = serialize_config(get_config())
|
|
156
|
+
model_info = config["pipeline_defs"].get(model_id)
|
|
157
|
+
if not model_info:
|
|
158
|
+
logger.error(f"Model {model_id} not found in configuration.")
|
|
159
|
+
return False, None
|
|
160
|
+
|
|
161
|
+
source = ModelSource(model_info["source"])
|
|
162
|
+
|
|
163
|
+
# Get components from model_config
|
|
164
|
+
if model_config:
|
|
165
|
+
components = model_config.get("components", [])
|
|
166
|
+
# add the component names to an array
|
|
167
|
+
if isinstance(components, list):
|
|
168
|
+
component_names = [component for component in components]
|
|
169
|
+
print(f"Component names: {component_names}")
|
|
170
|
+
else:
|
|
171
|
+
component_names = None
|
|
172
|
+
|
|
173
|
+
# Handle local files - just check if they exist
|
|
174
|
+
if source.type == "file":
|
|
175
|
+
exists = os.path.exists(source.location)
|
|
176
|
+
if not exists:
|
|
177
|
+
logger.error(f"Local file not found: {source.location}")
|
|
178
|
+
return exists, None
|
|
179
|
+
|
|
180
|
+
# Handle HuggingFace models as before
|
|
181
|
+
if source.type == "huggingface":
|
|
182
|
+
is_downloaded, variant = self._check_repo_downloaded(source.location, component_names)
|
|
183
|
+
print(
|
|
184
|
+
f"Repo {source.location} is downloaded: {is_downloaded}, variant: {variant}"
|
|
185
|
+
)
|
|
186
|
+
return is_downloaded, variant
|
|
187
|
+
|
|
188
|
+
# Special handling for Civitai models
|
|
189
|
+
elif source.type == "civitai":
|
|
190
|
+
# First check the default numeric ID path
|
|
191
|
+
default_path = await self._get_cache_path(model_id, source)
|
|
192
|
+
if self._check_file_downloaded(default_path):
|
|
193
|
+
return True, None
|
|
194
|
+
|
|
195
|
+
# If not found, try to get the original filename
|
|
196
|
+
if not self.session:
|
|
197
|
+
self.session = aiohttp.ClientSession()
|
|
198
|
+
need_cleanup = True
|
|
199
|
+
else:
|
|
200
|
+
need_cleanup = False
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
original_filename = await self._get_civitai_filename(
|
|
204
|
+
source.location
|
|
205
|
+
)
|
|
206
|
+
if original_filename:
|
|
207
|
+
dir_path = os.path.dirname(default_path)
|
|
208
|
+
alternate_path = os.path.join(dir_path, original_filename)
|
|
209
|
+
if self._check_file_downloaded(alternate_path):
|
|
210
|
+
return True, None
|
|
211
|
+
finally:
|
|
212
|
+
if need_cleanup and self.session:
|
|
213
|
+
await self.session.close()
|
|
214
|
+
self.session = None
|
|
215
|
+
|
|
216
|
+
return False, None
|
|
217
|
+
|
|
218
|
+
# Handle direct downloads
|
|
219
|
+
else:
|
|
220
|
+
cache_path = await self._get_cache_path(model_id, source)
|
|
221
|
+
return self._check_file_downloaded(cache_path), None
|
|
222
|
+
|
|
223
|
+
except Exception as e:
|
|
224
|
+
logger.error(f"Error checking download status for {model_id}: {e}")
|
|
225
|
+
return False, None
|
|
226
|
+
|
|
227
|
+
def _get_model_directory(self, model_id: str, url_hash: str) -> str:
|
|
228
|
+
"""Get the directory path for a model"""
|
|
229
|
+
safe_name = model_id.replace("/", "-")
|
|
230
|
+
return os.path.join(self.cozy_cache_dir, f"{safe_name}--{url_hash}")
|
|
231
|
+
|
|
232
|
+
async def download_model(self, model_id: str, source: ModelSource):
|
|
233
|
+
"""Download a model from any source"""
|
|
234
|
+
if not self.session:
|
|
235
|
+
raise RuntimeError("Session not initialized. Use async with context.")
|
|
236
|
+
|
|
237
|
+
if source.type == "huggingface":
|
|
238
|
+
repo_id, subfolder, filename = self.parse_hf_string(source.location)
|
|
239
|
+
return await self.download(repo_id, subfolder, filename)
|
|
240
|
+
elif source.type == "civitai":
|
|
241
|
+
return await self._download_civitai(model_id, source.location)
|
|
242
|
+
else:
|
|
243
|
+
return await self._download_direct(model_id, source.location)
|
|
244
|
+
|
|
245
|
+
async def _download_civitai(self, model_id: str, url: str):
|
|
246
|
+
"""Handle Civitai-specific download logic with proper filename handling"""
|
|
247
|
+
if not self.session:
|
|
248
|
+
raise RuntimeError("Session not initialized. Use async with context.")
|
|
249
|
+
|
|
250
|
+
# Convert to API URL if needed
|
|
251
|
+
if "/api/download/" not in url:
|
|
252
|
+
model_path = urlparse(url).path
|
|
253
|
+
model_number = model_path.split("/models/")[1].split("/")[0]
|
|
254
|
+
api_url = f"https://civitai.com/api/v1/models/{model_number}"
|
|
255
|
+
|
|
256
|
+
headers = {}
|
|
257
|
+
if self.civitai_api_key:
|
|
258
|
+
headers["Authorization"] = f"Bearer {self.civitai_api_key}"
|
|
259
|
+
|
|
260
|
+
async with self.session.get(api_url, headers=headers) as response:
|
|
261
|
+
if response.status != 200:
|
|
262
|
+
raise Exception(
|
|
263
|
+
f"Failed to get Civitai model info: {response.status}"
|
|
264
|
+
)
|
|
265
|
+
data = await response.json()
|
|
266
|
+
# Extract download URL from the first version
|
|
267
|
+
if "modelVersions" in data and len(data["modelVersions"]) > 0:
|
|
268
|
+
download_url = data["modelVersions"][0]["downloadUrl"]
|
|
269
|
+
else:
|
|
270
|
+
raise Exception("No model versions found in Civitai response")
|
|
271
|
+
else:
|
|
272
|
+
download_url = url
|
|
273
|
+
|
|
274
|
+
# Get original filename from redirect
|
|
275
|
+
dest_path = await self._get_cache_path(model_id, ModelSource(download_url)) # Use download_url for consistent hashing if filename changes
|
|
276
|
+
|
|
277
|
+
original_filename = await self._get_civitai_filename(download_url) # download_url is the one that might redirect
|
|
278
|
+
if original_filename:
|
|
279
|
+
# If we got an original filename, update the dest_path to use it.
|
|
280
|
+
# This ensures the filename in the cache matches what Civitai intends.
|
|
281
|
+
dir_path = os.path.dirname(dest_path) # Keep the hashed directory structure
|
|
282
|
+
dest_path = os.path.join(dir_path, original_filename)
|
|
283
|
+
logger.info(f"Using original filename from Civitai for destination: {dest_path}")
|
|
284
|
+
else:
|
|
285
|
+
logger.warning(f"Could not determine original filename from Civitai for {download_url}. Using default: {dest_path}")
|
|
286
|
+
|
|
287
|
+
# Download with the correct filename
|
|
288
|
+
await self._download_direct(model_id, download_url, dest_path)
|
|
289
|
+
|
|
290
|
+
@backoff.on_exception(
|
|
291
|
+
backoff.expo, (aiohttp.ClientError, asyncio.TimeoutError), max_tries=3
|
|
292
|
+
)
|
|
293
|
+
async def _download_direct(
|
|
294
|
+
self, model_id: str, url: str, dest_path: Optional[str] = None
|
|
295
|
+
):
|
|
296
|
+
"""Download from direct URL with progress bar, retry logic, and resume capability"""
|
|
297
|
+
if dest_path is None:
|
|
298
|
+
dest_path = await self._get_cache_path(model_id, ModelSource(url))
|
|
299
|
+
|
|
300
|
+
temp_path = dest_path + ".tmp"
|
|
301
|
+
|
|
302
|
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
|
303
|
+
|
|
304
|
+
headers = {}
|
|
305
|
+
if self.civitai_api_key:
|
|
306
|
+
headers["Authorization"] = f"Bearer {self.civitai_api_key}"
|
|
307
|
+
|
|
308
|
+
# Check if we have a partial download
|
|
309
|
+
initial_size = 0
|
|
310
|
+
if os.path.exists(temp_path):
|
|
311
|
+
initial_size = os.path.getsize(temp_path)
|
|
312
|
+
if initial_size > 0:
|
|
313
|
+
headers["Range"] = f"bytes={initial_size}-"
|
|
314
|
+
logger.info(f"Resuming download from byte {initial_size}")
|
|
315
|
+
|
|
316
|
+
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
async with self.session.get(
|
|
320
|
+
url, headers=headers, timeout=timeout
|
|
321
|
+
) as response:
|
|
322
|
+
# Handle resume responses
|
|
323
|
+
if initial_size > 0:
|
|
324
|
+
if response.status == 206: # Partial Content, resume successful
|
|
325
|
+
total_size = initial_size + int(
|
|
326
|
+
response.headers.get("content-length", 0)
|
|
327
|
+
)
|
|
328
|
+
elif response.status == 200: # Server doesn't support resume
|
|
329
|
+
logger.warning(
|
|
330
|
+
"Server doesn't support resume, starting from beginning"
|
|
331
|
+
)
|
|
332
|
+
total_size = int(response.headers.get("content-length", 0))
|
|
333
|
+
initial_size = 0
|
|
334
|
+
else:
|
|
335
|
+
raise Exception(f"Resume failed with status {response.status}")
|
|
336
|
+
else:
|
|
337
|
+
if response.status != 200:
|
|
338
|
+
raise Exception(
|
|
339
|
+
f"Download failed with status {response.status}"
|
|
340
|
+
)
|
|
341
|
+
total_size = int(response.headers.get("content-length", 0))
|
|
342
|
+
|
|
343
|
+
# Open file in append mode if resuming, write mode if starting fresh
|
|
344
|
+
mode = "ab" if initial_size > 0 else "wb"
|
|
345
|
+
downloaded_size = initial_size
|
|
346
|
+
last_progress_update = time.time()
|
|
347
|
+
stall_timer = 0
|
|
348
|
+
|
|
349
|
+
with tqdm(
|
|
350
|
+
total=total_size, initial=initial_size, unit="iB", unit_scale=True
|
|
351
|
+
) as pbar:
|
|
352
|
+
try:
|
|
353
|
+
with open(temp_path, mode) as f:
|
|
354
|
+
async for chunk in response.content.iter_chunked(8192):
|
|
355
|
+
if chunk: # filter out keep-alive chunks
|
|
356
|
+
f.write(chunk)
|
|
357
|
+
downloaded_size += len(chunk)
|
|
358
|
+
pbar.update(len(chunk))
|
|
359
|
+
|
|
360
|
+
# Check for download stalls
|
|
361
|
+
current_time = time.time()
|
|
362
|
+
if (
|
|
363
|
+
current_time - last_progress_update > 30
|
|
364
|
+
): # 30 seconds without progress
|
|
365
|
+
stall_timer += (
|
|
366
|
+
current_time - last_progress_update
|
|
367
|
+
)
|
|
368
|
+
if (
|
|
369
|
+
stall_timer > 120
|
|
370
|
+
): # 2 minutes total stall time
|
|
371
|
+
raise Exception(
|
|
372
|
+
"Download stalled for too long"
|
|
373
|
+
)
|
|
374
|
+
else:
|
|
375
|
+
stall_timer = 0
|
|
376
|
+
last_progress_update = current_time
|
|
377
|
+
|
|
378
|
+
# Verify downloaded size
|
|
379
|
+
if total_size > 0 and downloaded_size != total_size:
|
|
380
|
+
raise Exception(
|
|
381
|
+
f"Download incomplete. Expected {total_size} bytes, got {downloaded_size} bytes"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Verify file integrity
|
|
385
|
+
if await self._verify_file(temp_path):
|
|
386
|
+
os.rename(temp_path, dest_path)
|
|
387
|
+
logger.info(
|
|
388
|
+
f"Downloaded and saved as: {os.path.basename(dest_path)}"
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
raise Exception(
|
|
392
|
+
"File verification failed - will attempt resume on next try"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
except Exception as e:
|
|
396
|
+
logger.error(
|
|
397
|
+
f"Download error (temporary file kept for resume): {str(e)}"
|
|
398
|
+
)
|
|
399
|
+
raise
|
|
400
|
+
|
|
401
|
+
except Exception as e:
|
|
402
|
+
logger.error(f"Error downloading {url}: {str(e)}")
|
|
403
|
+
raise
|
|
404
|
+
|
|
405
|
+
async def _verify_file(self, path: str) -> bool:
|
|
406
|
+
"""Verify downloaded file integrity with more thorough checks"""
|
|
407
|
+
try:
|
|
408
|
+
if not os.path.exists(path):
|
|
409
|
+
logger.error(f"File {path} does not exist")
|
|
410
|
+
return False
|
|
411
|
+
|
|
412
|
+
# Size check
|
|
413
|
+
file_size = os.path.getsize(path)
|
|
414
|
+
if file_size < 1024 * 1024: # 1MB minimum
|
|
415
|
+
logger.error(f"File {path} is too small: {file_size} bytes")
|
|
416
|
+
return False
|
|
417
|
+
|
|
418
|
+
# Extension check - we check the final intended path, not the .tmp path
|
|
419
|
+
check_path = path[:-4] if path.endswith(".tmp") else path
|
|
420
|
+
valid_extensions = {".safetensors", ".ckpt", ".pt", ".bin"}
|
|
421
|
+
if not any(check_path.endswith(ext) for ext in valid_extensions):
|
|
422
|
+
logger.error(f"File {check_path} has invalid extension")
|
|
423
|
+
return False
|
|
424
|
+
|
|
425
|
+
# Try to open the file to ensure it's not corrupted
|
|
426
|
+
with open(path, "rb") as f:
|
|
427
|
+
# Read first and last 1MB to check file accessibility
|
|
428
|
+
f.read(1024 * 1024)
|
|
429
|
+
f.seek(-1024 * 1024, 2)
|
|
430
|
+
f.read(1024 * 1024)
|
|
431
|
+
|
|
432
|
+
logger.info(f"File {path} passed all verification checks")
|
|
433
|
+
return True
|
|
434
|
+
|
|
435
|
+
except Exception as e:
|
|
436
|
+
logger.error(f"File verification failed: {str(e)}")
|
|
437
|
+
return False
|
|
438
|
+
|
|
439
|
+
async def _get_cache_path(self, model_id: str, source: ModelSource) -> str:
|
|
440
|
+
"""Get the cache path for a model"""
|
|
441
|
+
if source.type == "huggingface":
|
|
442
|
+
return os.path.join(
|
|
443
|
+
HF_HUB_CACHE, repo_folder_name(source.location, "model")
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# For non-HF models
|
|
447
|
+
safe_name = model_id.replace("/", "-")
|
|
448
|
+
url_hash = hashlib.sha256(source.location.encode()).hexdigest()[:8]
|
|
449
|
+
|
|
450
|
+
# Create model directory with hash
|
|
451
|
+
model_dir = os.path.join(self.cozy_cache_dir, f"{safe_name}--{url_hash}")
|
|
452
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
453
|
+
|
|
454
|
+
if source.type == "civitai":
|
|
455
|
+
# Try to get original filename from Civitai
|
|
456
|
+
print(f"Getting Civitai filename for {source.location}")
|
|
457
|
+
original_filename = await self._get_civitai_filename(source.location)
|
|
458
|
+
if original_filename:
|
|
459
|
+
return os.path.join(model_dir, original_filename)
|
|
460
|
+
|
|
461
|
+
# Fallback for direct downloads or if couldn't get Civitai filename
|
|
462
|
+
url_path = urlparse(source.location).path
|
|
463
|
+
filename = (
|
|
464
|
+
os.path.basename(url_path) if url_path else f"{safe_name}.safetensors"
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Use hash for the final filename to avoid duplicates
|
|
468
|
+
base, ext = os.path.splitext(filename)
|
|
469
|
+
final_filename = f"{base}_{url_hash}{ext}"
|
|
470
|
+
|
|
471
|
+
return os.path.join(model_dir, final_filename)
|
|
472
|
+
|
|
473
|
+
def _check_repo_downloaded(self, repo_id: str, component_names: Optional[List[str]] = None) -> bool:
|
|
474
|
+
storage_folder = os.path.join(
|
|
475
|
+
self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
if not os.path.exists(storage_folder):
|
|
479
|
+
return False, None
|
|
480
|
+
|
|
481
|
+
# Get the latest commit hash
|
|
482
|
+
refs_path = os.path.join(storage_folder, "refs", "main")
|
|
483
|
+
if not os.path.exists(refs_path):
|
|
484
|
+
return False, None
|
|
485
|
+
|
|
486
|
+
with open(refs_path, "r") as f:
|
|
487
|
+
commit_hash = f.read().strip()
|
|
488
|
+
|
|
489
|
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
|
490
|
+
if not os.path.exists(snapshot_folder):
|
|
491
|
+
return False, None
|
|
492
|
+
|
|
493
|
+
# Check model_index.json for required folders
|
|
494
|
+
model_index_path = os.path.join(snapshot_folder, "model_index.json")
|
|
495
|
+
|
|
496
|
+
if os.path.exists(model_index_path):
|
|
497
|
+
with open(model_index_path, "r") as f:
|
|
498
|
+
model_index = json.load(f)
|
|
499
|
+
required_folders = {
|
|
500
|
+
k
|
|
501
|
+
for k, v in model_index.items()
|
|
502
|
+
if isinstance(v, list)
|
|
503
|
+
and len(v) == 2
|
|
504
|
+
and v[0] is not None
|
|
505
|
+
and v[1] is not None
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
# Remove known non-folder keys and ignored folders
|
|
509
|
+
ignored_folders = {
|
|
510
|
+
"_class_name",
|
|
511
|
+
"_diffusers_version",
|
|
512
|
+
"scheduler",
|
|
513
|
+
"feature_extractor",
|
|
514
|
+
"tokenizer",
|
|
515
|
+
"tokenizer_2",
|
|
516
|
+
"tokenizer_3",
|
|
517
|
+
"safety_checker",
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
required_folders -= ignored_folders
|
|
521
|
+
if component_names:
|
|
522
|
+
required_folders -= set(component_names)
|
|
523
|
+
|
|
524
|
+
print(f"Required folders: {required_folders}")
|
|
525
|
+
|
|
526
|
+
# Define variant hierarchy
|
|
527
|
+
variants = [
|
|
528
|
+
"bf16",
|
|
529
|
+
"fp8",
|
|
530
|
+
"fp16",
|
|
531
|
+
"",
|
|
532
|
+
] # empty string for normal variant
|
|
533
|
+
|
|
534
|
+
def check_folder_completeness(folder_path: str, variant: str) -> bool:
|
|
535
|
+
if not os.path.exists(folder_path):
|
|
536
|
+
return False
|
|
537
|
+
|
|
538
|
+
for _, _, files in os.walk(folder_path):
|
|
539
|
+
for file in files:
|
|
540
|
+
if file.endswith(".incomplete"):
|
|
541
|
+
print(f"Incomplete File: {file}")
|
|
542
|
+
return False
|
|
543
|
+
|
|
544
|
+
if (
|
|
545
|
+
file.endswith(f"{variant}.safetensors")
|
|
546
|
+
or file.endswith(f"{variant}.bin")
|
|
547
|
+
or (
|
|
548
|
+
variant == ""
|
|
549
|
+
and (
|
|
550
|
+
file.endswith(".safetensors")
|
|
551
|
+
or file.endswith(".bin")
|
|
552
|
+
or file.endswith(".ckpt")
|
|
553
|
+
)
|
|
554
|
+
)
|
|
555
|
+
):
|
|
556
|
+
return True
|
|
557
|
+
|
|
558
|
+
return False
|
|
559
|
+
|
|
560
|
+
def check_variant_completeness(variant: str) -> bool:
|
|
561
|
+
for folder in required_folders:
|
|
562
|
+
folder_path = os.path.join(snapshot_folder, folder)
|
|
563
|
+
|
|
564
|
+
if not check_folder_completeness(folder_path, variant):
|
|
565
|
+
return False
|
|
566
|
+
|
|
567
|
+
return True
|
|
568
|
+
|
|
569
|
+
# Check variants in hierarchy
|
|
570
|
+
for variant in variants:
|
|
571
|
+
print(f"Checking variant: {variant}")
|
|
572
|
+
if check_variant_completeness(variant):
|
|
573
|
+
print(f"Variant {variant} is complete")
|
|
574
|
+
return True, variant
|
|
575
|
+
|
|
576
|
+
else:
|
|
577
|
+
# For repos without model_index.json, check the blob folder
|
|
578
|
+
blob_folder = os.path.join(storage_folder, "blobs")
|
|
579
|
+
if os.path.exists(blob_folder):
|
|
580
|
+
for _root, _, files in os.walk(blob_folder):
|
|
581
|
+
if any(file.endswith(".incomplete") for file in files):
|
|
582
|
+
return False, None
|
|
583
|
+
|
|
584
|
+
return True, None
|
|
585
|
+
|
|
586
|
+
return False, None
|
|
587
|
+
|
|
588
|
+
def _check_component_downloaded(self, repo_id: str, component_name: str) -> bool:
|
|
589
|
+
storage_folder = os.path.join(
|
|
590
|
+
self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if not os.path.exists(storage_folder):
|
|
594
|
+
return False
|
|
595
|
+
|
|
596
|
+
refs_path = os.path.join(storage_folder, "refs", "main")
|
|
597
|
+
if not os.path.exists(refs_path):
|
|
598
|
+
return False
|
|
599
|
+
|
|
600
|
+
with open(refs_path, "r") as f:
|
|
601
|
+
commit_hash = f.read().strip()
|
|
602
|
+
|
|
603
|
+
component_folder = os.path.join(
|
|
604
|
+
storage_folder, "snapshots", commit_hash, component_name
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if not os.path.exists(component_folder):
|
|
608
|
+
return False
|
|
609
|
+
|
|
610
|
+
# Check for any .bin, .safetensors, or .ckpt file in the component folder
|
|
611
|
+
for _, _, files in os.walk(component_folder):
|
|
612
|
+
for file in files:
|
|
613
|
+
if file.endswith(
|
|
614
|
+
(".bin", ".safetensors", ".ckpt")
|
|
615
|
+
) and not file.endswith(".incomplete"):
|
|
616
|
+
return True
|
|
617
|
+
|
|
618
|
+
return False
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def _check_file_downloaded(self, path: str) -> bool:
|
|
622
|
+
"""Check if a file exists and is complete in the cache"""
|
|
623
|
+
# First check if the exact path exists
|
|
624
|
+
if os.path.exists(path):
|
|
625
|
+
# Check for temporary or incomplete markers
|
|
626
|
+
if os.path.exists(f"{path}.tmp") or os.path.exists(f"{path}.incomplete"):
|
|
627
|
+
print(f"Found incomplete markers for {path}")
|
|
628
|
+
return False
|
|
629
|
+
print(f"Found complete file at {path}")
|
|
630
|
+
return True
|
|
631
|
+
|
|
632
|
+
# If path doesn't exist, check the directory for any valid model files
|
|
633
|
+
dir_path = os.path.dirname(path)
|
|
634
|
+
if os.path.exists(dir_path):
|
|
635
|
+
for file in os.listdir(dir_path):
|
|
636
|
+
file_path = os.path.join(dir_path, file)
|
|
637
|
+
if file.endswith((".safetensors", ".ckpt", ".pt", ".bin")):
|
|
638
|
+
if not os.path.exists(f"{file_path}.tmp") and not os.path.exists(
|
|
639
|
+
f"{file_path}.incomplete"
|
|
640
|
+
):
|
|
641
|
+
print(f"Found alternative model file at {file_path}")
|
|
642
|
+
return True
|
|
643
|
+
|
|
644
|
+
print(f"No valid model files found in {dir_path}")
|
|
645
|
+
return False
|
|
646
|
+
|
|
647
|
+
def list(self) -> List[str]:
|
|
648
|
+
cache_info = scan_cache_dir()
|
|
649
|
+
return [
|
|
650
|
+
repo.repo_id
|
|
651
|
+
for repo in cache_info.repos
|
|
652
|
+
if self.is_downloaded(repo.repo_id)[0]
|
|
653
|
+
]
|
|
654
|
+
|
|
655
|
+
async def download(
|
|
656
|
+
self,
|
|
657
|
+
repo_id: str,
|
|
658
|
+
file_name: Optional[str] = None,
|
|
659
|
+
sub_folder: Optional[str] = None,
|
|
660
|
+
) -> None:
|
|
661
|
+
if file_name or sub_folder:
|
|
662
|
+
try:
|
|
663
|
+
if sub_folder and not file_name:
|
|
664
|
+
await asyncio.to_thread(
|
|
665
|
+
snapshot_download,
|
|
666
|
+
repo_id,
|
|
667
|
+
allow_patterns=f"{sub_folder}/*",
|
|
668
|
+
)
|
|
669
|
+
logger.info(
|
|
670
|
+
f"{sub_folder} subfolder from {repo_id} downloaded successfully."
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
await asyncio.to_thread(
|
|
674
|
+
hf_hub_download,
|
|
675
|
+
repo_id,
|
|
676
|
+
file_name,
|
|
677
|
+
cache_dir=self.cache_dir,
|
|
678
|
+
subfolder=sub_folder,
|
|
679
|
+
)
|
|
680
|
+
logger.info(
|
|
681
|
+
f"File {file_name} from {repo_id} downloaded successfully."
|
|
682
|
+
)
|
|
683
|
+
# self.list() # Refresh the cached list
|
|
684
|
+
return True
|
|
685
|
+
except Exception as e:
|
|
686
|
+
logger.error(f"Failed to download file {file_name} from {repo_id}: {e}")
|
|
687
|
+
return False
|
|
688
|
+
|
|
689
|
+
variants = ["bf16", "fp8", "fp16", None] # None represents no variant
|
|
690
|
+
for var in variants:
|
|
691
|
+
try:
|
|
692
|
+
if var:
|
|
693
|
+
logger.info(
|
|
694
|
+
f"Attempting to download {repo_id} with {var} variant..."
|
|
695
|
+
)
|
|
696
|
+
else:
|
|
697
|
+
logger.info(f"Attempting to download {repo_id} without variant...")
|
|
698
|
+
|
|
699
|
+
await asyncio.to_thread(
|
|
700
|
+
DiffusionPipeline.download,
|
|
701
|
+
repo_id,
|
|
702
|
+
variant=var,
|
|
703
|
+
cache_dir=self.cache_dir,
|
|
704
|
+
torch_dtype=torch.float16,
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
logger.info(
|
|
708
|
+
f"Model {repo_id} downloaded successfully with variant: {var if var else 'default'}"
|
|
709
|
+
)
|
|
710
|
+
# self.list() # Refresh the cached list
|
|
711
|
+
return True
|
|
712
|
+
|
|
713
|
+
except Exception as e:
|
|
714
|
+
if var:
|
|
715
|
+
logger.error(
|
|
716
|
+
f"Failed to download {var} variant for {repo_id}. Trying next variant..."
|
|
717
|
+
)
|
|
718
|
+
else:
|
|
719
|
+
logger.error(
|
|
720
|
+
f"Failed to download default variant for {repo_id}: {e}"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
logger.error(f"Failed to download model {repo_id} with any variant.")
|
|
724
|
+
return False
|
|
725
|
+
|
|
726
|
+
async def delete(self, repo_id: str) -> None:
|
|
727
|
+
model_path = os.path.join(
|
|
728
|
+
self.cache_dir, "models--" + repo_id.replace("/", "--")
|
|
729
|
+
)
|
|
730
|
+
if os.path.exists(model_path):
|
|
731
|
+
await asyncio.to_thread(shutil.rmtree, model_path)
|
|
732
|
+
logger.info(f"Model {repo_id} deleted successfully.")
|
|
733
|
+
else:
|
|
734
|
+
logger.warning(f"Model {repo_id} not found in cache.")
|
|
735
|
+
|
|
736
|
+
async def get_diffusers_multifolder_components(
|
|
737
|
+
self, repo_id: str
|
|
738
|
+
) -> dict[str, str | tuple[str, str]] | None:
|
|
739
|
+
"""
|
|
740
|
+
This is only meaningful if the repo is in diffusers-multifolder layout.
|
|
741
|
+
This retrieves and parses the model_index.json file, and None otherwise.
|
|
742
|
+
"""
|
|
743
|
+
try:
|
|
744
|
+
model_index_path = await asyncio.to_thread(
|
|
745
|
+
hf_hub_download,
|
|
746
|
+
repo_id=repo_id,
|
|
747
|
+
filename="model_index.json",
|
|
748
|
+
cache_dir=self.cache_dir,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if model_index_path:
|
|
752
|
+
with open(model_index_path, "r") as f:
|
|
753
|
+
data = json.load(f)
|
|
754
|
+
return {
|
|
755
|
+
k: tuple(v) if isinstance(v, list) else v
|
|
756
|
+
for k, v in data.items()
|
|
757
|
+
}
|
|
758
|
+
else:
|
|
759
|
+
return None
|
|
760
|
+
except Exception as e:
|
|
761
|
+
logger.error(f"Error retrieving model_index.json for {repo_id}: {e}")
|
|
762
|
+
return None
|
|
763
|
+
|