diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,306 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Adapted from
16
+ https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
17
+ """
18
+
19
+ import inspect
20
+ from inspect import signature
21
+ from typing import Union
22
+
23
+ from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
24
+ from ..quantization_config import QuantizationMethod
25
+
26
+
27
+ if is_torch_available():
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ if is_bitsandbytes_available():
32
+ import bitsandbytes as bnb
33
+
34
+ if is_accelerate_available():
35
+ import accelerate
36
+ from accelerate import init_empty_weights
37
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ def _replace_with_bnb_linear(
43
+ model,
44
+ modules_to_not_convert=None,
45
+ current_key_name=None,
46
+ quantization_config=None,
47
+ has_been_replaced=False,
48
+ ):
49
+ """
50
+ Private method that wraps the recursion for module replacement.
51
+
52
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
53
+ """
54
+ for name, module in model.named_children():
55
+ if current_key_name is None:
56
+ current_key_name = []
57
+ current_key_name.append(name)
58
+
59
+ if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
60
+ # Check if the current key is not in the `modules_to_not_convert`
61
+ current_key_name_str = ".".join(current_key_name)
62
+ if not any(
63
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
64
+ ):
65
+ with init_empty_weights():
66
+ in_features = module.in_features
67
+ out_features = module.out_features
68
+
69
+ if quantization_config.quantization_method() == "llm_int8":
70
+ model._modules[name] = bnb.nn.Linear8bitLt(
71
+ in_features,
72
+ out_features,
73
+ module.bias is not None,
74
+ has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
75
+ threshold=quantization_config.llm_int8_threshold,
76
+ )
77
+ has_been_replaced = True
78
+ else:
79
+ if (
80
+ quantization_config.llm_int8_skip_modules is not None
81
+ and name in quantization_config.llm_int8_skip_modules
82
+ ):
83
+ pass
84
+ else:
85
+ extra_kwargs = (
86
+ {"quant_storage": quantization_config.bnb_4bit_quant_storage}
87
+ if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
88
+ else {}
89
+ )
90
+ model._modules[name] = bnb.nn.Linear4bit(
91
+ in_features,
92
+ out_features,
93
+ module.bias is not None,
94
+ quantization_config.bnb_4bit_compute_dtype,
95
+ compress_statistics=quantization_config.bnb_4bit_use_double_quant,
96
+ quant_type=quantization_config.bnb_4bit_quant_type,
97
+ **extra_kwargs,
98
+ )
99
+ has_been_replaced = True
100
+ # Store the module class in case we need to transpose the weight later
101
+ model._modules[name].source_cls = type(module)
102
+ # Force requires grad to False to avoid unexpected errors
103
+ model._modules[name].requires_grad_(False)
104
+ if len(list(module.children())) > 0:
105
+ _, has_been_replaced = _replace_with_bnb_linear(
106
+ module,
107
+ modules_to_not_convert,
108
+ current_key_name,
109
+ quantization_config,
110
+ has_been_replaced=has_been_replaced,
111
+ )
112
+ # Remove the last key for recursion
113
+ current_key_name.pop(-1)
114
+ return model, has_been_replaced
115
+
116
+
117
+ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
118
+ """
119
+ Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or
120
+ `bnb.nn.Linear4bit` using the `bitsandbytes` library.
121
+
122
+ References:
123
+ * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
124
+ Scale](https://arxiv.org/abs/2208.07339)
125
+ * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
126
+
127
+ Parameters:
128
+ model (`torch.nn.Module`):
129
+ Input model or `torch.nn.Module` as the function is run recursively.
130
+ modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`):
131
+ Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in
132
+ full precision for numerical stability reasons.
133
+ current_key_name (`List[`str`]`, *optional*):
134
+ An array to track the current key of the recursion. This is used to check whether the current key (part of
135
+ it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
136
+ `disk`).
137
+ quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
138
+ To configure and manage settings related to quantization, a technique used to compress neural network
139
+ models by reducing the precision of the weights and activations, thus making models more efficient in terms
140
+ of both storage and computation.
141
+ """
142
+ model, has_been_replaced = _replace_with_bnb_linear(
143
+ model, modules_to_not_convert, current_key_name, quantization_config
144
+ )
145
+
146
+ if not has_been_replaced:
147
+ logger.warning(
148
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
149
+ " Please double check your model architecture, or submit an issue on github if you think this is"
150
+ " a bug."
151
+ )
152
+
153
+ return model
154
+
155
+
156
+ # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
157
+ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
158
+ """
159
+ Helper function to dequantize 4bit or 8bit bnb weights.
160
+
161
+ If the weight is not a bnb quantized weight, it will be returned as is.
162
+ """
163
+ if not isinstance(weight, torch.nn.Parameter):
164
+ raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
165
+
166
+ cls_name = weight.__class__.__name__
167
+ if cls_name not in ("Params4bit", "Int8Params"):
168
+ return weight
169
+
170
+ if cls_name == "Params4bit":
171
+ output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
172
+ logger.warning_once(
173
+ f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
174
+ )
175
+ return output_tensor
176
+
177
+ if state.SCB is None:
178
+ state.SCB = weight.SCB
179
+
180
+ im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
181
+ im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
182
+ im, Sim = bnb.functional.transform(im, "col32")
183
+ if state.CxB is None:
184
+ state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
185
+ out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
186
+ return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
187
+
188
+
189
+ def _create_accelerate_new_hook(old_hook):
190
+ r"""
191
+ Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
192
+ https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
193
+ some changes
194
+ """
195
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
196
+ old_hook_attr = old_hook.__dict__
197
+ filtered_old_hook_attr = {}
198
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
199
+ for k in old_hook_attr.keys():
200
+ if k in old_hook_init_signature.parameters:
201
+ filtered_old_hook_attr[k] = old_hook_attr[k]
202
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
203
+ return new_hook
204
+
205
+
206
+ def _dequantize_and_replace(
207
+ model,
208
+ modules_to_not_convert=None,
209
+ current_key_name=None,
210
+ quantization_config=None,
211
+ has_been_replaced=False,
212
+ ):
213
+ """
214
+ Converts a quantized model into its dequantized original version. The newly converted model will have some
215
+ performance drop compared to the original model before quantization - use it only for specific usecases such as
216
+ QLoRA adapters merging.
217
+
218
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
219
+ """
220
+ quant_method = quantization_config.quantization_method()
221
+
222
+ target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
223
+
224
+ for name, module in model.named_children():
225
+ if current_key_name is None:
226
+ current_key_name = []
227
+ current_key_name.append(name)
228
+
229
+ if isinstance(module, target_cls) and name not in modules_to_not_convert:
230
+ # Check if the current key is not in the `modules_to_not_convert`
231
+ current_key_name_str = ".".join(current_key_name)
232
+
233
+ if not any(
234
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
235
+ ):
236
+ bias = getattr(module, "bias", None)
237
+
238
+ device = module.weight.device
239
+ with init_empty_weights():
240
+ new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
241
+
242
+ if quant_method == "llm_int8":
243
+ state = module.state
244
+ else:
245
+ state = None
246
+
247
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
248
+
249
+ if bias is not None:
250
+ new_module.bias = bias
251
+
252
+ # Create a new hook and attach it in case we use accelerate
253
+ if hasattr(module, "_hf_hook"):
254
+ old_hook = module._hf_hook
255
+ new_hook = _create_accelerate_new_hook(old_hook)
256
+
257
+ remove_hook_from_module(module)
258
+ add_hook_to_module(new_module, new_hook)
259
+
260
+ new_module.to(device)
261
+ model._modules[name] = new_module
262
+ has_been_replaced = True
263
+ if len(list(module.children())) > 0:
264
+ _, has_been_replaced = _dequantize_and_replace(
265
+ module,
266
+ modules_to_not_convert,
267
+ current_key_name,
268
+ quantization_config,
269
+ has_been_replaced=has_been_replaced,
270
+ )
271
+ # Remove the last key for recursion
272
+ current_key_name.pop(-1)
273
+ return model, has_been_replaced
274
+
275
+
276
+ def dequantize_and_replace(
277
+ model,
278
+ modules_to_not_convert=None,
279
+ quantization_config=None,
280
+ ):
281
+ model, has_been_replaced = _dequantize_and_replace(
282
+ model,
283
+ modules_to_not_convert=modules_to_not_convert,
284
+ quantization_config=quantization_config,
285
+ )
286
+
287
+ if not has_been_replaced:
288
+ logger.warning(
289
+ "For some reason the model has not been properly dequantized. You might see unexpected behavior."
290
+ )
291
+
292
+ return model
293
+
294
+
295
+ def _check_bnb_status(module) -> Union[bool, bool]:
296
+ is_loaded_in_4bit_bnb = (
297
+ hasattr(module, "is_loaded_in_4bit")
298
+ and module.is_loaded_in_4bit
299
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
300
+ )
301
+ is_loaded_in_8bit_bnb = (
302
+ hasattr(module, "is_loaded_in_8bit")
303
+ and module.is_loaded_in_8bit
304
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
305
+ )
306
+ return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb
@@ -0,0 +1 @@
1
+ from .gguf_quantizer import GGUFQuantizer
@@ -0,0 +1,159 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
+
3
+ from ..base import DiffusersQuantizer
4
+
5
+
6
+ if TYPE_CHECKING:
7
+ from ...models.modeling_utils import ModelMixin
8
+
9
+
10
+ from ...utils import (
11
+ get_module_from_name,
12
+ is_accelerate_available,
13
+ is_accelerate_version,
14
+ is_gguf_available,
15
+ is_gguf_version,
16
+ is_torch_available,
17
+ logging,
18
+ )
19
+
20
+
21
+ if is_torch_available() and is_gguf_available():
22
+ import torch
23
+
24
+ from .utils import (
25
+ GGML_QUANT_SIZES,
26
+ GGUFParameter,
27
+ _dequantize_gguf_and_restore_linear,
28
+ _quant_shape_from_byte_shape,
29
+ _replace_with_gguf_linear,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class GGUFQuantizer(DiffusersQuantizer):
37
+ use_keep_in_fp32_modules = True
38
+
39
+ def __init__(self, quantization_config, **kwargs):
40
+ super().__init__(quantization_config, **kwargs)
41
+
42
+ self.compute_dtype = quantization_config.compute_dtype
43
+ self.pre_quantized = quantization_config.pre_quantized
44
+ self.modules_to_not_convert = quantization_config.modules_to_not_convert
45
+
46
+ if not isinstance(self.modules_to_not_convert, list):
47
+ self.modules_to_not_convert = [self.modules_to_not_convert]
48
+
49
+ def validate_environment(self, *args, **kwargs):
50
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
51
+ raise ImportError(
52
+ "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
53
+ )
54
+ if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
55
+ raise ImportError(
56
+ "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
57
+ )
58
+
59
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
60
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
61
+ # need more space for buffers that are created during quantization
62
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
63
+ return max_memory
64
+
65
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
66
+ if target_dtype != torch.uint8:
67
+ logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
68
+ return torch.uint8
69
+
70
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
71
+ if torch_dtype is None:
72
+ torch_dtype = self.compute_dtype
73
+ return torch_dtype
74
+
75
+ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
76
+ loaded_param_shape = loaded_param.shape
77
+ current_param_shape = current_param.shape
78
+ quant_type = loaded_param.quant_type
79
+
80
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
81
+
82
+ inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
83
+ if inferred_shape != current_param_shape:
84
+ raise ValueError(
85
+ f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
86
+ )
87
+
88
+ return True
89
+
90
+ def check_if_quantized_param(
91
+ self,
92
+ model: "ModelMixin",
93
+ param_value: Union["GGUFParameter", "torch.Tensor"],
94
+ param_name: str,
95
+ state_dict: Dict[str, Any],
96
+ **kwargs,
97
+ ) -> bool:
98
+ if isinstance(param_value, GGUFParameter):
99
+ return True
100
+
101
+ return False
102
+
103
+ def create_quantized_param(
104
+ self,
105
+ model: "ModelMixin",
106
+ param_value: Union["GGUFParameter", "torch.Tensor"],
107
+ param_name: str,
108
+ target_device: "torch.device",
109
+ state_dict: Optional[Dict[str, Any]] = None,
110
+ unexpected_keys: Optional[List[str]] = None,
111
+ ):
112
+ module, tensor_name = get_module_from_name(model, param_name)
113
+ if tensor_name not in module._parameters and tensor_name not in module._buffers:
114
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
115
+
116
+ if tensor_name in module._parameters:
117
+ module._parameters[tensor_name] = param_value.to(target_device)
118
+ if tensor_name in module._buffers:
119
+ module._buffers[tensor_name] = param_value.to(target_device)
120
+
121
+ def _process_model_before_weight_loading(
122
+ self,
123
+ model: "ModelMixin",
124
+ device_map,
125
+ keep_in_fp32_modules: List[str] = [],
126
+ **kwargs,
127
+ ):
128
+ state_dict = kwargs.get("state_dict", None)
129
+
130
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
131
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
132
+
133
+ _replace_with_gguf_linear(
134
+ model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
135
+ )
136
+
137
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
138
+ return model
139
+
140
+ @property
141
+ def is_serializable(self):
142
+ return False
143
+
144
+ @property
145
+ def is_trainable(self) -> bool:
146
+ return False
147
+
148
+ def _dequantize(self, model):
149
+ is_model_on_cpu = model.device.type == "cpu"
150
+ if is_model_on_cpu:
151
+ logger.info(
152
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
153
+ )
154
+ model.to(torch.cuda.current_device())
155
+
156
+ model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
157
+ if is_model_on_cpu:
158
+ model.to("cpu")
159
+ return model