gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FlashPack Loading Integration for DefaultModelManager
|
|
3
|
+
|
|
4
|
+
This module provides FlashPack loading capability to the model manager.
|
|
5
|
+
It checks if a FlashPack version of a model exists and loads from it
|
|
6
|
+
for faster loading times (2-4s vs 8-12s for safetensors).
|
|
7
|
+
|
|
8
|
+
Now with local cache support - copies models from NFS to local NVMe first.
|
|
9
|
+
|
|
10
|
+
Integration:
|
|
11
|
+
1. Add this import to manager.py:
|
|
12
|
+
from .utils.flashpack_loader import FlashPackLoader
|
|
13
|
+
|
|
14
|
+
2. Initialize in DefaultModelManager.__init__():
|
|
15
|
+
self.flashpack_loader = FlashPackLoader()
|
|
16
|
+
|
|
17
|
+
3. Modify _load_model_by_source() to try FlashPack first (see integration code below)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
import logging
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Optional, Tuple, Type, Union, Dict, List
|
|
24
|
+
import hashlib
|
|
25
|
+
import asyncio
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
from diffusers import DiffusionPipeline
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
# FlashPack suffix for directories
|
|
33
|
+
FLASHPACK_SUFFIX = ".flashpack"
|
|
34
|
+
|
|
35
|
+
# Components that can be loaded from FlashPack
|
|
36
|
+
FLASHPACK_COMPONENTS = ["unet", "vae", "text_encoder", "text_encoder_2", "transformer"]
|
|
37
|
+
|
|
38
|
+
# NFS paths
|
|
39
|
+
NFS_COZY_MODELS = "/workspace/.cozy-creator/models"
|
|
40
|
+
NFS_HF_CACHE = "/workspace/.cache/huggingface/hub"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class FlashPackLoader:
|
|
44
|
+
"""
|
|
45
|
+
Handles loading models from FlashPack format with local cache support.
|
|
46
|
+
|
|
47
|
+
FlashPack provides 2-4x faster loading compared to safetensors.
|
|
48
|
+
Local cache copies models from NFS to local NVMe for additional speedup.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
cozy_models_dir: str = NFS_COZY_MODELS,
|
|
54
|
+
hf_cache_dir: str = NFS_HF_CACHE,
|
|
55
|
+
use_local_cache: bool = True,
|
|
56
|
+
):
|
|
57
|
+
self.cozy_models_dir = Path(cozy_models_dir)
|
|
58
|
+
self.hf_cache_dir = Path(hf_cache_dir)
|
|
59
|
+
self._flashpack_available = self._check_flashpack_installed()
|
|
60
|
+
|
|
61
|
+
# Initialize local cache if enabled
|
|
62
|
+
self.local_cache = None
|
|
63
|
+
if use_local_cache:
|
|
64
|
+
try:
|
|
65
|
+
from .local_cache import LocalModelCache
|
|
66
|
+
self.local_cache = LocalModelCache()
|
|
67
|
+
logger.info("✓ Local NVMe cache enabled")
|
|
68
|
+
except ImportError:
|
|
69
|
+
logger.warning("LocalModelCache not available, using NFS directly")
|
|
70
|
+
|
|
71
|
+
def _check_flashpack_installed(self) -> bool:
|
|
72
|
+
"""Check if flashpack library is available"""
|
|
73
|
+
try:
|
|
74
|
+
from flashpack import assign_from_file
|
|
75
|
+
return True
|
|
76
|
+
except ImportError:
|
|
77
|
+
logger.warning("FlashPack not installed. Using standard loading.")
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
def get_flashpack_path(self, model_id: str, source: str) -> Optional[Path]:
|
|
81
|
+
"""
|
|
82
|
+
Get the FlashPack directory path for a model if it exists.
|
|
83
|
+
Checks local cache first, then NFS.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_id: Model identifier (e.g., "pony.realism")
|
|
87
|
+
source: Source string from pipeline_defs
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Path to FlashPack directory or None if not found
|
|
91
|
+
"""
|
|
92
|
+
if not self._flashpack_available:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
# Check local cache first
|
|
96
|
+
if self.local_cache:
|
|
97
|
+
local_path = self.local_cache.get_local_path_if_cached(model_id, source)
|
|
98
|
+
if local_path and local_path.exists() and FLASHPACK_SUFFIX in local_path.name:
|
|
99
|
+
logger.info(f"⚡ FlashPack found in local cache for {model_id}")
|
|
100
|
+
return local_path
|
|
101
|
+
|
|
102
|
+
# Check NFS
|
|
103
|
+
if source.startswith("hf:"):
|
|
104
|
+
base_path = self._get_hf_flashpack_path(source[3:])
|
|
105
|
+
else:
|
|
106
|
+
base_path = self._get_civitai_flashpack_path(model_id, source)
|
|
107
|
+
|
|
108
|
+
if base_path and base_path.exists():
|
|
109
|
+
if (base_path / "pipeline").exists():
|
|
110
|
+
logger.info(f"⚡ FlashPack found on NFS for {model_id}: {base_path}")
|
|
111
|
+
return base_path
|
|
112
|
+
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def _get_hf_flashpack_path(self, repo_id: str) -> Optional[Path]:
|
|
116
|
+
"""Get FlashPack path for HuggingFace model"""
|
|
117
|
+
folder_name = f"models--{repo_id.replace('/', '--')}"
|
|
118
|
+
flashpack_path = self.hf_cache_dir / (folder_name + FLASHPACK_SUFFIX)
|
|
119
|
+
return flashpack_path
|
|
120
|
+
|
|
121
|
+
def _get_civitai_flashpack_path(self, model_id: str, source: str) -> Optional[Path]:
|
|
122
|
+
"""Get FlashPack path for Civitai model"""
|
|
123
|
+
safe_name = model_id.replace("/", "-")
|
|
124
|
+
|
|
125
|
+
# Find the original model directory
|
|
126
|
+
matching_dirs = list(self.cozy_models_dir.glob(f"{safe_name}--*"))
|
|
127
|
+
if not matching_dirs:
|
|
128
|
+
# Try finding by URL hash
|
|
129
|
+
url_hash = hashlib.md5(source.encode()).hexdigest()[:8]
|
|
130
|
+
matching_dirs = list(self.cozy_models_dir.glob(f"{safe_name}--{url_hash}"))
|
|
131
|
+
|
|
132
|
+
if not matching_dirs:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
# Get the FlashPack sibling directory
|
|
136
|
+
original_dir = matching_dirs[0]
|
|
137
|
+
flashpack_path = original_dir.parent / (original_dir.name + FLASHPACK_SUFFIX)
|
|
138
|
+
return flashpack_path
|
|
139
|
+
|
|
140
|
+
async def load_from_flashpack(
|
|
141
|
+
self,
|
|
142
|
+
model_id: str,
|
|
143
|
+
flashpack_path: Path,
|
|
144
|
+
pipeline_class: Type[DiffusionPipeline],
|
|
145
|
+
) -> Optional[DiffusionPipeline]:
|
|
146
|
+
"""
|
|
147
|
+
Load a model from FlashPack format.
|
|
148
|
+
Copies to local cache first if enabled.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
model_id: Model identifier
|
|
152
|
+
flashpack_path: Path to FlashPack directory (on NFS)
|
|
153
|
+
pipeline_class: Pipeline class to instantiate
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Loaded pipeline or None if loading failed
|
|
157
|
+
"""
|
|
158
|
+
try:
|
|
159
|
+
from flashpack import assign_from_file
|
|
160
|
+
|
|
161
|
+
# Copy to local cache first if enabled
|
|
162
|
+
load_path = flashpack_path
|
|
163
|
+
if self.local_cache:
|
|
164
|
+
# Get source for cache lookup
|
|
165
|
+
source = self._infer_source_from_path(flashpack_path)
|
|
166
|
+
local_path = await self.local_cache.ensure_local(
|
|
167
|
+
model_id, source, priority=True
|
|
168
|
+
)
|
|
169
|
+
if local_path:
|
|
170
|
+
load_path = local_path
|
|
171
|
+
logger.info(f"⚡ Loading {model_id} from local cache")
|
|
172
|
+
else:
|
|
173
|
+
logger.warning(f"Local cache failed, loading from NFS")
|
|
174
|
+
|
|
175
|
+
logger.info(f"⚡ Loading {model_id} from FlashPack at {load_path}...")
|
|
176
|
+
|
|
177
|
+
# Determine dtype based on model type
|
|
178
|
+
torch_dtype = torch.bfloat16 if "flux" in model_id.lower() else torch.float16
|
|
179
|
+
|
|
180
|
+
# Load pipeline config (scheduler, tokenizer, etc.)
|
|
181
|
+
pipeline_config_dir = load_path / "pipeline"
|
|
182
|
+
|
|
183
|
+
# Load base pipeline from config (this creates the model structure)
|
|
184
|
+
pipeline = await asyncio.to_thread(
|
|
185
|
+
pipeline_class.from_pretrained,
|
|
186
|
+
str(pipeline_config_dir),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Assign FlashPack weights to each component
|
|
190
|
+
for component_name in FLASHPACK_COMPONENTS:
|
|
191
|
+
fp_file = load_path / f"{component_name}.flashpack"
|
|
192
|
+
if fp_file.exists() and hasattr(pipeline, component_name):
|
|
193
|
+
component = getattr(pipeline, component_name)
|
|
194
|
+
if component is not None:
|
|
195
|
+
logger.info(f" Assigning {component_name} from FlashPack...")
|
|
196
|
+
await asyncio.to_thread(
|
|
197
|
+
assign_from_file,
|
|
198
|
+
component,
|
|
199
|
+
str(fp_file)
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Move to cuda with correct dtype
|
|
203
|
+
pipeline.to("cuda", dtype=torch_dtype)
|
|
204
|
+
|
|
205
|
+
logger.info(f"✅ Successfully loaded {model_id} from FlashPack")
|
|
206
|
+
return pipeline
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"❌ FlashPack loading failed for {model_id}: {e}")
|
|
210
|
+
logger.exception("Full traceback:")
|
|
211
|
+
return None
|
|
212
|
+
|
|
213
|
+
def _infer_source_from_path(self, flashpack_path: Path) -> str:
|
|
214
|
+
"""Infer source string from FlashPack path for cache lookup"""
|
|
215
|
+
path_str = str(flashpack_path)
|
|
216
|
+
|
|
217
|
+
if "models--" in path_str:
|
|
218
|
+
# HuggingFace model
|
|
219
|
+
# Extract repo_id from models--org--name.flashpack
|
|
220
|
+
name = flashpack_path.name.replace(FLASHPACK_SUFFIX, "")
|
|
221
|
+
repo_id = name.replace("models--", "").replace("--", "/")
|
|
222
|
+
return f"hf:{repo_id}"
|
|
223
|
+
else:
|
|
224
|
+
# Civitai model - return path as source
|
|
225
|
+
return path_str
|
|
226
|
+
|
|
227
|
+
def has_flashpack(self, model_id: str, source: str) -> bool:
|
|
228
|
+
"""Check if FlashPack version exists for a model"""
|
|
229
|
+
return self.get_flashpack_path(model_id, source) is not None
|
|
230
|
+
|
|
231
|
+
async def prefetch_deployment_models(
|
|
232
|
+
self,
|
|
233
|
+
model_ids: List[str],
|
|
234
|
+
sources: Dict[str, str],
|
|
235
|
+
exclude_model_id: Optional[str] = None
|
|
236
|
+
):
|
|
237
|
+
"""
|
|
238
|
+
Background prefetch models for a deployment to local cache.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
model_ids: List of model IDs from deployment
|
|
242
|
+
sources: Dict mapping model_id → source string
|
|
243
|
+
exclude_model_id: Model to skip (already being loaded)
|
|
244
|
+
"""
|
|
245
|
+
if not self.local_cache:
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
# Filter out the model already being loaded
|
|
249
|
+
models_to_prefetch = [
|
|
250
|
+
mid for mid in model_ids
|
|
251
|
+
if mid != exclude_model_id
|
|
252
|
+
]
|
|
253
|
+
|
|
254
|
+
if models_to_prefetch:
|
|
255
|
+
logger.info(f"🔄 Starting background prefetch for {len(models_to_prefetch)} models")
|
|
256
|
+
await self.local_cache.prefetch_models(models_to_prefetch, sources)
|
|
257
|
+
|
|
258
|
+
def get_cache_stats(self) -> Optional[Dict]:
|
|
259
|
+
"""Get local cache statistics"""
|
|
260
|
+
if self.local_cache:
|
|
261
|
+
return self.local_cache.get_cache_stats()
|
|
262
|
+
return None
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Type, Any
|
|
2
|
+
|
|
3
|
+
from .base_types.architecture import Architecture
|
|
4
|
+
from .base_types.common import TorchDevice
|
|
5
|
+
from .device import get_torch_device
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
_available_torch_device: TorchDevice = get_torch_device()
|
|
9
|
+
|
|
10
|
+
# Model Memory Manager
|
|
11
|
+
_MODEL_MEMORY_MANAGER = None
|
|
12
|
+
|
|
13
|
+
# Model Downloader
|
|
14
|
+
_MODEL_DOWNLOADER = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_ARCHITECTURES: dict[str, type[Architecture[Any]]] = {}
|
|
18
|
+
"""
|
|
19
|
+
Global class containing all architecture definitions
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_model_downloader():
|
|
24
|
+
"""Get or create the global ModelManager instance"""
|
|
25
|
+
global _MODEL_DOWNLOADER
|
|
26
|
+
if _MODEL_DOWNLOADER is None:
|
|
27
|
+
from .model_downloader import ModelManager
|
|
28
|
+
|
|
29
|
+
_MODEL_DOWNLOADER = ModelManager()
|
|
30
|
+
return _MODEL_DOWNLOADER
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_model_memory_manager():
|
|
34
|
+
global _MODEL_MEMORY_MANAGER
|
|
35
|
+
if _MODEL_MEMORY_MANAGER is None:
|
|
36
|
+
from ..manager import ModelMemoryManager
|
|
37
|
+
|
|
38
|
+
_MODEL_MEMORY_MANAGER = ModelMemoryManager()
|
|
39
|
+
return _MODEL_MEMORY_MANAGER
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def update_architectures(architectures: dict[str, Type["Architecture"]]):
|
|
43
|
+
global _ARCHITECTURES
|
|
44
|
+
_ARCHITECTURES.update(architectures)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_architectures() -> dict[str, Type["Architecture"]]:
|
|
48
|
+
return _ARCHITECTURES
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_available_torch_device():
|
|
52
|
+
global _available_torch_device
|
|
53
|
+
return _available_torch_device
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def set_available_torch_device(device: TorchDevice):
|
|
57
|
+
print("Setting device", device)
|
|
58
|
+
global _available_torch_device
|
|
59
|
+
_available_torch_device = device
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import torch
|
|
5
|
+
import struct
|
|
6
|
+
import json
|
|
7
|
+
from typing import Type, Optional, Any
|
|
8
|
+
from safetensors.torch import load_file as safetensors_load_file
|
|
9
|
+
from spandrel import canonicalize_state_dict
|
|
10
|
+
from spandrel.__helpers.unpickler import (
|
|
11
|
+
RestrictedUnpickle,
|
|
12
|
+
) # probably shouldn't import from private modules...
|
|
13
|
+
|
|
14
|
+
from .base_types.architecture import (
|
|
15
|
+
Architecture,
|
|
16
|
+
StateDict,
|
|
17
|
+
TorchDevice,
|
|
18
|
+
ComponentMetadata,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
METADATA_HEADER_SIZE = 8
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# TO DO: make this more efficient; we don't want to have to evaluate EVERY architecture
|
|
25
|
+
# for EVERY file. ALSO we need stop multiple architectures from claiming the same
|
|
26
|
+
# keys; i.e., if there are 5 architecture definitions for stable-diffusion-1 installed,
|
|
27
|
+
# then only the first one should get to claim those keys, otherwise it gets confusing
|
|
28
|
+
# on which model it should use
|
|
29
|
+
def from_file(
|
|
30
|
+
path: str | Path,
|
|
31
|
+
device: Optional[TorchDevice] = None,
|
|
32
|
+
registry: dict[str, Type[Architecture]] = None,
|
|
33
|
+
) -> dict[str, Architecture]:
|
|
34
|
+
"""
|
|
35
|
+
Loads a model from a file path. It detects the architecture, instantiates the
|
|
36
|
+
architecture, and loads the state dict into the PyTorch class.
|
|
37
|
+
|
|
38
|
+
Throws a `ValueError` if the file extension is not supported.
|
|
39
|
+
Returns an empty dictionary if no supported model architecture is found.
|
|
40
|
+
"""
|
|
41
|
+
state_dict = load_state_dict_from_file(path, device=device)
|
|
42
|
+
|
|
43
|
+
metadata = read_safetensors_metadata(path)
|
|
44
|
+
|
|
45
|
+
return from_state_dict(state_dict, metadata, device, registry)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def from_state_dict(
|
|
49
|
+
state_dict: StateDict,
|
|
50
|
+
metadata: dict[str, Any] = {},
|
|
51
|
+
device: Optional[TorchDevice] = None,
|
|
52
|
+
registry: dict[str, Type[Architecture]] = None,
|
|
53
|
+
) -> dict[str, Architecture]:
|
|
54
|
+
"""
|
|
55
|
+
Load a model from the given state dict.
|
|
56
|
+
|
|
57
|
+
Returns an empty dictionary if no supported model architecture is found.
|
|
58
|
+
"""
|
|
59
|
+
# Fetch class instances
|
|
60
|
+
components = components_from_state_dict(state_dict, metadata, registry)
|
|
61
|
+
|
|
62
|
+
# Load the state dict into the class instance, and move to device
|
|
63
|
+
for _arch_id, architecture in components.items():
|
|
64
|
+
try:
|
|
65
|
+
architecture.load(state_dict, device)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
print(e)
|
|
68
|
+
|
|
69
|
+
return components
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def components_from_state_dict(
|
|
73
|
+
state_dict: StateDict,
|
|
74
|
+
metadata: dict,
|
|
75
|
+
registry: Optional[dict[str, Type[Architecture]]] = None,
|
|
76
|
+
) -> dict[str, Architecture]:
|
|
77
|
+
"""
|
|
78
|
+
Detect all models present inside of a state dict; does not load the state-dict into
|
|
79
|
+
memory however; it only calls the Architecture's constructor to return a class instance.
|
|
80
|
+
"""
|
|
81
|
+
components: dict[str, Architecture] = {}
|
|
82
|
+
|
|
83
|
+
if registry is None:
|
|
84
|
+
from .globals import _ARCHITECTURES
|
|
85
|
+
|
|
86
|
+
registry = _ARCHITECTURES
|
|
87
|
+
|
|
88
|
+
for arch_id, architecture in registry.items(): # Iterate through all architectures
|
|
89
|
+
try:
|
|
90
|
+
# print("Now in load model")
|
|
91
|
+
# print(metadata)
|
|
92
|
+
# print(architecture)
|
|
93
|
+
# print("Done above")
|
|
94
|
+
|
|
95
|
+
checkpoint_metadata = architecture.detect(
|
|
96
|
+
state_dict=state_dict, metadata=metadata
|
|
97
|
+
)
|
|
98
|
+
# print(checkpoint_metadata)
|
|
99
|
+
# print("Done in load model")
|
|
100
|
+
# detect_signature = inspect.signature(architecture.detect)
|
|
101
|
+
# if 'state_dict' in detect_signature.parameters and 'metadata' in detect_signature.parameters:
|
|
102
|
+
# checkpoint_metadata = architecture.detect(state_dict=state_dict, metadata=metadata)
|
|
103
|
+
# elif 'state_dict' in detect_signature.parameters:
|
|
104
|
+
# checkpoint_metadata = architecture.detect(state_dict=state_dict)
|
|
105
|
+
# elif 'metadata' in detect_signature.parameters:
|
|
106
|
+
# checkpoint_metadata = architecture.detect(metadata=metadata)
|
|
107
|
+
# else:
|
|
108
|
+
# continue
|
|
109
|
+
except Exception:
|
|
110
|
+
checkpoint_metadata = None
|
|
111
|
+
|
|
112
|
+
if checkpoint_metadata is not None:
|
|
113
|
+
model = architecture(metadata=metadata)
|
|
114
|
+
components.update({arch_id: model})
|
|
115
|
+
|
|
116
|
+
return components
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_state_dict_from_file(
|
|
120
|
+
path: str | Path, device: Optional[TorchDevice] = None
|
|
121
|
+
) -> StateDict:
|
|
122
|
+
"""
|
|
123
|
+
Load the state dict of a model from the given file path.
|
|
124
|
+
|
|
125
|
+
State dicts are typically only useful to pass them into the `load`
|
|
126
|
+
function of a specific architecture.
|
|
127
|
+
|
|
128
|
+
Throws a `ValueError` if the file extension is not supported.
|
|
129
|
+
"""
|
|
130
|
+
extension = os.path.splitext(path)[1].lower()
|
|
131
|
+
if isinstance(device, str):
|
|
132
|
+
device = torch.device(device) # make pyright type-checker happy
|
|
133
|
+
|
|
134
|
+
state_dict: StateDict
|
|
135
|
+
if extension == ".pt":
|
|
136
|
+
try:
|
|
137
|
+
state_dict = _load_torchscript(path, device)
|
|
138
|
+
except RuntimeError:
|
|
139
|
+
# If torchscript loading fails, try loading as a normal state dict
|
|
140
|
+
try:
|
|
141
|
+
pth_state_dict = _load_pth(path, device)
|
|
142
|
+
except Exception:
|
|
143
|
+
pth_state_dict = None
|
|
144
|
+
|
|
145
|
+
if pth_state_dict is None:
|
|
146
|
+
# the file was likely a torchscript file, but failed to load
|
|
147
|
+
# re-raise the original error, so the user knows what went wrong
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
state_dict = pth_state_dict
|
|
151
|
+
|
|
152
|
+
elif extension == ".pth" or extension == ".ckpt":
|
|
153
|
+
state_dict = _load_pth(path, device)
|
|
154
|
+
elif extension == ".safetensors":
|
|
155
|
+
state_dict = _load_safetensors(path, device)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Unsupported model file extension {extension}. Please try a supported model type."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return canonicalize_state_dict(state_dict)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _load_pth(path: str | Path, device: Optional[torch.device] = None) -> StateDict:
|
|
165
|
+
return torch.load(
|
|
166
|
+
f=path,
|
|
167
|
+
map_location=device,
|
|
168
|
+
pickle_module=RestrictedUnpickle,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _load_torchscript(
|
|
173
|
+
path: str | Path, device: Optional[torch.device] = None
|
|
174
|
+
) -> StateDict:
|
|
175
|
+
return torch.jit.load(path, map_location=device).state_dict()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _load_safetensors(
|
|
179
|
+
path: str | Path, device: Optional[TorchDevice] = None
|
|
180
|
+
) -> StateDict:
|
|
181
|
+
if device is not None:
|
|
182
|
+
if isinstance(device, torch.device):
|
|
183
|
+
device = str(device)
|
|
184
|
+
return safetensors_load_file(path, device=device)
|
|
185
|
+
else:
|
|
186
|
+
return safetensors_load_file(path)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def read_safetensors_metadata(file_path: str | Path) -> dict[str, Any]:
|
|
190
|
+
if not str(file_path).endswith(".safetensors"):
|
|
191
|
+
print(f"Error: File '{file_path}' is not a '.safetensors' file.")
|
|
192
|
+
return {}
|
|
193
|
+
if not os.path.isfile(file_path):
|
|
194
|
+
print(f"Error: File '{file_path}' not found.")
|
|
195
|
+
return {}
|
|
196
|
+
|
|
197
|
+
with open(file_path, "rb") as file:
|
|
198
|
+
header_size_bytes = file.read(METADATA_HEADER_SIZE)
|
|
199
|
+
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
|
200
|
+
if header_size is None or header_size == 0:
|
|
201
|
+
return {}
|
|
202
|
+
header_bytes = file.read(header_size)
|
|
203
|
+
header = json.loads(header_bytes)
|
|
204
|
+
|
|
205
|
+
return header.get("__metadata__", {})
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def find_component_models(
|
|
209
|
+
state_dict: StateDict,
|
|
210
|
+
metadata: Optional[dict] = None,
|
|
211
|
+
registry: dict[str, Type[Architecture]] = None,
|
|
212
|
+
) -> dict[str, ComponentMetadata]:
|
|
213
|
+
"""
|
|
214
|
+
Detect all models present inside of a state dict, and return a dict. The keys of
|
|
215
|
+
the dict are the architecture's unique identifier that can be instantiated using
|
|
216
|
+
this state-dict, and the value is the metadata of the corresponding architecture
|
|
217
|
+
if it were instantiated using this same state-dict + metadata.
|
|
218
|
+
"""
|
|
219
|
+
components: dict[str, ComponentMetadata] = {}
|
|
220
|
+
|
|
221
|
+
if registry is None:
|
|
222
|
+
from .globals import _ARCHITECTURES
|
|
223
|
+
|
|
224
|
+
registry = _ARCHITECTURES
|
|
225
|
+
|
|
226
|
+
for arch_id, architecture in registry.items(): # Iterate through all architectures
|
|
227
|
+
try:
|
|
228
|
+
checkpoint_metadata = architecture.detect(
|
|
229
|
+
state_dict=state_dict, metadata=metadata
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if checkpoint_metadata is not None:
|
|
233
|
+
# this will overwrite previous architectures with the same id
|
|
234
|
+
components.update({arch_id: checkpoint_metadata})
|
|
235
|
+
except Exception as e:
|
|
236
|
+
print(f"Encountered error running architecture.detect for {arch_id}: {e}")
|
|
237
|
+
|
|
238
|
+
return components
|