gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Local Model Cache - Copies models from NFS to local NVMe for faster loading.
|
|
3
|
+
|
|
4
|
+
Flow:
|
|
5
|
+
1. First job arrives with model_id
|
|
6
|
+
2. Copy that model from NFS → local (prioritized)
|
|
7
|
+
3. Load from local (fast!)
|
|
8
|
+
4. Background: Copy remaining deployment models
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from .utils.local_cache import LocalModelCache
|
|
12
|
+
|
|
13
|
+
# In worker init
|
|
14
|
+
self.local_cache = LocalModelCache()
|
|
15
|
+
|
|
16
|
+
# Before loading a model
|
|
17
|
+
local_path = await self.local_cache.ensure_local(model_id, source, priority=True)
|
|
18
|
+
|
|
19
|
+
# Start background copying for other models
|
|
20
|
+
asyncio.create_task(self.local_cache.prefetch_models(other_model_ids, sources))
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
import shutil
|
|
25
|
+
import asyncio
|
|
26
|
+
import logging
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Optional, Dict, List, Set
|
|
29
|
+
import hashlib
|
|
30
|
+
import time
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# Local cache location (container disk - NVMe)
|
|
35
|
+
LOCAL_CACHE_DIR = "/root/.local-model-cache"
|
|
36
|
+
|
|
37
|
+
# NFS locations
|
|
38
|
+
NFS_COZY_MODELS = "/workspace/.cozy-creator/models"
|
|
39
|
+
NFS_HF_CACHE = "/workspace/.cache/huggingface/hub"
|
|
40
|
+
|
|
41
|
+
# FlashPack suffix
|
|
42
|
+
FLASHPACK_SUFFIX = ".flashpack"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LocalModelCache:
|
|
46
|
+
"""
|
|
47
|
+
Manages local NVMe cache of models for faster loading.
|
|
48
|
+
|
|
49
|
+
Models are copied from NFS to local disk on-demand, with the
|
|
50
|
+
currently requested model getting priority.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
local_cache_dir: str = LOCAL_CACHE_DIR,
|
|
56
|
+
nfs_cozy_dir: str = NFS_COZY_MODELS,
|
|
57
|
+
nfs_hf_dir: str = NFS_HF_CACHE,
|
|
58
|
+
):
|
|
59
|
+
self.local_cache_dir = Path(local_cache_dir)
|
|
60
|
+
self.nfs_cozy_dir = Path(nfs_cozy_dir)
|
|
61
|
+
self.nfs_hf_dir = Path(nfs_hf_dir)
|
|
62
|
+
|
|
63
|
+
# Track what's cached and what's being copied
|
|
64
|
+
self.cached_models: Set[str] = set()
|
|
65
|
+
self.copying_models: Set[str] = set()
|
|
66
|
+
self._copy_lock = asyncio.Lock()
|
|
67
|
+
|
|
68
|
+
# Create local cache directory
|
|
69
|
+
self.local_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
70
|
+
|
|
71
|
+
# Scan for already cached models
|
|
72
|
+
self._scan_existing_cache()
|
|
73
|
+
|
|
74
|
+
logger.info(f"LocalModelCache initialized at {self.local_cache_dir}")
|
|
75
|
+
logger.info(f"Already cached: {len(self.cached_models)} models")
|
|
76
|
+
|
|
77
|
+
def _scan_existing_cache(self):
|
|
78
|
+
"""Scan local cache for already copied models"""
|
|
79
|
+
if not self.local_cache_dir.exists():
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
for item in self.local_cache_dir.iterdir():
|
|
83
|
+
if item.is_dir():
|
|
84
|
+
# Extract model_id from directory name
|
|
85
|
+
# Format: {model_id}--{hash} or {model_id}--{hash}.flashpack
|
|
86
|
+
name = item.name
|
|
87
|
+
if FLASHPACK_SUFFIX in name:
|
|
88
|
+
name = name.replace(FLASHPACK_SUFFIX, "")
|
|
89
|
+
if "--" in name:
|
|
90
|
+
model_id = name.rsplit("--", 1)[0]
|
|
91
|
+
self.cached_models.add(model_id)
|
|
92
|
+
|
|
93
|
+
async def ensure_local(
|
|
94
|
+
self,
|
|
95
|
+
model_id: str,
|
|
96
|
+
source: str,
|
|
97
|
+
priority: bool = False
|
|
98
|
+
) -> Optional[Path]:
|
|
99
|
+
"""
|
|
100
|
+
Ensure model is in local cache, copying from NFS if needed.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model_id: Model identifier
|
|
104
|
+
source: Source string from pipeline_defs
|
|
105
|
+
priority: If True, this is the active job's model (copy immediately)
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Path to local model (FlashPack dir or safetensors file), or None if failed
|
|
109
|
+
"""
|
|
110
|
+
# Check if already cached
|
|
111
|
+
local_path = self._get_local_path(model_id, source)
|
|
112
|
+
if local_path and local_path.exists():
|
|
113
|
+
logger.info(f"✓ Model {model_id} already in local cache")
|
|
114
|
+
self.cached_models.add(model_id)
|
|
115
|
+
return local_path
|
|
116
|
+
|
|
117
|
+
# Check if currently being copied
|
|
118
|
+
if model_id in self.copying_models:
|
|
119
|
+
if priority:
|
|
120
|
+
# Wait for copy to complete
|
|
121
|
+
logger.info(f"⏳ Waiting for {model_id} copy to complete...")
|
|
122
|
+
while model_id in self.copying_models:
|
|
123
|
+
await asyncio.sleep(0.5)
|
|
124
|
+
return self._get_local_path(model_id, source)
|
|
125
|
+
else:
|
|
126
|
+
# Non-priority, just return None (will use NFS)
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
# Need to copy
|
|
130
|
+
return await self._copy_to_local(model_id, source, priority)
|
|
131
|
+
|
|
132
|
+
async def _copy_to_local(
|
|
133
|
+
self,
|
|
134
|
+
model_id: str,
|
|
135
|
+
source: str,
|
|
136
|
+
priority: bool
|
|
137
|
+
) -> Optional[Path]:
|
|
138
|
+
"""Copy model from NFS to local cache"""
|
|
139
|
+
|
|
140
|
+
async with self._copy_lock:
|
|
141
|
+
# Double-check after acquiring lock
|
|
142
|
+
local_path = self._get_local_path(model_id, source)
|
|
143
|
+
if local_path and local_path.exists():
|
|
144
|
+
return local_path
|
|
145
|
+
|
|
146
|
+
if model_id in self.copying_models:
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
self.copying_models.add(model_id)
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
# Find source on NFS
|
|
153
|
+
nfs_path = self._find_nfs_path(model_id, source)
|
|
154
|
+
if not nfs_path:
|
|
155
|
+
logger.error(f"❌ Model {model_id} not found on NFS")
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
# Determine local destination
|
|
159
|
+
local_dest = self._get_local_dest(model_id, source, nfs_path)
|
|
160
|
+
|
|
161
|
+
logger.info(f"📦 Copying {model_id} to local cache...")
|
|
162
|
+
logger.info(f" From: {nfs_path}")
|
|
163
|
+
logger.info(f" To: {local_dest}")
|
|
164
|
+
|
|
165
|
+
start_time = time.time()
|
|
166
|
+
|
|
167
|
+
# Copy the model
|
|
168
|
+
if nfs_path.is_dir():
|
|
169
|
+
await asyncio.to_thread(
|
|
170
|
+
shutil.copytree,
|
|
171
|
+
str(nfs_path),
|
|
172
|
+
str(local_dest),
|
|
173
|
+
dirs_exist_ok=True
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
local_dest.parent.mkdir(parents=True, exist_ok=True)
|
|
177
|
+
await asyncio.to_thread(
|
|
178
|
+
shutil.copy2,
|
|
179
|
+
str(nfs_path),
|
|
180
|
+
str(local_dest)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
elapsed = time.time() - start_time
|
|
184
|
+
size_gb = self._get_size_gb(local_dest)
|
|
185
|
+
speed = size_gb / elapsed if elapsed > 0 else 0
|
|
186
|
+
|
|
187
|
+
logger.info(f"✅ Copied {model_id} in {elapsed:.1f}s ({size_gb:.2f}GB @ {speed:.1f}GB/s)")
|
|
188
|
+
|
|
189
|
+
self.cached_models.add(model_id)
|
|
190
|
+
return local_dest
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.error(f"❌ Failed to copy {model_id}: {e}")
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
finally:
|
|
197
|
+
self.copying_models.discard(model_id)
|
|
198
|
+
|
|
199
|
+
def _find_nfs_path(self, model_id: str, source: str) -> Optional[Path]:
|
|
200
|
+
"""Find model path on NFS, preferring FlashPack"""
|
|
201
|
+
|
|
202
|
+
if source.startswith("hf:"):
|
|
203
|
+
return self._find_hf_nfs_path(source[3:])
|
|
204
|
+
else:
|
|
205
|
+
return self._find_civitai_nfs_path(model_id, source)
|
|
206
|
+
|
|
207
|
+
def _find_hf_nfs_path(self, repo_id: str) -> Optional[Path]:
|
|
208
|
+
"""Find HuggingFace model on NFS"""
|
|
209
|
+
folder_name = f"models--{repo_id.replace('/', '--')}"
|
|
210
|
+
|
|
211
|
+
# Check for FlashPack version first
|
|
212
|
+
flashpack_path = self.nfs_hf_dir / (folder_name + FLASHPACK_SUFFIX)
|
|
213
|
+
if flashpack_path.exists():
|
|
214
|
+
return flashpack_path
|
|
215
|
+
|
|
216
|
+
# Fall back to regular HF cache
|
|
217
|
+
hf_path = self.nfs_hf_dir / folder_name
|
|
218
|
+
if hf_path.exists():
|
|
219
|
+
return hf_path
|
|
220
|
+
|
|
221
|
+
return None
|
|
222
|
+
|
|
223
|
+
def _find_civitai_nfs_path(self, model_id: str, source: str) -> Optional[Path]:
|
|
224
|
+
"""Find Civitai model on NFS"""
|
|
225
|
+
safe_name = model_id.replace("/", "-")
|
|
226
|
+
|
|
227
|
+
# Find the model directory
|
|
228
|
+
matching_dirs = list(self.nfs_cozy_dir.glob(f"{safe_name}--*"))
|
|
229
|
+
if not matching_dirs:
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
original_dir = matching_dirs[0]
|
|
233
|
+
|
|
234
|
+
# Check for FlashPack version first
|
|
235
|
+
flashpack_path = original_dir.parent / (original_dir.name + FLASHPACK_SUFFIX)
|
|
236
|
+
if flashpack_path.exists():
|
|
237
|
+
return flashpack_path
|
|
238
|
+
|
|
239
|
+
# Fall back to safetensors
|
|
240
|
+
safetensors_files = list(original_dir.glob("*.safetensors"))
|
|
241
|
+
if safetensors_files:
|
|
242
|
+
return safetensors_files[0]
|
|
243
|
+
|
|
244
|
+
return original_dir
|
|
245
|
+
|
|
246
|
+
def _get_local_path(self, model_id: str, source: str) -> Optional[Path]:
|
|
247
|
+
"""Get expected local path for a model"""
|
|
248
|
+
nfs_path = self._find_nfs_path(model_id, source)
|
|
249
|
+
if not nfs_path:
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
return self._get_local_dest(model_id, source, nfs_path)
|
|
253
|
+
|
|
254
|
+
def _get_local_dest(self, model_id: str, source: str, nfs_path: Path) -> Path:
|
|
255
|
+
"""Get local destination path matching NFS structure"""
|
|
256
|
+
# Keep the same directory/file name, just change base path
|
|
257
|
+
return self.local_cache_dir / nfs_path.name
|
|
258
|
+
|
|
259
|
+
def _get_size_gb(self, path: Path) -> float:
|
|
260
|
+
"""Get size of path in GB"""
|
|
261
|
+
if path.is_file():
|
|
262
|
+
return path.stat().st_size / (1024 ** 3)
|
|
263
|
+
|
|
264
|
+
total = sum(f.stat().st_size for f in path.rglob('*') if f.is_file())
|
|
265
|
+
return total / (1024 ** 3)
|
|
266
|
+
|
|
267
|
+
async def prefetch_models(
|
|
268
|
+
self,
|
|
269
|
+
model_ids: List[str],
|
|
270
|
+
sources: Dict[str, str]
|
|
271
|
+
):
|
|
272
|
+
"""
|
|
273
|
+
Background task to prefetch models to local cache.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
model_ids: List of model IDs to prefetch
|
|
277
|
+
sources: Dict mapping model_id → source string
|
|
278
|
+
"""
|
|
279
|
+
logger.info(f"🔄 Starting background prefetch for {len(model_ids)} models")
|
|
280
|
+
|
|
281
|
+
for model_id in model_ids:
|
|
282
|
+
if model_id in self.cached_models:
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
source = sources.get(model_id)
|
|
286
|
+
if not source:
|
|
287
|
+
logger.warning(f"No source found for {model_id}, skipping prefetch")
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
await self.ensure_local(model_id, source, priority=False)
|
|
291
|
+
|
|
292
|
+
# Small delay between copies to avoid overwhelming I/O
|
|
293
|
+
await asyncio.sleep(0.1)
|
|
294
|
+
|
|
295
|
+
logger.info(f"✅ Background prefetch complete. Cached: {len(self.cached_models)} models")
|
|
296
|
+
|
|
297
|
+
def get_local_path_if_cached(self, model_id: str, source: str) -> Optional[Path]:
|
|
298
|
+
"""
|
|
299
|
+
Get local path only if model is already cached.
|
|
300
|
+
Does not trigger a copy.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
model_id: Model identifier
|
|
304
|
+
source: Source string
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Local path if cached, None otherwise
|
|
308
|
+
"""
|
|
309
|
+
if model_id not in self.cached_models:
|
|
310
|
+
return None
|
|
311
|
+
|
|
312
|
+
local_path = self._get_local_path(model_id, source)
|
|
313
|
+
if local_path and local_path.exists():
|
|
314
|
+
return local_path
|
|
315
|
+
|
|
316
|
+
# Was in set but not on disk - remove from set
|
|
317
|
+
self.cached_models.discard(model_id)
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
def is_cached(self, model_id: str) -> bool:
|
|
321
|
+
"""Check if model is in local cache"""
|
|
322
|
+
return model_id in self.cached_models
|
|
323
|
+
|
|
324
|
+
def get_cache_stats(self) -> Dict:
|
|
325
|
+
"""Get cache statistics"""
|
|
326
|
+
total_size = 0
|
|
327
|
+
if self.local_cache_dir.exists():
|
|
328
|
+
total_size = sum(
|
|
329
|
+
f.stat().st_size
|
|
330
|
+
for f in self.local_cache_dir.rglob('*')
|
|
331
|
+
if f.is_file()
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return {
|
|
335
|
+
"cached_models": len(self.cached_models),
|
|
336
|
+
"models": list(self.cached_models),
|
|
337
|
+
"total_size_gb": total_size / (1024 ** 3),
|
|
338
|
+
"cache_dir": str(self.local_cache_dir),
|
|
339
|
+
"currently_copying": list(self.copying_models),
|
|
340
|
+
}
|