gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,2059 @@
1
+ import os
2
+ import gc
3
+ import logging
4
+ import importlib
5
+ from enum import Enum
6
+ import traceback
7
+ from typing import Optional, Any, Dict, List, Tuple, Union, Type, Set
8
+ import psutil
9
+ from collections import OrderedDict
10
+ import time
11
+ import sys
12
+ import threading
13
+
14
+ import torch
15
+ import diffusers
16
+ import numpy as np
17
+ import asyncio
18
+
19
+ from diffusers import (
20
+ DiffusionPipeline,
21
+ FluxInpaintPipeline,
22
+ FluxPipeline,
23
+ )
24
+
25
+ from diffusers.loaders import FromSingleFileMixin
26
+ from huggingface_hub.constants import HF_HUB_CACHE
27
+ from huggingface_hub.file_download import repo_folder_name
28
+ from huggingface_hub.utils import EntryNotFoundError
29
+ from accelerate import utils as accelerate_utils
30
+
31
+ from .utils.config import get_config, get_environment
32
+ from .utils.globals import (
33
+ # get_hf_model_manager,
34
+ get_architectures,
35
+ get_available_torch_device,
36
+ set_available_torch_device,
37
+ get_model_downloader,
38
+ )
39
+ from .utils.model_downloader import ModelSource
40
+ from .utils.load_models import load_state_dict_from_file
41
+ from .utils.base_types.config import PipelineConfig
42
+ from .utils import diffusers_fix
43
+
44
+ from ..model_interface import ModelManagementInterface, DownloaderType
45
+
46
+ from .utils.flashpack_loader import FlashPackLoader
47
+
48
+ # Configure logging
49
+ logger = logging.getLogger(__name__)
50
+ logging.basicConfig(
51
+ level=logging.INFO,
52
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
53
+ stream=sys.stdout,
54
+ )
55
+
56
+ # Constants
57
+ VRAM_SAFETY_MARGIN_GB = 3.5
58
+ DEFAULT_MAX_VRAM_BUFFER_GB = 2.0
59
+ RAM_SAFETY_MARGIN_GB = 10.0
60
+
61
+
62
+ # Keys correspond to diffusers pipeline classes
63
+ MODEL_COMPONENTS = {
64
+ "FluxPipeline": [
65
+ "vae",
66
+ "transformer",
67
+ "text_encoder",
68
+ "text_encoder_2",
69
+ "scheduler",
70
+ "tokenizer",
71
+ "tokenizer_2",
72
+ ],
73
+ "StableDiffusionXLPipeline": [
74
+ "vae",
75
+ "unet",
76
+ "text_encoder",
77
+ "text_encoder_2",
78
+ "scheduler",
79
+ "tokenizer",
80
+ "tokenizer_2",
81
+ ],
82
+ "StableDiffusionPipeline": [
83
+ "vae",
84
+ "unet",
85
+ "text_encoder",
86
+ "scheduler",
87
+ "tokenizer",
88
+ ],
89
+ }
90
+
91
+
92
+ def get_pipeline_class(
93
+ class_name: Union[str, Tuple[str, str]],
94
+ ) -> Tuple[Type[DiffusionPipeline], Optional[str]]:
95
+ """Get the appropriate pipeline class based on class_name configuration.
96
+ Learn about custom-pipelines here: https://huggingface.co/docs/diffusers/v0.6.0/en/using-diffusers/custom_pipelines
97
+ Note that from_single_file does not support custom pipelines, only from_pretrained does.
98
+
99
+ Args:
100
+ class_name: Either a string naming a diffusers class, or a tuple of (package, class)
101
+
102
+ Returns:
103
+ Tuple of (Pipeline class to use, custom_pipeline to include as a kwarg to the pipeline)
104
+ """
105
+ # class_name is in the form of [package, class]
106
+ if isinstance(class_name, tuple):
107
+ # Load from custom package
108
+ package, cls = class_name
109
+ module = importlib.import_module(package)
110
+ return (getattr(module, cls), None)
111
+
112
+ # Try loading class_name as a diffusers class
113
+ try:
114
+ pipeline_class = getattr(importlib.import_module("diffusers"), class_name)
115
+ if not issubclass(pipeline_class, DiffusionPipeline):
116
+ print("custompipeline2", class_name)
117
+ raise TypeError(f"{class_name} does not inherit from DiffusionPipeline")
118
+ return (pipeline_class, None)
119
+
120
+ except (ImportError, AttributeError):
121
+ # Assume the class name is the name of a custom pipeline
122
+ print("custompipeline1", class_name)
123
+ return (DiffusionPipeline, class_name)
124
+
125
+
126
+ class LRUCache:
127
+ """
128
+ Least Recently Used (LRU) Cache for tracking model usage.
129
+
130
+ Maintains separate tracking for GPU and CPU cached models using timestamps
131
+ to determine usage patterns and inform memory management decisions.
132
+ """
133
+
134
+ def __init__(self):
135
+ """Initialize empty GPU and CPU caches."""
136
+ self.gpu_cache: OrderedDict = OrderedDict() # model_id -> last_used_timestamp
137
+ self.cpu_cache: OrderedDict = OrderedDict() # model_id -> last_used_timestamp
138
+
139
+ def access(self, model_id: str, cache_type: str = "gpu") -> None:
140
+ """
141
+ Record access of a model, updating its position in the LRU cache.
142
+
143
+ Args:
144
+ model_id: Unique identifier for the model
145
+ cache_type: Type of cache to update ("gpu" or "cpu")
146
+ """
147
+ cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
148
+ cache.pop(model_id, None) # Remove if exists
149
+ cache[model_id] = time.time() # Add to end (most recently used)
150
+
151
+ def remove(self, model_id: str, cache_type: str = "gpu") -> None:
152
+ """
153
+ Remove a model from cache tracking.
154
+
155
+ Args:
156
+ model_id: Unique identifier for the model to remove
157
+ cache_type: Type of cache to remove from ("gpu" or "cpu")
158
+ """
159
+ cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
160
+ cache.pop(model_id, None)
161
+
162
+ def get_lru_models(
163
+ self, cache_type: str = "gpu", count: Optional[int] = None
164
+ ) -> List[str]:
165
+ """
166
+ Get least recently used models.
167
+
168
+ Args:
169
+ cache_type: Type of cache to query ("gpu" or "cpu")
170
+ count: Optional number of models to return. If None, returns all models
171
+
172
+ Returns:
173
+ List of model IDs ordered by least recently used first
174
+ """
175
+ cache = self.gpu_cache if cache_type == "gpu" else self.cpu_cache
176
+ model_list = list(cache.keys())
177
+ return model_list if count is None else model_list[:count]
178
+
179
+
180
+ class DefaultModelManager:
181
+ """
182
+ Manages loading, unloading and memory allocation of machine learning models.
183
+
184
+ Handles dynamic movement of models between GPU and CPU memory based on
185
+ available resources and usage patterns. Implements optimization strategies
186
+ for efficient memory usage and model performance.
187
+ """
188
+
189
+ def __init__(self):
190
+ """Initialize the model memory manager with empty states and default configurations."""
191
+ # Model tracking
192
+ self.current_model: Optional[str] = None
193
+ self.base_pipelines: Dict[str, DiffusionPipeline] = {}
194
+ self.loaded_models: Dict[str, DiffusionPipeline] = {}
195
+ self.cpu_models: Dict[str, DiffusionPipeline] = {}
196
+ self.model_sizes: Dict[str, float] = {}
197
+ self.model_types: Dict[str, torch.dtype] = {}
198
+ self.loaded_model: Optional[DiffusionPipeline] = None
199
+
200
+ self.pipeline_locks: Dict[str, threading.Lock] = {}
201
+
202
+ # Memory tracking
203
+ self.vram_usage: float = 0
204
+ self.ram_usage: float = 0
205
+ self.max_vram: float = self._get_total_vram() - DEFAULT_MAX_VRAM_BUFFER_GB
206
+ self.system_ram: float = psutil.virtual_memory().total / (1024**3)
207
+
208
+ # State flags
209
+ self.is_in_device: bool = False
210
+ self.is_startup_load: bool = False
211
+ self.environment: str = os.getenv("ENVIRONMENT", "dev")
212
+
213
+ # Managers and caches
214
+ self.model_downloader = get_model_downloader()
215
+ self.cache_dir = HF_HUB_CACHE
216
+ self.lru_cache = LRUCache()
217
+ self.flashpack_loader = FlashPackLoader()
218
+
219
+ self.allowed_model_ids: Optional[Set[str]] = None
220
+ self.supports_all_known_models: bool = False # Default to explicit list
221
+ self._download_lock = asyncio.Lock() # Lock for ensuring only one prioritized download happens at a time
222
+
223
+ device = "cuda" if torch.cuda.is_available() else "cpu"
224
+ set_available_torch_device(device)
225
+
226
+
227
+ async def process_supported_models_config(
228
+ self, supported_model_ids: List[str], downloader_instance: Optional[DownloaderType] = None # Can ignore if using self.model_downloader
229
+ ) -> None:
230
+ logger.info(f"DMM: Processing scheduler-provided supported_model_ids: {supported_model_ids}")
231
+ if len(supported_model_ids) == 1 and supported_model_ids[0] == "*": # "*" means all
232
+ self.supports_all_known_models = True
233
+ self.allowed_model_ids = None
234
+ logger.info("DMM: Configured to support ALL known models (on-demand) due to '*' from scheduler.")
235
+ elif not supported_model_ids: # Empty list means no specific dynamic models for this worker config
236
+ self.supports_all_known_models = False
237
+ self.allowed_model_ids = set()
238
+ logger.info("DMM: Configured with an empty supported model list from scheduler (supports no dynamic models).")
239
+ else: # Specific list
240
+ self.supports_all_known_models = False
241
+ self.allowed_model_ids = set(supported_model_ids)
242
+ logger.info(f"DMM: Configured to support specific models: {self.allowed_model_ids}")
243
+
244
+ async def load_model_into_vram(self, model_id: str) -> bool:
245
+ logger.info(f"DefaultModelManager: Request to load '{model_id}' into VRAM.")
246
+ pipeline = await self.load(model_id)
247
+ if pipeline:
248
+ self.base_pipelines[model_id] = pipeline
249
+ self.loaded_model = pipeline
250
+ self.current_model = model_id
251
+ self.loaded_models[model_id] = pipeline # this is for backward compatibility (will be removed in the future)
252
+ return True
253
+ return False
254
+
255
+ async def get_active_pipeline(self, model_id: str) -> Optional[Any]:
256
+ logger.info(f"DefaultModelManager: Request to get active pipeline for '{model_id}'.")
257
+ base_pipeline = await self.load(model_id)
258
+ if not base_pipeline:
259
+ logger.error(f"Failed to load base pipeline for {model_id}")
260
+ return None
261
+
262
+ # Create a task-specific pipeline instance with fresh scheduler state
263
+ task_pipeline = self.get_pipeline_for_task(model_id, "worker_task")
264
+ if not task_pipeline:
265
+ logger.error(f"Failed to create thread-safe pipeline for {model_id}")
266
+ return None
267
+
268
+ return task_pipeline
269
+
270
+ def get_vram_loaded_models(self) -> List[str]:
271
+ # This should return a list of model_ids that are currently in VRAM
272
+ return list[str](self.loaded_models.keys())
273
+
274
+ # def get_pipeline_for_task(self, model_id: str, run_id: str = None) -> Optional[DiffusionPipeline]:
275
+ # """
276
+ # Create a task-specific pipeline instance with fresh scheduler state.
277
+ # This prevents thread safety issues in concurrent executions.
278
+
279
+ # This is the KEY FIX for the IndexError: index 29 is out of bounds issue.
280
+ # """
281
+ # try:
282
+ # # Get the base pipeline (shared components)
283
+ # base_pipeline = self.get_base_pipeline(model_id)
284
+ # if not base_pipeline:
285
+ # logger.error(f"No base pipeline found for model {model_id}")
286
+ # return None
287
+
288
+ # # Create a fresh scheduler for this task to prevent state corruption
289
+ # fresh_scheduler = base_pipeline.scheduler.from_config(
290
+ # base_pipeline.scheduler.config
291
+ # )
292
+
293
+ # # Create task-specific pipeline with shared components but fresh scheduler
294
+ # task_pipeline = base_pipeline.__class__.from_pipe(
295
+ # base_pipeline,
296
+ # scheduler=fresh_scheduler
297
+ # )
298
+
299
+ # logger.info(f"[ModelManager] Created thread-safe pipeline for task {run_id} model {model_id}")
300
+ # return task_pipeline
301
+
302
+ # except Exception as e:
303
+ # logger.error(f"Error creating task-specific pipeline for {model_id}: {e}")
304
+ # return None
305
+
306
+ def get_pipeline_for_task(self, model_id: str, run_id: str = None) -> Optional[DiffusionPipeline]:
307
+ """
308
+ Return the base pipeline directly - no recreation needed.
309
+ Scheduler state issues are rare with proper usage.
310
+ """
311
+ base_pipeline = self.get_base_pipeline(model_id)
312
+ if not base_pipeline:
313
+ logger.error(f"No base pipeline found for model {model_id}")
314
+ return None
315
+
316
+ # Return the base pipeline directly - it's thread-safe enough for inference
317
+ logger.info(f"[ModelManager] Returning base pipeline for model {model_id}")
318
+ return base_pipeline
319
+
320
+ def get_base_pipeline(self, model_id: str) -> Optional[DiffusionPipeline]:
321
+ """
322
+ Get the base pipeline for component sharing.
323
+ This contains the loaded model components but should NOT be used directly for inference.
324
+ """
325
+ # Check base pipelines first
326
+ if model_id in self.base_pipelines:
327
+ return self.base_pipelines[model_id]
328
+
329
+ # Check loaded models (for backward compatibility)
330
+ if model_id in self.loaded_models:
331
+ return self.loaded_models[model_id]
332
+
333
+ # Check CPU models
334
+ if model_id in self.cpu_models:
335
+ return self.cpu_models[model_id]
336
+
337
+ return None
338
+
339
+
340
+ # If the scheduler is not set in `pipeline_defs`, then we'll rely on diffusers to pick a default
341
+ # scheduler.
342
+ def _setup_scheduler(self, pipeline: DiffusionPipeline, model_id: str) -> None:
343
+ """Setup scheduler from component config"""
344
+ config = get_config()
345
+ model_config = config.pipeline_defs.get(model_id)
346
+ if not model_config:
347
+ logger.error(f"Model {model_id} not found in configuration")
348
+ return
349
+ components = model_config.get("components", {})
350
+ scheduler_config = components.get("scheduler", {}) if components else {}
351
+ if not scheduler_config:
352
+ return
353
+
354
+ scheduler_class = scheduler_config.get("class_name")
355
+ scheduler_kwargs = scheduler_config.get("kwargs", {})
356
+
357
+ try:
358
+ INVALID_SCHEDULER_CONFIG_KEYS = {
359
+ 'do_classifier_free_guidance',
360
+ 'guidance_scale',
361
+ }
362
+
363
+ # Only pass kwargs that are valid for scheduler initialization
364
+ filtered_kwargs = {
365
+ k: v for k, v in scheduler_kwargs.items()
366
+ if k.lower() not in {key.lower() for key in INVALID_SCHEDULER_CONFIG_KEYS}
367
+ }
368
+
369
+ if len(scheduler_kwargs) != len(filtered_kwargs):
370
+ removed = set(scheduler_kwargs.keys()) - set(filtered_kwargs.keys())
371
+ logger.warning(f"Removed invalid scheduler config keys for {model_id}: {removed}")
372
+
373
+ new_scheduler = getattr(diffusers, scheduler_class).from_config(
374
+ pipeline.scheduler.config, **filtered_kwargs
375
+ )
376
+
377
+ pipeline.scheduler = new_scheduler
378
+
379
+ print(f"Successfully set scheduler to {scheduler_class}")
380
+ except Exception as e:
381
+ logger.error(f"Error setting scheduler for {model_id}: {e}")
382
+
383
+ def _get_memory_info(self) -> Tuple[float, float]:
384
+ """
385
+ Get current memory availability.
386
+
387
+ Returns:
388
+ Tuple containing:
389
+ - Available RAM in GB
390
+ - Available VRAM in GB
391
+ """
392
+ ram = self._get_available_ram()
393
+ vram = self._get_available_vram()
394
+ return ram, vram
395
+
396
+ def _get_available_ram(self) -> float:
397
+ """
398
+ Get available system RAM in GB.
399
+
400
+ Returns:
401
+ Available RAM in GB
402
+ """
403
+ return psutil.virtual_memory().available / (1024**3)
404
+
405
+
406
+ def _get_total_vram(self) -> float:
407
+ """
408
+ Get total VRAM available on the system.
409
+
410
+ Returns:
411
+ Total VRAM in GB, or 0 if no CUDA device is available
412
+ """
413
+ if torch.cuda.is_available():
414
+ return torch.cuda.get_device_properties(0).total_memory / (1024**3)
415
+ return 0
416
+
417
+ def _get_available_vram(self) -> float:
418
+ """
419
+ Get currently available VRAM.
420
+
421
+ Returns:
422
+ Available VRAM in GB
423
+ """
424
+ if torch.cuda.is_available():
425
+ available_vram_gb = (
426
+ torch.cuda.get_device_properties(0).total_memory
427
+ - torch.cuda.memory_allocated()
428
+ ) / (1024**3)
429
+ logger.debug(f"Available VRAM: {available_vram_gb:.2f} GB")
430
+ return available_vram_gb
431
+ return 100
432
+
433
+ def _need_optimization(self, model_size: float) -> bool:
434
+ """
435
+ Check if a model needs optimization to fit in GPU memory.
436
+
437
+ Args:
438
+ model_size: Size of the model in GB
439
+
440
+ Returns:
441
+ True if the model needs optimization, False otherwise
442
+ """
443
+ print(f"model_size: {model_size} GB")
444
+ print(f"system_ram: {self.system_ram} GB")
445
+ return model_size < self.system_ram
446
+
447
+ def _can_fit_gpu(self, model_size: float) -> bool:
448
+ """
449
+ Check if a model can fit in GPU memory.
450
+
451
+ Args:
452
+ model_size: Size of the model in GB
453
+
454
+ Returns:
455
+ True if the model can fit in GPU memory, False otherwise
456
+ """
457
+ print(f"model_size: {model_size} GB")
458
+ print(f"available_inference_memory: {self._get_available_vram() - VRAM_SAFETY_MARGIN_GB} GB")
459
+ return (model_size <= (self._get_available_vram() - VRAM_SAFETY_MARGIN_GB))
460
+
461
+
462
+ async def _ensure_model_files_on_disk(self, model_id: str, model_config_dict: Dict[str, Any]) -> bool:
463
+ """
464
+ Internal helper: Ensures a specific model and its components are downloaded.
465
+ This is called by self.load() before attempting to load from disk.
466
+ This method is now the primary point for on-demand downloads.
467
+ Returns True if all necessary files are confirmed/downloaded, False otherwise.
468
+ """
469
+ logger.info(f"DMM._ensure_model_files_on_disk: Checking/downloading '{model_id}'...")
470
+ all_parts_ok = True
471
+ async with self.model_downloader as downloader: # Use context manager for session
472
+ # 1. Main model source
473
+ main_source_str = model_config_dict.get("source")
474
+ if main_source_str:
475
+ is_main_downloaded, _ = await downloader.is_downloaded(model_id, model_config=model_config_dict)
476
+ if not is_main_downloaded:
477
+ logger.info(f"DMM: Main files for '{model_id}' (source: {main_source_str}) not on disk. Downloading...")
478
+ try:
479
+ await downloader.download_model(model_id, ModelSource(main_source_str))
480
+ is_main_downloaded, _ = await downloader.is_downloaded(model_id, model_config=model_config_dict) # Verify
481
+ if not is_main_downloaded:
482
+ logger.error(f"DMM: Download verification FAILED for main files of '{model_id}'.")
483
+ all_parts_ok = False
484
+ else:
485
+ logger.info(f"DMM: Main files for '{model_id}' downloaded and verified.")
486
+ except Exception as e:
487
+ logger.exception(f"DMM: Exception during download of main files for '{model_id}': {e}")
488
+ all_parts_ok = False
489
+ else:
490
+ logger.info(f"DMM: Main files for '{model_id}' already on disk.")
491
+ else:
492
+ logger.error(f"DMM: No source found for model '{model_id}' in its config. Cannot download.")
493
+ return False # Cannot proceed without a source
494
+
495
+ # 2. Components (if main download was okay or not needed, and components exist)
496
+ if all_parts_ok and model_config_dict.get("components") and isinstance(model_config_dict["components"], dict):
497
+ for comp_name, comp_details in model_config_dict["components"].items():
498
+ comp_source_str = None
499
+ if isinstance(comp_details, dict): comp_source_str = comp_details.get("source")
500
+ elif isinstance(comp_details, str): comp_source_str = comp_details
501
+
502
+ if comp_source_str:
503
+ comp_dl_id = f"{model_id}::{comp_name}"
504
+ # TODO: Need a robust way to check if component is downloaded via downloader.is_downloaded
505
+ # For now, we rely on downloader.download_model being idempotent.
506
+ logger.info(f"DMM: Ensuring component '{comp_name}' for '{model_id}' (source: {comp_source_str}) is downloaded...")
507
+ try:
508
+ await downloader.download_model(comp_dl_id, ModelSource(comp_source_str))
509
+ # Add verification for component if possible
510
+ logger.info(f"DMM: Component '{comp_name}' for '{model_id}' download attempt finished.")
511
+ except Exception as e:
512
+ logger.exception(f"DMM: Exception downloading component '{comp_name}' for '{model_id}': {e}")
513
+ all_parts_ok = False; break # Stop if a component fails
514
+ return all_parts_ok
515
+
516
+
517
+ async def load(
518
+ self, model_id: str, gpu: Optional[int] = None, pipe_type: Optional[str] = None
519
+ ) -> Optional[DiffusionPipeline]:
520
+ """
521
+ Load a model into memory, handling placement and optimization.
522
+
523
+ This method implements the main model loading logic, including:
524
+ - Checking if model is already loaded
525
+ - Managing memory allocation between GPU and CPU
526
+ - Applying optimizations as needed
527
+ - Handling model movement between devices
528
+
529
+ Args:
530
+ model_id: Identifier for the model to load
531
+ gpu: Optional GPU device number
532
+ pipe_type: Optional pipeline type specification
533
+
534
+ Returns:
535
+ Loaded pipeline or None if loading failed
536
+ """
537
+ logger.info(f"Processing model {model_id}")
538
+
539
+ if not self.supports_all_known_models:
540
+ if self.allowed_model_ids is None: # Should have been set if not supports_all
541
+ logger.error(f"DMM.load: Inconsistent state - worker not configured for all models, but allowed_model_ids is None. Denying load for '{model_id}'.")
542
+ return None
543
+ if model_id not in self.allowed_model_ids:
544
+ logger.error(f"DMM.load: Model '{model_id}' is NOT in the allowed list for this worker instance. Allowed: {self.allowed_model_ids}. Denying load.")
545
+ return None
546
+
547
+ # If self.supports_all_known_models is True, any model_id is principally allowed, contingent on it being defined in the global pipeline_defs.
548
+
549
+ logger.debug(f"DMM.load: Model '{model_id}' is permitted for this worker instance.")
550
+
551
+ # Check existing loaded models
552
+ if model_id in self.loaded_models:
553
+ logger.info(f"Model {model_id} already loaded in GPU")
554
+ self.lru_cache.access(model_id, "gpu")
555
+ return self.loaded_models[model_id]
556
+
557
+ # CPU-offloaded models
558
+ if model_id in self.cpu_models:
559
+ logger.info(f"Model {model_id} is CPU-offloaded and ready")
560
+ self.lru_cache.access(model_id, "cpu")
561
+ return self.cpu_models[model_id]
562
+
563
+
564
+ # On-Demand model download
565
+ cfg = get_config()
566
+ if not cfg or not cfg.pipeline_defs:
567
+ logger.error("DefaultModelManager: Global config or pipeline_defs not loaded. Cannot download models.")
568
+ return None
569
+
570
+ model_config_dict = cfg.pipeline_defs.get(model_id)
571
+ if not model_config_dict:
572
+ logger.error(f"Model {model_id} not found in configuration")
573
+ return None
574
+
575
+ async with self._download_lock: # Ensure only one download process at a time for this MMM instance
576
+ if not await self._ensure_model_files_on_disk(model_id, model_config_dict):
577
+ logger.error(f"Failed to ensure model files for {model_id} are on disk")
578
+ return None
579
+
580
+ # Load new model
581
+ return await self._load_new_model(model_id, gpu, pipe_type)
582
+
583
+
584
+ async def _load_new_model(
585
+ self, model_id: str, gpu: Optional[int] = None, pipe_type: Optional[str] = None
586
+ ) -> Optional[DiffusionPipeline]:
587
+ """
588
+ Load a new model that isn't currently in memory.
589
+
590
+ Handles the complete loading process including:
591
+ - Configuration validation
592
+ - Memory allocation
593
+ - Model loading
594
+ - Optimization application
595
+
596
+ Args:
597
+ model_id: Identifier for the model
598
+ gpu: Optional GPU device number
599
+ pipe_type: Optional pipeline type specification (deprecated, use class_name in config)
600
+
601
+ Returns:
602
+ Loaded pipeline or None if loading failed
603
+ """
604
+ try:
605
+ config = get_config()
606
+ model_config = config.pipeline_defs.get(model_id)
607
+
608
+ if not model_config:
609
+ logger.error(f"Model {model_id} not found in configuration")
610
+ return None
611
+
612
+ # Prepare memory
613
+ estimated_size = await self._get_model_size(model_config, model_id) + 1.0 # add 1GB for safety margin (model inference overhead in memory)
614
+
615
+ print(f"estimated_size for model {model_id}: {estimated_size} GB")
616
+
617
+ # try direct gpu load
618
+ if self._can_fit_gpu(estimated_size):
619
+ pipeline = await self._load_model_by_source(model_id, model_config)
620
+ if pipeline is None:
621
+ return None
622
+
623
+ self._setup_scheduler(pipeline, model_id)
624
+ if self._move_model_to_gpu(pipeline, model_id):
625
+ self.loaded_models[model_id] = pipeline
626
+ self.model_sizes[model_id] = estimated_size
627
+ self.vram_usage += estimated_size
628
+ self.lru_cache.access(model_id, "gpu")
629
+ # Background prefetch other deployment models
630
+ if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
631
+ if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
632
+ sources = {}
633
+ for mid in self.allowed_model_ids:
634
+ if mid != model_id:
635
+ model_cfg = config.pipeline_defs.get(mid)
636
+ if model_cfg:
637
+ sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
638
+ if sources:
639
+ logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
640
+ asyncio.create_task(
641
+ self.flashpack_loader.prefetch_deployment_models(
642
+ list(self.allowed_model_ids), sources, exclude_model_id=model_id
643
+ )
644
+ )
645
+ return pipeline
646
+ else:
647
+ logger.error(f"Failed to move {model_id} to GPU")
648
+ return None
649
+
650
+ # TODO: Change this ugly condition
651
+ if estimated_size <= self.max_vram - (VRAM_SAFETY_MARGIN_GB - DEFAULT_MAX_VRAM_BUFFER_GB) and (len(self.loaded_models) > 0 or len(self.cpu_models) > 0):
652
+ # if not enough space, try to make space
653
+ self._free_space_for_model(estimated_size)
654
+ if self._can_fit_gpu(estimated_size):
655
+ pipeline = await self._load_model_by_source(model_id, model_config)
656
+ if pipeline is None:
657
+ return None
658
+
659
+ self._setup_scheduler(pipeline, model_id)
660
+ if self._move_model_to_gpu(pipeline, model_id):
661
+ self.loaded_models[model_id] = pipeline
662
+ self.model_sizes[model_id] = estimated_size
663
+ self.vram_usage += estimated_size
664
+ self.lru_cache.access(model_id, "gpu")
665
+ # Background prefetch other deployment models
666
+ if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
667
+ if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
668
+ sources = {}
669
+ for mid in self.allowed_model_ids:
670
+ if mid != model_id:
671
+ model_cfg = config.pipeline_defs.get(mid)
672
+ if model_cfg:
673
+ sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
674
+ if sources:
675
+ logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
676
+ asyncio.create_task(
677
+ self.flashpack_loader.prefetch_deployment_models(
678
+ list(self.allowed_model_ids), sources, exclude_model_id=model_id
679
+ )
680
+ )
681
+ return pipeline
682
+ else:
683
+ logger.error(f"Failed to move {model_id} to GPU")
684
+ return None
685
+
686
+ # if still not enough VRAM, apply optimization (CPU offload)
687
+ if self.environment != "prod":
688
+ if self._need_optimization(estimated_size):
689
+ logger.info("Unloading all models for large model loading")
690
+
691
+ for model_id_to_unload in list(self.loaded_models.keys()):
692
+ self._unload_model_for_space(model_id_to_unload, self.model_sizes[model_id_to_unload], "gpu")
693
+ for model_id_to_unload in list(self.cpu_models.keys()):
694
+ self._unload_model_for_space(model_id_to_unload, self.model_sizes[model_id_to_unload], "cpu")
695
+
696
+ pipeline = await self._load_model_by_source(model_id, model_config)
697
+ if pipeline is None:
698
+ return None
699
+
700
+ self._setup_scheduler(pipeline, model_id)
701
+ logger.info(f"Applying optimizations for {model_id}")
702
+ self.apply_optimizations(pipeline, model_id, True)
703
+
704
+ self.cpu_models[model_id] = pipeline
705
+ self.model_sizes[model_id] = estimated_size
706
+ self.lru_cache.access(model_id, "cpu")
707
+ # Background prefetch other deployment models
708
+ if hasattr(self, 'flashpack_loader') and self.flashpack_loader.local_cache:
709
+ if self.allowed_model_ids and len(self.allowed_model_ids) > 1:
710
+ sources = {}
711
+ for mid in self.allowed_model_ids:
712
+ if mid != model_id:
713
+ model_cfg = config.pipeline_defs.get(mid)
714
+ if model_cfg:
715
+ sources[mid] = model_cfg.get("source") if isinstance(model_cfg, dict) else model_cfg.source
716
+ if sources:
717
+ logger.info(f"🔄 Starting background prefetch for {len(sources)} other models")
718
+ asyncio.create_task(
719
+ self.flashpack_loader.prefetch_deployment_models(
720
+ list(self.allowed_model_ids), sources, exclude_model_id=model_id
721
+ )
722
+ )
723
+ return pipeline
724
+
725
+ logger.error(f"Insufficient memory to load model {model_id}")
726
+ return None
727
+
728
+ except Exception as e:
729
+ logger.error("Error loading new model {}: {}".format(model_id, str(e)))
730
+ traceback.print_exc()
731
+ return None
732
+
733
+ async def _load_model_by_source(
734
+ self, model_id: str, model_config: PipelineConfig) -> Optional[DiffusionPipeline]:
735
+ """
736
+ Load a model based on its source configuration.
737
+
738
+ Args:
739
+ model_id: Model identifier
740
+ model_config: Model configuration from PipelineConfig
741
+ gpu: Optional GPU device number
742
+
743
+ Returns:
744
+ Loaded pipeline or None if loading failed
745
+ """
746
+
747
+ if isinstance(model_config, dict):
748
+ source = model_config.get("source")
749
+ class_name = model_config.get("class_name")
750
+ else:
751
+ source = model_config.source
752
+ class_name = model_config.class_name
753
+
754
+ prefix, path = source.split(":", 1)
755
+
756
+ try:
757
+ # === LOCAL CACHE + FLASHPACK ===
758
+ # Try FlashPack first
759
+ flashpack_path = self.flashpack_loader.get_flashpack_path(model_id, source)
760
+ if flashpack_path:
761
+ (pipeline_class, _) = get_pipeline_class(class_name)
762
+ pipeline = await self.flashpack_loader.load_from_flashpack(
763
+ model_id, flashpack_path, pipeline_class
764
+ )
765
+ if pipeline:
766
+ return pipeline
767
+ logger.warning(f"FlashPack loading failed for {model_id}, falling back")
768
+
769
+ # No FlashPack - copy safetensors to local cache anyway
770
+ if self.flashpack_loader.local_cache:
771
+ local_path = await self.flashpack_loader.local_cache.ensure_local(
772
+ model_id, source, priority=True
773
+ )
774
+ if local_path:
775
+ logger.info(f"📂 Loading from local cache: {local_path}")
776
+ if local_path.suffix == ".safetensors" or local_path.is_file():
777
+ path = str(local_path)
778
+ prefix = "file"
779
+ else:
780
+ # It's a directory (HF model)
781
+ path = str(local_path)
782
+ prefix = "hf"
783
+
784
+ # === STANDARD LOADING: Fallback to standard loading if FlashPack fails ===
785
+ if prefix == "hf":
786
+ is_downloaded, variant = await self.model_downloader.is_downloaded(
787
+ model_id,
788
+ model_config
789
+ )
790
+
791
+ if not is_downloaded:
792
+ logger.info(f"Model {model_id} not downloaded")
793
+ return None
794
+
795
+ # Get model index and use as fallback in case class_name is unspecified
796
+ if class_name is None or class_name == "":
797
+ model_index = await self.model_downloader.get_diffusers_multifolder_components(
798
+ path
799
+ )
800
+ if model_index and "_class_name" in model_index:
801
+ class_name = model_index["_class_name"]
802
+ else:
803
+ logger.error(f"Unknown diffusers class_name for {model_id}")
804
+ return None
805
+
806
+ return await self.load_huggingface_model(
807
+ model_id, path, class_name, variant, model_config
808
+ )
809
+ elif prefix in ["file", "ct"]:
810
+ return await self.load_single_file_model(
811
+ model_id, path, prefix, class_name
812
+ )
813
+ elif source.startswith(("http://", "https://")):
814
+ # Handle Civitai/direct download models
815
+ source_obj = ModelSource(source)
816
+ is_downloaded = await self.model_downloader.is_downloaded(model_id)
817
+ if not is_downloaded:
818
+ logger.info(f"Model {model_id} not downloaded")
819
+ return None
820
+
821
+ cached_path = await self.model_downloader._get_cache_path(
822
+ model_id, source_obj
823
+ )
824
+ if not os.path.exists(cached_path):
825
+ logger.error(f"Cached model file not found at {cached_path}")
826
+ return None
827
+
828
+ return await self.load_single_file_model(
829
+ model_id, cached_path, "file", class_name
830
+ )
831
+ else:
832
+ logger.error(f"Unsupported model source prefix: {prefix}")
833
+ return None
834
+ except Exception as e:
835
+ logger.error("Error loading model from source: {}".format(str(e)))
836
+ return None
837
+
838
+ def _free_space_for_model(self, model_size: float) -> None:
839
+ available_vram = self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
840
+ if available_vram >= model_size:
841
+ return
842
+
843
+ space_needed = model_size - available_vram
844
+ logger.info(f"Need to free {space_needed:.2f} GB of VRAM")
845
+
846
+ gpu_models = [
847
+ (mid, self.model_sizes[mid])
848
+ for mid in self.lru_cache.get_lru_models("gpu")
849
+ ]
850
+
851
+ freed_space = 0
852
+ for model_id_to_unload, size in gpu_models:
853
+ if freed_space >= space_needed:
854
+ break
855
+
856
+ freed_space += self._unload_model_for_space(model_id_to_unload, size, "gpu")
857
+
858
+ if self._get_available_vram() - VRAM_SAFETY_MARGIN_GB >= model_size:
859
+ break
860
+
861
+
862
+ def _unload_model_for_space(
863
+ self, model_id: str, model_size: float, device: str
864
+ ) -> float:
865
+ """
866
+ Completely unload a model and free up GPU/CPU memory.
867
+
868
+ Args:
869
+ model_id: Model identifier
870
+ model_size: Size of the model in GB
871
+ device: Device to unload from ("gpu" or "cpu")
872
+ Returns:
873
+ Amount of space freed in GB
874
+ """
875
+ logger.info(f"Unloading {model_id} from memory")
876
+
877
+ try:
878
+ if device == "gpu" and model_id in self.loaded_models:
879
+ pipeline = self.loaded_models.pop(model_id)
880
+
881
+ if hasattr(pipeline, "remove_all_hooks"):
882
+ logger.info(f"Removing all hooks for model {model_id}")
883
+ pipeline.remove_all_hooks()
884
+
885
+ # # Move model to CPU first to clear CUDA memory
886
+ # if hasattr(pipeline, "to"):
887
+ # pipeline.to("cpu", silence_dtype_warnings=True)
888
+
889
+ # Explicitly delete model components
890
+ for attr in [
891
+ "vae",
892
+ "unet",
893
+ "text_encoder",
894
+ "text_encoder_2",
895
+ "tokenizer",
896
+ "scheduler",
897
+ "transformer",
898
+ "tokenizer_2",
899
+ "text_encoder_3",
900
+ "tokenizer_3",
901
+ ]:
902
+ if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
903
+ component = getattr(pipeline, attr)
904
+ delattr(pipeline, attr)
905
+ del component
906
+
907
+ # Delete pipeline reference
908
+ # del self.loaded_models[model_id]
909
+ self.vram_usage -= model_size
910
+ self.lru_cache.remove(model_id, "gpu")
911
+ del pipeline
912
+
913
+ if model_id in self.cpu_models:
914
+ logger.info(f"Unloading {model_id} from CPU memory")
915
+ cpu_pipeline = self.cpu_models.pop(model_id)
916
+
917
+ if hasattr(cpu_pipeline, "remove_all_hooks"):
918
+ logger.info(f"Removing all hooks for CPU-offloaded model {model_id}")
919
+ cpu_pipeline.remove_all_hooks()
920
+
921
+ self.lru_cache.remove(model_id, "cpu")
922
+ del cpu_pipeline
923
+
924
+ # Force garbage collection and memory clearing
925
+ self.flush_memory()
926
+
927
+ return model_size
928
+
929
+ except Exception as e:
930
+ logger.error(
931
+ "Error during model unloading for {}: {}".format(model_id, str(e))
932
+ )
933
+ # Still try to clean up references even if error occurs
934
+ self.loaded_models.pop(model_id, None)
935
+ self.cpu_models.pop(model_id, None)
936
+ self.vram_usage -= model_size
937
+ self.lru_cache.remove(model_id, "gpu")
938
+ self.lru_cache.remove(model_id, "cpu")
939
+ self.flush_memory()
940
+ return model_size
941
+
942
+ def _move_model_to_gpu(self, pipeline: DiffusionPipeline, model_id: str) -> bool:
943
+ """
944
+ Safely move a model to GPU memory with proper dtype handling.
945
+
946
+ Args:
947
+ pipeline: The pipeline to move
948
+ model_id: Identifier of the model
949
+
950
+ Returns:
951
+ Boolean indicating success of the operation
952
+ """
953
+ # Skip if it's OmniGen pipeline
954
+ if pipeline.__class__.__name__ == "OmniGenPipeline":
955
+ return True
956
+
957
+ try:
958
+ device = get_available_torch_device()
959
+ print("Moving to device", device)
960
+
961
+ # check if pippeline supports vae tiling and slicing
962
+ if hasattr(pipeline, "enable_vae_tiling"):
963
+ logger.info("Enabling vae tiling")
964
+ pipeline.enable_vae_tiling()
965
+ if hasattr(pipeline, "enable_vae_slicing"):
966
+ logger.info("Enabling vae slicing")
967
+ pipeline.enable_vae_slicing()
968
+
969
+ pipeline = pipeline.to(device=device)
970
+
971
+ print("Done with moving to GPU")
972
+
973
+ self.flush_memory()
974
+ return True
975
+
976
+ except RuntimeError as e:
977
+ if "out of memory" in str(e):
978
+ logger.error(f"GPU out of memory while moving model {model_id} to GPU")
979
+ else:
980
+ logger.error(
981
+ "Runtime error moving model {} to GPU: {}".format(model_id, str(e))
982
+ )
983
+ self.flush_memory()
984
+ return False
985
+
986
+ async def load_huggingface_model(
987
+ self,
988
+ model_id: str,
989
+ repo_id: str,
990
+ class_name: Union[str, Tuple[str, str]],
991
+ variant: Optional[str] = None,
992
+ model_config: Optional[PipelineConfig] = None,
993
+ ) -> Optional[DiffusionPipeline]:
994
+ """
995
+ Load a model from HuggingFace.
996
+
997
+ Args:
998
+ model_id: Model identifier
999
+ repo_id: HuggingFace repository ID
1000
+ gpu: Optional GPU device number
1001
+ class_name: Pipeline class name or (package, class) tuple
1002
+ variant: Optional model variant
1003
+ model_config: Optional model configuration
1004
+
1005
+ Returns:
1006
+ Loaded pipeline or None if loading failed
1007
+ """
1008
+ try:
1009
+ pipeline_kwargs = await self._prepare_pipeline_kwargs(model_config, variant)
1010
+ variant = None if variant == "" else variant
1011
+ # TO DO: make this more robust
1012
+ torch_dtype = (
1013
+ torch.bfloat16 if "flux" in model_id.lower() else torch.float16
1014
+ )
1015
+
1016
+ # Get appropriate pipeline class
1017
+ (pipeline_class, custom_pipeline) = get_pipeline_class(class_name)
1018
+ print("custompipeline", custom_pipeline)
1019
+
1020
+ if custom_pipeline is not None:
1021
+ pipeline_kwargs["custom_pipeline"] = custom_pipeline
1022
+
1023
+ print(
1024
+ f"repo_id={repo_id},torch_dtype={torch_dtype},local_files_only=True,variant={variant},pipeline_kwargs={pipeline_kwargs},"
1025
+ )
1026
+ try:
1027
+ pipeline = await asyncio.to_thread(pipeline_class.from_pretrained,
1028
+ repo_id,
1029
+ torch_dtype=torch_dtype,
1030
+ local_files_only=True,
1031
+ variant=variant,
1032
+ **pipeline_kwargs,
1033
+ )
1034
+ except EntryNotFoundError as e:
1035
+ print(f"Custom pipeline '{custom_pipeline}' not found: {e}")
1036
+ print("Falling back to the default pipeline...")
1037
+ del pipeline_kwargs["custom_pipeline"]
1038
+ pipeline = await asyncio.to_thread(pipeline_class.from_pretrained,
1039
+ repo_id,
1040
+ variant=variant,
1041
+ torch_dtype=torch_dtype,
1042
+ **pipeline_kwargs,
1043
+ )
1044
+
1045
+ self.flush_memory()
1046
+ self.loaded_model = pipeline
1047
+ self.current_model = model_id
1048
+ logger.info(f"Model {model_id} loaded successfully")
1049
+ return pipeline
1050
+
1051
+ except Exception as e:
1052
+ traceback.print_exc()
1053
+ logger.error("Failed to load model {}: {}".format(model_id, str(e)))
1054
+ return None
1055
+
1056
+ async def _prepare_pipeline_kwargs(
1057
+ self, model_config: Optional[PipelineConfig], variant: Optional[str] = None
1058
+ ) -> Dict[str, Any]:
1059
+ """
1060
+ Prepare kwargs for pipeline initialization.
1061
+
1062
+ Args:
1063
+ model_config: Model configuration from PipelineConfig
1064
+
1065
+ Returns:
1066
+ Dictionary of pipeline initialization arguments
1067
+ """
1068
+ if not model_config:
1069
+ return {}
1070
+
1071
+ pipeline_kwargs = {}
1072
+ try:
1073
+ if isinstance(model_config, dict):
1074
+ class_name = model_config.get("class_name")
1075
+ components = model_config.get("components")
1076
+ custom_pipeline = model_config.get("custom_pipeline")
1077
+ else:
1078
+ class_name = model_config.class_name
1079
+ components = model_config.components
1080
+
1081
+ if custom_pipeline:
1082
+ pipeline_kwargs["custom_pipeline"] = custom_pipeline
1083
+
1084
+ if components:
1085
+ for key, component in components.items():
1086
+ main_model_source = model_config.get("source")
1087
+ # Check if component and also id component has .source attribute
1088
+ if component:
1089
+ component_source = component.get("source", None)
1090
+ if component_source:
1091
+ pipeline_kwargs[key] = await self._prepare_component(
1092
+ main_model_source, component, class_name, key, variant
1093
+ )
1094
+
1095
+ return pipeline_kwargs
1096
+
1097
+ except Exception as e:
1098
+ logger.error("Error preparing pipeline kwargs: {}".format(str(e)))
1099
+ return {}
1100
+
1101
+ async def _prepare_component(
1102
+ self,
1103
+ main_model_source: str,
1104
+ component: PipelineConfig,
1105
+ model_class_name: Optional[Union[str, Tuple[str, str]]],
1106
+ key: str,
1107
+ variant: Optional[str] = None,
1108
+ ) -> Any:
1109
+ """
1110
+ Prepare a model component based on its configuration.
1111
+
1112
+ Args:
1113
+ component: Component configuration
1114
+ model_class_name: Class name of the parent model
1115
+ key: Component key
1116
+
1117
+ Returns:
1118
+ Loaded component or None if loading failed
1119
+ """
1120
+ try:
1121
+ if isinstance(component, dict):
1122
+ source = component.get("source")
1123
+ else:
1124
+ source = component.source
1125
+
1126
+ if not source.endswith((".safetensors", ".bin", ".ckpt", ".pt")):
1127
+ # check if the url has more than 2 forward slashes. If it does, the last one is the subfolder, the source is the first part
1128
+ # e.g. hf:cozy-creator/Flux.1-schnell-8bit/transformer this will be, source = hf:cozy-creator/Flux.1-schnell-8bit, subfolder = transformer
1129
+ if source.count("/") > 1:
1130
+ repo_id = "/".join(source.split("/")[:-1])
1131
+ subfolder = source.split("/")[-1]
1132
+ return await self._load_diffusers_component(
1133
+ main_model_source.replace("hf:", ""),
1134
+ repo_id.replace("hf:", ""),
1135
+ subfolder,
1136
+ variant,
1137
+ )
1138
+ else:
1139
+ return await self._load_diffusers_component(
1140
+ main_model_source.replace("hf:", ""),
1141
+ source.replace("hf:", ""),
1142
+ variant,
1143
+ )
1144
+ else:
1145
+ return self._load_custom_component(source, model_class_name, key)
1146
+ except Exception as e:
1147
+ logger.error("Error preparing component {}: {}".format(key, str(e)))
1148
+ return None
1149
+
1150
+ async def load_single_file_model(
1151
+ self,
1152
+ model_id: str,
1153
+ path: str,
1154
+ prefix: str,
1155
+ class_name: Optional[Union[str, Tuple[str, str]]] = None,
1156
+ ) -> Optional[DiffusionPipeline]:
1157
+ """
1158
+ Load a model from a single file.
1159
+
1160
+ Args:
1161
+ model_id: Model identifier
1162
+ path: Path to model file
1163
+ prefix: Source prefix (file/ct)
1164
+ gpu: Optional GPU device number
1165
+ class_name: Pipeline class name or (package, class) tuple
1166
+
1167
+ Returns:
1168
+ Loaded pipeline or None if loading failed
1169
+ """
1170
+ logger.info(f"Loading single file model {model_id}")
1171
+
1172
+ # TO DO: we could try inferring the class using our old detect_model code here!
1173
+ if class_name is None or class_name == "":
1174
+ logger.error("Model class_name must be specified for single file models")
1175
+ return None
1176
+
1177
+ (pipeline_class, custom_pipeline) = get_pipeline_class(class_name)
1178
+
1179
+ try:
1180
+ print(f"Model path: {path}")
1181
+ if prefix != "file":
1182
+ model_path = self._get_model_path(path, prefix)
1183
+ if not os.path.exists(model_path):
1184
+ logger.error(f"Model file not found: {model_path}")
1185
+ return None
1186
+ else:
1187
+ model_path = path
1188
+
1189
+ pipeline = await self._load_pipeline_from_file(
1190
+ pipeline_class, model_path, model_id, class_name
1191
+ )
1192
+ if pipeline:
1193
+ self.loaded_model = pipeline
1194
+ self.current_model = model_id
1195
+ return pipeline
1196
+
1197
+ except Exception as e:
1198
+ logger.error("Error loading single file model: {}".format(str(e)))
1199
+ return None
1200
+
1201
+ def _get_model_path(self, path: str, prefix: str) -> str:
1202
+ if prefix == "ct" and ("http" in path or "https" in path):
1203
+ path = path.split("/")[-1]
1204
+ return os.path.join(get_config().models_path, path)
1205
+
1206
+ async def _load_pipeline_from_file(
1207
+ self,
1208
+ pipeline_class: Type[DiffusionPipeline],
1209
+ model_path: str,
1210
+ model_id: str,
1211
+ class_name: Optional[Union[str, Tuple[str, str]]],
1212
+ ) -> Optional[DiffusionPipeline]:
1213
+ """
1214
+ Load a pipeline from a file using appropriate loading method.
1215
+
1216
+ Args:
1217
+ pipeline_class: Class to instantiate pipeline
1218
+ model_path: Path to model file
1219
+ model_id: Model identifier
1220
+ class_name: Pipeline class name or (package, class) tuple
1221
+
1222
+ Returns:
1223
+ Loaded pipeline or None if loading failed
1224
+ """
1225
+ try:
1226
+ if issubclass(pipeline_class, FromSingleFileMixin):
1227
+ return self._load_from_single_file(pipeline_class, model_path, model_id)
1228
+ else:
1229
+ # Get the actual class name string if it's a tuple
1230
+ class_str = (
1231
+ class_name[1] if isinstance(class_name, tuple) else class_name
1232
+ )
1233
+ return self._load_custom_architecture(
1234
+ pipeline_class, model_path, class_str
1235
+ )
1236
+ except Exception as e:
1237
+ logger.error("Error loading pipeline from file: {}".format(str(e)))
1238
+ return None
1239
+
1240
+ def _load_from_single_file(
1241
+ self, pipeline_class: Any, path: str, model_id: str
1242
+ ) -> DiffusionPipeline:
1243
+ """
1244
+ Load a model pipeline from a single file using the FromSingleFileMixin.
1245
+
1246
+ Uses different torch datatypes based on model type:
1247
+ - bfloat16 for Flux models
1248
+ - float16 for other models
1249
+
1250
+ Args:
1251
+ pipeline_class: The pipeline class to instantiate
1252
+ path: Path to the model file
1253
+ model_id: Model identifier (used to determine model type)
1254
+
1255
+ Returns:
1256
+ Loaded pipeline instance
1257
+
1258
+ Raises:
1259
+ Exception: If loading fails
1260
+ """
1261
+ try:
1262
+ # Determine appropriate dtype based on model type
1263
+ torch_dtype = torch.bfloat16 if "flux" in model_id.lower() else torch.float16
1264
+
1265
+ # Load the model with appropriate dtype
1266
+ pipeline = pipeline_class.from_single_file(path, torch_dtype=torch_dtype)
1267
+
1268
+ logger.info(f"Successfully loaded single file model {model_id}")
1269
+ return pipeline
1270
+
1271
+ except Exception as e:
1272
+ logger.error("Error loading model from single file: {}".format(str(e)))
1273
+ raise
1274
+
1275
+ def _load_custom_architecture(
1276
+ self, pipeline_class: Any, path: str, type: str
1277
+ ) -> DiffusionPipeline:
1278
+ """
1279
+ Load a model with custom architecture configuration.
1280
+
1281
+ Handles loading of individual components and assembling them into
1282
+ a complete pipeline.
1283
+
1284
+ Args:
1285
+ pipeline_class: The pipeline class to instantiate
1286
+ path: Path to the model file
1287
+ type: Model type (determines which components to load)
1288
+
1289
+ Returns:
1290
+ Assembled pipeline instance
1291
+
1292
+ Raises:
1293
+ Exception: If loading fails or architecture is not found
1294
+ """
1295
+ try:
1296
+ # Load the complete state dict
1297
+ state_dict = load_state_dict_from_file(path)
1298
+
1299
+ # Create empty pipeline instance
1300
+ pipeline = pipeline_class()
1301
+
1302
+ # Load each component specified for this model type
1303
+ for component_name in MODEL_COMPONENTS[type]:
1304
+ # Skip certain components that don't need loading
1305
+ if component_name in ["scheduler", "tokenizer", "tokenizer_2"]:
1306
+ continue
1307
+
1308
+ # Construct architecture key and get class
1309
+ arch_key = f"core_extension_1.{type}_{component_name}"
1310
+ architecture_class = get_architectures().get(arch_key)
1311
+
1312
+ if not architecture_class:
1313
+ logger.error(f"Architecture not found for {arch_key}")
1314
+ continue
1315
+
1316
+ try:
1317
+ # Initialize and load the component
1318
+ architecture = architecture_class()
1319
+ architecture.load(state_dict)
1320
+
1321
+ # Set the component in the pipeline
1322
+ setattr(pipeline, component_name, architecture.model)
1323
+ logger.debug(f"Loaded component {component_name} for {type}")
1324
+
1325
+ except Exception as component_error:
1326
+ logger.error(
1327
+ "Error loading component {}: {}".format(
1328
+ component_name, str(component_error)
1329
+ )
1330
+ )
1331
+ continue
1332
+
1333
+ logger.info("Successfully loaded custom architecture model")
1334
+ return pipeline
1335
+
1336
+ except Exception as e:
1337
+ logger.error("Error loading custom architecture: {}".format(str(e)))
1338
+ raise
1339
+
1340
+ def apply_optimizations(
1341
+ self,
1342
+ pipeline: DiffusionPipeline,
1343
+ model_id: str,
1344
+ force_full_optimization: bool = False,
1345
+ ) -> None:
1346
+ """
1347
+ Apply memory and performance optimizations to a pipeline.
1348
+ Uses existing optimization functions when model exceeds total VRAM.
1349
+
1350
+ Args:
1351
+ pipeline: The pipeline to optimize
1352
+ model_id: Model identifier
1353
+ force_full_optimization: Whether to force optimization
1354
+ """
1355
+ # Skip if it's OmniGen pipeline
1356
+ if pipeline.__class__.__name__ == "OmniGenPipeline":
1357
+ return
1358
+
1359
+ if self.loaded_model is not None and self.is_in_device:
1360
+ logger.info(f"Model {model_id} already optimized")
1361
+ return
1362
+
1363
+ device = get_available_torch_device()
1364
+
1365
+ # Only optimize if model is bigger than total VRAM
1366
+ model_size = self.model_sizes.get(model_id, 0)
1367
+ if model_size > self.max_vram or force_full_optimization:
1368
+ logger.info(f"Applying optimizations for {model_id}")
1369
+
1370
+ # Get list of optimizations
1371
+ optimizations = self._get_full_optimizations()
1372
+
1373
+ # Apply the optimizations
1374
+ self._apply_optimization_list(pipeline, optimizations, device)
1375
+ else:
1376
+ logger.info(f"No optimization needed for {model_id}")
1377
+
1378
+ def _get_full_optimizations(self) -> List[Tuple[str, str, Dict[str, Any]]]:
1379
+ """
1380
+ Get list of all available optimizations.
1381
+
1382
+ Returns:
1383
+ List of tuples containing (optimization_function, name, parameters)
1384
+ """
1385
+ optimizations = [
1386
+ ("enable_vae_slicing", "VAE Sliced", {}),
1387
+ ("enable_vae_tiling", "VAE Tiled", {}),
1388
+ (
1389
+ "enable_model_cpu_offload",
1390
+ "CPU Offloading",
1391
+ {"device": get_available_torch_device()},
1392
+ ),
1393
+ ]
1394
+
1395
+ if not isinstance(self.loaded_model, (FluxPipeline, FluxInpaintPipeline)):
1396
+ optimizations.append(
1397
+ (
1398
+ "enable_xformers_memory_efficient_attention",
1399
+ "Memory Efficient Attention",
1400
+ {},
1401
+ )
1402
+ )
1403
+
1404
+ return optimizations
1405
+
1406
+ def _apply_optimization_list(
1407
+ self,
1408
+ pipeline: DiffusionPipeline,
1409
+ optimizations: List[Tuple[str, str, Dict[str, Any]]],
1410
+ device: torch.device,
1411
+ ) -> None:
1412
+ """
1413
+ Apply a list of optimizations to a pipeline.
1414
+
1415
+ Args:
1416
+ pipeline: The pipeline to optimize
1417
+ optimizations: List of optimization specifications
1418
+ device: Target device
1419
+ """
1420
+ device_type = device if isinstance(device, str) else device.type
1421
+
1422
+ if device_type == "mps":
1423
+ setattr(torch, "mps", torch.backends.mps)
1424
+
1425
+ for opt_func, opt_name, kwargs in optimizations:
1426
+ try:
1427
+ getattr(pipeline, opt_func)(**kwargs)
1428
+ logger.info("{} enabled".format(opt_name))
1429
+ except Exception as e:
1430
+ logger.error("Error enabling {}: {}".format(opt_name, str(e)))
1431
+
1432
+ if device_type == "mps":
1433
+ delattr(torch, "mps")
1434
+
1435
+ def flush_memory(self) -> None:
1436
+ """Clear unused memory from GPU and perform garbage collection."""
1437
+ gc.collect()
1438
+ if torch.cuda.is_available():
1439
+ torch.cuda.empty_cache()
1440
+ torch.cuda.ipc_collect()
1441
+ torch.cuda.synchronize()
1442
+ logger.info("GPU memory flushed successfully")
1443
+ elif torch.backends.mps.is_available():
1444
+ pass # MPS doesn't need explicit cache clearing
1445
+
1446
+ async def initialize_startup_models(self, model_ids: List[str]) -> None:
1447
+ """Inititalize models at startup with randomization"""
1448
+ if not model_ids:
1449
+ logger.info("No models configured for enabled models")
1450
+ return
1451
+
1452
+ self.is_startup_load = True
1453
+ try:
1454
+ model_configs = {}
1455
+ model_sizes = {}
1456
+
1457
+ for model_id in model_ids:
1458
+ model_config = get_config().pipeline_defs[model_id]
1459
+ if not model_config:
1460
+ logger.warning(f"Model {model_id} not found in pipeline_defs")
1461
+ continue
1462
+
1463
+ try:
1464
+ size = await self._get_model_size(model_config, model_id) + 1.0 # add 1GB for safety margin (model inference overhead in memory)
1465
+ model_sizes[model_id] = size
1466
+ model_configs[model_id] = model_config
1467
+ except Exception as e:
1468
+ logger.error(f"Error getting model size for {model_id}: {e}")
1469
+ continue
1470
+
1471
+ # Create randomized order of models
1472
+ available_models = list(model_ids)
1473
+ # available_vram = self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
1474
+
1475
+ rng = np.random.default_rng()
1476
+ random_models = rng.permutation(available_models).tolist()
1477
+
1478
+ logger.info(f"Loading models in random order: {random_models}")
1479
+ total_loaded = 0
1480
+ total_size = 0
1481
+
1482
+ for model_id in random_models:
1483
+ estimated_size = model_sizes[model_id]
1484
+
1485
+ # Check if exceed available VRAM
1486
+ if estimated_size > (
1487
+ self._get_available_vram() - VRAM_SAFETY_MARGIN_GB
1488
+ ):
1489
+ logger.info(
1490
+ f"Stopping model loading: Next model {model_id} "
1491
+ f"({estimated_size:.2f} GB) would exceed available inference Memory "
1492
+ f"({self._get_available_vram() - VRAM_SAFETY_MARGIN_GB:.2f} GB)"
1493
+ )
1494
+ break
1495
+
1496
+ try:
1497
+ pipeline = await self.load(model_id)
1498
+ if pipeline is not None:
1499
+ total_loaded += 1
1500
+ total_size += estimated_size
1501
+ logger.info(
1502
+ f"Successfully loaded model {model_id} "
1503
+ f"({total_loaded}/{len(random_models)}). "
1504
+ f"Total VRAM used: {total_size:.2f} GB"
1505
+ )
1506
+
1507
+ # warm up the model
1508
+ logger.info(f"Warming up model {model_id}")
1509
+ await self.warmup_pipeline(model_id)
1510
+ else:
1511
+ logger.warning(f"Failed to load model {model_id}")
1512
+
1513
+ except Exception as e:
1514
+ logger.error(f"Error loading model {model_id}: {e}")
1515
+
1516
+ logger.info(
1517
+ f"Startup loading complete. Loaded and warmed up {total_loaded}/{len(random_models)} "
1518
+ f"models using {total_size:.2f}GB/{self.max_vram - VRAM_SAFETY_MARGIN_GB:.2f}GB available VRAM"
1519
+ )
1520
+ except Exception as e:
1521
+ logger.error(f"Error initializing startup models: {e}")
1522
+ finally:
1523
+ self.is_startup_load = False
1524
+
1525
+ async def warmup_pipeline(self, model_id: str) -> None:
1526
+ """
1527
+ Warm up a pipeline by running a test inference.
1528
+
1529
+ Args:
1530
+ model_id: Model identifier
1531
+ """
1532
+ if model_id not in self.loaded_models:
1533
+ if not self.is_startup_load:
1534
+ logger.info(f"Loading model {model_id} for warm-up")
1535
+ pipeline = await self.load(model_id)
1536
+ if pipeline is None:
1537
+ logger.warning(f"Failed to load model {model_id} for warm-up")
1538
+ return
1539
+ else:
1540
+ logger.warning(f"Model {model_id} is not loaded")
1541
+ return
1542
+ else:
1543
+ pipeline = self.loaded_models[model_id]
1544
+
1545
+ logger.info(f"Warming up pipeline for model {model_id}")
1546
+
1547
+ try:
1548
+ with torch.no_grad():
1549
+ if isinstance(pipeline, DiffusionPipeline) and callable(pipeline):
1550
+ _ = pipeline(
1551
+ prompt="This is a warm-up prompt",
1552
+ num_inference_steps=4,
1553
+ output_type="pil",
1554
+ )
1555
+ else:
1556
+ logger.warning(
1557
+ f"Unsupported pipeline type for warm-up: {type(pipeline)}"
1558
+ )
1559
+ except Exception as e:
1560
+ logger.error(f"Error during warm-up for model {model_id}: {str(e)}")
1561
+
1562
+ self.flush_memory()
1563
+ logger.info(f"Warm-up completed for model {model_id}")
1564
+
1565
+ def get_all_model_ids(self) -> List[str]:
1566
+ """
1567
+ Get list of all available model IDs from config.
1568
+
1569
+ Returns:
1570
+ List of model identifiers
1571
+ """
1572
+ config = get_config()
1573
+ return list(config.pipeline_defs.keys())
1574
+
1575
+ def get_enabled_models(self) -> List[str]:
1576
+ """
1577
+ Get list of models that should be warmed up.
1578
+
1579
+ Returns:
1580
+ List of model identifiers to be used for generation
1581
+ """
1582
+ config = get_config()
1583
+ return config.enabled_models
1584
+
1585
+ def unload(self, model_id: str) -> None:
1586
+ """
1587
+ Unload a model from memory and clean up associated resources.
1588
+
1589
+ Handles unloading from both GPU and CPU memory, updates memory tracking,
1590
+ and cleans up LRU cache entries.
1591
+
1592
+ Args:
1593
+ model_id: Model identifier to unload
1594
+ """
1595
+ try:
1596
+ # Unload from GPU if present
1597
+ if model_id in self.loaded_models:
1598
+ model_size = self.model_sizes.get(model_id, 0)
1599
+ pipeline = self.loaded_models.pop(model_id)
1600
+
1601
+ del pipeline
1602
+ self.vram_usage -= model_size
1603
+ self.lru_cache.remove(model_id, "gpu")
1604
+ logger.info(f"Model {model_id} unloaded from GPU")
1605
+
1606
+ # Unload from CPU if present
1607
+ if model_id in self.cpu_models:
1608
+ model_size = self.model_sizes.get(model_id, 0)
1609
+ pipeline = self.cpu_models.pop(model_id)
1610
+
1611
+ del pipeline
1612
+ self.ram_usage -= model_size
1613
+ self.lru_cache.remove(model_id, "cpu")
1614
+ logger.info(f"Model {model_id} unloaded from CPU")
1615
+
1616
+ # Clean up current model reference if it matches
1617
+ if model_id == self.current_model:
1618
+ self.loaded_model = None
1619
+ self.current_model = None
1620
+
1621
+ # Remove from model sizes tracking
1622
+ if model_id in self.model_sizes:
1623
+ del self.model_sizes[model_id]
1624
+
1625
+ # Remove from model types tracking
1626
+ if model_id in self.model_types:
1627
+ del self.model_types[model_id]
1628
+
1629
+ # Force memory cleanup
1630
+ self.flush_memory()
1631
+
1632
+ except Exception as e:
1633
+ logger.error(f"Error unloading model {model_id}: {str(e)}")
1634
+
1635
+ def is_loaded(self, model_id: str) -> bool:
1636
+ """
1637
+ Check if a model is currently loaded.
1638
+
1639
+ Args:
1640
+ model_id: Model identifier
1641
+
1642
+ Returns:
1643
+ Boolean indicating if model is loaded
1644
+ """
1645
+ return model_id == self.current_model and self.loaded_model is not None
1646
+
1647
+ def get_model(self, model_id: str) -> Optional[DiffusionPipeline]:
1648
+ """
1649
+ Get a loaded model pipeline.
1650
+
1651
+ Args:
1652
+ model_id: Model identifier
1653
+
1654
+ Returns:
1655
+ Pipeline if model is loaded, None otherwise
1656
+ """
1657
+ if not self.is_loaded(model_id):
1658
+ return None
1659
+ return self.loaded_model
1660
+
1661
+ def get_model_device(self, model_id: str) -> Optional[torch.device]:
1662
+ """
1663
+ Get the device where a model is loaded.
1664
+
1665
+ Args:
1666
+ model_id: Model identifier
1667
+
1668
+ Returns:
1669
+ Device if model is loaded, None otherwise
1670
+ """
1671
+ model = self.get_model(model_id)
1672
+ return model.device if model else None
1673
+
1674
+
1675
+ async def _get_model_size(
1676
+ self, model_config: PipelineConfig, model_id: str
1677
+ ) -> float:
1678
+ """
1679
+ Calculate the total size of a model including all its components.
1680
+
1681
+ Args:
1682
+ model_config: Model configuration dictionary from pipeline_defs
1683
+ model_id: The model identifier (key from pipeline_defs)
1684
+
1685
+ Returns:
1686
+ Size in GB
1687
+ """
1688
+
1689
+ if isinstance(model_config, dict):
1690
+ source = model_config.get("source")
1691
+ else:
1692
+ source = model_config.source
1693
+
1694
+ if source.startswith(("http://", "https://")):
1695
+ # For downloaded HTTP(S) models, get size from cache
1696
+ try:
1697
+ source_obj = ModelSource(source) # Create ModelSource object
1698
+ cache_path = await self.model_downloader._get_cache_path(
1699
+ model_id, source_obj
1700
+ )
1701
+ print(f"Cache Path: {cache_path}")
1702
+ if os.path.exists(cache_path):
1703
+ return os.path.getsize(cache_path) / (1024**3)
1704
+ else:
1705
+ logger.warning(
1706
+ f"Cache file not found for {model_id}, assuming default size"
1707
+ )
1708
+ return 7.0 # Default size assumption (never going to be used)
1709
+ except Exception as e:
1710
+ logger.error(f"Error getting cached model size: {e}")
1711
+ return (
1712
+ 7.0 # Default fallback size (never going to be used for anything)
1713
+ )
1714
+ elif source.startswith("file:"):
1715
+ path = source.replace("file:", "")
1716
+ return os.path.getsize(path) / (1024**3) if os.path.exists(path) else 0
1717
+ elif source.startswith("hf:"):
1718
+ # Handle HuggingFace models as before
1719
+ repo_id = source.replace("hf:", "")
1720
+ return self._calculate_repo_size(repo_id, model_config)
1721
+ else:
1722
+ logger.error(f"Unsupported source type for size calculation: {source}")
1723
+ return 0
1724
+
1725
+ def _calculate_repo_size(self, repo_id: str, model_config: PipelineConfig) -> float:
1726
+ """
1727
+ Calculate the total size of a HuggingFace repository.
1728
+
1729
+ Args:
1730
+ repo_id: Repository identifier
1731
+ model_config: Model configuration dictionary
1732
+
1733
+ Returns:
1734
+ Total size in GB
1735
+ """
1736
+ total_size = self._get_size_for_repo(repo_id)
1737
+
1738
+ if isinstance(model_config, dict):
1739
+ components = model_config.get("components")
1740
+ else:
1741
+ components = model_config.components
1742
+
1743
+ if components:
1744
+ for key, component in components.items():
1745
+ if isinstance(component, dict) and "source" in component:
1746
+ component_size = self._calculate_component_size(
1747
+ component, repo_id, key
1748
+ )
1749
+ total_size += component_size
1750
+
1751
+ total_size_gb = total_size / (1024**3)
1752
+ logger.debug(f"Total size: {total_size_gb:.2f} GB")
1753
+ print(f"Total size: {total_size_gb:.2f} GB")
1754
+ return total_size_gb
1755
+
1756
+ def _calculate_component_size(
1757
+ self, component: Dict[str, Any], repo_id: str, key: str
1758
+ ) -> float:
1759
+ """
1760
+ Calculate the size of a model component.
1761
+
1762
+ Args:
1763
+ component: Component configuration
1764
+ repo_id: Repository identifier
1765
+ key: Component key
1766
+
1767
+ Returns:
1768
+ Component size in bytes
1769
+ """
1770
+ component_source = component["source"]
1771
+ if len(component_source.split("/")) > 2:
1772
+ component_repo = "/".join(component_source.split("/")[0:2]).replace(
1773
+ "hf:", ""
1774
+ )
1775
+ else:
1776
+ component_repo = component_source.replace("hf:", "")
1777
+
1778
+ component_name = (
1779
+ key
1780
+ if not component_source.endswith((".safetensors", ".bin", ".ckpt", ".pt"))
1781
+ else component_source.split("/")[-1]
1782
+ )
1783
+
1784
+ total_size = self._get_size_for_repo(
1785
+ component_repo, component_name
1786
+ ) - self._get_size_for_repo(repo_id, key)
1787
+
1788
+ # total_size = self._get_size_for_repo(component_repo, component_name)
1789
+
1790
+ return total_size
1791
+
1792
+ def _get_size_for_repo(
1793
+ self, repo_id: str, component_name: Optional[str] = None
1794
+ ) -> int:
1795
+ """
1796
+ Get the size of a specific repository or component.
1797
+
1798
+ Args:
1799
+ repo_id: Repository identifier
1800
+ component_name: Optional component name
1801
+
1802
+ Returns:
1803
+ Size in bytes
1804
+ """
1805
+ if component_name == "scheduler":
1806
+ return 0
1807
+
1808
+ print(f"Getting size for {repo_id} {component_name}")
1809
+ storage_folder = os.path.join(
1810
+ self.cache_dir, repo_folder_name(repo_id=repo_id, repo_type="model")
1811
+ )
1812
+
1813
+ if not os.path.exists(storage_folder):
1814
+ logger.warning(f"Storage folder for {repo_id} not found")
1815
+ return 0
1816
+
1817
+ commit_hash = self._get_commit_hash(storage_folder)
1818
+ if not commit_hash:
1819
+ return 0
1820
+
1821
+ snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
1822
+ if component_name:
1823
+ snapshot_folder = os.path.join(snapshot_folder, component_name)
1824
+
1825
+ return self._calculate_folder_size(snapshot_folder)
1826
+
1827
+ def _get_commit_hash(self, storage_folder: str) -> Optional[str]:
1828
+ """
1829
+ Get the commit hash for a repository.
1830
+
1831
+ Args:
1832
+ storage_folder: Path to the repository storage folder
1833
+
1834
+ Returns:
1835
+ Commit hash string or None if not found
1836
+ """
1837
+ refs_path = os.path.join(storage_folder, "refs", "main")
1838
+ if not os.path.exists(refs_path):
1839
+ logger.warning("No commit hash found for {}".format(storage_folder))
1840
+ return None
1841
+ try:
1842
+ with open(refs_path, "r") as f:
1843
+ return f.read().strip()
1844
+ except Exception as e:
1845
+ logger.error("Error reading commit hash: {}".format(str(e)))
1846
+ return None
1847
+
1848
+ def _calculate_folder_size(self, folder: str) -> int:
1849
+ """
1850
+ Calculate the total size of model files in a folder.
1851
+
1852
+ Args:
1853
+ folder: Path to the folder
1854
+
1855
+ Returns:
1856
+ Total size in bytes
1857
+ """
1858
+ if not os.path.isdir(folder):
1859
+ if os.path.exists(folder):
1860
+ return os.path.getsize(folder)
1861
+ else:
1862
+ return 0
1863
+
1864
+ variants = ["bf16", "fp8", "fp16", ""]
1865
+ selected_variant = next(
1866
+ (v for v in variants if self._check_variant_files(folder, v)), None
1867
+ )
1868
+
1869
+ if selected_variant is None:
1870
+ return 0
1871
+
1872
+ total_size = 0
1873
+ for root, _, files in os.walk(folder):
1874
+ # skip if in root
1875
+ if root == folder:
1876
+ continue
1877
+ for file in files:
1878
+ if self._is_valid_file(file, selected_variant):
1879
+ total_size += os.path.getsize(os.path.join(root, file))
1880
+ # break
1881
+
1882
+ return total_size
1883
+
1884
+ def _check_variant_files(self, folder: str, variant: str) -> bool:
1885
+ """
1886
+ Check if a folder contains files of a specific variant.
1887
+
1888
+ Args:
1889
+ folder: Path to the folder
1890
+ variant: Variant to check for
1891
+
1892
+ Returns:
1893
+ Boolean indicating if variant files exist
1894
+ """
1895
+ for root, _, files in os.walk(folder):
1896
+ if any(self._is_valid_file(f, variant) for f in files):
1897
+ return True
1898
+ return False
1899
+
1900
+ def _is_valid_file(self, file: str, variant: str) -> bool:
1901
+ """
1902
+ Check if a file is a valid model file of a specific variant.
1903
+
1904
+ Args:
1905
+ file: Filename to check
1906
+ variant: Variant to check for
1907
+
1908
+ Returns:
1909
+ Boolean indicating if file is valid
1910
+ """
1911
+ if variant:
1912
+ # check if the variant is in the file name
1913
+ if variant in file:
1914
+ return True
1915
+ else:
1916
+ return False
1917
+ return file.endswith((".safetensors", ".bin", ".ckpt"))
1918
+
1919
+
1920
+ async def _load_diffusers_component(
1921
+ self,
1922
+ main_model_repo: str,
1923
+ component_repo: str,
1924
+ component_name: Optional[str] = None,
1925
+ variant: Optional[str] = None,
1926
+ ) -> Any:
1927
+ """
1928
+ Load a diffusers component.
1929
+
1930
+ Args:
1931
+ component_repo: Repository identifier
1932
+ component_name: Name of the component
1933
+
1934
+ Returns:
1935
+ Loaded component
1936
+ """
1937
+ try:
1938
+ model_index = (
1939
+ await self.model_downloader.get_diffusers_multifolder_components(
1940
+ main_model_repo
1941
+ )
1942
+ )
1943
+ if model_index is None:
1944
+ raise ValueError(f"model_index does not exist for {main_model_repo}")
1945
+
1946
+ component_info = model_index.get(component_name)
1947
+ if not component_info:
1948
+ raise ValueError(f"Invalid component info for {component_name}")
1949
+
1950
+ module_path, class_name = component_info
1951
+ module = importlib.import_module(module_path)
1952
+ model_class = getattr(module, class_name)
1953
+
1954
+ if component_name:
1955
+ if variant:
1956
+ component = await asyncio.to_thread(model_class.from_pretrained,
1957
+ component_repo,
1958
+ subfolder=component_name,
1959
+ variant=variant,
1960
+ torch_dtype=torch.bfloat16
1961
+ if "flux" in component_repo.lower()
1962
+ else torch.float16,
1963
+ )
1964
+ else:
1965
+ component = await asyncio.to_thread(model_class.from_pretrained,
1966
+ component_repo,
1967
+ subfolder=component_name,
1968
+ torch_dtype=torch.bfloat16
1969
+ if "flux" in component_repo.lower()
1970
+ else torch.float16,
1971
+ )
1972
+ else:
1973
+ if variant:
1974
+ component = await asyncio.to_thread(model_class.from_pretrained,
1975
+ component_repo,
1976
+ variant=variant,
1977
+ torch_dtype=torch.bfloat16
1978
+ if "flux" in component_repo.lower()
1979
+ else torch.float16,
1980
+ )
1981
+ else:
1982
+ component = await asyncio.to_thread(model_class.from_pretrained,
1983
+ component_repo,
1984
+ torch_dtype=torch.bfloat16
1985
+ if "flux" in component_repo.lower()
1986
+ else torch.float16,
1987
+ )
1988
+
1989
+
1990
+ return component
1991
+
1992
+ except Exception as e:
1993
+ logger.error(
1994
+ "Error loading component {} from {}: {}".format(
1995
+ component_name, component_repo, str(e)
1996
+ )
1997
+ )
1998
+ raise
1999
+
2000
+ def _load_custom_component(
2001
+ self, repo_id: str, category: str, component_name: str
2002
+ ) -> Any:
2003
+ """
2004
+ Load a custom component.
2005
+
2006
+ Args:
2007
+ repo_id: Repository identifier
2008
+ category: Component category
2009
+ component_name: Name of the component
2010
+
2011
+ Returns:
2012
+ Loaded component
2013
+ """
2014
+ try:
2015
+ file_path = self._get_component_file_path(repo_id)
2016
+ state_dict = load_state_dict_from_file(file_path)
2017
+
2018
+ architectures = get_architectures()
2019
+ arch_key = f"core_extension_1.{category.lower()}_{component_name.lower()}"
2020
+ architecture_class = architectures.get(arch_key)
2021
+
2022
+ if not architecture_class:
2023
+ raise ValueError(f"Architecture not found for {arch_key}")
2024
+
2025
+ architecture = architecture_class()
2026
+ architecture.load(state_dict)
2027
+ model = architecture.model
2028
+
2029
+ return model
2030
+
2031
+ except Exception as e:
2032
+ logger.error("Error loading custom component: {}".format(str(e)))
2033
+ raise
2034
+
2035
+ def _get_component_file_path(self, repo_id: str) -> str:
2036
+ """
2037
+ Get the file path for a component.
2038
+
2039
+ Args:
2040
+ repo_id: Repository identifier
2041
+
2042
+ Returns:
2043
+ Path to component file
2044
+ """
2045
+ repo_folder = os.path.dirname(repo_id.replace("hf:", ""))
2046
+ weights_name = repo_id.split("/")[-1]
2047
+
2048
+ model_folder = os.path.join(
2049
+ self.cache_dir, repo_folder_name(repo_id=repo_folder, repo_type="model")
2050
+ )
2051
+
2052
+ if not os.path.exists(model_folder):
2053
+ model_folder = os.path.join(self.cache_dir, repo_folder)
2054
+ if not os.path.exists(model_folder):
2055
+ raise FileNotFoundError(f"Cache folder for {repo_id} not found")
2056
+ return os.path.join(model_folder, weights_name)
2057
+
2058
+ commit_hash = self._get_commit_hash(model_folder) or ""
2059
+ return os.path.join(model_folder, "snapshots", commit_hash, weights_name)