diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import copy
17
18
  import inspect
18
19
  import itertools
19
20
  import json
20
21
  import os
21
22
  import re
22
23
  from collections import OrderedDict
23
- from functools import partial
24
+ from functools import partial, wraps
24
25
  from pathlib import Path
25
26
  from typing import Any, Callable, List, Optional, Tuple, Union
26
27
 
@@ -31,6 +32,8 @@ from huggingface_hub.utils import validate_hf_hub_args
31
32
  from torch import Tensor, nn
32
33
 
33
34
  from .. import __version__
35
+ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
36
+ from ..quantizers.quantization_config import QuantizationMethod
34
37
  from ..utils import (
35
38
  CONFIG_NAME,
36
39
  FLAX_WEIGHTS_NAME,
@@ -43,6 +46,8 @@ from ..utils import (
43
46
  _get_model_file,
44
47
  deprecate,
45
48
  is_accelerate_available,
49
+ is_bitsandbytes_available,
50
+ is_bitsandbytes_version,
46
51
  is_torch_version,
47
52
  logging,
48
53
  )
@@ -54,7 +59,9 @@ from ..utils.hub_utils import (
54
59
  from .model_loading_utils import (
55
60
  _determine_device_map,
56
61
  _fetch_index_file,
62
+ _fetch_index_file_legacy,
57
63
  _load_state_dict_into_model,
64
+ _merge_sharded_checkpoints,
58
65
  load_model_dict_into_meta,
59
66
  load_state_dict,
60
67
  )
@@ -93,24 +100,20 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
93
100
 
94
101
  def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
95
102
  try:
96
- params = tuple(parameter.parameters())
97
- if len(params) > 0:
98
- return params[0].dtype
99
-
100
- buffers = tuple(parameter.buffers())
101
- if len(buffers) > 0:
102
- return buffers[0].dtype
103
-
103
+ return next(parameter.parameters()).dtype
104
104
  except StopIteration:
105
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
105
+ try:
106
+ return next(parameter.buffers()).dtype
107
+ except StopIteration:
108
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
106
109
 
107
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
108
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
109
- return tuples
110
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
111
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
112
+ return tuples
110
113
 
111
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
112
- first_tuple = next(gen)
113
- return first_tuple[1].dtype
114
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
115
+ first_tuple = next(gen)
116
+ return first_tuple[1].dtype
114
117
 
115
118
 
116
119
  class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -128,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128
131
  _supports_gradient_checkpointing = False
129
132
  _keys_to_ignore_on_load_unexpected = None
130
133
  _no_split_modules = None
134
+ _keep_in_fp32_modules = None
131
135
 
132
136
  def __init__(self):
133
137
  super().__init__()
@@ -311,13 +315,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
311
315
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
312
316
  return
313
317
 
318
+ hf_quantizer = getattr(self, "hf_quantizer", None)
319
+ if hf_quantizer is not None:
320
+ quantization_serializable = (
321
+ hf_quantizer is not None
322
+ and isinstance(hf_quantizer, DiffusersQuantizer)
323
+ and hf_quantizer.is_serializable
324
+ )
325
+ if not quantization_serializable:
326
+ raise ValueError(
327
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
328
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
329
+ )
330
+
314
331
  weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
315
332
  weights_name = _add_variant(weights_name, variant)
316
- weight_name_split = weights_name.split(".")
317
- if len(weight_name_split) in [2, 3]:
318
- weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
319
- else:
320
- raise ValueError(f"Invalid {weights_name} provided.")
333
+ weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
334
+ ".safetensors", "{suffix}.safetensors"
335
+ )
321
336
 
322
337
  os.makedirs(save_directory, exist_ok=True)
323
338
 
@@ -407,6 +422,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
407
422
  create_pr=create_pr,
408
423
  )
409
424
 
425
+ def dequantize(self):
426
+ """
427
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
428
+ dequantization.
429
+ """
430
+ hf_quantizer = getattr(self, "hf_quantizer", None)
431
+
432
+ if hf_quantizer is None:
433
+ raise ValueError("You need to first quantize your model in order to dequantize it")
434
+
435
+ return hf_quantizer.dequantize(self)
436
+
410
437
  @classmethod
411
438
  @validate_hf_hub_args
412
439
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -529,6 +556,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
529
556
  low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
530
557
  variant = kwargs.pop("variant", None)
531
558
  use_safetensors = kwargs.pop("use_safetensors", None)
559
+ quantization_config = kwargs.pop("quantization_config", None)
532
560
 
533
561
  allow_pickle = False
534
562
  if use_safetensors is None:
@@ -623,26 +651,85 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
623
651
  user_agent=user_agent,
624
652
  **kwargs,
625
653
  )
654
+ # no in-place modification of the original config.
655
+ config = copy.deepcopy(config)
656
+
657
+ # determine initial quantization config.
658
+ #######################################
659
+ pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
660
+ if pre_quantized or quantization_config is not None:
661
+ if pre_quantized:
662
+ config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
663
+ config["quantization_config"], quantization_config
664
+ )
665
+ else:
666
+ config["quantization_config"] = quantization_config
667
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
668
+ config["quantization_config"], pre_quantized=pre_quantized
669
+ )
670
+ else:
671
+ hf_quantizer = None
672
+
673
+ if hf_quantizer is not None:
674
+ if device_map is not None:
675
+ raise NotImplementedError(
676
+ "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
677
+ )
678
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
679
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
680
+
681
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
682
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
683
+
684
+ # Force-set to `True` for more mem efficiency
685
+ if low_cpu_mem_usage is None:
686
+ low_cpu_mem_usage = True
687
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
688
+ elif not low_cpu_mem_usage:
689
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
690
+
691
+ # Check if `_keep_in_fp32_modules` is not None
692
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
693
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
694
+ )
695
+ if use_keep_in_fp32_modules:
696
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
697
+ if not isinstance(keep_in_fp32_modules, list):
698
+ keep_in_fp32_modules = [keep_in_fp32_modules]
699
+
700
+ if low_cpu_mem_usage is None:
701
+ low_cpu_mem_usage = True
702
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
703
+ elif not low_cpu_mem_usage:
704
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
705
+ else:
706
+ keep_in_fp32_modules = []
707
+ #######################################
626
708
 
627
709
  # Determine if we're loading from a directory of sharded checkpoints.
628
710
  is_sharded = False
629
711
  index_file = None
630
712
  is_local = os.path.isdir(pretrained_model_name_or_path)
631
- index_file = _fetch_index_file(
632
- is_local=is_local,
633
- pretrained_model_name_or_path=pretrained_model_name_or_path,
634
- subfolder=subfolder or "",
635
- use_safetensors=use_safetensors,
636
- cache_dir=cache_dir,
637
- variant=variant,
638
- force_download=force_download,
639
- proxies=proxies,
640
- local_files_only=local_files_only,
641
- token=token,
642
- revision=revision,
643
- user_agent=user_agent,
644
- commit_hash=commit_hash,
645
- )
713
+ index_file_kwargs = {
714
+ "is_local": is_local,
715
+ "pretrained_model_name_or_path": pretrained_model_name_or_path,
716
+ "subfolder": subfolder or "",
717
+ "use_safetensors": use_safetensors,
718
+ "cache_dir": cache_dir,
719
+ "variant": variant,
720
+ "force_download": force_download,
721
+ "proxies": proxies,
722
+ "local_files_only": local_files_only,
723
+ "token": token,
724
+ "revision": revision,
725
+ "user_agent": user_agent,
726
+ "commit_hash": commit_hash,
727
+ }
728
+ index_file = _fetch_index_file(**index_file_kwargs)
729
+ # In case the index file was not found we still have to consider the legacy format.
730
+ # this becomes applicable when the variant is not None.
731
+ if variant is not None and (index_file is None or not os.path.exists(index_file)):
732
+ index_file = _fetch_index_file_legacy(**index_file_kwargs)
646
733
  if index_file is not None and index_file.is_file():
647
734
  is_sharded = True
648
735
 
@@ -684,6 +771,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
684
771
  revision=revision,
685
772
  subfolder=subfolder or "",
686
773
  )
774
+ if hf_quantizer is not None:
775
+ model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
776
+ logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
777
+ is_sharded = False
687
778
 
688
779
  elif use_safetensors and not is_sharded:
689
780
  try:
@@ -729,13 +820,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
729
820
  with accelerate.init_empty_weights():
730
821
  model = cls.from_config(config, **unused_kwargs)
731
822
 
823
+ if hf_quantizer is not None:
824
+ hf_quantizer.preprocess_model(
825
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
826
+ )
827
+
732
828
  # if device_map is None, load the state dict and move the params from meta device to the cpu
733
829
  if device_map is None and not is_sharded:
734
- param_device = "cpu"
830
+ # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
831
+ # It would error out during the `validate_environment()` call above in the absence of cuda.
832
+ is_quant_method_bnb = (
833
+ getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
834
+ )
835
+ if hf_quantizer is None:
836
+ param_device = "cpu"
837
+ # TODO (sayakpaul, SunMarc): remove this after model loading refactor
838
+ elif is_quant_method_bnb:
839
+ param_device = torch.cuda.current_device()
735
840
  state_dict = load_state_dict(model_file, variant=variant)
736
841
  model._convert_deprecated_attention_blocks(state_dict)
842
+
737
843
  # move the params from meta device to cpu
738
844
  missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
845
+ if hf_quantizer is not None:
846
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
739
847
  if len(missing_keys) > 0:
740
848
  raise ValueError(
741
849
  f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
@@ -750,6 +858,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
750
858
  device=param_device,
751
859
  dtype=torch_dtype,
752
860
  model_name_or_path=pretrained_model_name_or_path,
861
+ hf_quantizer=hf_quantizer,
862
+ keep_in_fp32_modules=keep_in_fp32_modules,
753
863
  )
754
864
 
755
865
  if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -765,7 +875,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
765
875
  # Load weights and dispatch according to the device_map
766
876
  # by default the device_map is None and the weights are loaded on the CPU
767
877
  force_hook = True
768
- device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
878
+ device_map = _determine_device_map(
879
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
880
+ )
769
881
  if device_map is None and is_sharded:
770
882
  # we load the parameters on the cpu
771
883
  device_map = {"": "cpu"}
@@ -843,14 +955,25 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
843
955
  "error_msgs": error_msgs,
844
956
  }
845
957
 
958
+ if hf_quantizer is not None:
959
+ hf_quantizer.postprocess_model(model)
960
+ model.hf_quantizer = hf_quantizer
961
+
846
962
  if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
847
963
  raise ValueError(
848
964
  f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
849
965
  )
850
- elif torch_dtype is not None:
966
+ # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
967
+ # completely lose the effectivity of `use_keep_in_fp32_modules`.
968
+ elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
851
969
  model = model.to(torch_dtype)
852
970
 
853
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
971
+ if hf_quantizer is not None:
972
+ # We also make sure to purge `_pre_quantization_dtype` when we serialize
973
+ # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
974
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
975
+ else:
976
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
854
977
 
855
978
  # Set model in evaluation mode to deactivate DropOut modules by default
856
979
  model.eval()
@@ -859,6 +982,76 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
859
982
 
860
983
  return model
861
984
 
985
+ # Adapted from `transformers`.
986
+ @wraps(torch.nn.Module.cuda)
987
+ def cuda(self, *args, **kwargs):
988
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
989
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
990
+ if getattr(self, "is_loaded_in_8bit", False):
991
+ raise ValueError(
992
+ "Calling `cuda()` is not supported for `8-bit` quantized models. "
993
+ " Please use the model as it is, since the model has already been set to the correct devices."
994
+ )
995
+ elif is_bitsandbytes_version("<", "0.43.2"):
996
+ raise ValueError(
997
+ "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
998
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
999
+ )
1000
+ return super().cuda(*args, **kwargs)
1001
+
1002
+ # Adapted from `transformers`.
1003
+ @wraps(torch.nn.Module.to)
1004
+ def to(self, *args, **kwargs):
1005
+ dtype_present_in_args = "dtype" in kwargs
1006
+
1007
+ if not dtype_present_in_args:
1008
+ for arg in args:
1009
+ if isinstance(arg, torch.dtype):
1010
+ dtype_present_in_args = True
1011
+ break
1012
+
1013
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
1014
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1015
+ if dtype_present_in_args:
1016
+ raise ValueError(
1017
+ "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
1018
+ " desired `dtype` by passing the correct `torch_dtype` argument."
1019
+ )
1020
+
1021
+ if getattr(self, "is_loaded_in_8bit", False):
1022
+ raise ValueError(
1023
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
1024
+ " model has already been set to the correct devices and casted to the correct `dtype`."
1025
+ )
1026
+ elif is_bitsandbytes_version("<", "0.43.2"):
1027
+ raise ValueError(
1028
+ "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1029
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1030
+ )
1031
+ return super().to(*args, **kwargs)
1032
+
1033
+ # Taken from `transformers`.
1034
+ def half(self, *args):
1035
+ # Checks if the model is quantized
1036
+ if getattr(self, "is_quantized", False):
1037
+ raise ValueError(
1038
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
1039
+ " model has already been cast to the correct `dtype`."
1040
+ )
1041
+ else:
1042
+ return super().half(*args)
1043
+
1044
+ # Taken from `transformers`.
1045
+ def float(self, *args):
1046
+ # Checks if the model is quantized
1047
+ if getattr(self, "is_quantized", False):
1048
+ raise ValueError(
1049
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
1050
+ " model has already been cast to the correct `dtype`."
1051
+ )
1052
+ else:
1053
+ return super().float(*args)
1054
+
862
1055
  @classmethod
863
1056
  def _load_pretrained_model(
864
1057
  cls,
@@ -1041,19 +1234,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1041
1234
  859520964
1042
1235
  ```
1043
1236
  """
1237
+ is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
1238
+
1239
+ if is_loaded_in_4bit:
1240
+ if is_bitsandbytes_available():
1241
+ import bitsandbytes as bnb
1242
+ else:
1243
+ raise ValueError(
1244
+ "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
1245
+ " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
1246
+ )
1044
1247
 
1045
1248
  if exclude_embeddings:
1046
1249
  embedding_param_names = [
1047
- f"{name}.weight"
1048
- for name, module_type in self.named_modules()
1049
- if isinstance(module_type, torch.nn.Embedding)
1250
+ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
1050
1251
  ]
1051
- non_embedding_parameters = [
1252
+ total_parameters = [
1052
1253
  parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1053
1254
  ]
1054
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1055
1255
  else:
1056
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1256
+ total_parameters = list(self.parameters())
1257
+
1258
+ total_numel = []
1259
+
1260
+ for param in total_parameters:
1261
+ if param.requires_grad or not only_trainable:
1262
+ # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
1263
+ # used for the 4bit quantization (uint8 tensors are stored)
1264
+ if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
1265
+ if hasattr(param, "element_size"):
1266
+ num_bytes = param.element_size()
1267
+ elif hasattr(param, "quant_storage"):
1268
+ num_bytes = param.quant_storage.itemsize
1269
+ else:
1270
+ num_bytes = 1
1271
+ total_numel.append(param.numel() * 2 * num_bytes)
1272
+ else:
1273
+ total_numel.append(param.numel())
1274
+
1275
+ return sum(total_numel)
1276
+
1277
+ def get_memory_footprint(self, return_buffers=True):
1278
+ r"""
1279
+ Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
1280
+ Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
1281
+ PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
1282
+
1283
+ Arguments:
1284
+ return_buffers (`bool`, *optional*, defaults to `True`):
1285
+ Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
1286
+ are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
1287
+ norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
1288
+ """
1289
+ mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
1290
+ if return_buffers:
1291
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
1292
+ mem = mem + mem_bufs
1293
+ return mem
1057
1294
 
1058
1295
  def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1059
1296
  deprecated_attention_block_paths = []
@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm):
97
97
  ).to(origin_dtype)
98
98
 
99
99
 
100
+ class SD35AdaLayerNormZeroX(nn.Module):
101
+ r"""
102
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
103
+
104
+ Parameters:
105
+ embedding_dim (`int`): The size of each embedding vector.
106
+ num_embeddings (`int`): The size of the embeddings dictionary.
107
+ """
108
+
109
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
110
+ super().__init__()
111
+
112
+ self.silu = nn.SiLU()
113
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
114
+ if norm_type == "layer_norm":
115
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
116
+ else:
117
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ emb: Optional[torch.Tensor] = None,
123
+ ) -> Tuple[torch.Tensor, ...]:
124
+ emb = self.linear(self.silu(emb))
125
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
126
+ 9, dim=1
127
+ )
128
+ norm_hidden_states = self.norm(hidden_states)
129
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
130
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
131
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
132
+
133
+
100
134
  class AdaLayerNormZero(nn.Module):
101
135
  r"""
102
136
  Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -355,6 +389,51 @@ class LuminaLayerNormContinuous(nn.Module):
355
389
  return x
356
390
 
357
391
 
392
+ class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
393
+ r"""
394
+ Norm layer adaptive layer norm zero (adaLN-Zero).
395
+
396
+ Parameters:
397
+ embedding_dim (`int`): The size of each embedding vector.
398
+ num_embeddings (`int`): The size of the embeddings dictionary.
399
+ """
400
+
401
+ def __init__(self, embedding_dim: int, dim: int):
402
+ super().__init__()
403
+
404
+ self.silu = nn.SiLU()
405
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
406
+ self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
407
+ self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
408
+
409
+ def forward(
410
+ self,
411
+ x: torch.Tensor,
412
+ context: torch.Tensor,
413
+ emb: Optional[torch.Tensor] = None,
414
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
415
+ emb = self.linear(self.silu(emb))
416
+ (
417
+ shift_msa,
418
+ scale_msa,
419
+ gate_msa,
420
+ shift_mlp,
421
+ scale_mlp,
422
+ gate_mlp,
423
+ c_shift_msa,
424
+ c_scale_msa,
425
+ c_gate_msa,
426
+ c_shift_mlp,
427
+ c_scale_mlp,
428
+ c_gate_mlp,
429
+ ) = emb.chunk(12, dim=1)
430
+ normed_x = self.norm_x(x)
431
+ normed_context = self.norm_c(context)
432
+ x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
433
+ context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
434
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
435
+
436
+
358
437
  class CogVideoXLayerNormZero(nn.Module):
359
438
  def __init__(
360
439
  self,
@@ -14,6 +14,7 @@ if is_torch_available():
14
14
  from .stable_audio_transformer import StableAudioDiTModel
15
15
  from .t5_film_transformer import T5FilmDecoder
16
16
  from .transformer_2d import Transformer2DModel
17
+ from .transformer_cogview3plus import CogView3PlusTransformer2DModel
17
18
  from .transformer_flux import FluxTransformer2DModel
18
19
  from .transformer_sd3 import SD3Transformer2DModel
19
20
  from .transformer_temporal import TransformerTemporalModel
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
274
274
  pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
275
275
  """
276
276
 
277
+ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
277
278
  _supports_gradient_checkpointing = True
278
279
 
279
280
  @register_to_config