gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,340 @@
1
+ """
2
+ Local Model Cache - Copies models from NFS to local NVMe for faster loading.
3
+
4
+ Flow:
5
+ 1. First job arrives with model_id
6
+ 2. Copy that model from NFS → local (prioritized)
7
+ 3. Load from local (fast!)
8
+ 4. Background: Copy remaining deployment models
9
+
10
+ Usage:
11
+ from .utils.local_cache import LocalModelCache
12
+
13
+ # In worker init
14
+ self.local_cache = LocalModelCache()
15
+
16
+ # Before loading a model
17
+ local_path = await self.local_cache.ensure_local(model_id, source, priority=True)
18
+
19
+ # Start background copying for other models
20
+ asyncio.create_task(self.local_cache.prefetch_models(other_model_ids, sources))
21
+ """
22
+
23
+ import os
24
+ import shutil
25
+ import asyncio
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Optional, Dict, List, Set
29
+ import hashlib
30
+ import time
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Local cache location (container disk - NVMe)
35
+ LOCAL_CACHE_DIR = "/root/.local-model-cache"
36
+
37
+ # NFS locations
38
+ NFS_COZY_MODELS = "/workspace/.cozy-creator/models"
39
+ NFS_HF_CACHE = "/workspace/.cache/huggingface/hub"
40
+
41
+ # FlashPack suffix
42
+ FLASHPACK_SUFFIX = ".flashpack"
43
+
44
+
45
+ class LocalModelCache:
46
+ """
47
+ Manages local NVMe cache of models for faster loading.
48
+
49
+ Models are copied from NFS to local disk on-demand, with the
50
+ currently requested model getting priority.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ local_cache_dir: str = LOCAL_CACHE_DIR,
56
+ nfs_cozy_dir: str = NFS_COZY_MODELS,
57
+ nfs_hf_dir: str = NFS_HF_CACHE,
58
+ ):
59
+ self.local_cache_dir = Path(local_cache_dir)
60
+ self.nfs_cozy_dir = Path(nfs_cozy_dir)
61
+ self.nfs_hf_dir = Path(nfs_hf_dir)
62
+
63
+ # Track what's cached and what's being copied
64
+ self.cached_models: Set[str] = set()
65
+ self.copying_models: Set[str] = set()
66
+ self._copy_lock = asyncio.Lock()
67
+
68
+ # Create local cache directory
69
+ self.local_cache_dir.mkdir(parents=True, exist_ok=True)
70
+
71
+ # Scan for already cached models
72
+ self._scan_existing_cache()
73
+
74
+ logger.info(f"LocalModelCache initialized at {self.local_cache_dir}")
75
+ logger.info(f"Already cached: {len(self.cached_models)} models")
76
+
77
+ def _scan_existing_cache(self):
78
+ """Scan local cache for already copied models"""
79
+ if not self.local_cache_dir.exists():
80
+ return
81
+
82
+ for item in self.local_cache_dir.iterdir():
83
+ if item.is_dir():
84
+ # Extract model_id from directory name
85
+ # Format: {model_id}--{hash} or {model_id}--{hash}.flashpack
86
+ name = item.name
87
+ if FLASHPACK_SUFFIX in name:
88
+ name = name.replace(FLASHPACK_SUFFIX, "")
89
+ if "--" in name:
90
+ model_id = name.rsplit("--", 1)[0]
91
+ self.cached_models.add(model_id)
92
+
93
+ async def ensure_local(
94
+ self,
95
+ model_id: str,
96
+ source: str,
97
+ priority: bool = False
98
+ ) -> Optional[Path]:
99
+ """
100
+ Ensure model is in local cache, copying from NFS if needed.
101
+
102
+ Args:
103
+ model_id: Model identifier
104
+ source: Source string from pipeline_defs
105
+ priority: If True, this is the active job's model (copy immediately)
106
+
107
+ Returns:
108
+ Path to local model (FlashPack dir or safetensors file), or None if failed
109
+ """
110
+ # Check if already cached
111
+ local_path = self._get_local_path(model_id, source)
112
+ if local_path and local_path.exists():
113
+ logger.info(f"✓ Model {model_id} already in local cache")
114
+ self.cached_models.add(model_id)
115
+ return local_path
116
+
117
+ # Check if currently being copied
118
+ if model_id in self.copying_models:
119
+ if priority:
120
+ # Wait for copy to complete
121
+ logger.info(f"⏳ Waiting for {model_id} copy to complete...")
122
+ while model_id in self.copying_models:
123
+ await asyncio.sleep(0.5)
124
+ return self._get_local_path(model_id, source)
125
+ else:
126
+ # Non-priority, just return None (will use NFS)
127
+ return None
128
+
129
+ # Need to copy
130
+ return await self._copy_to_local(model_id, source, priority)
131
+
132
+ async def _copy_to_local(
133
+ self,
134
+ model_id: str,
135
+ source: str,
136
+ priority: bool
137
+ ) -> Optional[Path]:
138
+ """Copy model from NFS to local cache"""
139
+
140
+ async with self._copy_lock:
141
+ # Double-check after acquiring lock
142
+ local_path = self._get_local_path(model_id, source)
143
+ if local_path and local_path.exists():
144
+ return local_path
145
+
146
+ if model_id in self.copying_models:
147
+ return None
148
+
149
+ self.copying_models.add(model_id)
150
+
151
+ try:
152
+ # Find source on NFS
153
+ nfs_path = self._find_nfs_path(model_id, source)
154
+ if not nfs_path:
155
+ logger.error(f"❌ Model {model_id} not found on NFS")
156
+ return None
157
+
158
+ # Determine local destination
159
+ local_dest = self._get_local_dest(model_id, source, nfs_path)
160
+
161
+ logger.info(f"📦 Copying {model_id} to local cache...")
162
+ logger.info(f" From: {nfs_path}")
163
+ logger.info(f" To: {local_dest}")
164
+
165
+ start_time = time.time()
166
+
167
+ # Copy the model
168
+ if nfs_path.is_dir():
169
+ await asyncio.to_thread(
170
+ shutil.copytree,
171
+ str(nfs_path),
172
+ str(local_dest),
173
+ dirs_exist_ok=True
174
+ )
175
+ else:
176
+ local_dest.parent.mkdir(parents=True, exist_ok=True)
177
+ await asyncio.to_thread(
178
+ shutil.copy2,
179
+ str(nfs_path),
180
+ str(local_dest)
181
+ )
182
+
183
+ elapsed = time.time() - start_time
184
+ size_gb = self._get_size_gb(local_dest)
185
+ speed = size_gb / elapsed if elapsed > 0 else 0
186
+
187
+ logger.info(f"✅ Copied {model_id} in {elapsed:.1f}s ({size_gb:.2f}GB @ {speed:.1f}GB/s)")
188
+
189
+ self.cached_models.add(model_id)
190
+ return local_dest
191
+
192
+ except Exception as e:
193
+ logger.error(f"❌ Failed to copy {model_id}: {e}")
194
+ return None
195
+
196
+ finally:
197
+ self.copying_models.discard(model_id)
198
+
199
+ def _find_nfs_path(self, model_id: str, source: str) -> Optional[Path]:
200
+ """Find model path on NFS, preferring FlashPack"""
201
+
202
+ if source.startswith("hf:"):
203
+ return self._find_hf_nfs_path(source[3:])
204
+ else:
205
+ return self._find_civitai_nfs_path(model_id, source)
206
+
207
+ def _find_hf_nfs_path(self, repo_id: str) -> Optional[Path]:
208
+ """Find HuggingFace model on NFS"""
209
+ folder_name = f"models--{repo_id.replace('/', '--')}"
210
+
211
+ # Check for FlashPack version first
212
+ flashpack_path = self.nfs_hf_dir / (folder_name + FLASHPACK_SUFFIX)
213
+ if flashpack_path.exists():
214
+ return flashpack_path
215
+
216
+ # Fall back to regular HF cache
217
+ hf_path = self.nfs_hf_dir / folder_name
218
+ if hf_path.exists():
219
+ return hf_path
220
+
221
+ return None
222
+
223
+ def _find_civitai_nfs_path(self, model_id: str, source: str) -> Optional[Path]:
224
+ """Find Civitai model on NFS"""
225
+ safe_name = model_id.replace("/", "-")
226
+
227
+ # Find the model directory
228
+ matching_dirs = list(self.nfs_cozy_dir.glob(f"{safe_name}--*"))
229
+ if not matching_dirs:
230
+ return None
231
+
232
+ original_dir = matching_dirs[0]
233
+
234
+ # Check for FlashPack version first
235
+ flashpack_path = original_dir.parent / (original_dir.name + FLASHPACK_SUFFIX)
236
+ if flashpack_path.exists():
237
+ return flashpack_path
238
+
239
+ # Fall back to safetensors
240
+ safetensors_files = list(original_dir.glob("*.safetensors"))
241
+ if safetensors_files:
242
+ return safetensors_files[0]
243
+
244
+ return original_dir
245
+
246
+ def _get_local_path(self, model_id: str, source: str) -> Optional[Path]:
247
+ """Get expected local path for a model"""
248
+ nfs_path = self._find_nfs_path(model_id, source)
249
+ if not nfs_path:
250
+ return None
251
+
252
+ return self._get_local_dest(model_id, source, nfs_path)
253
+
254
+ def _get_local_dest(self, model_id: str, source: str, nfs_path: Path) -> Path:
255
+ """Get local destination path matching NFS structure"""
256
+ # Keep the same directory/file name, just change base path
257
+ return self.local_cache_dir / nfs_path.name
258
+
259
+ def _get_size_gb(self, path: Path) -> float:
260
+ """Get size of path in GB"""
261
+ if path.is_file():
262
+ return path.stat().st_size / (1024 ** 3)
263
+
264
+ total = sum(f.stat().st_size for f in path.rglob('*') if f.is_file())
265
+ return total / (1024 ** 3)
266
+
267
+ async def prefetch_models(
268
+ self,
269
+ model_ids: List[str],
270
+ sources: Dict[str, str]
271
+ ):
272
+ """
273
+ Background task to prefetch models to local cache.
274
+
275
+ Args:
276
+ model_ids: List of model IDs to prefetch
277
+ sources: Dict mapping model_id → source string
278
+ """
279
+ logger.info(f"🔄 Starting background prefetch for {len(model_ids)} models")
280
+
281
+ for model_id in model_ids:
282
+ if model_id in self.cached_models:
283
+ continue
284
+
285
+ source = sources.get(model_id)
286
+ if not source:
287
+ logger.warning(f"No source found for {model_id}, skipping prefetch")
288
+ continue
289
+
290
+ await self.ensure_local(model_id, source, priority=False)
291
+
292
+ # Small delay between copies to avoid overwhelming I/O
293
+ await asyncio.sleep(0.1)
294
+
295
+ logger.info(f"✅ Background prefetch complete. Cached: {len(self.cached_models)} models")
296
+
297
+ def get_local_path_if_cached(self, model_id: str, source: str) -> Optional[Path]:
298
+ """
299
+ Get local path only if model is already cached.
300
+ Does not trigger a copy.
301
+
302
+ Args:
303
+ model_id: Model identifier
304
+ source: Source string
305
+
306
+ Returns:
307
+ Local path if cached, None otherwise
308
+ """
309
+ if model_id not in self.cached_models:
310
+ return None
311
+
312
+ local_path = self._get_local_path(model_id, source)
313
+ if local_path and local_path.exists():
314
+ return local_path
315
+
316
+ # Was in set but not on disk - remove from set
317
+ self.cached_models.discard(model_id)
318
+ return None
319
+
320
+ def is_cached(self, model_id: str) -> bool:
321
+ """Check if model is in local cache"""
322
+ return model_id in self.cached_models
323
+
324
+ def get_cache_stats(self) -> Dict:
325
+ """Get cache statistics"""
326
+ total_size = 0
327
+ if self.local_cache_dir.exists():
328
+ total_size = sum(
329
+ f.stat().st_size
330
+ for f in self.local_cache_dir.rglob('*')
331
+ if f.is_file()
332
+ )
333
+
334
+ return {
335
+ "cached_models": len(self.cached_models),
336
+ "models": list(self.cached_models),
337
+ "total_size_gb": total_size / (1024 ** 3),
338
+ "cache_dir": str(self.local_cache_dir),
339
+ "currently_copying": list(self.copying_models),
340
+ }