diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
+ import copy
17
18
  import inspect
18
19
  import itertools
19
20
  import json
20
21
  import os
21
22
  import re
22
23
  from collections import OrderedDict
23
- from functools import partial
24
+ from functools import partial, wraps
24
25
  from pathlib import Path
25
26
  from typing import Any, Callable, List, Optional, Tuple, Union
26
27
 
@@ -31,6 +32,8 @@ from huggingface_hub.utils import validate_hf_hub_args
31
32
  from torch import Tensor, nn
32
33
 
33
34
  from .. import __version__
35
+ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
36
+ from ..quantizers.quantization_config import QuantizationMethod
34
37
  from ..utils import (
35
38
  CONFIG_NAME,
36
39
  FLAX_WEIGHTS_NAME,
@@ -43,6 +46,8 @@ from ..utils import (
43
46
  _get_model_file,
44
47
  deprecate,
45
48
  is_accelerate_available,
49
+ is_bitsandbytes_available,
50
+ is_bitsandbytes_version,
46
51
  is_torch_version,
47
52
  logging,
48
53
  )
@@ -54,7 +59,9 @@ from ..utils.hub_utils import (
54
59
  from .model_loading_utils import (
55
60
  _determine_device_map,
56
61
  _fetch_index_file,
62
+ _fetch_index_file_legacy,
57
63
  _load_state_dict_into_model,
64
+ _merge_sharded_checkpoints,
58
65
  load_model_dict_into_meta,
59
66
  load_state_dict,
60
67
  )
@@ -92,25 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
92
99
 
93
100
 
94
101
  def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
95
- try:
96
- params = tuple(parameter.parameters())
97
- if len(params) > 0:
98
- return params[0].dtype
99
-
100
- buffers = tuple(parameter.buffers())
101
- if len(buffers) > 0:
102
- return buffers[0].dtype
103
-
104
- except StopIteration:
105
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
106
-
107
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
108
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
109
- return tuples
110
-
111
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
112
- first_tuple = next(gen)
113
- return first_tuple[1].dtype
102
+ """
103
+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104
+ """
105
+ last_dtype = None
106
+ for param in parameter.parameters():
107
+ last_dtype = param.dtype
108
+ if param.is_floating_point():
109
+ return param.dtype
110
+
111
+ for buffer in parameter.buffers():
112
+ last_dtype = buffer.dtype
113
+ if buffer.is_floating_point():
114
+ return buffer.dtype
115
+
116
+ if last_dtype is not None:
117
+ # if no floating dtype was found return whatever the first dtype is
118
+ return last_dtype
119
+
120
+ # For nn.DataParallel compatibility in PyTorch > 1.5
121
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
122
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
123
+ return tuples
124
+
125
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
126
+ last_tuple = None
127
+ for tuple in gen:
128
+ last_tuple = tuple
129
+ if tuple[1].is_floating_point():
130
+ return tuple[1].dtype
131
+
132
+ if last_tuple is not None:
133
+ # fallback to the last dtype
134
+ return last_tuple[1].dtype
114
135
 
115
136
 
116
137
  class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -128,6 +149,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
128
149
  _supports_gradient_checkpointing = False
129
150
  _keys_to_ignore_on_load_unexpected = None
130
151
  _no_split_modules = None
152
+ _keep_in_fp32_modules = None
131
153
 
132
154
  def __init__(self):
133
155
  super().__init__()
@@ -204,6 +226,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
204
226
  """
205
227
  self.set_use_npu_flash_attention(False)
206
228
 
229
+ def set_use_xla_flash_attention(
230
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
231
+ ) -> None:
232
+ # Recursively walk through all the children.
233
+ # Any children which exposes the set_use_xla_flash_attention method
234
+ # gets the message
235
+ def fn_recursive_set_flash_attention(module: torch.nn.Module):
236
+ if hasattr(module, "set_use_xla_flash_attention"):
237
+ module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
238
+
239
+ for child in module.children():
240
+ fn_recursive_set_flash_attention(child)
241
+
242
+ for module in self.children():
243
+ if isinstance(module, torch.nn.Module):
244
+ fn_recursive_set_flash_attention(module)
245
+
246
+ def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
247
+ r"""
248
+ Enable the flash attention pallals kernel for torch_xla.
249
+ """
250
+ self.set_use_xla_flash_attention(True, partition_spec)
251
+
252
+ def disable_xla_flash_attention(self):
253
+ r"""
254
+ Disable the flash attention pallals kernel for torch_xla.
255
+ """
256
+ self.set_use_xla_flash_attention(False)
257
+
207
258
  def set_use_memory_efficient_attention_xformers(
208
259
  self, valid: bool, attention_op: Optional[Callable] = None
209
260
  ) -> None:
@@ -311,19 +362,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
311
362
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
312
363
  return
313
364
 
365
+ hf_quantizer = getattr(self, "hf_quantizer", None)
366
+ if hf_quantizer is not None:
367
+ quantization_serializable = (
368
+ hf_quantizer is not None
369
+ and isinstance(hf_quantizer, DiffusersQuantizer)
370
+ and hf_quantizer.is_serializable
371
+ )
372
+ if not quantization_serializable:
373
+ raise ValueError(
374
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
375
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
376
+ )
377
+
314
378
  weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
315
379
  weights_name = _add_variant(weights_name, variant)
316
- weight_name_split = weights_name.split(".")
317
- if len(weight_name_split) in [2, 3]:
318
- weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
319
- else:
320
- raise ValueError(f"Invalid {weights_name} provided.")
380
+ weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
381
+ ".safetensors", "{suffix}.safetensors"
382
+ )
321
383
 
322
384
  os.makedirs(save_directory, exist_ok=True)
323
385
 
324
386
  if push_to_hub:
325
387
  commit_message = kwargs.pop("commit_message", None)
326
- private = kwargs.pop("private", False)
388
+ private = kwargs.pop("private", None)
327
389
  create_pr = kwargs.pop("create_pr", False)
328
390
  token = kwargs.pop("token", None)
329
391
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -407,6 +469,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
407
469
  create_pr=create_pr,
408
470
  )
409
471
 
472
+ def dequantize(self):
473
+ """
474
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
475
+ dequantization.
476
+ """
477
+ hf_quantizer = getattr(self, "hf_quantizer", None)
478
+
479
+ if hf_quantizer is None:
480
+ raise ValueError("You need to first quantize your model in order to dequantize it")
481
+
482
+ return hf_quantizer.dequantize(self)
483
+
410
484
  @classmethod
411
485
  @validate_hf_hub_args
412
486
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -529,6 +603,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
529
603
  low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
530
604
  variant = kwargs.pop("variant", None)
531
605
  use_safetensors = kwargs.pop("use_safetensors", None)
606
+ quantization_config = kwargs.pop("quantization_config", None)
532
607
 
533
608
  allow_pickle = False
534
609
  if use_safetensors is None:
@@ -623,26 +698,87 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
623
698
  user_agent=user_agent,
624
699
  **kwargs,
625
700
  )
701
+ # no in-place modification of the original config.
702
+ config = copy.deepcopy(config)
703
+
704
+ # determine initial quantization config.
705
+ #######################################
706
+ pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
707
+ if pre_quantized or quantization_config is not None:
708
+ if pre_quantized:
709
+ config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
710
+ config["quantization_config"], quantization_config
711
+ )
712
+ else:
713
+ config["quantization_config"] = quantization_config
714
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
715
+ config["quantization_config"], pre_quantized=pre_quantized
716
+ )
717
+ else:
718
+ hf_quantizer = None
719
+
720
+ if hf_quantizer is not None:
721
+ is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
722
+ if is_bnb_quantization_method and device_map is not None:
723
+ raise NotImplementedError(
724
+ "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
725
+ )
726
+
727
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
728
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
729
+
730
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
731
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
732
+
733
+ # Force-set to `True` for more mem efficiency
734
+ if low_cpu_mem_usage is None:
735
+ low_cpu_mem_usage = True
736
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
737
+ elif not low_cpu_mem_usage:
738
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
739
+
740
+ # Check if `_keep_in_fp32_modules` is not None
741
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
742
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
743
+ )
744
+ if use_keep_in_fp32_modules:
745
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
746
+ if not isinstance(keep_in_fp32_modules, list):
747
+ keep_in_fp32_modules = [keep_in_fp32_modules]
748
+
749
+ if low_cpu_mem_usage is None:
750
+ low_cpu_mem_usage = True
751
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
752
+ elif not low_cpu_mem_usage:
753
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
754
+ else:
755
+ keep_in_fp32_modules = []
756
+ #######################################
626
757
 
627
758
  # Determine if we're loading from a directory of sharded checkpoints.
628
759
  is_sharded = False
629
760
  index_file = None
630
761
  is_local = os.path.isdir(pretrained_model_name_or_path)
631
- index_file = _fetch_index_file(
632
- is_local=is_local,
633
- pretrained_model_name_or_path=pretrained_model_name_or_path,
634
- subfolder=subfolder or "",
635
- use_safetensors=use_safetensors,
636
- cache_dir=cache_dir,
637
- variant=variant,
638
- force_download=force_download,
639
- proxies=proxies,
640
- local_files_only=local_files_only,
641
- token=token,
642
- revision=revision,
643
- user_agent=user_agent,
644
- commit_hash=commit_hash,
645
- )
762
+ index_file_kwargs = {
763
+ "is_local": is_local,
764
+ "pretrained_model_name_or_path": pretrained_model_name_or_path,
765
+ "subfolder": subfolder or "",
766
+ "use_safetensors": use_safetensors,
767
+ "cache_dir": cache_dir,
768
+ "variant": variant,
769
+ "force_download": force_download,
770
+ "proxies": proxies,
771
+ "local_files_only": local_files_only,
772
+ "token": token,
773
+ "revision": revision,
774
+ "user_agent": user_agent,
775
+ "commit_hash": commit_hash,
776
+ }
777
+ index_file = _fetch_index_file(**index_file_kwargs)
778
+ # In case the index file was not found we still have to consider the legacy format.
779
+ # this becomes applicable when the variant is not None.
780
+ if variant is not None and (index_file is None or not os.path.exists(index_file)):
781
+ index_file = _fetch_index_file_legacy(**index_file_kwargs)
646
782
  if index_file is not None and index_file.is_file():
647
783
  is_sharded = True
648
784
 
@@ -684,6 +820,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
684
820
  revision=revision,
685
821
  subfolder=subfolder or "",
686
822
  )
823
+ if hf_quantizer is not None and is_bnb_quantization_method:
824
+ model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
825
+ logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
826
+ is_sharded = False
687
827
 
688
828
  elif use_safetensors and not is_sharded:
689
829
  try:
@@ -729,13 +869,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
729
869
  with accelerate.init_empty_weights():
730
870
  model = cls.from_config(config, **unused_kwargs)
731
871
 
872
+ if hf_quantizer is not None:
873
+ hf_quantizer.preprocess_model(
874
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
875
+ )
876
+
732
877
  # if device_map is None, load the state dict and move the params from meta device to the cpu
733
878
  if device_map is None and not is_sharded:
734
- param_device = "cpu"
879
+ # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
880
+ # It would error out during the `validate_environment()` call above in the absence of cuda.
881
+ if hf_quantizer is None:
882
+ param_device = "cpu"
883
+ # TODO (sayakpaul, SunMarc): remove this after model loading refactor
884
+ else:
885
+ param_device = torch.device(torch.cuda.current_device())
735
886
  state_dict = load_state_dict(model_file, variant=variant)
736
887
  model._convert_deprecated_attention_blocks(state_dict)
888
+
737
889
  # move the params from meta device to cpu
738
890
  missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
891
+ if hf_quantizer is not None:
892
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
739
893
  if len(missing_keys) > 0:
740
894
  raise ValueError(
741
895
  f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
@@ -750,6 +904,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
750
904
  device=param_device,
751
905
  dtype=torch_dtype,
752
906
  model_name_or_path=pretrained_model_name_or_path,
907
+ hf_quantizer=hf_quantizer,
908
+ keep_in_fp32_modules=keep_in_fp32_modules,
753
909
  )
754
910
 
755
911
  if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -765,7 +921,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
765
921
  # Load weights and dispatch according to the device_map
766
922
  # by default the device_map is None and the weights are loaded on the CPU
767
923
  force_hook = True
768
- device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
924
+ device_map = _determine_device_map(
925
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
926
+ )
769
927
  if device_map is None and is_sharded:
770
928
  # we load the parameters on the cpu
771
929
  device_map = {"": "cpu"}
@@ -843,14 +1001,25 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
843
1001
  "error_msgs": error_msgs,
844
1002
  }
845
1003
 
1004
+ if hf_quantizer is not None:
1005
+ hf_quantizer.postprocess_model(model)
1006
+ model.hf_quantizer = hf_quantizer
1007
+
846
1008
  if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
847
1009
  raise ValueError(
848
1010
  f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
849
1011
  )
850
- elif torch_dtype is not None:
1012
+ # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
1013
+ # completely lose the effectivity of `use_keep_in_fp32_modules`.
1014
+ elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
851
1015
  model = model.to(torch_dtype)
852
1016
 
853
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1017
+ if hf_quantizer is not None:
1018
+ # We also make sure to purge `_pre_quantization_dtype` when we serialize
1019
+ # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
1020
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
1021
+ else:
1022
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
854
1023
 
855
1024
  # Set model in evaluation mode to deactivate DropOut modules by default
856
1025
  model.eval()
@@ -859,6 +1028,76 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
859
1028
 
860
1029
  return model
861
1030
 
1031
+ # Adapted from `transformers`.
1032
+ @wraps(torch.nn.Module.cuda)
1033
+ def cuda(self, *args, **kwargs):
1034
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
1035
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1036
+ if getattr(self, "is_loaded_in_8bit", False):
1037
+ raise ValueError(
1038
+ "Calling `cuda()` is not supported for `8-bit` quantized models. "
1039
+ " Please use the model as it is, since the model has already been set to the correct devices."
1040
+ )
1041
+ elif is_bitsandbytes_version("<", "0.43.2"):
1042
+ raise ValueError(
1043
+ "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1044
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1045
+ )
1046
+ return super().cuda(*args, **kwargs)
1047
+
1048
+ # Adapted from `transformers`.
1049
+ @wraps(torch.nn.Module.to)
1050
+ def to(self, *args, **kwargs):
1051
+ dtype_present_in_args = "dtype" in kwargs
1052
+
1053
+ if not dtype_present_in_args:
1054
+ for arg in args:
1055
+ if isinstance(arg, torch.dtype):
1056
+ dtype_present_in_args = True
1057
+ break
1058
+
1059
+ if getattr(self, "is_quantized", False):
1060
+ if dtype_present_in_args:
1061
+ raise ValueError(
1062
+ "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
1063
+ "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
1064
+ )
1065
+
1066
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1067
+ if getattr(self, "is_loaded_in_8bit", False):
1068
+ raise ValueError(
1069
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
1070
+ " model has already been set to the correct devices and casted to the correct `dtype`."
1071
+ )
1072
+ elif is_bitsandbytes_version("<", "0.43.2"):
1073
+ raise ValueError(
1074
+ "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
1075
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
1076
+ )
1077
+ return super().to(*args, **kwargs)
1078
+
1079
+ # Taken from `transformers`.
1080
+ def half(self, *args):
1081
+ # Checks if the model is quantized
1082
+ if getattr(self, "is_quantized", False):
1083
+ raise ValueError(
1084
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
1085
+ " model has already been cast to the correct `dtype`."
1086
+ )
1087
+ else:
1088
+ return super().half(*args)
1089
+
1090
+ # Taken from `transformers`.
1091
+ def float(self, *args):
1092
+ # Checks if the model is quantized
1093
+ if getattr(self, "is_quantized", False):
1094
+ raise ValueError(
1095
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
1096
+ " model has already been cast to the correct `dtype`."
1097
+ )
1098
+ else:
1099
+ return super().float(*args)
1100
+
862
1101
  @classmethod
863
1102
  def _load_pretrained_model(
864
1103
  cls,
@@ -1041,19 +1280,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1041
1280
  859520964
1042
1281
  ```
1043
1282
  """
1283
+ is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
1284
+
1285
+ if is_loaded_in_4bit:
1286
+ if is_bitsandbytes_available():
1287
+ import bitsandbytes as bnb
1288
+ else:
1289
+ raise ValueError(
1290
+ "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
1291
+ " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
1292
+ )
1044
1293
 
1045
1294
  if exclude_embeddings:
1046
1295
  embedding_param_names = [
1047
- f"{name}.weight"
1048
- for name, module_type in self.named_modules()
1049
- if isinstance(module_type, torch.nn.Embedding)
1296
+ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
1050
1297
  ]
1051
- non_embedding_parameters = [
1298
+ total_parameters = [
1052
1299
  parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1053
1300
  ]
1054
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1055
1301
  else:
1056
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1302
+ total_parameters = list(self.parameters())
1303
+
1304
+ total_numel = []
1305
+
1306
+ for param in total_parameters:
1307
+ if param.requires_grad or not only_trainable:
1308
+ # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
1309
+ # used for the 4bit quantization (uint8 tensors are stored)
1310
+ if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
1311
+ if hasattr(param, "element_size"):
1312
+ num_bytes = param.element_size()
1313
+ elif hasattr(param, "quant_storage"):
1314
+ num_bytes = param.quant_storage.itemsize
1315
+ else:
1316
+ num_bytes = 1
1317
+ total_numel.append(param.numel() * 2 * num_bytes)
1318
+ else:
1319
+ total_numel.append(param.numel())
1320
+
1321
+ return sum(total_numel)
1322
+
1323
+ def get_memory_footprint(self, return_buffers=True):
1324
+ r"""
1325
+ Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
1326
+ Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
1327
+ PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
1328
+
1329
+ Arguments:
1330
+ return_buffers (`bool`, *optional*, defaults to `True`):
1331
+ Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
1332
+ are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
1333
+ norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
1334
+ """
1335
+ mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
1336
+ if return_buffers:
1337
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
1338
+ mem = mem + mem_bufs
1339
+ return mem
1057
1340
 
1058
1341
  def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1059
1342
  deprecated_attention_block_paths = []