diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -44,21 +44,22 @@ from ..configuration_utils import ConfigMixin
44
44
  from ..models import AutoencoderKL
45
45
  from ..models.attention_processor import FusedAttnProcessor2_0
46
46
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
47
+ from ..quantizers.bitsandbytes.utils import _check_bnb_status
47
48
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48
49
  from ..utils import (
49
50
  CONFIG_NAME,
50
51
  DEPRECATED_REVISION_ARGS,
51
52
  BaseOutput,
52
53
  PushToHubMixin,
53
- deprecate,
54
54
  is_accelerate_available,
55
55
  is_accelerate_version,
56
56
  is_torch_npu_available,
57
57
  is_torch_version,
58
+ is_transformers_version,
58
59
  logging,
59
60
  numpy_to_pil,
60
61
  )
61
- from ..utils.hub_utils import load_or_create_model_card, populate_model_card
62
+ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
62
63
  from ..utils.torch_utils import is_compiled_module
63
64
 
64
65
 
@@ -72,11 +73,16 @@ from .pipeline_loading_utils import (
72
73
  CUSTOM_PIPELINE_FILE_NAME,
73
74
  LOADABLE_CLASSES,
74
75
  _fetch_class_library_tuple,
76
+ _get_custom_components_and_folders,
75
77
  _get_custom_pipeline_class,
76
78
  _get_final_device_map,
79
+ _get_ignore_patterns,
77
80
  _get_pipeline_class,
81
+ _identify_model_variants,
82
+ _maybe_raise_warning_for_inpainting,
83
+ _resolve_custom_pipeline_and_cls,
78
84
  _unwrap_model,
79
- is_safetensors_compatible,
85
+ _update_init_kwargs_with_connected_pipeline,
80
86
  load_sub_model,
81
87
  maybe_raise_or_warn,
82
88
  variant_compatible_siblings,
@@ -185,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
185
191
  save_directory: Union[str, os.PathLike],
186
192
  safe_serialization: bool = True,
187
193
  variant: Optional[str] = None,
194
+ max_shard_size: Optional[Union[int, str]] = None,
188
195
  push_to_hub: bool = False,
189
196
  **kwargs,
190
197
  ):
@@ -200,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
200
207
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
201
208
  variant (`str`, *optional*):
202
209
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
210
+ max_shard_size (`int` or `str`, defaults to `None`):
211
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
212
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
213
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
214
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
215
+ This is to establish a common default size for this argument across different libraries in the Hugging
216
+ Face ecosystem (`transformers`, and `accelerate`, for example).
203
217
  push_to_hub (`bool`, *optional*, defaults to `False`):
204
218
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
205
219
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -274,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
274
288
  save_method_signature = inspect.signature(save_method)
275
289
  save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
276
290
  save_method_accept_variant = "variant" in save_method_signature.parameters
291
+ save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
277
292
 
278
293
  save_kwargs = {}
279
294
  if save_method_accept_safe:
280
295
  save_kwargs["safe_serialization"] = safe_serialization
281
296
  if save_method_accept_variant:
282
297
  save_kwargs["variant"] = variant
298
+ if save_method_accept_max_shard_size and max_shard_size is not None:
299
+ # max_shard_size is expected to not be None in ModelMixin
300
+ save_kwargs["max_shard_size"] = max_shard_size
283
301
 
284
302
  save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
285
303
 
@@ -416,18 +434,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
416
434
 
417
435
  is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
418
436
  for module in modules:
419
- is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
437
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
420
438
 
421
- if is_loaded_in_8bit and dtype is not None:
439
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
422
440
  logger.warning(
423
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
441
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
424
442
  )
425
443
 
426
- if is_loaded_in_8bit and device is not None:
444
+ if is_loaded_in_8bit_bnb and device is not None:
427
445
  logger.warning(
428
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
446
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
429
447
  )
430
- else:
448
+
449
+ # This can happen for `transformer` models. CPU placement was added in
450
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
451
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
452
+ module.to(device=device)
453
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
431
454
  module.to(device, dtype)
432
455
 
433
456
  if (
@@ -622,6 +645,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
622
645
  >>> pipeline.scheduler = scheduler
623
646
  ```
624
647
  """
648
+ # Copy the kwargs to re-use during loading connected pipeline.
649
+ kwargs_copied = kwargs.copy()
650
+
625
651
  cache_dir = kwargs.pop("cache_dir", None)
626
652
  force_download = kwargs.pop("force_download", False)
627
653
  proxies = kwargs.pop("proxies", None)
@@ -716,39 +742,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
716
742
  else:
717
743
  cached_folder = pretrained_model_name_or_path
718
744
 
745
+ # The variant filenames can have the legacy sharding checkpoint format that we check and throw
746
+ # a warning if detected.
747
+ if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
748
+ warn_msg = (
749
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
750
+ "Please check your files carefully:\n\n"
751
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
752
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
753
+ "If you find any files in the deprecated format:\n"
754
+ "1. Remove all existing checkpoint files for this variant.\n"
755
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
756
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
757
+ )
758
+ logger.warning(warn_msg)
759
+
719
760
  config_dict = cls.load_config(cached_folder)
720
761
 
721
762
  # pop out "_ignore_files" as it is only needed for download
722
763
  config_dict.pop("_ignore_files", None)
723
764
 
724
765
  # 2. Define which model components should load variants
725
- # We retrieve the information by matching whether variant
726
- # model checkpoints exist in the subfolders
727
- model_variants = {}
728
- if variant is not None:
729
- for folder in os.listdir(cached_folder):
730
- folder_path = os.path.join(cached_folder, folder)
731
- is_folder = os.path.isdir(folder_path) and folder in config_dict
732
- variant_exists = is_folder and any(
733
- p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
734
- )
735
- if variant_exists:
736
- model_variants[folder] = variant
766
+ # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
767
+ # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
768
+ # with variant being `"fp16"`.
769
+ model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
770
+ if len(model_variants) == 0 and variant is not None:
771
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
772
+ raise ValueError(error_message)
737
773
 
738
774
  # 3. Load the pipeline class, if using custom module then load it from the hub
739
775
  # if we load from explicit class, let's use it
740
- custom_class_name = None
741
- if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
742
- custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
743
- elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
744
- os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
745
- ):
746
- custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
747
- custom_class_name = config_dict["_class_name"][1]
748
-
776
+ custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
777
+ folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
778
+ )
749
779
  pipeline_class = _get_pipeline_class(
750
780
  cls,
751
- config_dict,
781
+ config=config_dict,
752
782
  load_connected_pipeline=load_connected_pipeline,
753
783
  custom_pipeline=custom_pipeline,
754
784
  class_name=custom_class_name,
@@ -760,23 +790,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
760
790
  raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
761
791
 
762
792
  # DEPRECATED: To be removed in 1.0.0
763
- if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
764
- version.parse(config_dict["_diffusers_version"]).base_version
765
- ) <= version.parse("0.5.1"):
766
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
767
-
768
- pipeline_class = StableDiffusionInpaintPipelineLegacy
769
-
770
- deprecation_message = (
771
- "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
772
- f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
773
- " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
774
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
775
- f" checkpoint {pretrained_model_name_or_path} to the format of"
776
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
777
- " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
778
- )
779
- deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
793
+ # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
794
+ # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
795
+ _maybe_raise_warning_for_inpainting(
796
+ pipeline_class=pipeline_class,
797
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
798
+ config=config_dict,
799
+ )
780
800
 
781
801
  # 4. Define expected modules given pipeline signature
782
802
  # and define non-None initialized modules (=`init_kwargs`)
@@ -787,7 +807,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
787
807
  expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
788
808
  passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
789
809
  passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
790
-
791
810
  init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
792
811
 
793
812
  # define init kwargs and make sure that optional component modules are filtered out
@@ -847,6 +866,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
847
866
  # 7. Load each module in the pipeline
848
867
  current_device_map = None
849
868
  for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
869
+ # 7.1 device_map shenanigans
850
870
  if final_device_map is not None and len(final_device_map) > 0:
851
871
  component_device = final_device_map.get(name, None)
852
872
  if component_device is not None:
@@ -854,15 +874,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
854
874
  else:
855
875
  current_device_map = None
856
876
 
857
- # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
877
+ # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
858
878
  class_name = class_name[4:] if class_name.startswith("Flax") else class_name
859
879
 
860
- # 7.2 Define all importable classes
880
+ # 7.3 Define all importable classes
861
881
  is_pipeline_module = hasattr(pipelines, library_name)
862
882
  importable_classes = ALL_IMPORTABLE_CLASSES
863
883
  loaded_sub_model = None
864
884
 
865
- # 7.3 Use passed sub model or load class_name from library_name
885
+ # 7.4 Use passed sub model or load class_name from library_name
866
886
  if name in passed_class_obj:
867
887
  # if the model is in a pipeline module, then we load it from the pipeline
868
888
  # check that passed_class_obj has correct parent class
@@ -893,6 +913,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
893
913
  variant=variant,
894
914
  low_cpu_mem_usage=low_cpu_mem_usage,
895
915
  cached_folder=cached_folder,
916
+ use_safetensors=use_safetensors,
896
917
  )
897
918
  logger.info(
898
919
  f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -900,56 +921,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
900
921
 
901
922
  init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
902
923
 
924
+ # 8. Handle connected pipelines.
903
925
  if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
904
- modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
905
- connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
906
- load_kwargs = {
907
- "cache_dir": cache_dir,
908
- "force_download": force_download,
909
- "proxies": proxies,
910
- "local_files_only": local_files_only,
911
- "token": token,
912
- "revision": revision,
913
- "torch_dtype": torch_dtype,
914
- "custom_pipeline": custom_pipeline,
915
- "custom_revision": custom_revision,
916
- "provider": provider,
917
- "sess_options": sess_options,
918
- "device_map": device_map,
919
- "max_memory": max_memory,
920
- "offload_folder": offload_folder,
921
- "offload_state_dict": offload_state_dict,
922
- "low_cpu_mem_usage": low_cpu_mem_usage,
923
- "variant": variant,
924
- "use_safetensors": use_safetensors,
925
- }
926
-
927
- def get_connected_passed_kwargs(prefix):
928
- connected_passed_class_obj = {
929
- k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
930
- }
931
- connected_passed_pipe_kwargs = {
932
- k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
933
- }
934
-
935
- connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
936
- return connected_passed_kwargs
937
-
938
- connected_pipes = {
939
- prefix: DiffusionPipeline.from_pretrained(
940
- repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
941
- )
942
- for prefix, repo_id in connected_pipes.items()
943
- if repo_id is not None
944
- }
945
-
946
- for prefix, connected_pipe in connected_pipes.items():
947
- # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
948
- init_kwargs.update(
949
- {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
950
- )
926
+ init_kwargs = _update_init_kwargs_with_connected_pipeline(
927
+ init_kwargs=init_kwargs,
928
+ passed_pipe_kwargs=passed_pipe_kwargs,
929
+ passed_class_objs=passed_class_obj,
930
+ folder=cached_folder,
931
+ **kwargs_copied,
932
+ )
951
933
 
952
- # 8. Potentially add passed objects if expected
934
+ # 9. Potentially add passed objects if expected
953
935
  missing_modules = set(expected_modules) - set(init_kwargs.keys())
954
936
  passed_modules = list(passed_class_obj.keys())
955
937
  optional_modules = pipeline_class._optional_components
@@ -1065,9 +1047,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1065
1047
  hook = None
1066
1048
  for model_str in self.model_cpu_offload_seq.split("->"):
1067
1049
  model = all_model_components.pop(model_str, None)
1050
+
1068
1051
  if not isinstance(model, torch.nn.Module):
1069
1052
  continue
1070
1053
 
1054
+ # This is because the model would already be placed on a CUDA device.
1055
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1056
+ if is_loaded_in_8bit_bnb:
1057
+ logger.info(
1058
+ f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
1059
+ )
1060
+ continue
1061
+
1071
1062
  _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
1072
1063
  self._all_hooks.append(hook)
1073
1064
 
@@ -1295,6 +1286,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1295
1286
  model_info_call_error = e # save error to reraise it if model is not cached locally
1296
1287
 
1297
1288
  if not local_files_only:
1289
+ filenames = {sibling.rfilename for sibling in info.siblings}
1290
+ if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
1291
+ warn_msg = (
1292
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
1293
+ "Please check your files carefully:\n\n"
1294
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
1295
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
1296
+ "If you find any files in the deprecated format:\n"
1297
+ "1. Remove all existing checkpoint files for this variant.\n"
1298
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
1299
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
1300
+ )
1301
+ logger.warning(warn_msg)
1302
+
1303
+ model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1304
+
1298
1305
  config_file = hf_hub_download(
1299
1306
  pretrained_model_name,
1300
1307
  cls.config_name,
@@ -1308,52 +1315,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1308
1315
  config_dict = cls._dict_from_json_file(config_file)
1309
1316
  ignore_filenames = config_dict.pop("_ignore_files", [])
1310
1317
 
1311
- # retrieve all folder_names that contain relevant files
1312
- folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
1313
-
1314
- filenames = {sibling.rfilename for sibling in info.siblings}
1315
- model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1316
-
1317
- diffusers_module = importlib.import_module(__name__.split(".")[0])
1318
- pipelines = getattr(diffusers_module, "pipelines")
1319
-
1320
- # optionally create a custom component <> custom file mapping
1321
- custom_components = {}
1322
- for component in folder_names:
1323
- module_candidate = config_dict[component][0]
1324
-
1325
- if module_candidate is None or not isinstance(module_candidate, str):
1326
- continue
1327
-
1328
- # We compute candidate file path on the Hub. Do not use `os.path.join`.
1329
- candidate_file = f"{component}/{module_candidate}.py"
1330
-
1331
- if candidate_file in filenames:
1332
- custom_components[component] = module_candidate
1333
- elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
1334
- raise ValueError(
1335
- f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
1336
- )
1337
-
1338
- if len(variant_filenames) == 0 and variant is not None:
1339
- deprecation_message = (
1340
- f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1341
- f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1342
- "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1343
- "modeling files is deprecated."
1344
- )
1345
- deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
1346
-
1347
1318
  # remove ignored filenames
1348
1319
  model_filenames = set(model_filenames) - set(ignore_filenames)
1349
1320
  variant_filenames = set(variant_filenames) - set(ignore_filenames)
1350
1321
 
1351
- # if the whole pipeline is cached we don't have to ping the Hub
1352
1322
  if revision in DEPRECATED_REVISION_ARGS and version.parse(
1353
1323
  version.parse(__version__).base_version
1354
1324
  ) >= version.parse("0.22.0"):
1355
1325
  warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1356
1326
 
1327
+ custom_components, folder_names = _get_custom_components_and_folders(
1328
+ pretrained_model_name, config_dict, filenames, variant_filenames, variant
1329
+ )
1357
1330
  model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1358
1331
 
1359
1332
  custom_class_name = None
@@ -1413,49 +1386,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1413
1386
  expected_components, _ = cls._get_signature_keys(pipeline_class)
1414
1387
  passed_components = [k for k in expected_components if k in kwargs]
1415
1388
 
1416
- if (
1417
- use_safetensors
1418
- and not allow_pickle
1419
- and not is_safetensors_compatible(
1420
- model_filenames, variant=variant, passed_components=passed_components
1421
- )
1422
- ):
1423
- raise EnvironmentError(
1424
- f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
1425
- )
1426
- if from_flax:
1427
- ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1428
- elif use_safetensors and is_safetensors_compatible(
1429
- model_filenames, variant=variant, passed_components=passed_components
1430
- ):
1431
- ignore_patterns = ["*.bin", "*.msgpack"]
1432
-
1433
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1434
- if not use_onnx:
1435
- ignore_patterns += ["*.onnx", "*.pb"]
1436
-
1437
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
1438
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1439
- if (
1440
- len(safetensors_variant_filenames) > 0
1441
- and safetensors_model_filenames != safetensors_variant_filenames
1442
- ):
1443
- logger.warning(
1444
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1445
- )
1446
- else:
1447
- ignore_patterns = ["*.safetensors", "*.msgpack"]
1448
-
1449
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1450
- if not use_onnx:
1451
- ignore_patterns += ["*.onnx", "*.pb"]
1452
-
1453
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1454
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1455
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1456
- logger.warning(
1457
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1458
- )
1389
+ # retrieve all patterns that should not be downloaded and error out when needed
1390
+ ignore_patterns = _get_ignore_patterns(
1391
+ passed_components,
1392
+ model_folder_names,
1393
+ model_filenames,
1394
+ variant_filenames,
1395
+ use_safetensors,
1396
+ from_flax,
1397
+ allow_pickle,
1398
+ use_onnx,
1399
+ pipeline_class._is_onnx,
1400
+ variant,
1401
+ )
1459
1402
 
1460
1403
  # Don't download any objects that are passed
1461
1404
  allow_patterns = [
@@ -178,7 +178,7 @@ def retrieve_timesteps(
178
178
  sigmas: Optional[List[float]] = None,
179
179
  **kwargs,
180
180
  ):
181
- """
181
+ r"""
182
182
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
183
183
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
184
184
 
@@ -122,7 +122,7 @@ def retrieve_timesteps(
122
122
  sigmas: Optional[List[float]] = None,
123
123
  **kwargs,
124
124
  ):
125
- """
125
+ r"""
126
126
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127
127
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128
128
 
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
281
281
  def num_timesteps(self):
282
282
  return self._num_timesteps
283
283
 
284
+ def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
285
+ s = torch.tensor([0.008])
286
+ clamp_range = [0, 1]
287
+ min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
288
+ var = alphas_cumprod[t]
289
+ var = var.clamp(*clamp_range)
290
+ s, min_var = s.to(var.device), min_var.to(var.device)
291
+ ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
292
+ return ratio
293
+
284
294
  @torch.no_grad()
285
295
  @replace_example_docstring(EXAMPLE_DOC_STRING)
286
296
  def __call__(
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
434
444
  batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
435
445
  )
436
446
 
447
+ if isinstance(self.scheduler, DDPMWuerstchenScheduler):
448
+ timesteps = timesteps[:-1]
449
+ else:
450
+ if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
451
+ self.scheduler.config.clip_sample = False # disample sample clipping
452
+ logger.warning(" set `clip_sample` to be False")
453
+
437
454
  # 6. Run denoising loop
438
- self._num_timesteps = len(timesteps[:-1])
439
- for i, t in enumerate(self.progress_bar(timesteps[:-1])):
440
- timestep_ratio = t.expand(latents.size(0)).to(dtype)
455
+ if hasattr(self.scheduler, "betas"):
456
+ alphas = 1.0 - self.scheduler.betas
457
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
458
+ else:
459
+ alphas_cumprod = []
460
+
461
+ self._num_timesteps = len(timesteps)
462
+ for i, t in enumerate(self.progress_bar(timesteps)):
463
+ if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
464
+ if len(alphas_cumprod) > 0:
465
+ timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
466
+ timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
467
+ else:
468
+ timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
469
+ else:
470
+ timestep_ratio = t.expand(latents.size(0)).to(dtype)
441
471
 
442
472
  # 7. Denoise latents
443
473
  predicted_latents = self.decoder(
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
454
484
  predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
455
485
 
456
486
  # 9. Renoise latents to next timestep
487
+ if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
488
+ timestep_ratio = t
457
489
  latents = self.scheduler.step(
458
490
  model_output=predicted_latents,
459
491
  timestep=timestep_ratio,
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
353
353
  return self._num_timesteps
354
354
 
355
355
  def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
356
- s = torch.tensor([0.003])
356
+ s = torch.tensor([0.008])
357
357
  clamp_range = [0, 1]
358
358
  min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
359
359
  var = alphas_cumprod[t]
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
557
557
  if isinstance(self.scheduler, DDPMWuerstchenScheduler):
558
558
  timesteps = timesteps[:-1]
559
559
  else:
560
- if self.scheduler.config.clip_sample:
560
+ if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
561
561
  self.scheduler.config.clip_sample = False # disample sample clipping
562
562
  logger.warning(" set `clip_sample` to be False")
563
563
  # 6. Run denoising loop