gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,2059 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import gc
|
|
3
|
+
import logging
|
|
4
|
+
import importlib
|
|
5
|
+
from enum import Enum
|
|
6
|
+
import traceback
|
|
7
|
+
from typing import Optional, Any, Dict, List, Tuple, Union, Type, Set
|
|
8
|
+
import psutil
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
import time
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import diffusers
|
|
16
|
+
import numpy as np
|
|
17
|
+
import asyncio
|
|
18
|
+
|
|
19
|
+
from diffusers import (
|
|
20
|
+
DiffusionPipeline,
|
|
21
|
+
FluxInpaintPipeline,
|
|
22
|
+
FluxPipeline,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from diffusers.loaders import FromSingleFileMixin
|
|
26
|
+
from huggingface_hub.constants import HF_HUB_CACHE
|
|
27
|
+
from huggingface_hub.file_download import repo_folder_name
|
|
28
|
+
from huggingface_hub.utils import EntryNotFoundError
|
|
29
|
+
from accelerate import utils as accelerate_utils
|
|
30
|
+
|
|
31
|
+
from .utils.config import get_config, get_environment
|
|
32
|
+
from .utils.globals import (
|
|
33
|
+
# get_hf_model_manager,
|
|
34
|
+
get_architectures,
|
|
35
|
+
get_available_torch_device,
|
|
36
|
+
set_available_torch_device,
|
|
37
|
+
get_model_downloader,
|
|
38
|
+
)
|
|
39
|
+
from .utils.model_downloader import ModelSource
|
|
40
|
+
from .utils.load_models import load_state_dict_from_file
|
|
41
|
+
from .utils.base_types.config import PipelineConfig
|
|
42
|
+
from .utils import diffusers_fix
|
|
43
|
+
|
|
44
|
+
from ..model_interface import ModelManagementInterface, DownloaderType
|
|
45
|
+
|
|
46
|
+
from .utils.flashpack_loader import FlashPackLoader
|
|
47
|
+
|
|
48
|
+
# Configure logging
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
logging.basicConfig(
|
|
51
|
+
level=logging.INFO,
|
|
52
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
53
|
+
stream=sys.stdout,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Constants
|
|
57
|
+
VRAM_SAFETY_MARGIN_GB = 3.5
|
|
58
|
+
DEFAULT_MAX_VRAM_BUFFER_GB = 2.0
|
|
59
|
+
RAM_SAFETY_MARGIN_GB = 10.0
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Keys correspond to diffusers pipeline classes
|
|
63
|
+
MODEL_COMPONENTS = {
|
|
64
|
+
"FluxPipeline": [
|
|
65
|
+
"vae",
|
|
66
|
+
"transformer",
|
|
67
|
+
"text_encoder",
|
|
68
|
+
"text_encoder_2",
|
|
69
|
+
"scheduler",
|
|
70
|
+
"tokenizer",
|
|
71
|
+
"tokenizer_2",
|
|
72
|
+
],
|
|
73
|
+
"StableDiffusionXLPipeline": [
|
|
74
|
+
"vae",
|
|
75
|
+
"unet",
|
|
76
|
+
"text_encoder",
|
|
77
|
+
"text_encoder_2",
|
|
78
|
+
"scheduler",
|
|
79
|
+
"tokenizer",
|
|
80
|
+
"tokenizer_2",
|
|
81
|
+
],
|
|
82
|
+
"StableDiffusionPipeline": [
|
|
83
|
+
"vae",
|
|
84
|
+
"unet",
|
|
85
|
+
"text_encoder",
|
|
86
|
+
"scheduler",
|
|
87
|
+
"tokenizer",
|
|
88
|
+
],
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_pipeline_class(
|
|
93
|
+
class_name: Union[str, Tuple[str, str]],
|
|
94
|
+
) -> Tuple[Type[DiffusionPipeline], Optional[str]]:
|
|
95
|
+
"""Get the appropriate pipeline class based on class_name configuration.
|
|
96
|
+
Learn about custom-pipelines here: https://huggingface.co/docs/diffusers/v0.6.0/en/using-diffusers/custom_pipelines
|
|
97
|
+
Note that from_single_file does not support custom pipelines, only from_pretrained does.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
class_name: Either a string naming a diffusers class, or a tuple of (package, class)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tuple of (Pipeline class to use, custom_pipeline to include as a kwarg to the pipeline)
|
|
104
|
+
"""
|
|
105
|
+
# class_name is in the form of [package, class]
|
|
106
|
+
if isinstance(class_name, tuple):
|
|
107
|
+
# Load from custom package
|
|
108
|
+
package, cls = class_name
|
|
109
|
+
module = importlib.import_module(package)
|
|
110
|
+
return (getattr(module, cls), None)
|
|
111
|
+
|
|
112
|
+
# Try loading class_name as a diffusers class
|
|
113
|
+
try:
|
|
114
|
+
pipeline_class = getattr(importlib.import_module("diffusers"), class_name)
|
|
115
|
+
if not issubclass(pipeline_class, DiffusionPipeline):
|
|
116
|
+
print("custompipeline2", class_name)
|
|
117
|
+
raise TypeError(f"{class_name} does not inherit from DiffusionPipeline")
|
|
118
|
+
return (pipeline_class, None)
|
|
119
|
+
|
|
120
|
+
except (ImportError, AttributeError):
|
|
121
|
+
# Assume the class name is the name of a custom pipeline
|
|
122
|
+
print("custompipeline1", class_name)
|
|
123
|
+
return (DiffusionPipeline, class_name)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class LRUCache:
|
|
127
|
+
"""
|
|
128
|
+
Least Recently Used (LRU) Cache for tracking model usage.
|
|
129
|
+
|
|
130
|
+
Maintains separate tracking for GPU and CPU cached models using timestamps
|
|
131
|
+
to determine usage patterns and inform memory management decisions.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self):
|
|
135
|
+
"""Initialize empty GPU and CPU caches."""
|
|
136
|
+
self.gpu_cache: OrderedDict = OrderedDict() # model_id -> last_used_timestamp
|
|
137
|
+
self.cpu_cache: OrderedDict = OrderedDict() # model_id -> last_used_timestamp
|
|
138
|
+
|
|
139
|
+
def access(self, model_id: str, cache_type: str = "gpu") -> None:
|
|
140
|
+
"""
|
|
141
|
+
Record access of a model, updating its position in the LRU cache.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_id: Unique identifier for the model
|
|
145
|
+
cache_type: Type of cache to update ("gpu" or "cpu")
|
|
146
|
+
"""
|
|
147
|
+
cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
|
|
148
|
+
cache.pop(model_id, None) # Remove if exists
|
|
149
|
+
cache[model_id] = time.time() # Add to end (most recently used)
|
|
150
|
+
|
|
151
|
+
def remove(self, model_id: str, cache_type: str = "gpu") -> None:
|
|
152
|
+
"""
|
|
153
|
+
Remove a model from cache tracking.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
model_id: Unique identifier for the model to remove
|
|
157
|
+
cache_type: Type of cache to remove from ("gpu" or "cpu")
|
|
158
|
+
"""
|
|
159
|
+
cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
|
|
160
|
+
cache.pop(model_id, None)
|
|
161
|
+
|
|
162
|
+
def get_lru_models(
|
|
163
|
+
self, cache_type: str = "gpu", count: Optional[int] = None
|
|
164
|
+
) -> List[str]:
|
|
165
|
+
"""
|
|
166
|
+
Get least recently used models.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
cache_type: Type of cache to query ("gpu" or "cpu")
|
|
170
|
+
count: Optional number of models to return. If None, returns all models
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
List of model IDs ordered by least recently used first
|
|
174
|
+
"""
|
|
175
|
+
cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
|
|
176
|
+
model_list = list(cache.keys())
|
|
177
|
+
return model_list if count is None else model_list[:count]
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class DefaultModelManager:
|
|
181
|
+
"""
|
|
182
|
+
Manages loading, unloading and memory allocation of machine learning models.
|
|
183
|
+
|
|
184
|
+
Handles dynamic movement of models between GPU and CPU memory based on
|
|
185
|
+
available resources and usage patterns. Implements optimization strategies
|
|
186
|
+
for efficient memory usage and model performance.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(self):
|
|
190
|
+
"""Initialize the model memory manager with empty states and default configurations."""
|
|
191
|
+
# Model tracking
|
|
192
|
+
self.current_model: Optional[str] = None
|
|
193
|
+
self.base_pipelines: Dict[str, DiffusionPipeline] = {}
|
|
194
|
+
self.loaded_models: Dict[str, DiffusionPipeline] = {}
|
|
195
|
+
self.cpu_models: Dict[str, DiffusionPipeline] = {}
|
|
196
|
+
self.model_sizes: Dict[str, float] = {}
|
|
197
|
+
self.model_types: Dict[str, torch.dtype] = {}
|
|
198
|
+
self.loaded_model: Optional[DiffusionPipeline] = None
|
|
199
|
+
|
|
200
|
+
self.pipeline_locks: Dict[str, threading.Lock] = {}
|
|
201
|
+
|
|
202
|
+
# Memory tracking
|
|
203
|
+
self.vram_usage: float = 0
|
|
204
|
+
self.ram_usage: float = 0
|
|
205
|
+
self.max_vram: float = self._get_total_vram() - DEFAULT_MAX_VRAM_BUFFER_GB
|
|
206
|
+
self.system_ram: float = psutil.virtual_memory().total / (1024**3)
|
|
207
|
+
|
|
208
|
+
# State flags
|
|
209
|
+
self.is_in_device: bool = False
|
|
210
|
+
self.is_startup_load: bool = False
|
|
211
|
+
self.environment: str = os.getenv("ENVIRONMENT", "dev")
|
|
212
|
+
|
|
213
|
+
# Managers and caches
|
|
214
|
+
self.model_downloader = get_model_downloader()
|
|
215
|
+
self.cache_dir = HF_HUB_CACHE
|
|
216
|
+
self.lru_cache = LRUCache()
|
|
217
|
+
self.flashpack_loader = FlashPackLoader()
|
|
218
|
+
|
|
219
|
+
self.allowed_model_ids: Optional[Set[str]] = None
|
|
220
|
+
self.supports_all_known_models: bool = False # Default to explicit list
|
|
221
|
+
self._download_lock = asyncio.Lock() # Lock for ensuring only one prioritized download happens at a time
|
|
222
|
+
|
|
223
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
224
|
+
set_available_torch_device(device)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
async def process_supported_models_config(
|
|
228
|
+
self, supported_model_ids: List[str], downloader_instance: Optional[DownloaderType] = None # Can ignore if using self.model_downloader
|
|
229
|
+
) -> None:
|
|
230
|
+
logger.info(f"DMM: Processing scheduler-provided supported_model_ids: {supported_model_ids}")
|
|
231
|
+
if len(supported_model_ids) == 1 and supported_model_ids[0] == "*": # "*" means all
|
|
232
|
+
self.supports_all_known_models = True
|
|
233
|
+
self.allowed_model_ids = None
|
|
234
|
+
logger.info("DMM: Configured to support ALL known models (on-demand) due to '*' from scheduler.")
|
|
235
|
+
elif not supported_model_ids: # Empty list means no specific dynamic models for this worker config
|
|
236
|
+
self.supports_all_known_models = False
|
|
237
|
+
self.allowed_model_ids = set()
|
|
238
|
+
logger.info("DMM: Configured with an empty supported model list from scheduler (supports no dynamic models).")
|
|
239
|
+
else: # Specific list
|
|
240
|
+
self.supports_all_known_models = False
|
|
241
|
+
self.allowed_model_ids = set(supported_model_ids)
|
|
242
|
+
logger.info(f"DMM: Configured to support specific models: {self.allowed_model_ids}")
|
|
243
|
+
|
|
244
|
+
async def load_model_into_vram(self, model_id: str) -> bool:
|
|
245
|
+
logger.info(f"DefaultModelManager: Request to load '{model_id}' into VRAM.")
|
|
246
|
+
pipeline = await self.load(model_id)
|
|
247
|
+
if pipeline:
|
|
248
|
+
self.base_pipelines[model_id] = pipeline
|
|
249
|
+
self.loaded_model = pipeline
|
|
250
|
+
self.current_model = model_id
|
|
251
|
+
self.loaded_models[model_id] = pipeline # this is for backward compatibility (will be removed in the future)
|
|
252
|
+
return True
|
|
253
|
+
return False
|
|
254
|
+
|
|
255
|
+
async def get_active_pipeline(self, model_id: str) -> Optional[Any]:
|
|
256
|
+
logger.info(f"DefaultModelManager: Request to get active pipeline for '{model_id}'.")
|
|
257
|
+
base_pipeline = await self.load(model_id)
|
|
258
|
+
if not base_pipeline:
|
|
259
|
+
logger.error(f"Failed to load base pipeline for {model_id}")
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
# Create a task-specific pipeline instance with fresh scheduler state
|
|
263
|
+
task_pipeline = self.get_pipeline_for_task(model_id, "worker_task")
|
|
264
|
+
if not task_pipeline:
|
|
265
|
+
logger.error(f"Failed to create thread-safe pipeline for {model_id}")
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
return task_pipeline
|
|
269
|
+
|
|
270
|
+
def get_vram_loaded_models(self) -> List[str]:
|
|
271
|
+
# This should return a list of model_ids that are currently in VRAM
|
|
272
|
+
return list[str](self.loaded_models.keys())
|
|
273
|
+
|
|
274
|
+
# def get_pipeline_for_task(self, model_id: str, run_id: str = None) -> Optional[DiffusionPipeline]:
|
|
275
|
+
# """
|
|
276
|
+
# Create a task-specific pipeline instance with fresh scheduler state.
|
|
277
|
+
# This prevents thread safety issues in concurrent executions.
|
|
278
|
+
|
|
279
|
+
# This is the KEY FIX for the IndexError: index 29 is out of bounds issue.
|
|
280
|
+
# """
|
|
281
|
+
# try:
|
|
282
|
+
# # Get the base pipeline (shared components)
|
|
283
|
+
# base_pipeline = self.get_base_pipeline(model_id)
|
|
284
|
+
# if not base_pipeline:
|
|
285
|
+
# logger.error(f"No base pipeline found for model {model_id}")
|
|
286
|
+
# return None
|
|
287
|
+
|
|
288
|
+
# # Create a fresh scheduler for this task to prevent state corruption
|
|
289
|
+
# fresh_scheduler = base_pipeline.scheduler.from_config(
|
|
290
|
+
# base_pipeline.scheduler.config
|
|
291
|
+
# )
|
|
292
|
+
|
|
293
|
+
# # Create task-specific pipeline with shared components but fresh scheduler
|
|
294
|
+
# task_pipeline = base_pipeline.__class__.from_pipe(
|
|
295
|
+
# base_pipeline,
|
|
296
|
+
# scheduler=fresh_scheduler
|
|
297
|
+
# )
|
|
298
|
+
|
|
299
|
+
# logger.info(f"[ModelManager] Created thread-safe pipeline for task {run_id} model {model_id}")
|
|
300
|
+
# return task_pipeline
|
|
301
|
+
|
|
302
|
+
# except Exception as e:
|
|
303
|
+
# logger.error(f"Error creating task-specific pipeline for {model_id}: {e}")
|
|
304
|
+
# return None
|
|
305
|
+
|
|
306
|
+
def get_pipeline_for_task(self, model_id: str, run_id: str = None) -> Optional[DiffusionPipeline]:
|
|
307
|
+
"""
|
|
308
|
+
Return the base pipeline directly - no recreation needed.
|
|
309
|
+
Scheduler state issues are rare with proper usage.
|
|
310
|
+
"""
|
|
311
|
+
base_pipeline = self.get_base_pipeline(model_id)
|
|
312
|
+
if not base_pipeline:
|
|
313
|
+
logger.error(f"No base pipeline found for model {model_id}")
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
# Return the base pipeline directly - it's thread-safe enough for inference
|
|
317
|
+
logger.info(f"[ModelManager] Returning base pipeline for model {model_id}")
|
|
318
|
+
return base_pipeline
|
|
319
|
+
|
|
320
|
+
def get_base_pipeline(self, model_id: str) -> Optional[DiffusionPipeline]:
|
|
321
|
+
"""
|
|
322
|
+
Get the base pipeline for component sharing.
|
|
323
|
+
This contains the loaded model components but should NOT be used directly for inference.
|
|
324
|
+
"""
|
|
325
|
+
# Check base pipelines first
|
|
326
|
+
if model_id in self.base_pipelines:
|
|
327
|
+
return self.base_pipelines[model_id]
|
|
328
|
+
|
|
329
|
+
# Check loaded models (for backward compatibility)
|
|
330
|
+
if model_id in self.loaded_models:
|
|
331
|
+
return self.loaded_models[model_id]
|
|
332
|
+
|
|
333
|
+
# Check CPU models
|
|
334
|
+
if model_id in self.cpu_models:
|
|
335
|
+
return self.cpu_models[model_id]
|
|
336
|
+
|
|
337
|
+
return None
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
# If the scheduler is not set in `pipeline_defs`, then we'll rely on diffusers to pick a default
|
|
341
|
+
# scheduler.
|
|
342
|
+
def _setup_scheduler(self, pipeline: DiffusionPipeline, model_id: str) -> None:
|
|
343
|
+
"""Setup scheduler from component config"""
|
|
344
|
+
config = get_config()
|
|
345
|
+
model_config = config.pipeline_defs.get(model_id)
|
|
346
|
+
if not model_config:
|
|
347
|
+
logger.error(f"Model {model_id} not found in configuration")
|
|
348
|
+
return
|
|
349
|
+
components = model_config.get("components", {})
|
|
350
|
+
scheduler_config = components.get("scheduler", {}) if components else {}
|
|
351
|
+
if not scheduler_config:
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
scheduler_class = scheduler_config.get("class_name")
|
|
355
|
+
scheduler_kwargs = scheduler_config.get("kwargs", {})
|
|
356
|
+
|
|
357
|
+
try:
|
|
358
|
+
INVALID_SCHEDULER_CONFIG_KEYS = {
|
|
359
|
+
'do_classifier_free_guidance',
|
|
360
|
+
'guidance_scale',
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
# Only pass kwargs that are valid for scheduler initialization
|
|
364
|
+
filtered_kwargs = {
|
|
365
|
+
k: v for k, v in scheduler_kwargs.items()
|
|
366
|
+
if k.lower() not in {key.lower() for key in INVALID_SCHEDULER_CONFIG_KEYS}
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
if len(scheduler_kwargs) != len(filtered_kwargs):
|
|
370
|
+
removed = set(scheduler_kwargs.keys()) - set(filtered_kwargs.keys())
|
|
371
|
+
logger.warning(f"Removed invalid scheduler config keys for {model_id}: {removed}")
|
|
372
|
+
|
|
373
|
+
new_scheduler = getattr(diffusers, scheduler_class).from_config(
|
|
374
|
+
pipeline.scheduler.config, **filtered_kwargs
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
pipeline.scheduler = new_scheduler
|
|
378
|
+
|
|
379
|
+
print(f"Successfully set scheduler to {scheduler_class}")
|
|
380
|
+
except Exception as e:
|
|
381
|
+
logger.error(f"Error setting scheduler for {model_id}: {e}")
|
|
382
|
+
|
|
383
|
+
def _get_memory_info(self) -> Tuple[float, float]:
|
|
384
|
+
"""
|
|
385
|
+
Get current memory availability.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Tuple containing:
|
|
389
|
+
- Available RAM in GB
|
|
390
|
+
- Available VRAM in GB
|
|
391
|
+
"""
|
|
392
|
+
ram = self._get_available_ram()
|
|
393
|
+
vram = self._get_available_vram()
|
|
394
|
+
return ram, vram
|
|
395
|
+
|
|
396
|
+
def _get_available_ram(self) -> float:
|
|
397
|
+
"""
|
|
398
|
+
Get available system RAM in GB.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
Available RAM in GB
|
|
402
|
+
"""
|
|
403
|
+
return psutil.virtual_memory().available / (1024**3)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def _get_total_vram(self) -> float:
|
|
407
|
+
"""
|
|
408
|
+
Get total VRAM available on the system.
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Total VRAM in GB, or 0 if no CUDA device is available
|
|
412
|
+
"""
|
|
413
|
+
if torch.cuda.is_available():
|
|
414
|
+
return torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
|
415
|
+
return 0
|
|
416
|
+
|
|
417
|
+
def _get_available_vram(self) -> float:
|
|
418
|
+
"""
|
|
419
|
+
Get currently available VRAM.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Available VRAM in GB
|
|
423
|
+
"""
|
|
424
|
+
if torch.cuda.is_available():
|
|
425
|
+
available_vram_gb = (
|
|
426
|
+
torch.cuda.get_device_properties(0).total_memory
|
|
427
|
+
- torch.cuda.memory_allocated()
|
|
428
|
+
) / (1024**3)
|
|
429
|
+
logger.debug(f"Available VRAM: {available_vram_gb:.2f} GB")
|
|
430
|
+
return available_vram_gb
|
|
431
|
+
return 100
|
|
432
|
+
|
|
433
|
+
def _need_optimization(self, model_size: float) -> bool:
|
|
434
|
+
"""
|
|
435
|
+
Check if a model needs optimization to fit in GPU memory.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
model_size: Size of the model in GB
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
True if the model needs optimization, False otherwise
|
|
442
|
+
"""
|
|
443
|
+
print(f"model_size: {model_size} GB")
|
|
444
|
+
print(f"system_ram: {self.system_ram} GB")
|
|
445
|
+
return model_size < self.system_ram
|
|
446
|
+
|
|
447
|
+
def _can_fit_gpu(self, model_size: float) -> bool:
|
|
448
|
+
"""
|
|
449
|
+
Check if a model can fit in GPU memory.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
model_size: Size of the model in GB
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
True if the model can fit in GPU memory, False otherwise
|
|
456
|
+
"""
|
|
457
|
+
print(f"model_size: {model_size} GB")
|
|
458
|
+
print(f"available_inference_memory: {self._get_available_vram() - VRAM_SAFETY_MARGIN_GB} GB")
|
|
459
|
+
return (model_size <= (self._get_available_vram() - VRAM_SAFETY_MARGIN_GB))
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def _ensure_model_files_on_disk(self, model_id: str, model_config_dict: Dict[str, Any]) -> bool:
|
|
463
|
+
"""
|
|
464
|
+
Internal helper: Ensures a specific model and its components are downloaded.
|
|
465
|
+
This is called by self.load() before attempting to load from disk.
|
|
466
|
+
This method is now the primary point for on-demand downloads.
|
|
467
|
+
Returns True if all necessary files are confirmed/downloaded, False otherwise.
|
|
468
|
+
"""
|
|
469
|
+
logger.info(f"DMM._ensure_model_files_on_disk: Checking/downloading '{model_id}'...")
|
|
470
|
+
all_parts_ok = True
|
|
471
|
+
async with self.model_downloader as downloader: # Use context manager for session
|
|
472
|
+
# 1. Main model source
|
|
473
|
+
main_source_str = model_config_dict.get("source")
|
|
474
|
+
if main_source_str:
|
|
475
|
+
is_main_downloaded, _ = await downloader.is_downloaded(model_id, model_config=model_config_dict)
|
|
476
|
+
if not is_main_downloaded:
|
|
477
|
+
logger.info(f"DMM: Main files for '{model_id}' (source: {main_source_str}) not on disk. Downloading...")
|
|
478
|
+
try:
|
|
479
|
+
await downloader.download_model(model_id, ModelSource(main_source_str))
|
|
480
|
+
is_main_downloaded, _ = await downloader.is_downloaded(model_id, model_config=model_config_dict) # Verify
|
|
481
|
+
if not is_main_downloaded:
|
|
482
|
+
logger.error(f"DMM: Download verification FAILED for main files of '{model_id}'.")
|
|
483
|
+
all_parts_ok = False
|
|
484
|
+
else:
|
|
485
|
+
logger.info(f"DMM: Main files for '{model_id}' downloaded and verified.")
|
|
486
|
+
except Exception as e:
|
|
487
|
+
logger.exception(f"DMM: Exception during download of main files for '{model_id}': {e}")
|
|
488
|
+
all_parts_ok = False
|
|
489
|
+
else:
|
|
490
|
+
logger.info(f"DMM: Main files for '{model_id}' already on disk.")
|
|
491
|
+
else:
|
|
492
|
+
logger.error(f"DMM: No source found for model '{model_id}' in its config. Cannot download.")
|
|
493
|
+
return False # Cannot proceed without a source
|
|
494
|
+
|
|
495
|
+
# 2. Components (if main download was okay or not needed, and components exist)
|
|
496
|
+
if all_parts_ok and model_config_dict.get("components") and isinstance(model_config_dict["components"], dict):
|
|
497
|
+
for comp_name, comp_details in model_config_dict["components"].items():
|
|
498
|
+
comp_source_str = None
|
|
499
|
+
if isinstance(comp_details, dict): comp_source_str = comp_details.get("source")
|
|
500
|
+
elif isinstance(comp_details, str): comp_source_str = comp_details
|
|
501
|
+
|
|
502
|
+
if comp_source_str:
|
|
503
|
+
comp_dl_id = f"{model_id}::{comp_name}"
|
|
504
|
+
# TODO: Need a robust way to check if component is downloaded via downloader.is_downloaded
|
|
505
|
+
# For now, we rely on downloader.download_model being idempotent.
|
|
506
|
+
logger.info(f"DMM: Ensuring component '{comp_name}' for '{model_id}' (source: {comp_source_str}) is downloaded...")
|
|
507
|
+
try:
|
|
508
|
+
await downloader.download_model(comp_dl_id, ModelSource(comp_source_str))
|
|
509
|
+
# Add verification for component if possible
|
|
510
|
+
logger.info(f"DMM: Component '{comp_name}' for '{model_id}' download attempt finished.")
|
|
511
|
+
except Exception as e:
|
|
512
|
+
logger.exception(f"DMM: Exception downloading component '{comp_name}' for '{model_id}': {e}")
|
|
513
|
+
all_parts_ok = False; break # Stop if a component fails
|
|
514
|
+
return all_parts_ok
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
async def load(
|
|
518
|
+
self, model_id: str, gpu: Optional[int] = None, pipe_type: Optional[str] = None
|
|
519
|
+
) -> Optional[DiffusionPipeline]:
|
|
520
|
+
"""
|
|
521
|
+
Load a model into memory, handling placement and optimization.
|
|
522
|
+
|
|
523
|
+
This method implements the main model loading logic, including:
|
|
524
|
+
- Checking if model is already loaded
|
|
525
|
+
- Managing memory allocation between GPU and CPU
|
|
526
|
+
- Applying optimizations as needed
|
|
527
|
+
- Handling model movement between devices
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
model_id: Identifier for the model to load
|
|
531
|
+
gpu: Optional GPU device number
|
|
532
|
+
pipe_type: Optional pipeline type specification
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
Loaded pipeline or None if loading failed
|
|
536
|
+
"""
|
|
537
|
+
logger.info(f"Processing model {model_id}")
|
|
538
|
+
|
|
539
|
+
if not self.supports_all_known_models:
|
|
540
|
+
if self.allowed_model_ids is None: # Should have been set if not supports_all
|
|
541
|
+
logger.error(f"DMM.load: Inconsistent state - worker not configured for all models, but allowed_model_ids is None. Denying load for '{model_id}'.")
|
|
542
|
+
return None
|
|
543
|
+
if model_id not in self.allowed_model_ids:
|
|
544
|
+
logger.error(f"DMM.load: Model '{model_id}' is NOT in the allowed list for this worker instance. Allowed: {self.allowed_model_ids}. Denying load.")
|
|
545
|
+
return None
|
|
546
|
+
|
|
547
|
+
# If self.supports_all_known_models is True, any model_id is principally allowed, contingent on it being defined in the global pipeline_defs.
|
|
548
|
+
|
|
549
|
+
logger.debug(f"DMM.load: Model '{model_id}' is permitted for this worker instance.")
|
|
550
|
+
|
|
551
|
+
# Check existing loaded models
|
|
552
|
+
if model_id in self.loaded_models:
|
|
553
|
+
logger.info(f"Model {model_id} already loaded in GPU")
|
|
554
|
+
self.lru_cache.access(model_id, "gpu")
|
|
555
|
+
return self.loaded_models[model_id]
|
|
556
|
+
|
|
557
|
+
# CPU-offloaded models
|
|
558
|
+
if model_id in self.cpu_models:
|
|
559
|
+
logger.info(f"Model {model_id} is CPU-offloaded and ready")
|
|
560
|
+
self.lru_cache.access(model_id, "cpu")
|
|
561
|
+
return self.cpu_models[model_id]
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
# On-Demand model download
|
|
565
|
+
cfg = get_config()
|
|
566
|
+
if not cfg or not cfg.pipeline_defs:
|
|
567
|
+
logger.error("DefaultModelManager: Global config or pipeline_defs not loaded. Cannot download models.")
|
|
568
|
+
return None
|
|
569
|
+
|
|
570
|
+
model_config_dict = cfg.pipeline_defs.get(model_id)
|
|
571
|
+
if not model_config_dict:
|
|
572
|
+
logger.error(f"Model {model_id} not found in configuration")
|
|
573
|
+
return None
|
|
574
|
+
|
|
575
|
+
async with self._download_lock: # Ensure only one download process at a time for this MMM instance
|
|
576
|
+
if not await self._ensure_model_files_on_disk(model_id, model_config_dict):
|
|
577
|
+
logger.error(f"Failed to ensure model files for {model_id} are on disk")
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
# Load new model
|
|
581
|
+
return await self._load_new_model(model_id, gpu, pipe_type)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
async def _load_new_model(
|
|
585
|
+
self, model_id: str, gpu: Optional[int] = None, pipe_type: Optional[str] = None
|
|
586
|
+
) -> Optional[DiffusionPipeline]:
|
|
587
|
+
"""
|
|
588
|
+
Load a new model that isn't currently in memory.
|
|
589
|
+
|
|
590
|
+
Handles the complete loading process including:
|
|
591
|
+
- Configuration validation
|
|
592
|
+
- Memory allocation
|
|
593
|
+
- Model loading
|
|
594
|
+
- Optimization application
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
model_id: Identifier for the model
|
|
598
|
+
gpu: Optional GPU device number
|
|
599
|
+
pipe_type: Optional pipeline type specification (deprecated, use class_name in config)
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
Loaded pipeline or None if loading failed
|
|
603
|
+
"""
|
|
604
|
+
try:
|
|
605
|
+
config = get_config()
|
|
606
|
+
model_config = config.pipeline_defs.get(model_id)
|
|
607
|
+
|
|
608
|
+
if not model_config:
|
|
609
|
+
logger.error(f"Model {model_id} not found in configuration")
|
|
610
|
+
return None
|
|
611
|
+
|
|
612
|
+
# Prepare memory
|
|
613
|
+
estimated_size = await self._get_model_size(model_config, model_id) + 1.0 # add 1GB for safety margin (model inference overhead in memory)
|
|
614
|
+
|
|
615
|
+
print(f"estimated_size for model {model_id}: {estimated_size} GB")
|
|
616
|
+
|
|
617
|
+
# try direct gpu load
|
|
618
|
+
if self._can_fit_gpu(estimated_size):
|
|
619
|
+
pipeline = await self._load_model_by_source(model_id, model_config)
|
|
620
|
+
if pipeline is None:
|
|
621
|
+
return None
|
|
622
|
+
|
|
623
|
+
self._setup_scheduler(pipeline, model_id)
|
|
624
|
+
if self._move_model_to_gpu(pipeline, model_id):
|
|
625
|
+
self.loaded_models[model_id] = pipeline
|
|
626
|
+
self.model_sizes[model_id] = estimated_size
|
|
627
|
+
self.vram_usage += estimated_size
|
|
628
|
+
self.lru_cache.access(model_id, "gpu")
|
|
629
|
+
# Background prefetch other deployment models
|
|
630
|
+
if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
|
|
631
|
+
if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
|
|
632
|
+
sources = {}
|
|
633
|
+
for mid in self.allowed_model_ids:
|
|
634
|
+
if mid != model_id:
|
|
635
|
+
model_cfg = config.pipeline_defs.get(mid)
|
|
636
|
+
if model_cfg:
|
|
637
|
+
sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
|
|
638
|
+
if sources:
|
|
639
|
+
logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
|
|
640
|
+
asyncio.create_task(
|
|
641
|
+
self.flashpack_loader.prefetch_deployment_models(
|
|
642
|
+
list(self.allowed_model_ids), sources, exclude_model_id=model_id
|
|
643
|
+
)
|
|
644
|
+
)
|
|
645
|
+
return pipeline
|
|
646
|
+
else:
|
|
647
|
+
logger.error(f"Failed to move {model_id} to GPU")
|
|
648
|
+
return None
|
|
649
|
+
|
|
650
|
+
# TODO: Change this ugly condition
|
|
651
|
+
if estimated_size <= self.max_vram - (VRAM_SAFETY_MARGIN_GB - DEFAULT_MAX_VRAM_BUFFER_GB) and (len(self.loaded_models) > 0 or len(self.cpu_models) > 0):
|
|
652
|
+
# if not enough space, try to make space
|
|
653
|
+
self._free_space_for_model(estimated_size)
|
|
654
|
+
if self._can_fit_gpu(estimated_size):
|
|
655
|
+
pipeline = await self._load_model_by_source(model_id, model_config)
|
|
656
|
+
if pipeline is None:
|
|
657
|
+
return None
|
|
658
|
+
|
|
659
|
+
self._setup_scheduler(pipeline, model_id)
|
|
660
|
+
if self._move_model_to_gpu(pipeline, model_id):
|
|
661
|
+
self.loaded_models[model_id] = pipeline
|
|
662
|
+
self.model_sizes[model_id] = estimated_size
|
|
663
|
+
self.vram_usage += estimated_size
|
|
664
|
+
self.lru_cache.access(model_id, "gpu")
|
|
665
|
+
# Background prefetch other deployment models
|
|
666
|
+
if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
|
|
667
|
+
if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
|
|
668
|
+
sources = {}
|
|
669
|
+
for mid in self.allowed_model_ids:
|
|
670
|
+
if mid != model_id:
|
|
671
|
+
model_cfg = config.pipeline_defs.get(mid)
|
|
672
|
+
if model_cfg:
|
|
673
|
+
sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
|
|
674
|
+
if sources:
|
|
675
|
+
logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
|
|
676
|
+
asyncio.create_task(
|
|
677
|
+
self.flashpack_loader.prefetch_deployment_models(
|
|
678
|
+
list(self.allowed_model_ids), sources, exclude_model_id=model_id
|
|
679
|
+
)
|
|
680
|
+
)
|
|
681
|
+
return pipeline
|
|
682
|
+
else:
|
|
683
|
+
logger.error(f"Failed to move {model_id} to GPU")
|
|
684
|
+
return None
|
|
685
|
+
|
|
686
|
+
# if still not enough VRAM, apply optimization (CPU offload)
|
|
687
|
+
if self.environment != "prod":
|
|
688
|
+
if self._need_optimization(estimated_size):
|
|
689
|
+
logger.info("Unloading all models for large model loading")
|
|
690
|
+
|
|
691
|
+
for model_id_to_unload in list(self.loaded_models.keys()):
|
|
692
|
+
self._unload_model_for_space(model_id_to_unload, self.model_sizes[model_id_to_unload], "gpu")
|
|
693
|
+
for model_id_to_unload in list(self.cpu_models.keys()):
|
|
694
|
+
self._unload_model_for_space(model_id_to_unload, self.model_sizes[model_id_to_unload], "cpu")
|
|
695
|
+
|
|
696
|
+
pipeline = await self._load_model_by_source(model_id, model_config)
|
|
697
|
+
if pipeline is None:
|
|
698
|
+
return None
|
|
699
|
+
|
|
700
|
+
self._setup_scheduler(pipeline, model_id)
|
|
701
|
+
logger.info(f"Applying optimizations for {model_id}")
|
|
702
|
+
self.apply_optimizations(pipeline, model_id, True)
|
|
703
|
+
|
|
704
|
+
self.cpu_models[model_id] = pipeline
|
|
705
|
+
self.model_sizes[model_id] = estimated_size
|
|
706
|
+
self.lru_cache.access(model_id, "cpu")
|
|
707
|
+
# Background prefetch other deployment models
|
|
708
|
+
if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
|
|
709
|
+
if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
|
|
710
|
+
sources = {}
|
|
711
|
+
for mid in self.allowed_model_ids:
|
|
712
|
+
if mid != model_id:
|
|
713
|
+
model_cfg = config.pipeline_defs.get(mid)
|
|
714
|
+
if model_cfg:
|
|
715
|
+
sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
|
|
716
|
+
if sources:
|
|
717
|
+
logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
|
|
718
|
+
asyncio.create_task(
|
|
719
|
+
self.flashpack_loader.prefetch_deployment_models(
|
|
720
|
+
list(self.allowed_model_ids), sources, exclude_model_id=model_id
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
return pipeline
|
|
724
|
+
|
|
725
|
+
logger.error(f"Insufficient memory to load model {model_id}")
|
|
726
|
+
return None
|
|
727
|
+
|
|
728
|
+
except Exception as e:
|
|
729
|
+
logger.error("Error loading new model {}: {}".format(model_id, str(e)))
|
|
730
|
+
traceback.print_exc()
|
|
731
|
+
return None
|
|
732
|
+
|
|
733
|
+
async def _load_model_by_source(
|
|
734
|
+
self, model_id: str, model_config: PipelineConfig) -> Optional[DiffusionPipeline]:
|
|
735
|
+
"""
|
|
736
|
+
Load a model based on its source configuration.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
model_id: Model identifier
|
|
740
|
+
model_config: Model configuration from PipelineConfig
|
|
741
|
+
gpu: Optional GPU device number
|
|
742
|
+
|
|
743
|
+
Returns:
|
|
744
|
+
Loaded pipeline or None if loading failed
|
|
745
|
+
"""
|
|
746
|
+
|
|
747
|
+
if isinstance(model_config, dict):
|
|
748
|
+
source = model_config.get("source")
|
|
749
|
+
class_name = model_config.get("class_name")
|
|
750
|
+
else:
|
|
751
|
+
source = model_config.source
|
|
752
|
+
class_name = model_config.class_name
|
|
753
|
+
|
|
754
|
+
prefix, path = source.split(":", 1)
|
|
755
|
+
|
|
756
|
+
try:
|
|
757
|
+
# === LOCAL CACHE + FLASHPACK ===
|
|
758
|
+
# Try FlashPack first
|
|
759
|
+
flashpack_path = self.flashpack_loader.get_flashpack_path(model_id, source)
|
|
760
|
+
if flashpack_path:
|
|
761
|
+
(pipeline_class, _) = get_pipeline_class(class_name)
|
|
762
|
+
pipeline = await self.flashpack_loader.load_from_flashpack(
|
|
763
|
+
model_id, flashpack_path, pipeline_class
|
|
764
|
+
)
|
|
765
|
+
if pipeline:
|
|
766
|
+
return pipeline
|
|
767
|
+
logger.warning(f"FlashPack loading failed for {model_id}, falling back")
|
|
768
|
+
|
|
769
|
+
# No FlashPack - copy safetensors to local cache anyway
|
|
770
|
+
if self.flashpack_loader.local_cache:
|
|
771
|
+
local_path = await self.flashpack_loader.local_cache.ensure_local(
|
|
772
|
+
model_id, source, priority=True
|
|
773
|
+
)
|
|
774
|
+
if local_path:
|
|
775
|
+
logger.info(f"📂 Loading from local cache: {local_path}")
|
|
776
|
+
if local_path.suffix == ".safetensors" or local_path.is_file():
|
|
777
|
+
path = str(local_path)
|
|
778
|
+
prefix = "file"
|
|
779
|
+
else:
|
|
780
|
+
# It's a directory (HF model)
|
|
781
|
+
path = str(local_path)
|
|
782
|
+
prefix = "hf"
|
|
783
|
+
|
|
784
|
+
# === STANDARD LOADING: Fallback to standard loading if FlashPack fails ===
|
|
785
|
+
if prefix == "hf":
|
|
786
|
+
is_downloaded, variant = await self.model_downloader.is_downloaded(
|
|
787
|
+
model_id,
|
|
788
|
+
model_config
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
if not is_downloaded:
|
|
792
|
+
logger.info(f"Model {model_id} not downloaded")
|
|
793
|
+
return None
|
|
794
|
+
|
|
795
|
+
# Get model index and use as fallback in case class_name is unspecified
|
|
796
|
+
if class_name is None or class_name == "":
|
|
797
|
+
model_index = await self.model_downloader.get_diffusers_multifolder_components(
|
|
798
|
+
path
|
|
799
|
+
)
|
|
800
|
+
if model_index and "_class_name" in model_index:
|
|
801
|
+
class_name = model_index["_class_name"]
|
|
802
|
+
else:
|
|
803
|
+
logger.error(f"Unknown diffusers class_name for {model_id}")
|
|
804
|
+
return None
|
|
805
|
+
|
|
806
|
+
return await self.load_huggingface_model(
|
|
807
|
+
model_id, path, class_name, variant, model_config
|
|
808
|
+
)
|
|
809
|
+
elif prefix in ["file", "ct"]:
|
|
810
|
+
return await self.load_single_file_model(
|
|
811
|
+
model_id, path, prefix, class_name
|
|
812
|
+
)
|
|
813
|
+
elif source.startswith(("http://", "https://")):
|
|
814
|
+
# Handle Civitai/direct download models
|
|
815
|
+
source_obj = ModelSource(source)
|
|
816
|
+
is_downloaded = await self.model_downloader.is_downloaded(model_id)
|
|
817
|
+
if not is_downloaded:
|
|
818
|
+
logger.info(f"Model {model_id} not downloaded")
|
|
819
|
+
return None
|
|
820
|
+
|
|
821
|
+
cached_path = await self.model_downloader._get_cache_path(
|
|
822
|
+
model_id, source_obj
|
|
823
|
+
)
|
|
824
|
+
if not os.path.exists(cached_path):
|
|
825
|
+
logger.error(f"Cached model file not found at {cached_path}")
|
|
826
|
+
return None
|
|
827
|
+
|
|
828
|
+
return await self.load_single_file_model(
|
|
829
|
+
model_id, cached_path, "file", class_name
|
|
830
|
+
)
|
|
831
|
+
else:
|
|
832
|
+
logger.error(f"Unsupported model source prefix: {prefix}")
|
|
833
|
+
return None
|
|
834
|
+
except Exception as e:
|
|
835
|
+
logger.error("Error loading model from source: {}".format(str(e)))
|
|
836
|
+
return None
|
|
837
|
+
|
|
838
|
+
def _free_space_for_model(self, model_size: float) -> None:
|
|
839
|
+
available_vram = self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
|
|
840
|
+
if available_vram >= model_size:
|
|
841
|
+
return
|
|
842
|
+
|
|
843
|
+
space_needed = model_size - available_vram
|
|
844
|
+
logger.info(f"Need to free {space_needed:.2f} GB of VRAM")
|
|
845
|
+
|
|
846
|
+
gpu_models = [
|
|
847
|
+
(mid, self.model_sizes[mid])
|
|
848
|
+
for mid in self.lru_cache.get_lru_models("gpu")
|
|
849
|
+
]
|
|
850
|
+
|
|
851
|
+
freed_space = 0
|
|
852
|
+
for model_id_to_unload, size in gpu_models:
|
|
853
|
+
if freed_space >= space_needed:
|
|
854
|
+
break
|
|
855
|
+
|
|
856
|
+
freed_space += self._unload_model_for_space(model_id_to_unload, size, "gpu")
|
|
857
|
+
|
|
858
|
+
if self._get_available_vram() - VRAM_SAFETY_MARGIN_GB >= model_size:
|
|
859
|
+
break
|
|
860
|
+
|
|
861
|
+
|
|
862
|
+
def _unload_model_for_space(
|
|
863
|
+
self, model_id: str, model_size: float, device: str
|
|
864
|
+
) -> float:
|
|
865
|
+
"""
|
|
866
|
+
Completely unload a model and free up GPU/CPU memory.
|
|
867
|
+
|
|
868
|
+
Args:
|
|
869
|
+
model_id: Model identifier
|
|
870
|
+
model_size: Size of the model in GB
|
|
871
|
+
device: Device to unload from ("gpu" or "cpu")
|
|
872
|
+
Returns:
|
|
873
|
+
Amount of space freed in GB
|
|
874
|
+
"""
|
|
875
|
+
logger.info(f"Unloading {model_id} from memory")
|
|
876
|
+
|
|
877
|
+
try:
|
|
878
|
+
if device == "gpu" and model_id in self.loaded_models:
|
|
879
|
+
pipeline = self.loaded_models.pop(model_id)
|
|
880
|
+
|
|
881
|
+
if hasattr(pipeline, "remove_all_hooks"):
|
|
882
|
+
logger.info(f"Removing all hooks for model {model_id}")
|
|
883
|
+
pipeline.remove_all_hooks()
|
|
884
|
+
|
|
885
|
+
# # Move model to CPU first to clear CUDA memory
|
|
886
|
+
# if hasattr(pipeline, "to"):
|
|
887
|
+
# pipeline.to("cpu", silence_dtype_warnings=True)
|
|
888
|
+
|
|
889
|
+
# Explicitly delete model components
|
|
890
|
+
for attr in [
|
|
891
|
+
"vae",
|
|
892
|
+
"unet",
|
|
893
|
+
"text_encoder",
|
|
894
|
+
"text_encoder_2",
|
|
895
|
+
"tokenizer",
|
|
896
|
+
"scheduler",
|
|
897
|
+
"transformer",
|
|
898
|
+
"tokenizer_2",
|
|
899
|
+
"text_encoder_3",
|
|
900
|
+
"tokenizer_3",
|
|
901
|
+
]:
|
|
902
|
+
if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
|
|
903
|
+
component = getattr(pipeline, attr)
|
|
904
|
+
delattr(pipeline, attr)
|
|
905
|
+
del component
|
|
906
|
+
|
|
907
|
+
# Delete pipeline reference
|
|
908
|
+
# del self.loaded_models[model_id]
|
|
909
|
+
self.vram_usage -= model_size
|
|
910
|
+
self.lru_cache.remove(model_id, "gpu")
|
|
911
|
+
del pipeline
|
|
912
|
+
|
|
913
|
+
if model_id in self.cpu_models:
|
|
914
|
+
logger.info(f"Unloading {model_id} from CPU memory")
|
|
915
|
+
cpu_pipeline = self.cpu_models.pop(model_id)
|
|
916
|
+
|
|
917
|
+
if hasattr(cpu_pipeline, "remove_all_hooks"):
|
|
918
|
+
logger.info(f"Removing all hooks for CPU-offloaded model {model_id}")
|
|
919
|
+
cpu_pipeline.remove_all_hooks()
|
|
920
|
+
|
|
921
|
+
self.lru_cache.remove(model_id, "cpu")
|
|
922
|
+
del cpu_pipeline
|
|
923
|
+
|
|
924
|
+
# Force garbage collection and memory clearing
|
|
925
|
+
self.flush_memory()
|
|
926
|
+
|
|
927
|
+
return model_size
|
|
928
|
+
|
|
929
|
+
except Exception as e:
|
|
930
|
+
logger.error(
|
|
931
|
+
"Error during model unloading for {}: {}".format(model_id, str(e))
|
|
932
|
+
)
|
|
933
|
+
# Still try to clean up references even if error occurs
|
|
934
|
+
self.loaded_models.pop(model_id, None)
|
|
935
|
+
self.cpu_models.pop(model_id, None)
|
|
936
|
+
self.vram_usage -= model_size
|
|
937
|
+
self.lru_cache.remove(model_id, "gpu")
|
|
938
|
+
self.lru_cache.remove(model_id, "cpu")
|
|
939
|
+
self.flush_memory()
|
|
940
|
+
return model_size
|
|
941
|
+
|
|
942
|
+
def _move_model_to_gpu(self, pipeline: DiffusionPipeline, model_id: str) -> bool:
|
|
943
|
+
"""
|
|
944
|
+
Safely move a model to GPU memory with proper dtype handling.
|
|
945
|
+
|
|
946
|
+
Args:
|
|
947
|
+
pipeline: The pipeline to move
|
|
948
|
+
model_id: Identifier of the model
|
|
949
|
+
|
|
950
|
+
Returns:
|
|
951
|
+
Boolean indicating success of the operation
|
|
952
|
+
"""
|
|
953
|
+
# Skip if it's OmniGen pipeline
|
|
954
|
+
if pipeline.__class__.__name__ == "OmniGenPipeline":
|
|
955
|
+
return True
|
|
956
|
+
|
|
957
|
+
try:
|
|
958
|
+
device = get_available_torch_device()
|
|
959
|
+
print("Moving to device", device)
|
|
960
|
+
|
|
961
|
+
# check if pippeline supports vae tiling and slicing
|
|
962
|
+
if hasattr(pipeline, "enable_vae_tiling"):
|
|
963
|
+
logger.info("Enabling vae tiling")
|
|
964
|
+
pipeline.enable_vae_tiling()
|
|
965
|
+
if hasattr(pipeline, "enable_vae_slicing"):
|
|
966
|
+
logger.info("Enabling vae slicing")
|
|
967
|
+
pipeline.enable_vae_slicing()
|
|
968
|
+
|
|
969
|
+
pipeline = pipeline.to(device=device)
|
|
970
|
+
|
|
971
|
+
print("Done with moving to GPU")
|
|
972
|
+
|
|
973
|
+
self.flush_memory()
|
|
974
|
+
return True
|
|
975
|
+
|
|
976
|
+
except RuntimeError as e:
|
|
977
|
+
if "out of memory" in str(e):
|
|
978
|
+
logger.error(f"GPU out of memory while moving model {model_id} to GPU")
|
|
979
|
+
else:
|
|
980
|
+
logger.error(
|
|
981
|
+
"Runtime error moving model {} to GPU: {}".format(model_id, str(e))
|
|
982
|
+
)
|
|
983
|
+
self.flush_memory()
|
|
984
|
+
return False
|
|
985
|
+
|
|
986
|
+
async def load_huggingface_model(
|
|
987
|
+
self,
|
|
988
|
+
model_id: str,
|
|
989
|
+
repo_id: str,
|
|
990
|
+
class_name: Union[str, Tuple[str, str]],
|
|
991
|
+
variant: Optional[str] = None,
|
|
992
|
+
model_config: Optional[PipelineConfig] = None,
|
|
993
|
+
) -> Optional[DiffusionPipeline]:
|
|
994
|
+
"""
|
|
995
|
+
Load a model from HuggingFace.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
model_id: Model identifier
|
|
999
|
+
repo_id: HuggingFace repository ID
|
|
1000
|
+
gpu: Optional GPU device number
|
|
1001
|
+
class_name: Pipeline class name or (package, class) tuple
|
|
1002
|
+
variant: Optional model variant
|
|
1003
|
+
model_config: Optional model configuration
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Loaded pipeline or None if loading failed
|
|
1007
|
+
"""
|
|
1008
|
+
try:
|
|
1009
|
+
pipeline_kwargs = await self._prepare_pipeline_kwargs(model_config, variant)
|
|
1010
|
+
variant = None if variant == "" else variant
|
|
1011
|
+
# TO DO: make this more robust
|
|
1012
|
+
torch_dtype = (
|
|
1013
|
+
torch.bfloat16 if "flux" in model_id.lower() else torch.float16
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
# Get appropriate pipeline class
|
|
1017
|
+
(pipeline_class, custom_pipeline) = get_pipeline_class(class_name)
|
|
1018
|
+
print("custompipeline", custom_pipeline)
|
|
1019
|
+
|
|
1020
|
+
if custom_pipeline is not None:
|
|
1021
|
+
pipeline_kwargs["custom_pipeline"] = custom_pipeline
|
|
1022
|
+
|
|
1023
|
+
print(
|
|
1024
|
+
f"repo_id={repo_id},torch_dtype={torch_dtype},local_files_only=True,variant={variant},pipeline_kwargs={pipeline_kwargs},"
|
|
1025
|
+
)
|
|
1026
|
+
try:
|
|
1027
|
+
pipeline = await asyncio.to_thread(pipeline_class.from_pretrained,
|
|
1028
|
+
repo_id,
|
|
1029
|
+
torch_dtype=torch_dtype,
|
|
1030
|
+
local_files_only=True,
|
|
1031
|
+
variant=variant,
|
|
1032
|
+
**pipeline_kwargs,
|
|
1033
|
+
)
|
|
1034
|
+
except EntryNotFoundError as e:
|
|
1035
|
+
print(f"Custom pipeline '{custom_pipeline}' not found: {e}")
|
|
1036
|
+
print("Falling back to the default pipeline...")
|
|
1037
|
+
del pipeline_kwargs["custom_pipeline"]
|
|
1038
|
+
pipeline = await asyncio.to_thread(pipeline_class.from_pretrained,
|
|
1039
|
+
repo_id,
|
|
1040
|
+
variant=variant,
|
|
1041
|
+
torch_dtype=torch_dtype,
|
|
1042
|
+
**pipeline_kwargs,
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
self.flush_memory()
|
|
1046
|
+
self.loaded_model = pipeline
|
|
1047
|
+
self.current_model = model_id
|
|
1048
|
+
logger.info(f"Model {model_id} loaded successfully")
|
|
1049
|
+
return pipeline
|
|
1050
|
+
|
|
1051
|
+
except Exception as e:
|
|
1052
|
+
traceback.print_exc()
|
|
1053
|
+
logger.error("Failed to load model {}: {}".format(model_id, str(e)))
|
|
1054
|
+
return None
|
|
1055
|
+
|
|
1056
|
+
async def _prepare_pipeline_kwargs(
|
|
1057
|
+
self, model_config: Optional[PipelineConfig], variant: Optional[str] = None
|
|
1058
|
+
) -> Dict[str, Any]:
|
|
1059
|
+
"""
|
|
1060
|
+
Prepare kwargs for pipeline initialization.
|
|
1061
|
+
|
|
1062
|
+
Args:
|
|
1063
|
+
model_config: Model configuration from PipelineConfig
|
|
1064
|
+
|
|
1065
|
+
Returns:
|
|
1066
|
+
Dictionary of pipeline initialization arguments
|
|
1067
|
+
"""
|
|
1068
|
+
if not model_config:
|
|
1069
|
+
return {}
|
|
1070
|
+
|
|
1071
|
+
pipeline_kwargs = {}
|
|
1072
|
+
try:
|
|
1073
|
+
if isinstance(model_config, dict):
|
|
1074
|
+
class_name = model_config.get("class_name")
|
|
1075
|
+
components = model_config.get("components")
|
|
1076
|
+
custom_pipeline = model_config.get("custom_pipeline")
|
|
1077
|
+
else:
|
|
1078
|
+
class_name = model_config.class_name
|
|
1079
|
+
components = model_config.components
|
|
1080
|
+
|
|
1081
|
+
if custom_pipeline:
|
|
1082
|
+
pipeline_kwargs["custom_pipeline"] = custom_pipeline
|
|
1083
|
+
|
|
1084
|
+
if components:
|
|
1085
|
+
for key, component in components.items():
|
|
1086
|
+
main_model_source = model_config.get("source")
|
|
1087
|
+
# Check if component and also id component has .source attribute
|
|
1088
|
+
if component:
|
|
1089
|
+
component_source = component.get("source", None)
|
|
1090
|
+
if component_source:
|
|
1091
|
+
pipeline_kwargs[key] = await self._prepare_component(
|
|
1092
|
+
main_model_source, component, class_name, key, variant
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
return pipeline_kwargs
|
|
1096
|
+
|
|
1097
|
+
except Exception as e:
|
|
1098
|
+
logger.error("Error preparing pipeline kwargs: {}".format(str(e)))
|
|
1099
|
+
return {}
|
|
1100
|
+
|
|
1101
|
+
async def _prepare_component(
|
|
1102
|
+
self,
|
|
1103
|
+
main_model_source: str,
|
|
1104
|
+
component: PipelineConfig,
|
|
1105
|
+
model_class_name: Optional[Union[str, Tuple[str, str]]],
|
|
1106
|
+
key: str,
|
|
1107
|
+
variant: Optional[str] = None,
|
|
1108
|
+
) -> Any:
|
|
1109
|
+
"""
|
|
1110
|
+
Prepare a model component based on its configuration.
|
|
1111
|
+
|
|
1112
|
+
Args:
|
|
1113
|
+
component: Component configuration
|
|
1114
|
+
model_class_name: Class name of the parent model
|
|
1115
|
+
key: Component key
|
|
1116
|
+
|
|
1117
|
+
Returns:
|
|
1118
|
+
Loaded component or None if loading failed
|
|
1119
|
+
"""
|
|
1120
|
+
try:
|
|
1121
|
+
if isinstance(component, dict):
|
|
1122
|
+
source = component.get("source")
|
|
1123
|
+
else:
|
|
1124
|
+
source = component.source
|
|
1125
|
+
|
|
1126
|
+
if not source.endswith((".safetensors", ".bin", ".ckpt", ".pt")):
|
|
1127
|
+
# check if the url has more than 2 forward slashes. If it does, the last one is the subfolder, the source is the first part
|
|
1128
|
+
# e.g. hf:cozy-creator/Flux.1-schnell-8bit/transformer this will be, source = hf:cozy-creator/Flux.1-schnell-8bit, subfolder = transformer
|
|
1129
|
+
if source.count("/") > 1:
|
|
1130
|
+
repo_id = "/".join(source.split("/")[:-1])
|
|
1131
|
+
subfolder = source.split("/")[-1]
|
|
1132
|
+
return await self._load_diffusers_component(
|
|
1133
|
+
main_model_source.replace("hf:", ""),
|
|
1134
|
+
repo_id.replace("hf:", ""),
|
|
1135
|
+
subfolder,
|
|
1136
|
+
variant,
|
|
1137
|
+
)
|
|
1138
|
+
else:
|
|
1139
|
+
return await self._load_diffusers_component(
|
|
1140
|
+
main_model_source.replace("hf:", ""),
|
|
1141
|
+
source.replace("hf:", ""),
|
|
1142
|
+
variant,
|
|
1143
|
+
)
|
|
1144
|
+
else:
|
|
1145
|
+
return self._load_custom_component(source, model_class_name, key)
|
|
1146
|
+
except Exception as e:
|
|
1147
|
+
logger.error("Error preparing component {}: {}".format(key, str(e)))
|
|
1148
|
+
return None
|
|
1149
|
+
|
|
1150
|
+
async def load_single_file_model(
|
|
1151
|
+
self,
|
|
1152
|
+
model_id: str,
|
|
1153
|
+
path: str,
|
|
1154
|
+
prefix: str,
|
|
1155
|
+
class_name: Optional[Union[str, Tuple[str, str]]] = None,
|
|
1156
|
+
) -> Optional[DiffusionPipeline]:
|
|
1157
|
+
"""
|
|
1158
|
+
Load a model from a single file.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
model_id: Model identifier
|
|
1162
|
+
path: Path to model file
|
|
1163
|
+
prefix: Source prefix (file/ct)
|
|
1164
|
+
gpu: Optional GPU device number
|
|
1165
|
+
class_name: Pipeline class name or (package, class) tuple
|
|
1166
|
+
|
|
1167
|
+
Returns:
|
|
1168
|
+
Loaded pipeline or None if loading failed
|
|
1169
|
+
"""
|
|
1170
|
+
logger.info(f"Loading single file model {model_id}")
|
|
1171
|
+
|
|
1172
|
+
# TO DO: we could try inferring the class using our old detect_model code here!
|
|
1173
|
+
if class_name is None or class_name == "":
|
|
1174
|
+
logger.error("Model class_name must be specified for single file models")
|
|
1175
|
+
return None
|
|
1176
|
+
|
|
1177
|
+
(pipeline_class, custom_pipeline) = get_pipeline_class(class_name)
|
|
1178
|
+
|
|
1179
|
+
try:
|
|
1180
|
+
print(f"Model path: {path}")
|
|
1181
|
+
if prefix != "file":
|
|
1182
|
+
model_path = self._get_model_path(path, prefix)
|
|
1183
|
+
if not os.path.exists(model_path):
|
|
1184
|
+
logger.error(f"Model file not found: {model_path}")
|
|
1185
|
+
return None
|
|
1186
|
+
else:
|
|
1187
|
+
model_path = path
|
|
1188
|
+
|
|
1189
|
+
pipeline = await self._load_pipeline_from_file(
|
|
1190
|
+
pipeline_class, model_path, model_id, class_name
|
|
1191
|
+
)
|
|
1192
|
+
if pipeline:
|
|
1193
|
+
self.loaded_model = pipeline
|
|
1194
|
+
self.current_model = model_id
|
|
1195
|
+
return pipeline
|
|
1196
|
+
|
|
1197
|
+
except Exception as e:
|
|
1198
|
+
logger.error("Error loading single file model: {}".format(str(e)))
|
|
1199
|
+
return None
|
|
1200
|
+
|
|
1201
|
+
def _get_model_path(self, path: str, prefix: str) -> str:
|
|
1202
|
+
if prefix == "ct" and ("http" in path or "https" in path):
|
|
1203
|
+
path = path.split("/")[-1]
|
|
1204
|
+
return os.path.join(get_config().models_path, path)
|
|
1205
|
+
|
|
1206
|
+
async def _load_pipeline_from_file(
|
|
1207
|
+
self,
|
|
1208
|
+
pipeline_class: Type[DiffusionPipeline],
|
|
1209
|
+
model_path: str,
|
|
1210
|
+
model_id: str,
|
|
1211
|
+
class_name: Optional[Union[str, Tuple[str, str]]],
|
|
1212
|
+
) -> Optional[DiffusionPipeline]:
|
|
1213
|
+
"""
|
|
1214
|
+
Load a pipeline from a file using appropriate loading method.
|
|
1215
|
+
|
|
1216
|
+
Args:
|
|
1217
|
+
pipeline_class: Class to instantiate pipeline
|
|
1218
|
+
model_path: Path to model file
|
|
1219
|
+
model_id: Model identifier
|
|
1220
|
+
class_name: Pipeline class name or (package, class) tuple
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
Loaded pipeline or None if loading failed
|
|
1224
|
+
"""
|
|
1225
|
+
try:
|
|
1226
|
+
if issubclass(pipeline_class, FromSingleFileMixin):
|
|
1227
|
+
return self._load_from_single_file(pipeline_class, model_path, model_id)
|
|
1228
|
+
else:
|
|
1229
|
+
# Get the actual class name string if it's a tuple
|
|
1230
|
+
class_str = (
|
|
1231
|
+
class_name[1] if isinstance(class_name, tuple) else class_name
|
|
1232
|
+
)
|
|
1233
|
+
return self._load_custom_architecture(
|
|
1234
|
+
pipeline_class, model_path, class_str
|
|
1235
|
+
)
|
|
1236
|
+
except Exception as e:
|
|
1237
|
+
logger.error("Error loading pipeline from file: {}".format(str(e)))
|
|
1238
|
+
return None
|
|
1239
|
+
|
|
1240
|
+
def _load_from_single_file(
|
|
1241
|
+
self, pipeline_class: Any, path: str, model_id: str
|
|
1242
|
+
) -> DiffusionPipeline:
|
|
1243
|
+
"""
|
|
1244
|
+
Load a model pipeline from a single file using the FromSingleFileMixin.
|
|
1245
|
+
|
|
1246
|
+
Uses different torch datatypes based on model type:
|
|
1247
|
+
- bfloat16 for Flux models
|
|
1248
|
+
- float16 for other models
|
|
1249
|
+
|
|
1250
|
+
Args:
|
|
1251
|
+
pipeline_class: The pipeline class to instantiate
|
|
1252
|
+
path: Path to the model file
|
|
1253
|
+
model_id: Model identifier (used to determine model type)
|
|
1254
|
+
|
|
1255
|
+
Returns:
|
|
1256
|
+
Loaded pipeline instance
|
|
1257
|
+
|
|
1258
|
+
Raises:
|
|
1259
|
+
Exception: If loading fails
|
|
1260
|
+
"""
|
|
1261
|
+
try:
|
|
1262
|
+
# Determine appropriate dtype based on model type
|
|
1263
|
+
torch_dtype = torch.bfloat16 if "flux" in model_id.lower() else torch.float16
|
|
1264
|
+
|
|
1265
|
+
# Load the model with appropriate dtype
|
|
1266
|
+
pipeline = pipeline_class.from_single_file(path, torch_dtype=torch_dtype)
|
|
1267
|
+
|
|
1268
|
+
logger.info(f"Successfully loaded single file model {model_id}")
|
|
1269
|
+
return pipeline
|
|
1270
|
+
|
|
1271
|
+
except Exception as e:
|
|
1272
|
+
logger.error("Error loading model from single file: {}".format(str(e)))
|
|
1273
|
+
raise
|
|
1274
|
+
|
|
1275
|
+
def _load_custom_architecture(
|
|
1276
|
+
self, pipeline_class: Any, path: str, type: str
|
|
1277
|
+
) -> DiffusionPipeline:
|
|
1278
|
+
"""
|
|
1279
|
+
Load a model with custom architecture configuration.
|
|
1280
|
+
|
|
1281
|
+
Handles loading of individual components and assembling them into
|
|
1282
|
+
a complete pipeline.
|
|
1283
|
+
|
|
1284
|
+
Args:
|
|
1285
|
+
pipeline_class: The pipeline class to instantiate
|
|
1286
|
+
path: Path to the model file
|
|
1287
|
+
type: Model type (determines which components to load)
|
|
1288
|
+
|
|
1289
|
+
Returns:
|
|
1290
|
+
Assembled pipeline instance
|
|
1291
|
+
|
|
1292
|
+
Raises:
|
|
1293
|
+
Exception: If loading fails or architecture is not found
|
|
1294
|
+
"""
|
|
1295
|
+
try:
|
|
1296
|
+
# Load the complete state dict
|
|
1297
|
+
state_dict = load_state_dict_from_file(path)
|
|
1298
|
+
|
|
1299
|
+
# Create empty pipeline instance
|
|
1300
|
+
pipeline = pipeline_class()
|
|
1301
|
+
|
|
1302
|
+
# Load each component specified for this model type
|
|
1303
|
+
for component_name in MODEL_COMPONENTS[type]:
|
|
1304
|
+
# Skip certain components that don't need loading
|
|
1305
|
+
if component_name in ["scheduler", "tokenizer", "tokenizer_2"]:
|
|
1306
|
+
continue
|
|
1307
|
+
|
|
1308
|
+
# Construct architecture key and get class
|
|
1309
|
+
arch_key = f"core_extension_1.{type}_{component_name}"
|
|
1310
|
+
architecture_class = get_architectures().get(arch_key)
|
|
1311
|
+
|
|
1312
|
+
if not architecture_class:
|
|
1313
|
+
logger.error(f"Architecture not found for {arch_key}")
|
|
1314
|
+
continue
|
|
1315
|
+
|
|
1316
|
+
try:
|
|
1317
|
+
# Initialize and load the component
|
|
1318
|
+
architecture = architecture_class()
|
|
1319
|
+
architecture.load(state_dict)
|
|
1320
|
+
|
|
1321
|
+
# Set the component in the pipeline
|
|
1322
|
+
setattr(pipeline, component_name, architecture.model)
|
|
1323
|
+
logger.debug(f"Loaded component {component_name} for {type}")
|
|
1324
|
+
|
|
1325
|
+
except Exception as component_error:
|
|
1326
|
+
logger.error(
|
|
1327
|
+
"Error loading component {}: {}".format(
|
|
1328
|
+
component_name, str(component_error)
|
|
1329
|
+
)
|
|
1330
|
+
)
|
|
1331
|
+
continue
|
|
1332
|
+
|
|
1333
|
+
logger.info("Successfully loaded custom architecture model")
|
|
1334
|
+
return pipeline
|
|
1335
|
+
|
|
1336
|
+
except Exception as e:
|
|
1337
|
+
logger.error("Error loading custom architecture: {}".format(str(e)))
|
|
1338
|
+
raise
|
|
1339
|
+
|
|
1340
|
+
def apply_optimizations(
|
|
1341
|
+
self,
|
|
1342
|
+
pipeline: DiffusionPipeline,
|
|
1343
|
+
model_id: str,
|
|
1344
|
+
force_full_optimization: bool = False,
|
|
1345
|
+
) -> None:
|
|
1346
|
+
"""
|
|
1347
|
+
Apply memory and performance optimizations to a pipeline.
|
|
1348
|
+
Uses existing optimization functions when model exceeds total VRAM.
|
|
1349
|
+
|
|
1350
|
+
Args:
|
|
1351
|
+
pipeline: The pipeline to optimize
|
|
1352
|
+
model_id: Model identifier
|
|
1353
|
+
force_full_optimization: Whether to force optimization
|
|
1354
|
+
"""
|
|
1355
|
+
# Skip if it's OmniGen pipeline
|
|
1356
|
+
if pipeline.__class__.__name__ == "OmniGenPipeline":
|
|
1357
|
+
return
|
|
1358
|
+
|
|
1359
|
+
if self.loaded_model is not None and self.is_in_device:
|
|
1360
|
+
logger.info(f"Model {model_id} already optimized")
|
|
1361
|
+
return
|
|
1362
|
+
|
|
1363
|
+
device = get_available_torch_device()
|
|
1364
|
+
|
|
1365
|
+
# Only optimize if model is bigger than total VRAM
|
|
1366
|
+
model_size = self.model_sizes.get(model_id, 0)
|
|
1367
|
+
if model_size > self.max_vram or force_full_optimization:
|
|
1368
|
+
logger.info(f"Applying optimizations for {model_id}")
|
|
1369
|
+
|
|
1370
|
+
# Get list of optimizations
|
|
1371
|
+
optimizations = self._get_full_optimizations()
|
|
1372
|
+
|
|
1373
|
+
# Apply the optimizations
|
|
1374
|
+
self._apply_optimization_list(pipeline, optimizations, device)
|
|
1375
|
+
else:
|
|
1376
|
+
logger.info(f"No optimization needed for {model_id}")
|
|
1377
|
+
|
|
1378
|
+
def _get_full_optimizations(self) -> List[Tuple[str, str, Dict[str, Any]]]:
|
|
1379
|
+
"""
|
|
1380
|
+
Get list of all available optimizations.
|
|
1381
|
+
|
|
1382
|
+
Returns:
|
|
1383
|
+
List of tuples containing (optimization_function, name, parameters)
|
|
1384
|
+
"""
|
|
1385
|
+
optimizations = [
|
|
1386
|
+
("enable_vae_slicing", "VAE Sliced", {}),
|
|
1387
|
+
("enable_vae_tiling", "VAE Tiled", {}),
|
|
1388
|
+
(
|
|
1389
|
+
"enable_model_cpu_offload",
|
|
1390
|
+
"CPU Offloading",
|
|
1391
|
+
{"device": get_available_torch_device()},
|
|
1392
|
+
),
|
|
1393
|
+
]
|
|
1394
|
+
|
|
1395
|
+
if not isinstance(self.loaded_model, (FluxPipeline, FluxInpaintPipeline)):
|
|
1396
|
+
optimizations.append(
|
|
1397
|
+
(
|
|
1398
|
+
"enable_xformers_memory_efficient_attention",
|
|
1399
|
+
"Memory Efficient Attention",
|
|
1400
|
+
{},
|
|
1401
|
+
)
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
return optimizations
|
|
1405
|
+
|
|
1406
|
+
def _apply_optimization_list(
|
|
1407
|
+
self,
|
|
1408
|
+
pipeline: DiffusionPipeline,
|
|
1409
|
+
optimizations: List[Tuple[str, str, Dict[str, Any]]],
|
|
1410
|
+
device: torch.device,
|
|
1411
|
+
) -> None:
|
|
1412
|
+
"""
|
|
1413
|
+
Apply a list of optimizations to a pipeline.
|
|
1414
|
+
|
|
1415
|
+
Args:
|
|
1416
|
+
pipeline: The pipeline to optimize
|
|
1417
|
+
optimizations: List of optimization specifications
|
|
1418
|
+
device: Target device
|
|
1419
|
+
"""
|
|
1420
|
+
device_type = device if isinstance(device, str) else device.type
|
|
1421
|
+
|
|
1422
|
+
if device_type == "mps":
|
|
1423
|
+
setattr(torch, "mps", torch.backends.mps)
|
|
1424
|
+
|
|
1425
|
+
for opt_func, opt_name, kwargs in optimizations:
|
|
1426
|
+
try:
|
|
1427
|
+
getattr(pipeline, opt_func)(**kwargs)
|
|
1428
|
+
logger.info("{} enabled".format(opt_name))
|
|
1429
|
+
except Exception as e:
|
|
1430
|
+
logger.error("Error enabling {}: {}".format(opt_name, str(e)))
|
|
1431
|
+
|
|
1432
|
+
if device_type == "mps":
|
|
1433
|
+
delattr(torch, "mps")
|
|
1434
|
+
|
|
1435
|
+
def flush_memory(self) -> None:
|
|
1436
|
+
"""Clear unused memory from GPU and perform garbage collection."""
|
|
1437
|
+
gc.collect()
|
|
1438
|
+
if torch.cuda.is_available():
|
|
1439
|
+
torch.cuda.empty_cache()
|
|
1440
|
+
torch.cuda.ipc_collect()
|
|
1441
|
+
torch.cuda.synchronize()
|
|
1442
|
+
logger.info("GPU memory flushed successfully")
|
|
1443
|
+
elif torch.backends.mps.is_available():
|
|
1444
|
+
pass # MPS doesn't need explicit cache clearing
|
|
1445
|
+
|
|
1446
|
+
async def initialize_startup_models(self, model_ids: List[str]) -> None:
|
|
1447
|
+
"""Inititalize models at startup with randomization"""
|
|
1448
|
+
if not model_ids:
|
|
1449
|
+
logger.info("No models configured for enabled models")
|
|
1450
|
+
return
|
|
1451
|
+
|
|
1452
|
+
self.is_startup_load = True
|
|
1453
|
+
try:
|
|
1454
|
+
model_configs = {}
|
|
1455
|
+
model_sizes = {}
|
|
1456
|
+
|
|
1457
|
+
for model_id in model_ids:
|
|
1458
|
+
model_config = get_config().pipeline_defs[model_id]
|
|
1459
|
+
if not model_config:
|
|
1460
|
+
logger.warning(f"Model {model_id} not found in pipeline_defs")
|
|
1461
|
+
continue
|
|
1462
|
+
|
|
1463
|
+
try:
|
|
1464
|
+
size = await self._get_model_size(model_config, model_id) + 1.0 # add 1GB for safety margin (model inference overhead in memory)
|
|
1465
|
+
model_sizes[model_id] = size
|
|
1466
|
+
model_configs[model_id] = model_config
|
|
1467
|
+
except Exception as e:
|
|
1468
|
+
logger.error(f"Error getting model size for {model_id}: {e}")
|
|
1469
|
+
continue
|
|
1470
|
+
|
|
1471
|
+
# Create randomized order of models
|
|
1472
|
+
available_models = list(model_ids)
|
|
1473
|
+
# available_vram = self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
|
|
1474
|
+
|
|
1475
|
+
rng = np.random.default_rng()
|
|
1476
|
+
random_models = rng.permutation(available_models).tolist()
|
|
1477
|
+
|
|
1478
|
+
logger.info(f"Loading models in random order: {random_models}")
|
|
1479
|
+
total_loaded = 0
|
|
1480
|
+
total_size = 0
|
|
1481
|
+
|
|
1482
|
+
for model_id in random_models:
|
|
1483
|
+
estimated_size = model_sizes[model_id]
|
|
1484
|
+
|
|
1485
|
+
# Check if exceed available VRAM
|
|
1486
|
+
if estimated_size > (
|
|
1487
|
+
self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
|
|
1488
|
+
):
|
|
1489
|
+
logger.info(
|
|
1490
|
+
f"Stopping model loading: Next model {model_id} "
|
|
1491
|
+
f"({estimated_size:.2f} GB) would exceed available inference Memory "
|
|
1492
|
+
f"({self._get_available_vram() - VRAM_SAFETY_MARGIN_GB:.2f} GB)"
|
|
1493
|
+
)
|
|
1494
|
+
break
|
|
1495
|
+
|
|
1496
|
+
try:
|
|
1497
|
+
pipeline = await self.load(model_id)
|
|
1498
|
+
if pipeline is not None:
|
|
1499
|
+
total_loaded += 1
|
|
1500
|
+
total_size += estimated_size
|
|
1501
|
+
logger.info(
|
|
1502
|
+
f"Successfully loaded model {model_id} "
|
|
1503
|
+
f"({total_loaded}/{len(random_models)}). "
|
|
1504
|
+
f"Total VRAM used: {total_size:.2f} GB"
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1507
|
+
# warm up the model
|
|
1508
|
+
logger.info(f"Warming up model {model_id}")
|
|
1509
|
+
await self.warmup_pipeline(model_id)
|
|
1510
|
+
else:
|
|
1511
|
+
logger.warning(f"Failed to load model {model_id}")
|
|
1512
|
+
|
|
1513
|
+
except Exception as e:
|
|
1514
|
+
logger.error(f"Error loading model {model_id}: {e}")
|
|
1515
|
+
|
|
1516
|
+
logger.info(
|
|
1517
|
+
f"Startup loading complete. Loaded and warmed up {total_loaded}/{len(random_models)} "
|
|
1518
|
+
f"models using {total_size:.2f}GB/{self.max_vram - VRAM_SAFETY_MARGIN_GB:.2f}GB available VRAM"
|
|
1519
|
+
)
|
|
1520
|
+
except Exception as e:
|
|
1521
|
+
logger.error(f"Error initializing startup models: {e}")
|
|
1522
|
+
finally:
|
|
1523
|
+
self.is_startup_load = False
|
|
1524
|
+
|
|
1525
|
+
async def warmup_pipeline(self, model_id: str) -> None:
|
|
1526
|
+
"""
|
|
1527
|
+
Warm up a pipeline by running a test inference.
|
|
1528
|
+
|
|
1529
|
+
Args:
|
|
1530
|
+
model_id: Model identifier
|
|
1531
|
+
"""
|
|
1532
|
+
if model_id not in self.loaded_models:
|
|
1533
|
+
if not self.is_startup_load:
|
|
1534
|
+
logger.info(f"Loading model {model_id} for warm-up")
|
|
1535
|
+
pipeline = await self.load(model_id)
|
|
1536
|
+
if pipeline is None:
|
|
1537
|
+
logger.warning(f"Failed to load model {model_id} for warm-up")
|
|
1538
|
+
return
|
|
1539
|
+
else:
|
|
1540
|
+
logger.warning(f"Model {model_id} is not loaded")
|
|
1541
|
+
return
|
|
1542
|
+
else:
|
|
1543
|
+
pipeline = self.loaded_models[model_id]
|
|
1544
|
+
|
|
1545
|
+
logger.info(f"Warming up pipeline for model {model_id}")
|
|
1546
|
+
|
|
1547
|
+
try:
|
|
1548
|
+
with torch.no_grad():
|
|
1549
|
+
if isinstance(pipeline, DiffusionPipeline) and callable(pipeline):
|
|
1550
|
+
_ = pipeline(
|
|
1551
|
+
prompt="This is a warm-up prompt",
|
|
1552
|
+
num_inference_steps=4,
|
|
1553
|
+
output_type="pil",
|
|
1554
|
+
)
|
|
1555
|
+
else:
|
|
1556
|
+
logger.warning(
|
|
1557
|
+
f"Unsupported pipeline type for warm-up: {type(pipeline)}"
|
|
1558
|
+
)
|
|
1559
|
+
except Exception as e:
|
|
1560
|
+
logger.error(f"Error during warm-up for model {model_id}: {str(e)}")
|
|
1561
|
+
|
|
1562
|
+
self.flush_memory()
|
|
1563
|
+
logger.info(f"Warm-up completed for model {model_id}")
|
|
1564
|
+
|
|
1565
|
+
def get_all_model_ids(self) -> List[str]:
|
|
1566
|
+
"""
|
|
1567
|
+
Get list of all available model IDs from config.
|
|
1568
|
+
|
|
1569
|
+
Returns:
|
|
1570
|
+
List of model identifiers
|
|
1571
|
+
"""
|
|
1572
|
+
config = get_config()
|
|
1573
|
+
return list(config.pipeline_defs.keys())
|
|
1574
|
+
|
|
1575
|
+
def get_enabled_models(self) -> List[str]:
|
|
1576
|
+
"""
|
|
1577
|
+
Get list of models that should be warmed up.
|
|
1578
|
+
|
|
1579
|
+
Returns:
|
|
1580
|
+
List of model identifiers to be used for generation
|
|
1581
|
+
"""
|
|
1582
|
+
config = get_config()
|
|
1583
|
+
return config.enabled_models
|
|
1584
|
+
|
|
1585
|
+
def unload(self, model_id: str) -> None:
|
|
1586
|
+
"""
|
|
1587
|
+
Unload a model from memory and clean up associated resources.
|
|
1588
|
+
|
|
1589
|
+
Handles unloading from both GPU and CPU memory, updates memory tracking,
|
|
1590
|
+
and cleans up LRU cache entries.
|
|
1591
|
+
|
|
1592
|
+
Args:
|
|
1593
|
+
model_id: Model identifier to unload
|
|
1594
|
+
"""
|
|
1595
|
+
try:
|
|
1596
|
+
# Unload from GPU if present
|
|
1597
|
+
if model_id in self.loaded_models:
|
|
1598
|
+
model_size = self.model_sizes.get(model_id, 0)
|
|
1599
|
+
pipeline = self.loaded_models.pop(model_id)
|
|
1600
|
+
|
|
1601
|
+
del pipeline
|
|
1602
|
+
self.vram_usage -= model_size
|
|
1603
|
+
self.lru_cache.remove(model_id, "gpu")
|
|
1604
|
+
logger.info(f"Model {model_id} unloaded from GPU")
|
|
1605
|
+
|
|
1606
|
+
# Unload from CPU if present
|
|
1607
|
+
if model_id in self.cpu_models:
|
|
1608
|
+
model_size = self.model_sizes.get(model_id, 0)
|
|
1609
|
+
pipeline = self.cpu_models.pop(model_id)
|
|
1610
|
+
|
|
1611
|
+
del pipeline
|
|
1612
|
+
self.ram_usage -= model_size
|
|
1613
|
+
self.lru_cache.remove(model_id, "cpu")
|
|
1614
|
+
logger.info(f"Model {model_id} unloaded from CPU")
|
|
1615
|
+
|
|
1616
|
+
# Clean up current model reference if it matches
|
|
1617
|
+
if model_id == self.current_model:
|
|
1618
|
+
self.loaded_model = None
|
|
1619
|
+
self.current_model = None
|
|
1620
|
+
|
|
1621
|
+
# Remove from model sizes tracking
|
|
1622
|
+
if model_id in self.model_sizes:
|
|
1623
|
+
del self.model_sizes[model_id]
|
|
1624
|
+
|
|
1625
|
+
# Remove from model types tracking
|
|
1626
|
+
if model_id in self.model_types:
|
|
1627
|
+
del self.model_types[model_id]
|
|
1628
|
+
|
|
1629
|
+
# Force memory cleanup
|
|
1630
|
+
self.flush_memory()
|
|
1631
|
+
|
|
1632
|
+
except Exception as e:
|
|
1633
|
+
logger.error(f"Error unloading model {model_id}: {str(e)}")
|
|
1634
|
+
|
|
1635
|
+
def is_loaded(self, model_id: str) -> bool:
|
|
1636
|
+
"""
|
|
1637
|
+
Check if a model is currently loaded.
|
|
1638
|
+
|
|
1639
|
+
Args:
|
|
1640
|
+
model_id: Model identifier
|
|
1641
|
+
|
|
1642
|
+
Returns:
|
|
1643
|
+
Boolean indicating if model is loaded
|
|
1644
|
+
"""
|
|
1645
|
+
return model_id == self.current_model and self.loaded_model is not None
|
|
1646
|
+
|
|
1647
|
+
def get_model(self, model_id: str) -> Optional[DiffusionPipeline]:
|
|
1648
|
+
"""
|
|
1649
|
+
Get a loaded model pipeline.
|
|
1650
|
+
|
|
1651
|
+
Args:
|
|
1652
|
+
model_id: Model identifier
|
|
1653
|
+
|
|
1654
|
+
Returns:
|
|
1655
|
+
Pipeline if model is loaded, None otherwise
|
|
1656
|
+
"""
|
|
1657
|
+
if not self.is_loaded(model_id):
|
|
1658
|
+
return None
|
|
1659
|
+
return self.loaded_model
|
|
1660
|
+
|
|
1661
|
+
def get_model_device(self, model_id: str) -> Optional[torch.device]:
|
|
1662
|
+
"""
|
|
1663
|
+
Get the device where a model is loaded.
|
|
1664
|
+
|
|
1665
|
+
Args:
|
|
1666
|
+
model_id: Model identifier
|
|
1667
|
+
|
|
1668
|
+
Returns:
|
|
1669
|
+
Device if model is loaded, None otherwise
|
|
1670
|
+
"""
|
|
1671
|
+
model = self.get_model(model_id)
|
|
1672
|
+
return model.device if model else None
|
|
1673
|
+
|
|
1674
|
+
|
|
1675
|
+
async def _get_model_size(
|
|
1676
|
+
self, model_config: PipelineConfig, model_id: str
|
|
1677
|
+
) -> float:
|
|
1678
|
+
"""
|
|
1679
|
+
Calculate the total size of a model including all its components.
|
|
1680
|
+
|
|
1681
|
+
Args:
|
|
1682
|
+
model_config: Model configuration dictionary from pipeline_defs
|
|
1683
|
+
model_id: The model identifier (key from pipeline_defs)
|
|
1684
|
+
|
|
1685
|
+
Returns:
|
|
1686
|
+
Size in GB
|
|
1687
|
+
"""
|
|
1688
|
+
|
|
1689
|
+
if isinstance(model_config, dict):
|
|
1690
|
+
source = model_config.get("source")
|
|
1691
|
+
else:
|
|
1692
|
+
source = model_config.source
|
|
1693
|
+
|
|
1694
|
+
if source.startswith(("http://", "https://")):
|
|
1695
|
+
# For downloaded HTTP(S) models, get size from cache
|
|
1696
|
+
try:
|
|
1697
|
+
source_obj = ModelSource(source) # Create ModelSource object
|
|
1698
|
+
cache_path = await self.model_downloader._get_cache_path(
|
|
1699
|
+
model_id, source_obj
|
|
1700
|
+
)
|
|
1701
|
+
print(f"Cache Path: {cache_path}")
|
|
1702
|
+
if os.path.exists(cache_path):
|
|
1703
|
+
return os.path.getsize(cache_path) / (1024**3)
|
|
1704
|
+
else:
|
|
1705
|
+
logger.warning(
|
|
1706
|
+
f"Cache file not found for {model_id}, assuming default size"
|
|
1707
|
+
)
|
|
1708
|
+
return 7.0 # Default size assumption (never going to be used)
|
|
1709
|
+
except Exception as e:
|
|
1710
|
+
logger.error(f"Error getting cached model size: {e}")
|
|
1711
|
+
return (
|
|
1712
|
+
7.0 # Default fallback size (never going to be used for anything)
|
|
1713
|
+
)
|
|
1714
|
+
elif source.startswith("file:"):
|
|
1715
|
+
path = source.replace("file:", "")
|
|
1716
|
+
return os.path.getsize(path) / (1024**3) if os.path.exists(path) else 0
|
|
1717
|
+
elif source.startswith("hf:"):
|
|
1718
|
+
# Handle HuggingFace models as before
|
|
1719
|
+
repo_id = source.replace("hf:", "")
|
|
1720
|
+
return self._calculate_repo_size(repo_id, model_config)
|
|
1721
|
+
else:
|
|
1722
|
+
logger.error(f"Unsupported source type for size calculation: {source}")
|
|
1723
|
+
return 0
|
|
1724
|
+
|
|
1725
|
+
def _calculate_repo_size(self, repo_id: str, model_config: PipelineConfig) -> float:
|
|
1726
|
+
"""
|
|
1727
|
+
Calculate the total size of a HuggingFace repository.
|
|
1728
|
+
|
|
1729
|
+
Args:
|
|
1730
|
+
repo_id: Repository identifier
|
|
1731
|
+
model_config: Model configuration dictionary
|
|
1732
|
+
|
|
1733
|
+
Returns:
|
|
1734
|
+
Total size in GB
|
|
1735
|
+
"""
|
|
1736
|
+
total_size = self._get_size_for_repo(repo_id)
|
|
1737
|
+
|
|
1738
|
+
if isinstance(model_config, dict):
|
|
1739
|
+
components = model_config.get("components")
|
|
1740
|
+
else:
|
|
1741
|
+
components = model_config.components
|
|
1742
|
+
|
|
1743
|
+
if components:
|
|
1744
|
+
for key, component in components.items():
|
|
1745
|
+
if isinstance(component, dict) and "source" in component:
|
|
1746
|
+
component_size = self._calculate_component_size(
|
|
1747
|
+
component, repo_id, key
|
|
1748
|
+
)
|
|
1749
|
+
total_size += component_size
|
|
1750
|
+
|
|
1751
|
+
total_size_gb = total_size / (1024**3)
|
|
1752
|
+
logger.debug(f"Total size: {total_size_gb:.2f} GB")
|
|
1753
|
+
print(f"Total size: {total_size_gb:.2f} GB")
|
|
1754
|
+
return total_size_gb
|
|
1755
|
+
|
|
1756
|
+
def _calculate_component_size(
|
|
1757
|
+
self, component: Dict[str, Any], repo_id: str, key: str
|
|
1758
|
+
) -> float:
|
|
1759
|
+
"""
|
|
1760
|
+
Calculate the size of a model component.
|
|
1761
|
+
|
|
1762
|
+
Args:
|
|
1763
|
+
component: Component configuration
|
|
1764
|
+
repo_id: Repository identifier
|
|
1765
|
+
key: Component key
|
|
1766
|
+
|
|
1767
|
+
Returns:
|
|
1768
|
+
Component size in bytes
|
|
1769
|
+
"""
|
|
1770
|
+
component_source = component["source"]
|
|
1771
|
+
if len(component_source.split("/")) > 2:
|
|
1772
|
+
component_repo = "/".join(component_source.split("/")[0:2]).replace(
|
|
1773
|
+
"hf:", ""
|
|
1774
|
+
)
|
|
1775
|
+
else:
|
|
1776
|
+
component_repo = component_source.replace("hf:", "")
|
|
1777
|
+
|
|
1778
|
+
component_name = (
|
|
1779
|
+
key
|
|
1780
|
+
if not component_source.endswith((".safetensors", ".bin", ".ckpt", ".pt"))
|
|
1781
|
+
else component_source.split("/")[-1]
|
|
1782
|
+
)
|
|
1783
|
+
|
|
1784
|
+
total_size = self._get_size_for_repo(
|
|
1785
|
+
component_repo, component_name
|
|
1786
|
+
) - self._get_size_for_repo(repo_id, key)
|
|
1787
|
+
|
|
1788
|
+
# total_size = self._get_size_for_repo(component_repo, component_name)
|
|
1789
|
+
|
|
1790
|
+
return total_size
|
|
1791
|
+
|
|
1792
|
+
def _get_size_for_repo(
|
|
1793
|
+
self, repo_id: str, component_name: Optional[str] = None
|
|
1794
|
+
) -> int:
|
|
1795
|
+
"""
|
|
1796
|
+
Get the size of a specific repository or component.
|
|
1797
|
+
|
|
1798
|
+
Args:
|
|
1799
|
+
repo_id: Repository identifier
|
|
1800
|
+
component_name: Optional component name
|
|
1801
|
+
|
|
1802
|
+
Returns:
|
|
1803
|
+
Size in bytes
|
|
1804
|
+
"""
|
|
1805
|
+
if component_name == "scheduler":
|
|
1806
|
+
return 0
|
|
1807
|
+
|
|
1808
|
+
print(f"Getting size for {repo_id} {component_name}")
|
|
1809
|
+
storage_folder = os.path.join(
|
|
1810
|
+
self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
|
|
1811
|
+
)
|
|
1812
|
+
|
|
1813
|
+
if not os.path.exists(storage_folder):
|
|
1814
|
+
logger.warning(f"Storage folder for {repo_id} not found")
|
|
1815
|
+
return 0
|
|
1816
|
+
|
|
1817
|
+
commit_hash = self._get_commit_hash(storage_folder)
|
|
1818
|
+
if not commit_hash:
|
|
1819
|
+
return 0
|
|
1820
|
+
|
|
1821
|
+
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
|
1822
|
+
if component_name:
|
|
1823
|
+
snapshot_folder = os.path.join(snapshot_folder, component_name)
|
|
1824
|
+
|
|
1825
|
+
return self._calculate_folder_size(snapshot_folder)
|
|
1826
|
+
|
|
1827
|
+
def _get_commit_hash(self, storage_folder: str) -> Optional[str]:
|
|
1828
|
+
"""
|
|
1829
|
+
Get the commit hash for a repository.
|
|
1830
|
+
|
|
1831
|
+
Args:
|
|
1832
|
+
storage_folder: Path to the repository storage folder
|
|
1833
|
+
|
|
1834
|
+
Returns:
|
|
1835
|
+
Commit hash string or None if not found
|
|
1836
|
+
"""
|
|
1837
|
+
refs_path = os.path.join(storage_folder, "refs", "main")
|
|
1838
|
+
if not os.path.exists(refs_path):
|
|
1839
|
+
logger.warning("No commit hash found for {}".format(storage_folder))
|
|
1840
|
+
return None
|
|
1841
|
+
try:
|
|
1842
|
+
with open(refs_path, "r") as f:
|
|
1843
|
+
return f.read().strip()
|
|
1844
|
+
except Exception as e:
|
|
1845
|
+
logger.error("Error reading commit hash: {}".format(str(e)))
|
|
1846
|
+
return None
|
|
1847
|
+
|
|
1848
|
+
def _calculate_folder_size(self, folder: str) -> int:
|
|
1849
|
+
"""
|
|
1850
|
+
Calculate the total size of model files in a folder.
|
|
1851
|
+
|
|
1852
|
+
Args:
|
|
1853
|
+
folder: Path to the folder
|
|
1854
|
+
|
|
1855
|
+
Returns:
|
|
1856
|
+
Total size in bytes
|
|
1857
|
+
"""
|
|
1858
|
+
if not os.path.isdir(folder):
|
|
1859
|
+
if os.path.exists(folder):
|
|
1860
|
+
return os.path.getsize(folder)
|
|
1861
|
+
else:
|
|
1862
|
+
return 0
|
|
1863
|
+
|
|
1864
|
+
variants = ["bf16", "fp8", "fp16", ""]
|
|
1865
|
+
selected_variant = next(
|
|
1866
|
+
(v for v in variants if self._check_variant_files(folder, v)), None
|
|
1867
|
+
)
|
|
1868
|
+
|
|
1869
|
+
if selected_variant is None:
|
|
1870
|
+
return 0
|
|
1871
|
+
|
|
1872
|
+
total_size = 0
|
|
1873
|
+
for root, _, files in os.walk(folder):
|
|
1874
|
+
# skip if in root
|
|
1875
|
+
if root == folder:
|
|
1876
|
+
continue
|
|
1877
|
+
for file in files:
|
|
1878
|
+
if self._is_valid_file(file, selected_variant):
|
|
1879
|
+
total_size += os.path.getsize(os.path.join(root, file))
|
|
1880
|
+
# break
|
|
1881
|
+
|
|
1882
|
+
return total_size
|
|
1883
|
+
|
|
1884
|
+
def _check_variant_files(self, folder: str, variant: str) -> bool:
|
|
1885
|
+
"""
|
|
1886
|
+
Check if a folder contains files of a specific variant.
|
|
1887
|
+
|
|
1888
|
+
Args:
|
|
1889
|
+
folder: Path to the folder
|
|
1890
|
+
variant: Variant to check for
|
|
1891
|
+
|
|
1892
|
+
Returns:
|
|
1893
|
+
Boolean indicating if variant files exist
|
|
1894
|
+
"""
|
|
1895
|
+
for root, _, files in os.walk(folder):
|
|
1896
|
+
if any(self._is_valid_file(f, variant) for f in files):
|
|
1897
|
+
return True
|
|
1898
|
+
return False
|
|
1899
|
+
|
|
1900
|
+
def _is_valid_file(self, file: str, variant: str) -> bool:
|
|
1901
|
+
"""
|
|
1902
|
+
Check if a file is a valid model file of a specific variant.
|
|
1903
|
+
|
|
1904
|
+
Args:
|
|
1905
|
+
file: Filename to check
|
|
1906
|
+
variant: Variant to check for
|
|
1907
|
+
|
|
1908
|
+
Returns:
|
|
1909
|
+
Boolean indicating if file is valid
|
|
1910
|
+
"""
|
|
1911
|
+
if variant:
|
|
1912
|
+
# check if the variant is in the file name
|
|
1913
|
+
if variant in file:
|
|
1914
|
+
return True
|
|
1915
|
+
else:
|
|
1916
|
+
return False
|
|
1917
|
+
return file.endswith((".safetensors", ".bin", ".ckpt"))
|
|
1918
|
+
|
|
1919
|
+
|
|
1920
|
+
async def _load_diffusers_component(
|
|
1921
|
+
self,
|
|
1922
|
+
main_model_repo: str,
|
|
1923
|
+
component_repo: str,
|
|
1924
|
+
component_name: Optional[str] = None,
|
|
1925
|
+
variant: Optional[str] = None,
|
|
1926
|
+
) -> Any:
|
|
1927
|
+
"""
|
|
1928
|
+
Load a diffusers component.
|
|
1929
|
+
|
|
1930
|
+
Args:
|
|
1931
|
+
component_repo: Repository identifier
|
|
1932
|
+
component_name: Name of the component
|
|
1933
|
+
|
|
1934
|
+
Returns:
|
|
1935
|
+
Loaded component
|
|
1936
|
+
"""
|
|
1937
|
+
try:
|
|
1938
|
+
model_index = (
|
|
1939
|
+
await self.model_downloader.get_diffusers_multifolder_components(
|
|
1940
|
+
main_model_repo
|
|
1941
|
+
)
|
|
1942
|
+
)
|
|
1943
|
+
if model_index is None:
|
|
1944
|
+
raise ValueError(f"model_index does not exist for {main_model_repo}")
|
|
1945
|
+
|
|
1946
|
+
component_info = model_index.get(component_name)
|
|
1947
|
+
if not component_info:
|
|
1948
|
+
raise ValueError(f"Invalid component info for {component_name}")
|
|
1949
|
+
|
|
1950
|
+
module_path, class_name = component_info
|
|
1951
|
+
module = importlib.import_module(module_path)
|
|
1952
|
+
model_class = getattr(module, class_name)
|
|
1953
|
+
|
|
1954
|
+
if component_name:
|
|
1955
|
+
if variant:
|
|
1956
|
+
component = await asyncio.to_thread(model_class.from_pretrained,
|
|
1957
|
+
component_repo,
|
|
1958
|
+
subfolder=component_name,
|
|
1959
|
+
variant=variant,
|
|
1960
|
+
torch_dtype=torch.bfloat16
|
|
1961
|
+
if "flux" in component_repo.lower()
|
|
1962
|
+
else torch.float16,
|
|
1963
|
+
)
|
|
1964
|
+
else:
|
|
1965
|
+
component = await asyncio.to_thread(model_class.from_pretrained,
|
|
1966
|
+
component_repo,
|
|
1967
|
+
subfolder=component_name,
|
|
1968
|
+
torch_dtype=torch.bfloat16
|
|
1969
|
+
if "flux" in component_repo.lower()
|
|
1970
|
+
else torch.float16,
|
|
1971
|
+
)
|
|
1972
|
+
else:
|
|
1973
|
+
if variant:
|
|
1974
|
+
component = await asyncio.to_thread(model_class.from_pretrained,
|
|
1975
|
+
component_repo,
|
|
1976
|
+
variant=variant,
|
|
1977
|
+
torch_dtype=torch.bfloat16
|
|
1978
|
+
if "flux" in component_repo.lower()
|
|
1979
|
+
else torch.float16,
|
|
1980
|
+
)
|
|
1981
|
+
else:
|
|
1982
|
+
component = await asyncio.to_thread(model_class.from_pretrained,
|
|
1983
|
+
component_repo,
|
|
1984
|
+
torch_dtype=torch.bfloat16
|
|
1985
|
+
if "flux" in component_repo.lower()
|
|
1986
|
+
else torch.float16,
|
|
1987
|
+
)
|
|
1988
|
+
|
|
1989
|
+
|
|
1990
|
+
return component
|
|
1991
|
+
|
|
1992
|
+
except Exception as e:
|
|
1993
|
+
logger.error(
|
|
1994
|
+
"Error loading component {} from {}: {}".format(
|
|
1995
|
+
component_name, component_repo, str(e)
|
|
1996
|
+
)
|
|
1997
|
+
)
|
|
1998
|
+
raise
|
|
1999
|
+
|
|
2000
|
+
def _load_custom_component(
|
|
2001
|
+
self, repo_id: str, category: str, component_name: str
|
|
2002
|
+
) -> Any:
|
|
2003
|
+
"""
|
|
2004
|
+
Load a custom component.
|
|
2005
|
+
|
|
2006
|
+
Args:
|
|
2007
|
+
repo_id: Repository identifier
|
|
2008
|
+
category: Component category
|
|
2009
|
+
component_name: Name of the component
|
|
2010
|
+
|
|
2011
|
+
Returns:
|
|
2012
|
+
Loaded component
|
|
2013
|
+
"""
|
|
2014
|
+
try:
|
|
2015
|
+
file_path = self._get_component_file_path(repo_id)
|
|
2016
|
+
state_dict = load_state_dict_from_file(file_path)
|
|
2017
|
+
|
|
2018
|
+
architectures = get_architectures()
|
|
2019
|
+
arch_key = f"core_extension_1.{category.lower()}_{component_name.lower()}"
|
|
2020
|
+
architecture_class = architectures.get(arch_key)
|
|
2021
|
+
|
|
2022
|
+
if not architecture_class:
|
|
2023
|
+
raise ValueError(f"Architecture not found for {arch_key}")
|
|
2024
|
+
|
|
2025
|
+
architecture = architecture_class()
|
|
2026
|
+
architecture.load(state_dict)
|
|
2027
|
+
model = architecture.model
|
|
2028
|
+
|
|
2029
|
+
return model
|
|
2030
|
+
|
|
2031
|
+
except Exception as e:
|
|
2032
|
+
logger.error("Error loading custom component: {}".format(str(e)))
|
|
2033
|
+
raise
|
|
2034
|
+
|
|
2035
|
+
def _get_component_file_path(self, repo_id: str) -> str:
|
|
2036
|
+
"""
|
|
2037
|
+
Get the file path for a component.
|
|
2038
|
+
|
|
2039
|
+
Args:
|
|
2040
|
+
repo_id: Repository identifier
|
|
2041
|
+
|
|
2042
|
+
Returns:
|
|
2043
|
+
Path to component file
|
|
2044
|
+
"""
|
|
2045
|
+
repo_folder = os.path.dirname(repo_id.replace("hf:", ""))
|
|
2046
|
+
weights_name = repo_id.split("/")[-1]
|
|
2047
|
+
|
|
2048
|
+
model_folder = os.path.join(
|
|
2049
|
+
self.cache_dir, repo_folder_name(repo_id=repo_folder, repo_type="model")
|
|
2050
|
+
)
|
|
2051
|
+
|
|
2052
|
+
if not os.path.exists(model_folder):
|
|
2053
|
+
model_folder = os.path.join(self.cache_dir, repo_folder)
|
|
2054
|
+
if not os.path.exists(model_folder):
|
|
2055
|
+
raise FileNotFoundError(f"Cache folder for {repo_id} not found")
|
|
2056
|
+
return os.path.join(model_folder, weights_name)
|
|
2057
|
+
|
|
2058
|
+
commit_hash = self._get_commit_hash(model_folder) or ""
|
|
2059
|
+
return os.path.join(model_folder, "snapshots", commit_hash, weights_name)
|