gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,262 @@
1
+ """
2
+ FlashPack Loading Integration for DefaultModelManager
3
+
4
+ This module provides FlashPack loading capability to the model manager.
5
+ It checks if a FlashPack version of a model exists and loads from it
6
+ for faster loading times (2-4s vs 8-12s for safetensors).
7
+
8
+ Now with local cache support - copies models from NFS to local NVMe first.
9
+
10
+ Integration:
11
+ 1. Add this import to manager.py:
12
+ from .utils.flashpack_loader import FlashPackLoader
13
+
14
+ 2. Initialize in DefaultModelManager.__init__():
15
+ self.flashpack_loader = FlashPackLoader()
16
+
17
+ 3. Modify _load_model_by_source() to try FlashPack first (see integration code below)
18
+ """
19
+
20
+ import os
21
+ import logging
22
+ from pathlib import Path
23
+ from typing import Optional, Tuple, Type, Union, Dict, List
24
+ import hashlib
25
+ import asyncio
26
+
27
+ import torch
28
+ from diffusers import DiffusionPipeline
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # FlashPack suffix for directories
33
+ FLASHPACK_SUFFIX = ".flashpack"
34
+
35
+ # Components that can be loaded from FlashPack
36
+ FLASHPACK_COMPONENTS = ["unet", "vae", "text_encoder", "text_encoder_2", "transformer"]
37
+
38
+ # NFS paths
39
+ NFS_COZY_MODELS = "/workspace/.cozy-creator/models"
40
+ NFS_HF_CACHE = "/workspace/.cache/huggingface/hub"
41
+
42
+
43
+ class FlashPackLoader:
44
+ """
45
+ Handles loading models from FlashPack format with local cache support.
46
+
47
+ FlashPack provides 2-4x faster loading compared to safetensors.
48
+ Local cache copies models from NFS to local NVMe for additional speedup.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ cozy_models_dir: str = NFS_COZY_MODELS,
54
+ hf_cache_dir: str = NFS_HF_CACHE,
55
+ use_local_cache: bool = True,
56
+ ):
57
+ self.cozy_models_dir = Path(cozy_models_dir)
58
+ self.hf_cache_dir = Path(hf_cache_dir)
59
+ self._flashpack_available = self._check_flashpack_installed()
60
+
61
+ # Initialize local cache if enabled
62
+ self.local_cache = None
63
+ if use_local_cache:
64
+ try:
65
+ from .local_cache import LocalModelCache
66
+ self.local_cache = LocalModelCache()
67
+ logger.info("✓ Local NVMe cache enabled")
68
+ except ImportError:
69
+ logger.warning("LocalModelCache not available, using NFS directly")
70
+
71
+ def _check_flashpack_installed(self) -> bool:
72
+ """Check if flashpack library is available"""
73
+ try:
74
+ from flashpack import assign_from_file
75
+ return True
76
+ except ImportError:
77
+ logger.warning("FlashPack not installed. Using standard loading.")
78
+ return False
79
+
80
+ def get_flashpack_path(self, model_id: str, source: str) -> Optional[Path]:
81
+ """
82
+ Get the FlashPack directory path for a model if it exists.
83
+ Checks local cache first, then NFS.
84
+
85
+ Args:
86
+ model_id: Model identifier (e.g., "pony.realism")
87
+ source: Source string from pipeline_defs
88
+
89
+ Returns:
90
+ Path to FlashPack directory or None if not found
91
+ """
92
+ if not self._flashpack_available:
93
+ return None
94
+
95
+ # Check local cache first
96
+ if self.local_cache:
97
+ local_path = self.local_cache.get_local_path_if_cached(model_id, source)
98
+ if local_path and local_path.exists() and FLASHPACK_SUFFIX in local_path.name:
99
+ logger.info(f"⚡ FlashPack found in local cache for {model_id}")
100
+ return local_path
101
+
102
+ # Check NFS
103
+ if source.startswith("hf:"):
104
+ base_path = self._get_hf_flashpack_path(source[3:])
105
+ else:
106
+ base_path = self._get_civitai_flashpack_path(model_id, source)
107
+
108
+ if base_path and base_path.exists():
109
+ if (base_path / "pipeline").exists():
110
+ logger.info(f"⚡ FlashPack found on NFS for {model_id}: {base_path}")
111
+ return base_path
112
+
113
+ return None
114
+
115
+ def _get_hf_flashpack_path(self, repo_id: str) -> Optional[Path]:
116
+ """Get FlashPack path for HuggingFace model"""
117
+ folder_name = f"models--{repo_id.replace('/', '--')}"
118
+ flashpack_path = self.hf_cache_dir / (folder_name + FLASHPACK_SUFFIX)
119
+ return flashpack_path
120
+
121
+ def _get_civitai_flashpack_path(self, model_id: str, source: str) -> Optional[Path]:
122
+ """Get FlashPack path for Civitai model"""
123
+ safe_name = model_id.replace("/", "-")
124
+
125
+ # Find the original model directory
126
+ matching_dirs = list(self.cozy_models_dir.glob(f"{safe_name}--*"))
127
+ if not matching_dirs:
128
+ # Try finding by URL hash
129
+ url_hash = hashlib.md5(source.encode()).hexdigest()[:8]
130
+ matching_dirs = list(self.cozy_models_dir.glob(f"{safe_name}--{url_hash}"))
131
+
132
+ if not matching_dirs:
133
+ return None
134
+
135
+ # Get the FlashPack sibling directory
136
+ original_dir = matching_dirs[0]
137
+ flashpack_path = original_dir.parent / (original_dir.name + FLASHPACK_SUFFIX)
138
+ return flashpack_path
139
+
140
+ async def load_from_flashpack(
141
+ self,
142
+ model_id: str,
143
+ flashpack_path: Path,
144
+ pipeline_class: Type[DiffusionPipeline],
145
+ ) -> Optional[DiffusionPipeline]:
146
+ """
147
+ Load a model from FlashPack format.
148
+ Copies to local cache first if enabled.
149
+
150
+ Args:
151
+ model_id: Model identifier
152
+ flashpack_path: Path to FlashPack directory (on NFS)
153
+ pipeline_class: Pipeline class to instantiate
154
+
155
+ Returns:
156
+ Loaded pipeline or None if loading failed
157
+ """
158
+ try:
159
+ from flashpack import assign_from_file
160
+
161
+ # Copy to local cache first if enabled
162
+ load_path = flashpack_path
163
+ if self.local_cache:
164
+ # Get source for cache lookup
165
+ source = self._infer_source_from_path(flashpack_path)
166
+ local_path = await self.local_cache.ensure_local(
167
+ model_id, source, priority=True
168
+ )
169
+ if local_path:
170
+ load_path = local_path
171
+ logger.info(f"⚡ Loading {model_id} from local cache")
172
+ else:
173
+ logger.warning(f"Local cache failed, loading from NFS")
174
+
175
+ logger.info(f"⚡ Loading {model_id} from FlashPack at {load_path}...")
176
+
177
+ # Determine dtype based on model type
178
+ torch_dtype = torch.bfloat16 if "flux" in model_id.lower() else torch.float16
179
+
180
+ # Load pipeline config (scheduler, tokenizer, etc.)
181
+ pipeline_config_dir = load_path / "pipeline"
182
+
183
+ # Load base pipeline from config (this creates the model structure)
184
+ pipeline = await asyncio.to_thread(
185
+ pipeline_class.from_pretrained,
186
+ str(pipeline_config_dir),
187
+ )
188
+
189
+ # Assign FlashPack weights to each component
190
+ for component_name in FLASHPACK_COMPONENTS:
191
+ fp_file = load_path / f"{component_name}.flashpack"
192
+ if fp_file.exists() and hasattr(pipeline, component_name):
193
+ component = getattr(pipeline, component_name)
194
+ if component is not None:
195
+ logger.info(f" Assigning {component_name} from FlashPack...")
196
+ await asyncio.to_thread(
197
+ assign_from_file,
198
+ component,
199
+ str(fp_file)
200
+ )
201
+
202
+ # Move to cuda with correct dtype
203
+ pipeline.to("cuda", dtype=torch_dtype)
204
+
205
+ logger.info(f"✅ Successfully loaded {model_id} from FlashPack")
206
+ return pipeline
207
+
208
+ except Exception as e:
209
+ logger.error(f"❌ FlashPack loading failed for {model_id}: {e}")
210
+ logger.exception("Full traceback:")
211
+ return None
212
+
213
+ def _infer_source_from_path(self, flashpack_path: Path) -> str:
214
+ """Infer source string from FlashPack path for cache lookup"""
215
+ path_str = str(flashpack_path)
216
+
217
+ if "models--" in path_str:
218
+ # HuggingFace model
219
+ # Extract repo_id from models--org--name.flashpack
220
+ name = flashpack_path.name.replace(FLASHPACK_SUFFIX, "")
221
+ repo_id = name.replace("models--", "").replace("--", "/")
222
+ return f"hf:{repo_id}"
223
+ else:
224
+ # Civitai model - return path as source
225
+ return path_str
226
+
227
+ def has_flashpack(self, model_id: str, source: str) -> bool:
228
+ """Check if FlashPack version exists for a model"""
229
+ return self.get_flashpack_path(model_id, source) is not None
230
+
231
+ async def prefetch_deployment_models(
232
+ self,
233
+ model_ids: List[str],
234
+ sources: Dict[str, str],
235
+ exclude_model_id: Optional[str] = None
236
+ ):
237
+ """
238
+ Background prefetch models for a deployment to local cache.
239
+
240
+ Args:
241
+ model_ids: List of model IDs from deployment
242
+ sources: Dict mapping model_id → source string
243
+ exclude_model_id: Model to skip (already being loaded)
244
+ """
245
+ if not self.local_cache:
246
+ return
247
+
248
+ # Filter out the model already being loaded
249
+ models_to_prefetch = [
250
+ mid for mid in model_ids
251
+ if mid != exclude_model_id
252
+ ]
253
+
254
+ if models_to_prefetch:
255
+ logger.info(f"🔄 Starting background prefetch for {len(models_to_prefetch)} models")
256
+ await self.local_cache.prefetch_models(models_to_prefetch, sources)
257
+
258
+ def get_cache_stats(self) -> Optional[Dict]:
259
+ """Get local cache statistics"""
260
+ if self.local_cache:
261
+ return self.local_cache.get_cache_stats()
262
+ return None
@@ -0,0 +1,59 @@
1
+ from typing import Type, Any
2
+
3
+ from .base_types.architecture import Architecture
4
+ from .base_types.common import TorchDevice
5
+ from .device import get_torch_device
6
+
7
+
8
+ _available_torch_device: TorchDevice = get_torch_device()
9
+
10
+ # Model Memory Manager
11
+ _MODEL_MEMORY_MANAGER = None
12
+
13
+ # Model Downloader
14
+ _MODEL_DOWNLOADER = None
15
+
16
+
17
+ _ARCHITECTURES: dict[str, type[Architecture[Any]]] = {}
18
+ """
19
+ Global class containing all architecture definitions
20
+ """
21
+
22
+
23
+ def get_model_downloader():
24
+ """Get or create the global ModelManager instance"""
25
+ global _MODEL_DOWNLOADER
26
+ if _MODEL_DOWNLOADER is None:
27
+ from .model_downloader import ModelManager
28
+
29
+ _MODEL_DOWNLOADER = ModelManager()
30
+ return _MODEL_DOWNLOADER
31
+
32
+
33
+ def get_model_memory_manager():
34
+ global _MODEL_MEMORY_MANAGER
35
+ if _MODEL_MEMORY_MANAGER is None:
36
+ from ..manager import ModelMemoryManager
37
+
38
+ _MODEL_MEMORY_MANAGER = ModelMemoryManager()
39
+ return _MODEL_MEMORY_MANAGER
40
+
41
+
42
+ def update_architectures(architectures: dict[str, Type["Architecture"]]):
43
+ global _ARCHITECTURES
44
+ _ARCHITECTURES.update(architectures)
45
+
46
+
47
+ def get_architectures() -> dict[str, Type["Architecture"]]:
48
+ return _ARCHITECTURES
49
+
50
+
51
+ def get_available_torch_device():
52
+ global _available_torch_device
53
+ return _available_torch_device
54
+
55
+
56
+ def set_available_torch_device(device: TorchDevice):
57
+ print("Setting device", device)
58
+ global _available_torch_device
59
+ _available_torch_device = device
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+ import os
3
+ from pathlib import Path
4
+ import torch
5
+ import struct
6
+ import json
7
+ from typing import Type, Optional, Any
8
+ from safetensors.torch import load_file as safetensors_load_file
9
+ from spandrel import canonicalize_state_dict
10
+ from spandrel.__helpers.unpickler import (
11
+ RestrictedUnpickle,
12
+ ) # probably shouldn't import from private modules...
13
+
14
+ from .base_types.architecture import (
15
+ Architecture,
16
+ StateDict,
17
+ TorchDevice,
18
+ ComponentMetadata,
19
+ )
20
+
21
+ METADATA_HEADER_SIZE = 8
22
+
23
+
24
+ # TO DO: make this more efficient; we don't want to have to evaluate EVERY architecture
25
+ # for EVERY file. ALSO we need stop multiple architectures from claiming the same
26
+ # keys; i.e., if there are 5 architecture definitions for stable-diffusion-1 installed,
27
+ # then only the first one should get to claim those keys, otherwise it gets confusing
28
+ # on which model it should use
29
+ def from_file(
30
+ path: str | Path,
31
+ device: Optional[TorchDevice] = None,
32
+ registry: dict[str, Type[Architecture]] = None,
33
+ ) -> dict[str, Architecture]:
34
+ """
35
+ Loads a model from a file path. It detects the architecture, instantiates the
36
+ architecture, and loads the state dict into the PyTorch class.
37
+
38
+ Throws a `ValueError` if the file extension is not supported.
39
+ Returns an empty dictionary if no supported model architecture is found.
40
+ """
41
+ state_dict = load_state_dict_from_file(path, device=device)
42
+
43
+ metadata = read_safetensors_metadata(path)
44
+
45
+ return from_state_dict(state_dict, metadata, device, registry)
46
+
47
+
48
+ def from_state_dict(
49
+ state_dict: StateDict,
50
+ metadata: dict[str, Any] = {},
51
+ device: Optional[TorchDevice] = None,
52
+ registry: dict[str, Type[Architecture]] = None,
53
+ ) -> dict[str, Architecture]:
54
+ """
55
+ Load a model from the given state dict.
56
+
57
+ Returns an empty dictionary if no supported model architecture is found.
58
+ """
59
+ # Fetch class instances
60
+ components = components_from_state_dict(state_dict, metadata, registry)
61
+
62
+ # Load the state dict into the class instance, and move to device
63
+ for _arch_id, architecture in components.items():
64
+ try:
65
+ architecture.load(state_dict, device)
66
+ except Exception as e:
67
+ print(e)
68
+
69
+ return components
70
+
71
+
72
+ def components_from_state_dict(
73
+ state_dict: StateDict,
74
+ metadata: dict,
75
+ registry: Optional[dict[str, Type[Architecture]]] = None,
76
+ ) -> dict[str, Architecture]:
77
+ """
78
+ Detect all models present inside of a state dict; does not load the state-dict into
79
+ memory however; it only calls the Architecture's constructor to return a class instance.
80
+ """
81
+ components: dict[str, Architecture] = {}
82
+
83
+ if registry is None:
84
+ from .globals import _ARCHITECTURES
85
+
86
+ registry = _ARCHITECTURES
87
+
88
+ for arch_id, architecture in registry.items(): # Iterate through all architectures
89
+ try:
90
+ # print("Now in load model")
91
+ # print(metadata)
92
+ # print(architecture)
93
+ # print("Done above")
94
+
95
+ checkpoint_metadata = architecture.detect(
96
+ state_dict=state_dict, metadata=metadata
97
+ )
98
+ # print(checkpoint_metadata)
99
+ # print("Done in load model")
100
+ # detect_signature = inspect.signature(architecture.detect)
101
+ # if 'state_dict' in detect_signature.parameters and 'metadata' in detect_signature.parameters:
102
+ # checkpoint_metadata = architecture.detect(state_dict=state_dict, metadata=metadata)
103
+ # elif 'state_dict' in detect_signature.parameters:
104
+ # checkpoint_metadata = architecture.detect(state_dict=state_dict)
105
+ # elif 'metadata' in detect_signature.parameters:
106
+ # checkpoint_metadata = architecture.detect(metadata=metadata)
107
+ # else:
108
+ # continue
109
+ except Exception:
110
+ checkpoint_metadata = None
111
+
112
+ if checkpoint_metadata is not None:
113
+ model = architecture(metadata=metadata)
114
+ components.update({arch_id: model})
115
+
116
+ return components
117
+
118
+
119
+ def load_state_dict_from_file(
120
+ path: str | Path, device: Optional[TorchDevice] = None
121
+ ) -> StateDict:
122
+ """
123
+ Load the state dict of a model from the given file path.
124
+
125
+ State dicts are typically only useful to pass them into the `load`
126
+ function of a specific architecture.
127
+
128
+ Throws a `ValueError` if the file extension is not supported.
129
+ """
130
+ extension = os.path.splitext(path)[1].lower()
131
+ if isinstance(device, str):
132
+ device = torch.device(device) # make pyright type-checker happy
133
+
134
+ state_dict: StateDict
135
+ if extension == ".pt":
136
+ try:
137
+ state_dict = _load_torchscript(path, device)
138
+ except RuntimeError:
139
+ # If torchscript loading fails, try loading as a normal state dict
140
+ try:
141
+ pth_state_dict = _load_pth(path, device)
142
+ except Exception:
143
+ pth_state_dict = None
144
+
145
+ if pth_state_dict is None:
146
+ # the file was likely a torchscript file, but failed to load
147
+ # re-raise the original error, so the user knows what went wrong
148
+ raise
149
+
150
+ state_dict = pth_state_dict
151
+
152
+ elif extension == ".pth" or extension == ".ckpt":
153
+ state_dict = _load_pth(path, device)
154
+ elif extension == ".safetensors":
155
+ state_dict = _load_safetensors(path, device)
156
+ else:
157
+ raise ValueError(
158
+ f"Unsupported model file extension {extension}. Please try a supported model type."
159
+ )
160
+
161
+ return canonicalize_state_dict(state_dict)
162
+
163
+
164
+ def _load_pth(path: str | Path, device: Optional[torch.device] = None) -> StateDict:
165
+ return torch.load(
166
+ f=path,
167
+ map_location=device,
168
+ pickle_module=RestrictedUnpickle,
169
+ )
170
+
171
+
172
+ def _load_torchscript(
173
+ path: str | Path, device: Optional[torch.device] = None
174
+ ) -> StateDict:
175
+ return torch.jit.load(path, map_location=device).state_dict()
176
+
177
+
178
+ def _load_safetensors(
179
+ path: str | Path, device: Optional[TorchDevice] = None
180
+ ) -> StateDict:
181
+ if device is not None:
182
+ if isinstance(device, torch.device):
183
+ device = str(device)
184
+ return safetensors_load_file(path, device=device)
185
+ else:
186
+ return safetensors_load_file(path)
187
+
188
+
189
+ def read_safetensors_metadata(file_path: str | Path) -> dict[str, Any]:
190
+ if not str(file_path).endswith(".safetensors"):
191
+ print(f"Error: File '{file_path}' is not a '.safetensors' file.")
192
+ return {}
193
+ if not os.path.isfile(file_path):
194
+ print(f"Error: File '{file_path}' not found.")
195
+ return {}
196
+
197
+ with open(file_path, "rb") as file:
198
+ header_size_bytes = file.read(METADATA_HEADER_SIZE)
199
+ header_size = struct.unpack("<Q", header_size_bytes)[0]
200
+ if header_size is None or header_size == 0:
201
+ return {}
202
+ header_bytes = file.read(header_size)
203
+ header = json.loads(header_bytes)
204
+
205
+ return header.get("__metadata__", {})
206
+
207
+
208
+ def find_component_models(
209
+ state_dict: StateDict,
210
+ metadata: Optional[dict] = None,
211
+ registry: dict[str, Type[Architecture]] = None,
212
+ ) -> dict[str, ComponentMetadata]:
213
+ """
214
+ Detect all models present inside of a state dict, and return a dict. The keys of
215
+ the dict are the architecture's unique identifier that can be instantiated using
216
+ this state-dict, and the value is the metadata of the corresponding architecture
217
+ if it were instantiated using this same state-dict + metadata.
218
+ """
219
+ components: dict[str, ComponentMetadata] = {}
220
+
221
+ if registry is None:
222
+ from .globals import _ARCHITECTURES
223
+
224
+ registry = _ARCHITECTURES
225
+
226
+ for arch_id, architecture in registry.items(): # Iterate through all architectures
227
+ try:
228
+ checkpoint_metadata = architecture.detect(
229
+ state_dict=state_dict, metadata=metadata
230
+ )
231
+
232
+ if checkpoint_metadata is not None:
233
+ # this will overwrite previous architectures with the same id
234
+ components.update({arch_id: checkpoint_metadata})
235
+ except Exception as e:
236
+ print(f"Encountered error running architecture.detect for {arch_id}: {e}")
237
+
238
+ return components