diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import copy
17
18
  import inspect
18
19
  import itertools
19
20
  import json
20
21
  import os
21
22
  import re
22
23
  from collections import OrderedDict
23
- from functools import partial
24
+ from functools import partial, wraps
24
25
  from pathlib import Path
25
26
  from typing import Any, Callable, List, Optional, Tuple, Union
26
27
 
@@ -31,6 +32,8 @@ from huggingface_hub.utils import validate_hf_hub_args
31
32
  from torch import Tensor, nn
32
33
 
33
34
  from .. import __version__
35
+ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
36
+ from ..quantizers.quantization_config import QuantizationMethod
34
37
  from ..utils import (
35
38
  CONFIG_NAME,
36
39
  FLAX_WEIGHTS_NAME,
@@ -43,6 +46,8 @@ from ..utils import (
43
46
  _get_model_file,
44
47
  deprecate,
45
48
  is_accelerate_available,
49
+ is_bitsandbytes_available,
50
+ is_bitsandbytes_version,
46
51
  is_torch_version,
47
52
  logging,
48
53
  )
@@ -54,7 +59,9 @@ from ..utils.hub_utils import (
54
59
  from .model_loading_utils import (
55
60
  _determine_device_map,
56
61
  _fetch_index_file,
62
+ _fetch_index_file_legacy,
57
63
  _load_state_dict_into_model,
64
+ _merge_sharded_checkpoints,
58
65
  load_model_dict_into_meta,
59
66
  load_state_dict,
60
67
  )
@@ -93,24 +100,20 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
93
100
 
94
101
  def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
95
102
  try:
96
- params = tuple(parameter.parameters())
97
- if len(params) > 0:
98
- return params[0].dtype
99
-
100
- buffers = tuple(parameter.buffers())
101
- if len(buffers) > 0:
102
- return buffers[0].dtype
103
-
103
+ return next(parameter.parameters()).dtype
104
104
  except StopIteration:
105
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
105
+ try:
106
+ return next(parameter.buffers()).dtype
107
+ except StopIteration:
108
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
106
109
 
107
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
108
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
109
- return tuples
110
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
111
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
112
+ return tuples
110
113
 
111
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
112
- first_tuple = next(gen)
113
- return first_tuple[1].dtype
114
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
115
+ first_tuple = next(gen)
116
+ return first_tuple[1].dtype
114
117
 
115
118
 
116
119
  class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -128,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128
131
  _supports_gradient_checkpointing = False
129
132
  _keys_to_ignore_on_load_unexpected = None
130
133
  _no_split_modules = None
134
+ _keep_in_fp32_modules = None
131
135
 
132
136
  def __init__(self):
133
137
  super().__init__()
@@ -311,13 +315,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
311
315
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
312
316
  return
313
317
 
318
+ hf_quantizer = getattr(self, "hf_quantizer", None)
319
+ if hf_quantizer is not None:
320
+ quantization_serializable = (
321
+ hf_quantizer is not None
322
+ and isinstance(hf_quantizer, DiffusersQuantizer)
323
+ and hf_quantizer.is_serializable
324
+ )
325
+ if not quantization_serializable:
326
+ raise ValueError(
327
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
328
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
329
+ )
330
+
314
331
  weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
315
332
  weights_name = _add_variant(weights_name, variant)
316
- weight_name_split = weights_name.split(".")
317
- if len(weight_name_split) in [2, 3]:
318
- weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
319
- else:
320
- raise ValueError(f"Invalid {weights_name} provided.")
333
+ weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
334
+ ".safetensors", "{suffix}.safetensors"
335
+ )
321
336
 
322
337
  os.makedirs(save_directory, exist_ok=True)
323
338
 
@@ -407,6 +422,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
407
422
  create_pr=create_pr,
408
423
  )
409
424
 
425
+ def dequantize(self):
426
+ """
427
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
428
+ dequantization.
429
+ """
430
+ hf_quantizer = getattr(self, "hf_quantizer", None)
431
+
432
+ if hf_quantizer is None:
433
+ raise ValueError("You need to first quantize your model in order to dequantize it")
434
+
435
+ return hf_quantizer.dequantize(self)
436
+
410
437
  @classmethod
411
438
  @validate_hf_hub_args
412
439
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -529,6 +556,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
529
556
  low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
530
557
  variant = kwargs.pop("variant", None)
531
558
  use_safetensors = kwargs.pop("use_safetensors", None)
559
+ quantization_config = kwargs.pop("quantization_config", None)
532
560
 
533
561
  allow_pickle = False
534
562
  if use_safetensors is None:
@@ -623,26 +651,85 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
623
651
  user_agent=user_agent,
624
652
  **kwargs,
625
653
  )
654
+ # no in-place modification of the original config.
655
+ config = copy.deepcopy(config)
656
+
657
+ # determine initial quantization config.
658
+ #######################################
659
+ pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
660
+ if pre_quantized or quantization_config is not None:
661
+ if pre_quantized:
662
+ config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
663
+ config["quantization_config"], quantization_config
664
+ )
665
+ else:
666
+ config["quantization_config"] = quantization_config
667
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
668
+ config["quantization_config"], pre_quantized=pre_quantized
669
+ )
670
+ else:
671
+ hf_quantizer = None
672
+
673
+ if hf_quantizer is not None:
674
+ if device_map is not None:
675
+ raise NotImplementedError(
676
+ "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
677
+ )
678
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
679
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
680
+
681
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
682
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
683
+
684
+ # Force-set to `True` for more mem efficiency
685
+ if low_cpu_mem_usage is None:
686
+ low_cpu_mem_usage = True
687
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
688
+ elif not low_cpu_mem_usage:
689
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
690
+
691
+ # Check if `_keep_in_fp32_modules` is not None
692
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
693
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
694
+ )
695
+ if use_keep_in_fp32_modules:
696
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
697
+ if not isinstance(keep_in_fp32_modules, list):
698
+ keep_in_fp32_modules = [keep_in_fp32_modules]
699
+
700
+ if low_cpu_mem_usage is None:
701
+ low_cpu_mem_usage = True
702
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
703
+ elif not low_cpu_mem_usage:
704
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
705
+ else:
706
+ keep_in_fp32_modules = []
707
+ #######################################
626
708
 
627
709
  # Determine if we're loading from a directory of sharded checkpoints.
628
710
  is_sharded = False
629
711
  index_file = None
630
712
  is_local = os.path.isdir(pretrained_model_name_or_path)
631
- index_file = _fetch_index_file(
632
- is_local=is_local,
633
- pretrained_model_name_or_path=pretrained_model_name_or_path,
634
- subfolder=subfolder or "",
635
- use_safetensors=use_safetensors,
636
- cache_dir=cache_dir,
637
- variant=variant,
638
- force_download=force_download,
639
- proxies=proxies,
640
- local_files_only=local_files_only,
641
- token=token,
642
- revision=revision,
643
- user_agent=user_agent,
644
- commit_hash=commit_hash,
645
- )
713
+ index_file_kwargs = {
714
+ "is_local": is_local,
715
+ "pretrained_model_name_or_path": pretrained_model_name_or_path,
716
+ "subfolder": subfolder or "",
717
+ "use_safetensors": use_safetensors,
718
+ "cache_dir": cache_dir,
719
+ "variant": variant,
720
+ "force_download": force_download,
721
+ "proxies": proxies,
722
+ "local_files_only": local_files_only,
723
+ "token": token,
724
+ "revision": revision,
725
+ "user_agent": user_agent,
726
+ "commit_hash": commit_hash,
727
+ }
728
+ index_file = _fetch_index_file(**index_file_kwargs)
729
+ # In case the index file was not found we still have to consider the legacy format.
730
+ # this becomes applicable when the variant is not None.
731
+ if variant is not None and (index_file is None or not os.path.exists(index_file)):
732
+ index_file = _fetch_index_file_legacy(**index_file_kwargs)
646
733
  if index_file is not None and index_file.is_file():
647
734
  is_sharded = True
648
735
 
@@ -684,6 +771,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
684
771
  revision=revision,
685
772
  subfolder=subfolder or "",
686
773
  )
774
+ if hf_quantizer is not None:
775
+ model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
776
+ logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
777
+ is_sharded = False
687
778
 
688
779
  elif use_safetensors and not is_sharded:
689
780
  try:
@@ -729,13 +820,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
729
820
  with accelerate.init_empty_weights():
730
821
  model = cls.from_config(config, **unused_kwargs)
731
822
 
823
+ if hf_quantizer is not None:
824
+ hf_quantizer.preprocess_model(
825
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
826
+ )
827
+
732
828
  # if device_map is None, load the state dict and move the params from meta device to the cpu
733
829
  if device_map is None and not is_sharded:
734
- param_device = "cpu"
830
+ # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
831
+ # It would error out during the `validate_environment()` call above in the absence of cuda.
832
+ is_quant_method_bnb = (
833
+ getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
834
+ )
835
+ if hf_quantizer is None:
836
+ param_device = "cpu"
837
+ # TODO (sayakpaul, SunMarc): remove this after model loading refactor
838
+ elif is_quant_method_bnb:
839
+ param_device = torch.cuda.current_device()
735
840
  state_dict = load_state_dict(model_file, variant=variant)
736
841
  model._convert_deprecated_attention_blocks(state_dict)
842
+
737
843
  # move the params from meta device to cpu
738
844
  missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
845
+ if hf_quantizer is not None:
846
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
739
847
  if len(missing_keys) > 0:
740
848
  raise ValueError(
741
849
  f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
@@ -750,6 +858,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
750
858
  device=param_device,
751
859
  dtype=torch_dtype,
752
860
  model_name_or_path=pretrained_model_name_or_path,
861
+ hf_quantizer=hf_quantizer,
862
+ keep_in_fp32_modules=keep_in_fp32_modules,
753
863
  )
754
864
 
755
865
  if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -765,7 +875,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
765
875
  # Load weights and dispatch according to the device_map
766
876
  # by default the device_map is None and the weights are loaded on the CPU
767
877
  force_hook = True
768
- device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
878
+ device_map = _determine_device_map(
879
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
880
+ )
769
881
  if device_map is None and is_sharded:
770
882
  # we load the parameters on the cpu
771
883
  device_map = {"": "cpu"}
@@ -843,14 +955,25 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
843
955
  "error_msgs": error_msgs,
844
956
  }
845
957
 
958
+ if hf_quantizer is not None:
959
+ hf_quantizer.postprocess_model(model)
960
+ model.hf_quantizer = hf_quantizer
961
+
846
962
  if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
847
963
  raise ValueError(
848
964
  f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
849
965
  )
850
- elif torch_dtype is not None:
966
+ # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
967
+ # completely lose the effectivity of `use_keep_in_fp32_modules`.
968
+ elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
851
969
  model = model.to(torch_dtype)
852
970
 
853
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
971
+ if hf_quantizer is not None:
972
+ # We also make sure to purge `_pre_quantization_dtype` when we serialize
973
+ # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
974
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
975
+ else:
976
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
854
977
 
855
978
  # Set model in evaluation mode to deactivate DropOut modules by default
856
979
  model.eval()
@@ -859,6 +982,76 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
859
982
 
860
983
  return model
861
984
 
985
+ # Adapted from `transformers`.
986
+ @wraps(torch.nn.Module.cuda)
987
+ def cuda(self, *args, **kwargs):
988
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
989
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
990
+ if getattr(self, "is_loaded_in_8bit", False):
991
+ raise ValueError(
992
+ "Calling `cuda()` is not supported for `8-bit` quantized models. "
993
+ " Please use the model as it is, since the model has already been set to the correct devices."
994
+ )
995
+ elif is_bitsandbytes_version("<", "0.43.2"):
996
+ raise ValueError(
997
+ "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
998
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
999
+ )
1000
+ return super().cuda(*args, **kwargs)
1001
+
1002
+ # Adapted from `transformers`.
1003
+ @wraps(torch.nn.Module.to)
1004
+ def to(self, *args, **kwargs):
1005
+ dtype_present_in_args = "dtype" in kwargs
1006
+
1007
+ if not dtype_present_in_args:
1008
+ for arg in args:
1009
+ if isinstance(arg, torch.dtype):
1010
+ dtype_present_in_args = True
1011
+ break
1012
+
1013
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
1014
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1015
+ if dtype_present_in_args:
1016
+ raise ValueError(
1017
+ "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
1018
+ " desired `dtype` by passing the correct `torch_dtype` argument."
1019
+ )
1020
+
1021
+ if getattr(self, "is_loaded_in_8bit", False):
1022
+ raise ValueError(
1023
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
1024
+ " model has already been set to the correct devices and casted to the correct `dtype`."
1025
+ )
1026
+ elif is_bitsandbytes_version("<", "0.43.2"):
1027
+ raise ValueError(
1028
+ "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1029
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1030
+ )
1031
+ return super().to(*args, **kwargs)
1032
+
1033
+ # Taken from `transformers`.
1034
+ def half(self, *args):
1035
+ # Checks if the model is quantized
1036
+ if getattr(self, "is_quantized", False):
1037
+ raise ValueError(
1038
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
1039
+ " model has already been cast to the correct `dtype`."
1040
+ )
1041
+ else:
1042
+ return super().half(*args)
1043
+
1044
+ # Taken from `transformers`.
1045
+ def float(self, *args):
1046
+ # Checks if the model is quantized
1047
+ if getattr(self, "is_quantized", False):
1048
+ raise ValueError(
1049
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
1050
+ " model has already been cast to the correct `dtype`."
1051
+ )
1052
+ else:
1053
+ return super().float(*args)
1054
+
862
1055
  @classmethod
863
1056
  def _load_pretrained_model(
864
1057
  cls,
@@ -1041,19 +1234,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1041
1234
  859520964
1042
1235
  ```
1043
1236
  """
1237
+ is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
1238
+
1239
+ if is_loaded_in_4bit:
1240
+ if is_bitsandbytes_available():
1241
+ import bitsandbytes as bnb
1242
+ else:
1243
+ raise ValueError(
1244
+ "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
1245
+ " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
1246
+ )
1044
1247
 
1045
1248
  if exclude_embeddings:
1046
1249
  embedding_param_names = [
1047
- f"{name}.weight"
1048
- for name, module_type in self.named_modules()
1049
- if isinstance(module_type, torch.nn.Embedding)
1250
+ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
1050
1251
  ]
1051
- non_embedding_parameters = [
1252
+ total_parameters = [
1052
1253
  parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1053
1254
  ]
1054
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1055
1255
  else:
1056
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1256
+ total_parameters = list(self.parameters())
1257
+
1258
+ total_numel = []
1259
+
1260
+ for param in total_parameters:
1261
+ if param.requires_grad or not only_trainable:
1262
+ # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
1263
+ # used for the 4bit quantization (uint8 tensors are stored)
1264
+ if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
1265
+ if hasattr(param, "element_size"):
1266
+ num_bytes = param.element_size()
1267
+ elif hasattr(param, "quant_storage"):
1268
+ num_bytes = param.quant_storage.itemsize
1269
+ else:
1270
+ num_bytes = 1
1271
+ total_numel.append(param.numel() * 2 * num_bytes)
1272
+ else:
1273
+ total_numel.append(param.numel())
1274
+
1275
+ return sum(total_numel)
1276
+
1277
+ def get_memory_footprint(self, return_buffers=True):
1278
+ r"""
1279
+ Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
1280
+ Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
1281
+ PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
1282
+
1283
+ Arguments:
1284
+ return_buffers (`bool`, *optional*, defaults to `True`):
1285
+ Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
1286
+ are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
1287
+ norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
1288
+ """
1289
+ mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
1290
+ if return_buffers:
1291
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
1292
+ mem = mem + mem_bufs
1293
+ return mem
1057
1294
 
1058
1295
  def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1059
1296
  deprecated_attention_block_paths = []
@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm):
97
97
  ).to(origin_dtype)
98
98
 
99
99
 
100
+ class SD35AdaLayerNormZeroX(nn.Module):
101
+ r"""
102
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
103
+
104
+ Parameters:
105
+ embedding_dim (`int`): The size of each embedding vector.
106
+ num_embeddings (`int`): The size of the embeddings dictionary.
107
+ """
108
+
109
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
110
+ super().__init__()
111
+
112
+ self.silu = nn.SiLU()
113
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
114
+ if norm_type == "layer_norm":
115
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
116
+ else:
117
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ emb: Optional[torch.Tensor] = None,
123
+ ) -> Tuple[torch.Tensor, ...]:
124
+ emb = self.linear(self.silu(emb))
125
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
126
+ 9, dim=1
127
+ )
128
+ norm_hidden_states = self.norm(hidden_states)
129
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
130
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
131
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
132
+
133
+
100
134
  class AdaLayerNormZero(nn.Module):
101
135
  r"""
102
136
  Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -355,6 +389,51 @@ class LuminaLayerNormContinuous(nn.Module):
355
389
  return x
356
390
 
357
391
 
392
+ class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
393
+ r"""
394
+ Norm layer adaptive layer norm zero (adaLN-Zero).
395
+
396
+ Parameters:
397
+ embedding_dim (`int`): The size of each embedding vector.
398
+ num_embeddings (`int`): The size of the embeddings dictionary.
399
+ """
400
+
401
+ def __init__(self, embedding_dim: int, dim: int):
402
+ super().__init__()
403
+
404
+ self.silu = nn.SiLU()
405
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
406
+ self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
407
+ self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
408
+
409
+ def forward(
410
+ self,
411
+ x: torch.Tensor,
412
+ context: torch.Tensor,
413
+ emb: Optional[torch.Tensor] = None,
414
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
415
+ emb = self.linear(self.silu(emb))
416
+ (
417
+ shift_msa,
418
+ scale_msa,
419
+ gate_msa,
420
+ shift_mlp,
421
+ scale_mlp,
422
+ gate_mlp,
423
+ c_shift_msa,
424
+ c_scale_msa,
425
+ c_gate_msa,
426
+ c_shift_mlp,
427
+ c_scale_mlp,
428
+ c_gate_mlp,
429
+ ) = emb.chunk(12, dim=1)
430
+ normed_x = self.norm_x(x)
431
+ normed_context = self.norm_c(context)
432
+ x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
433
+ context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
434
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
435
+
436
+
358
437
  class CogVideoXLayerNormZero(nn.Module):
359
438
  def __init__(
360
439
  self,
@@ -14,6 +14,7 @@ if is_torch_available():
14
14
  from .stable_audio_transformer import StableAudioDiTModel
15
15
  from .t5_film_transformer import T5FilmDecoder
16
16
  from .transformer_2d import Transformer2DModel
17
+ from .transformer_cogview3plus import CogView3PlusTransformer2DModel
17
18
  from .transformer_flux import FluxTransformer2DModel
18
19
  from .transformer_sd3 import SD3Transformer2DModel
19
20
  from .transformer_temporal import TransformerTemporalModel
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
274
274
  pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
275
275
  """
276
276
 
277
+ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
277
278
  _supports_gradient_checkpointing = True
278
279
 
279
280
  @register_to_config
@@ -19,7 +19,8 @@ import torch
19
19
  from torch import nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...utils import is_torch_version, logging
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
23
24
  from ...utils.torch_utils import maybe_allow_in_graph
24
25
  from ..attention import Attention, FeedForward
25
26
  from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
152
153
  return hidden_states, encoder_hidden_states
153
154
 
154
155
 
155
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
156
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
156
157
  """
157
158
  A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
158
159
 
@@ -411,8 +412,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
411
412
  timestep: Union[int, float, torch.LongTensor],
412
413
  timestep_cond: Optional[torch.Tensor] = None,
413
414
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
415
+ attention_kwargs: Optional[Dict[str, Any]] = None,
414
416
  return_dict: bool = True,
415
417
  ):
418
+ if attention_kwargs is not None:
419
+ attention_kwargs = attention_kwargs.copy()
420
+ lora_scale = attention_kwargs.pop("scale", 1.0)
421
+ else:
422
+ lora_scale = 1.0
423
+
424
+ if USE_PEFT_BACKEND:
425
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
426
+ scale_lora_layers(self, lora_scale)
427
+ else:
428
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
429
+ logger.warning(
430
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
431
+ )
432
+
416
433
  batch_size, num_frames, channels, height, width = hidden_states.shape
417
434
 
418
435
  # 1. Time embedding
@@ -481,6 +498,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
481
498
  output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
482
499
  output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
483
500
 
501
+ if USE_PEFT_BACKEND:
502
+ # remove `lora_scale` from each PEFT layer
503
+ unscale_lora_layers(self, lora_scale)
504
+
484
505
  if not return_dict:
485
506
  return (output,)
486
507
  return Transformer2DModelOutput(sample=output)
@@ -19,7 +19,7 @@ from torch import nn
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...utils import is_torch_version, logging
21
21
  from ..attention import BasicTransformerBlock
22
- from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22
+ from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
23
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
24
24
  from ..modeling_outputs import Transformer2DModelOutput
25
25
  from ..modeling_utils import ModelMixin
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
247
247
  for name, module in self.named_children():
248
248
  fn_recursive_attn_processor(name, module, processor)
249
249
 
250
+ def set_default_attn_processor(self):
251
+ """
252
+ Disables custom attention processors and sets the default attention implementation.
253
+
254
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
255
+ """
256
+ self.set_attn_processor(AttnProcessor())
257
+
250
258
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251
259
  def fuse_qkv_projections(self):
252
260
  """