diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
29
29
  """Returns the positional encoding (same as Tensor2Tensor).
30
30
 
31
31
  Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
32
+ timesteps (`jnp.ndarray` of shape `(N,)`):
33
+ A 1-D array of N indices, one per batch element. These may be fractional.
34
+ embedding_dim (`int`):
35
+ The number of output channels.
36
+ freq_shift (`float`, *optional*, defaults to `1`):
37
+ Shift applied to the frequency scaling of the embeddings.
38
+ min_timescale (`float`, *optional*, defaults to `1`):
39
+ The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40
+ max_timescale (`float`, *optional*, defaults to `1.0e4`):
41
+ The largest time unit used in the sinusoidal calculation.
42
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43
+ Whether to flip the order of sinusoidal components to cosine first.
44
+ scale (`float`, *optional*, defaults to `1.0`):
45
+ A scaling factor applied to the positional embeddings.
46
+
37
47
  Returns:
38
48
  a Tensor of timing signals [N, num_channels]
39
49
  """
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
61
71
 
62
72
  Args:
63
73
  time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
74
+ Time step embedding dimension.
75
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
76
+ The data type for the embedding parameters.
67
77
  """
68
78
 
69
79
  time_embed_dim: int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
83
93
 
84
94
  Args:
85
95
  dim (`int`, *optional*, defaults to `32`):
86
- Time step embedding dimension
96
+ Time step embedding dimension.
97
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98
+ Whether to flip the sinusoidal function from sine to cosine.
99
+ freq_shift (`float`, *optional*, defaults to `1`):
100
+ Frequency shift applied to the sinusoidal embeddings.
87
101
  """
88
102
 
89
103
  dim: int = 32
@@ -17,6 +17,7 @@
17
17
  import importlib
18
18
  import inspect
19
19
  import os
20
+ from array import array
20
21
  from collections import OrderedDict
21
22
  from pathlib import Path
22
23
  from typing import List, Optional, Union
@@ -26,12 +27,16 @@ import torch
26
27
  from huggingface_hub.utils import EntryNotFoundError
27
28
 
28
29
  from ..utils import (
30
+ GGUF_FILE_EXTENSION,
29
31
  SAFE_WEIGHTS_INDEX_NAME,
30
32
  SAFETENSORS_FILE_EXTENSION,
31
33
  WEIGHTS_INDEX_NAME,
32
34
  _add_variant,
33
35
  _get_model_file,
36
+ deprecate,
34
37
  is_accelerate_available,
38
+ is_gguf_available,
39
+ is_torch_available,
35
40
  is_torch_version,
36
41
  logging,
37
42
  )
@@ -53,11 +58,36 @@ if is_accelerate_available():
53
58
 
54
59
 
55
60
  # Adapted from `transformers` (see modeling_utils.py)
56
- def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
61
+ def _determine_device_map(
62
+ model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
63
+ ):
57
64
  if isinstance(device_map, str):
65
+ special_dtypes = {}
66
+ if hf_quantizer is not None:
67
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
68
+ special_dtypes.update(
69
+ {
70
+ name: torch.float32
71
+ for name, _ in model.named_parameters()
72
+ if any(m in name for m in keep_in_fp32_modules)
73
+ }
74
+ )
75
+
76
+ target_dtype = torch_dtype
77
+ if hf_quantizer is not None:
78
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
79
+
58
80
  no_split_modules = model._get_no_split_modules(device_map)
59
81
  device_map_kwargs = {"no_split_module_classes": no_split_modules}
60
82
 
83
+ if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
84
+ device_map_kwargs["special_dtypes"] = special_dtypes
85
+ elif len(special_dtypes) > 0:
86
+ logger.warning(
87
+ "This model has some weights that should be kept in higher precision, you need to upgrade "
88
+ "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
89
+ )
90
+
61
91
  if device_map != "sequential":
62
92
  max_memory = get_balanced_memory(
63
93
  model,
@@ -69,8 +99,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
69
99
  else:
70
100
  max_memory = get_max_memory(max_memory)
71
101
 
102
+ if hf_quantizer is not None:
103
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
104
+
72
105
  device_map_kwargs["max_memory"] = max_memory
73
- device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
106
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
107
+
108
+ if hf_quantizer is not None:
109
+ hf_quantizer.validate_environment(device_map=device_map)
74
110
 
75
111
  return device_map
76
112
 
@@ -99,10 +135,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
99
135
  """
100
136
  Reads a checkpoint file, returning properly formatted errors if they arise.
101
137
  """
138
+ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
139
+ # when refactoring the _merge_sharded_checkpoints() method later.
140
+ if isinstance(checkpoint_file, dict):
141
+ return checkpoint_file
102
142
  try:
103
143
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
104
144
  if file_extension == SAFETENSORS_FILE_EXTENSION:
105
145
  return safetensors.torch.load_file(checkpoint_file, device="cpu")
146
+ elif file_extension == GGUF_FILE_EXTENSION:
147
+ return load_gguf_checkpoint(checkpoint_file)
106
148
  else:
107
149
  weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
108
150
  return torch.load(
@@ -136,29 +178,69 @@ def load_model_dict_into_meta(
136
178
  device: Optional[Union[str, torch.device]] = None,
137
179
  dtype: Optional[Union[str, torch.dtype]] = None,
138
180
  model_name_or_path: Optional[str] = None,
181
+ hf_quantizer=None,
182
+ keep_in_fp32_modules=None,
139
183
  ) -> List[str]:
140
- device = device or torch.device("cpu")
184
+ if device is not None and not isinstance(device, (str, torch.device)):
185
+ raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
186
+ if hf_quantizer is None:
187
+ device = device or torch.device("cpu")
141
188
  dtype = dtype or torch.float32
189
+ is_quantized = hf_quantizer is not None
142
190
 
143
191
  accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
144
-
145
- unexpected_keys = []
146
192
  empty_state_dict = model.state_dict()
193
+ unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
194
+
147
195
  for param_name, param in state_dict.items():
148
196
  if param_name not in empty_state_dict:
149
- unexpected_keys.append(param_name)
150
197
  continue
151
198
 
199
+ set_module_kwargs = {}
200
+ # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
201
+ # in int/uint/bool and not cast them.
202
+ # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
203
+ if torch.is_floating_point(param):
204
+ if (
205
+ keep_in_fp32_modules is not None
206
+ and any(
207
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
208
+ )
209
+ and dtype == torch.float16
210
+ ):
211
+ param = param.to(torch.float32)
212
+ if accepts_dtype:
213
+ set_module_kwargs["dtype"] = torch.float32
214
+ else:
215
+ param = param.to(dtype)
216
+ if accepts_dtype:
217
+ set_module_kwargs["dtype"] = dtype
218
+
219
+ # bnb params are flattened.
220
+ # gguf quants have a different shape based on the type of quantization applied
152
221
  if empty_state_dict[param_name].shape != param.shape:
153
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
154
- raise ValueError(
155
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
156
- )
157
-
158
- if accepts_dtype:
159
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
222
+ if (
223
+ is_quantized
224
+ and hf_quantizer.pre_quantized
225
+ and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
226
+ ):
227
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
228
+ else:
229
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
230
+ raise ValueError(
231
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
232
+ )
233
+
234
+ if is_quantized and (
235
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
236
+ ):
237
+ hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
160
238
  else:
161
- set_module_tensor_to_device(model, param_name, device, value=param)
239
+ if accepts_dtype:
240
+ set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
241
+ else:
242
+ set_module_tensor_to_device(model, param_name, device, value=param)
243
+
162
244
  return unexpected_keys
163
245
 
164
246
 
@@ -228,3 +310,171 @@ def _fetch_index_file(
228
310
  index_file = None
229
311
 
230
312
  return index_file
313
+
314
+
315
+ # Adapted from
316
+ # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
317
+ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
318
+ weight_map = sharded_metadata.get("weight_map", None)
319
+ if weight_map is None:
320
+ raise KeyError("'weight_map' key not found in the shard index file.")
321
+
322
+ # Collect all unique safetensors files from weight_map
323
+ files_to_load = set(weight_map.values())
324
+ is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
325
+ merged_state_dict = {}
326
+
327
+ # Load tensors from each unique file
328
+ for file_name in files_to_load:
329
+ part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
330
+ if not os.path.exists(part_file_path):
331
+ raise FileNotFoundError(f"Part file {file_name} not found.")
332
+
333
+ if is_safetensors:
334
+ with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
335
+ for tensor_key in f.keys():
336
+ if tensor_key in weight_map:
337
+ merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
338
+ else:
339
+ merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
340
+
341
+ return merged_state_dict
342
+
343
+
344
+ def _fetch_index_file_legacy(
345
+ is_local,
346
+ pretrained_model_name_or_path,
347
+ subfolder,
348
+ use_safetensors,
349
+ cache_dir,
350
+ variant,
351
+ force_download,
352
+ proxies,
353
+ local_files_only,
354
+ token,
355
+ revision,
356
+ user_agent,
357
+ commit_hash,
358
+ ):
359
+ if is_local:
360
+ index_file = Path(
361
+ pretrained_model_name_or_path,
362
+ subfolder or "",
363
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
364
+ ).as_posix()
365
+ splits = index_file.split(".")
366
+ split_index = -3 if ".cache" in index_file else -2
367
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
368
+ index_file = ".".join(splits)
369
+ if os.path.exists(index_file):
370
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
371
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
372
+ index_file = Path(index_file)
373
+ else:
374
+ index_file = None
375
+ else:
376
+ if variant is not None:
377
+ index_file_in_repo = Path(
378
+ subfolder or "",
379
+ SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
380
+ ).as_posix()
381
+ splits = index_file_in_repo.split(".")
382
+ split_index = -2
383
+ splits = splits[:-split_index] + [variant] + splits[-split_index:]
384
+ index_file_in_repo = ".".join(splits)
385
+ try:
386
+ index_file = _get_model_file(
387
+ pretrained_model_name_or_path,
388
+ weights_name=index_file_in_repo,
389
+ cache_dir=cache_dir,
390
+ force_download=force_download,
391
+ proxies=proxies,
392
+ local_files_only=local_files_only,
393
+ token=token,
394
+ revision=revision,
395
+ subfolder=None,
396
+ user_agent=user_agent,
397
+ commit_hash=commit_hash,
398
+ )
399
+ index_file = Path(index_file)
400
+ deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
401
+ deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
402
+ except (EntryNotFoundError, EnvironmentError):
403
+ index_file = None
404
+
405
+ return index_file
406
+
407
+
408
+ def _gguf_parse_value(_value, data_type):
409
+ if not isinstance(data_type, list):
410
+ data_type = [data_type]
411
+ if len(data_type) == 1:
412
+ data_type = data_type[0]
413
+ array_data_type = None
414
+ else:
415
+ if data_type[0] != 9:
416
+ raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
417
+ data_type, array_data_type = data_type
418
+
419
+ if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
420
+ _value = int(_value[0])
421
+ elif data_type in [6, 12]:
422
+ _value = float(_value[0])
423
+ elif data_type in [7]:
424
+ _value = bool(_value[0])
425
+ elif data_type in [8]:
426
+ _value = array("B", list(_value)).tobytes().decode()
427
+ elif data_type in [9]:
428
+ _value = _gguf_parse_value(_value, array_data_type)
429
+ return _value
430
+
431
+
432
+ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
433
+ """
434
+ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
435
+ attributes.
436
+
437
+ Args:
438
+ gguf_checkpoint_path (`str`):
439
+ The path the to GGUF file to load
440
+ return_tensors (`bool`, defaults to `True`):
441
+ Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
442
+ metadata in memory.
443
+ """
444
+
445
+ if is_gguf_available() and is_torch_available():
446
+ import gguf
447
+ from gguf import GGUFReader
448
+
449
+ from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
450
+ else:
451
+ logger.error(
452
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
453
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
454
+ )
455
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
456
+
457
+ reader = GGUFReader(gguf_checkpoint_path)
458
+
459
+ parsed_parameters = {}
460
+ for tensor in reader.tensors:
461
+ name = tensor.name
462
+ quant_type = tensor.tensor_type
463
+
464
+ # if the tensor is a torch supported dtype do not use GGUFParameter
465
+ is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
466
+ if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
467
+ _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
468
+ raise ValueError(
469
+ (
470
+ f"{name} has a quantization type: {str(quant_type)} which is unsupported."
471
+ "\n\nCurrently the following quantization types are supported: \n\n"
472
+ f"{_supported_quants_str}"
473
+ "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
474
+ )
475
+ )
476
+
477
+ weights = torch.from_numpy(tensor.data.copy())
478
+ parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
479
+
480
+ return parsed_parameters
@@ -530,7 +530,7 @@ class FlaxModelMixin(PushToHubMixin):
530
530
 
531
531
  if push_to_hub:
532
532
  commit_message = kwargs.pop("commit_message", None)
533
- private = kwargs.pop("private", False)
533
+ private = kwargs.pop("private", None)
534
534
  create_pr = kwargs.pop("create_pr", False)
535
535
  token = kwargs.pop("token", None)
536
536
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])