diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
158
158
  c_embed = self.cond_mapper(c)
159
159
  r_embed = self.gen_r_embedding(r)
160
160
 
161
- if self.training and self.gradient_checkpointing:
161
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
162
162
 
163
163
  def create_custom_forward(module):
164
164
  def custom_forward(*inputs):
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .auto import DiffusersAutoQuantizer
16
+ from .base import DiffusersQuantizer
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Adapted from
16
+ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
17
+ """
18
+
19
+ import warnings
20
+ from typing import Dict, Optional, Union
21
+
22
+ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
23
+ from .gguf import GGUFQuantizer
24
+ from .quantization_config import (
25
+ BitsAndBytesConfig,
26
+ GGUFQuantizationConfig,
27
+ QuantizationConfigMixin,
28
+ QuantizationMethod,
29
+ TorchAoConfig,
30
+ )
31
+ from .torchao import TorchAoHfQuantizer
32
+
33
+
34
+ AUTO_QUANTIZER_MAPPING = {
35
+ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
36
+ "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
37
+ "gguf": GGUFQuantizer,
38
+ "torchao": TorchAoHfQuantizer,
39
+ }
40
+
41
+ AUTO_QUANTIZATION_CONFIG_MAPPING = {
42
+ "bitsandbytes_4bit": BitsAndBytesConfig,
43
+ "bitsandbytes_8bit": BitsAndBytesConfig,
44
+ "gguf": GGUFQuantizationConfig,
45
+ "torchao": TorchAoConfig,
46
+ }
47
+
48
+
49
+ class DiffusersAutoQuantizer:
50
+ """
51
+ The auto diffusers quantizer class that takes care of automatically instantiating to the correct
52
+ `DiffusersQuantizer` given the `QuantizationConfig`.
53
+ """
54
+
55
+ @classmethod
56
+ def from_dict(cls, quantization_config_dict: Dict):
57
+ quant_method = quantization_config_dict.get("quant_method", None)
58
+ # We need a special care for bnb models to make sure everything is BC ..
59
+ if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
60
+ suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
61
+ quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
62
+ elif quant_method is None:
63
+ raise ValueError(
64
+ "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
65
+ )
66
+
67
+ if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
68
+ raise ValueError(
69
+ f"Unknown quantization type, got {quant_method} - supported types are:"
70
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
71
+ )
72
+
73
+ target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
74
+ return target_cls.from_dict(quantization_config_dict)
75
+
76
+ @classmethod
77
+ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
78
+ # Convert it to a QuantizationConfig if the q_config is a dict
79
+ if isinstance(quantization_config, dict):
80
+ quantization_config = cls.from_dict(quantization_config)
81
+
82
+ quant_method = quantization_config.quant_method
83
+
84
+ # Again, we need a special care for bnb as we have a single quantization config
85
+ # class for both 4-bit and 8-bit quantization
86
+ if quant_method == QuantizationMethod.BITS_AND_BYTES:
87
+ if quantization_config.load_in_8bit:
88
+ quant_method += "_8bit"
89
+ else:
90
+ quant_method += "_4bit"
91
+
92
+ if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
93
+ raise ValueError(
94
+ f"Unknown quantization type, got {quant_method} - supported types are:"
95
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
96
+ )
97
+
98
+ target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
99
+ return target_cls(quantization_config, **kwargs)
100
+
101
+ @classmethod
102
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
103
+ model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
104
+ if getattr(model_config, "quantization_config", None) is None:
105
+ raise ValueError(
106
+ f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
107
+ )
108
+ quantization_config_dict = model_config.quantization_config
109
+ quantization_config = cls.from_dict(quantization_config_dict)
110
+ # Update with potential kwargs that are passed through from_pretrained.
111
+ quantization_config.update(kwargs)
112
+
113
+ return cls.from_config(quantization_config)
114
+
115
+ @classmethod
116
+ def merge_quantization_configs(
117
+ cls,
118
+ quantization_config: Union[dict, QuantizationConfigMixin],
119
+ quantization_config_from_args: Optional[QuantizationConfigMixin],
120
+ ):
121
+ """
122
+ handles situations where both quantization_config from args and quantization_config from model config are
123
+ present.
124
+ """
125
+ if quantization_config_from_args is not None:
126
+ warning_msg = (
127
+ "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
128
+ " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
129
+ )
130
+ else:
131
+ warning_msg = ""
132
+
133
+ if isinstance(quantization_config, dict):
134
+ quantization_config = cls.from_dict(quantization_config)
135
+
136
+ if warning_msg != "":
137
+ warnings.warn(warning_msg)
138
+
139
+ return quantization_config
@@ -0,0 +1,233 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Adapted from
17
+ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py
18
+ """
19
+
20
+ from abc import ABC, abstractmethod
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
22
+
23
+ from ..utils import is_torch_available
24
+ from .quantization_config import QuantizationConfigMixin
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from ..models.modeling_utils import ModelMixin
29
+
30
+ if is_torch_available():
31
+ import torch
32
+
33
+
34
+ class DiffusersQuantizer(ABC):
35
+ """
36
+ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or
37
+ quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be
38
+ easily used outside the scope of that method yet.
39
+
40
+ Attributes
41
+ quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`):
42
+ The quantization config that defines the quantization parameters of your model that you want to quantize.
43
+ modules_to_not_convert (`List[str]`, *optional*):
44
+ The list of module names to not convert when quantizing the model.
45
+ required_packages (`List[str]`, *optional*):
46
+ The list of required pip packages to install prior to using the quantizer
47
+ requires_calibration (`bool`):
48
+ Whether the quantization method requires to calibrate the model before using it.
49
+ """
50
+
51
+ requires_calibration = False
52
+ required_packages = None
53
+
54
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
55
+ self.quantization_config = quantization_config
56
+
57
+ # -- Handle extra kwargs below --
58
+ self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
59
+ self.pre_quantized = kwargs.pop("pre_quantized", True)
60
+
61
+ if not self.pre_quantized and self.requires_calibration:
62
+ raise ValueError(
63
+ f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
64
+ f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
65
+ f"pass `pre_quantized=True` while knowing what you are doing."
66
+ )
67
+
68
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
69
+ """
70
+ Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to
71
+ override this method in case you want to make sure that behavior is preserved
72
+
73
+ Args:
74
+ torch_dtype (`torch.dtype`):
75
+ The input dtype that is passed in `from_pretrained`
76
+ """
77
+ return torch_dtype
78
+
79
+ def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
80
+ """
81
+ Override this method if you want to pass a override the existing device map with a new one. E.g. for
82
+ bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to
83
+ `"auto"``
84
+
85
+ Args:
86
+ device_map (`Union[dict, str]`, *optional*):
87
+ The device_map that is passed through the `from_pretrained` method.
88
+ """
89
+ return device_map
90
+
91
+ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
92
+ """
93
+ Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the
94
+ device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8`
95
+ and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.
96
+
97
+ Args:
98
+ torch_dtype (`torch.dtype`, *optional*):
99
+ The torch_dtype that is used to compute the device_map.
100
+ """
101
+ return torch_dtype
102
+
103
+ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
104
+ """
105
+ Override this method if you want to adjust the `missing_keys`.
106
+
107
+ Args:
108
+ missing_keys (`List[str]`, *optional*):
109
+ The list of missing keys in the checkpoint compared to the state dict of the model
110
+ """
111
+ return missing_keys
112
+
113
+ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
114
+ """
115
+ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
116
+ passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in
117
+ `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes
118
+ yet but this can change soon in the future.
119
+
120
+ Args:
121
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
122
+ The model to quantize
123
+ torch_dtype (`torch.dtype`):
124
+ The dtype passed in `from_pretrained` method.
125
+ """
126
+
127
+ return {
128
+ name: torch_dtype
129
+ for name, _ in model.named_parameters()
130
+ if any(m in name for m in self.modules_to_not_convert)
131
+ }
132
+
133
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
134
+ """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
135
+ return max_memory
136
+
137
+ def check_if_quantized_param(
138
+ self,
139
+ model: "ModelMixin",
140
+ param_value: "torch.Tensor",
141
+ param_name: str,
142
+ state_dict: Dict[str, Any],
143
+ **kwargs,
144
+ ) -> bool:
145
+ """
146
+ checks if a loaded state_dict component is part of quantized param + some validation; only defined for
147
+ quantization methods that require to create a new parameters for quantization.
148
+ """
149
+ return False
150
+
151
+ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
152
+ """
153
+ takes needed components from state_dict and creates quantized param.
154
+ """
155
+ return
156
+
157
+ def check_quantized_param_shape(self, *args, **kwargs):
158
+ """
159
+ checks if the quantized param has expected shape.
160
+ """
161
+ return True
162
+
163
+ def validate_environment(self, *args, **kwargs):
164
+ """
165
+ This method is used to potentially check for potential conflicts with arguments that are passed in
166
+ `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no
167
+ explicit check are needed, simply return nothing.
168
+ """
169
+ return
170
+
171
+ def preprocess_model(self, model: "ModelMixin", **kwargs):
172
+ """
173
+ Setting model attributes and/or converting model before weights loading. At this point the model should be
174
+ initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace
175
+ modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
176
+
177
+ Args:
178
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
179
+ The model to quantize
180
+ kwargs (`dict`, *optional*):
181
+ The keyword arguments that are passed along `_process_model_before_weight_loading`.
182
+ """
183
+ model.is_quantized = True
184
+ model.quantization_method = self.quantization_config.quant_method
185
+ return self._process_model_before_weight_loading(model, **kwargs)
186
+
187
+ def postprocess_model(self, model: "ModelMixin", **kwargs):
188
+ """
189
+ Post-process the model post weights loading. Make sure to override the abstract method
190
+ `_process_model_after_weight_loading`.
191
+
192
+ Args:
193
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
194
+ The model to quantize
195
+ kwargs (`dict`, *optional*):
196
+ The keyword arguments that are passed along `_process_model_after_weight_loading`.
197
+ """
198
+ return self._process_model_after_weight_loading(model, **kwargs)
199
+
200
+ def dequantize(self, model):
201
+ """
202
+ Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
203
+ not all quantization schemes support this.
204
+ """
205
+ model = self._dequantize(model)
206
+
207
+ # Delete quantizer and quantization config
208
+ del model.hf_quantizer
209
+
210
+ return model
211
+
212
+ def _dequantize(self, model):
213
+ raise NotImplementedError(
214
+ f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
215
+ )
216
+
217
+ @abstractmethod
218
+ def _process_model_before_weight_loading(self, model, **kwargs):
219
+ ...
220
+
221
+ @abstractmethod
222
+ def _process_model_after_weight_loading(self, model, **kwargs):
223
+ ...
224
+
225
+ @property
226
+ @abstractmethod
227
+ def is_serializable(self):
228
+ ...
229
+
230
+ @property
231
+ @abstractmethod
232
+ def is_trainable(self):
233
+ ...
@@ -0,0 +1,2 @@
1
+ from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
2
+ from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear