diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
+ import enum
16
17
  import fnmatch
17
18
  import importlib
18
19
  import inspect
@@ -44,39 +45,44 @@ from ..configuration_utils import ConfigMixin
44
45
  from ..models import AutoencoderKL
45
46
  from ..models.attention_processor import FusedAttnProcessor2_0
46
47
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
48
+ from ..quantizers.bitsandbytes.utils import _check_bnb_status
47
49
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48
50
  from ..utils import (
49
51
  CONFIG_NAME,
50
52
  DEPRECATED_REVISION_ARGS,
51
53
  BaseOutput,
52
54
  PushToHubMixin,
53
- deprecate,
54
55
  is_accelerate_available,
55
56
  is_accelerate_version,
56
57
  is_torch_npu_available,
57
58
  is_torch_version,
59
+ is_transformers_version,
58
60
  logging,
59
61
  numpy_to_pil,
60
62
  )
61
- from ..utils.hub_utils import load_or_create_model_card, populate_model_card
63
+ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
62
64
  from ..utils.torch_utils import is_compiled_module
63
65
 
64
66
 
65
67
  if is_torch_npu_available():
66
68
  import torch_npu # noqa: F401
67
69
 
68
-
69
70
  from .pipeline_loading_utils import (
70
71
  ALL_IMPORTABLE_CLASSES,
71
72
  CONNECTED_PIPES_KEYS,
72
73
  CUSTOM_PIPELINE_FILE_NAME,
73
74
  LOADABLE_CLASSES,
74
75
  _fetch_class_library_tuple,
76
+ _get_custom_components_and_folders,
75
77
  _get_custom_pipeline_class,
76
78
  _get_final_device_map,
79
+ _get_ignore_patterns,
77
80
  _get_pipeline_class,
81
+ _identify_model_variants,
82
+ _maybe_raise_warning_for_inpainting,
83
+ _resolve_custom_pipeline_and_cls,
78
84
  _unwrap_model,
79
- is_safetensors_compatible,
85
+ _update_init_kwargs_with_connected_pipeline,
80
86
  load_sub_model,
81
87
  maybe_raise_or_warn,
82
88
  variant_compatible_siblings,
@@ -185,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
185
191
  save_directory: Union[str, os.PathLike],
186
192
  safe_serialization: bool = True,
187
193
  variant: Optional[str] = None,
194
+ max_shard_size: Optional[Union[int, str]] = None,
188
195
  push_to_hub: bool = False,
189
196
  **kwargs,
190
197
  ):
@@ -200,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
200
207
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
201
208
  variant (`str`, *optional*):
202
209
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
210
+ max_shard_size (`int` or `str`, defaults to `None`):
211
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
212
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
213
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
214
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
215
+ This is to establish a common default size for this argument across different libraries in the Hugging
216
+ Face ecosystem (`transformers`, and `accelerate`, for example).
203
217
  push_to_hub (`bool`, *optional*, defaults to `False`):
204
218
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
205
219
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -215,7 +229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
215
229
 
216
230
  if push_to_hub:
217
231
  commit_message = kwargs.pop("commit_message", None)
218
- private = kwargs.pop("private", False)
232
+ private = kwargs.pop("private", None)
219
233
  create_pr = kwargs.pop("create_pr", False)
220
234
  token = kwargs.pop("token", None)
221
235
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -274,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
274
288
  save_method_signature = inspect.signature(save_method)
275
289
  save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
276
290
  save_method_accept_variant = "variant" in save_method_signature.parameters
291
+ save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
277
292
 
278
293
  save_kwargs = {}
279
294
  if save_method_accept_safe:
280
295
  save_kwargs["safe_serialization"] = safe_serialization
281
296
  if save_method_accept_variant:
282
297
  save_kwargs["variant"] = variant
298
+ if save_method_accept_max_shard_size and max_shard_size is not None:
299
+ # max_shard_size is expected to not be None in ModelMixin
300
+ save_kwargs["max_shard_size"] = max_shard_size
283
301
 
284
302
  save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
285
303
 
@@ -370,6 +388,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
370
388
  )
371
389
 
372
390
  device = device or device_arg
391
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
373
392
 
374
393
  # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
375
394
  def module_is_sequentially_offloaded(module):
@@ -392,10 +411,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
392
411
  pipeline_is_sequentially_offloaded = any(
393
412
  module_is_sequentially_offloaded(module) for _, module in self.components.items()
394
413
  )
395
- if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
396
- raise ValueError(
397
- "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
398
- )
414
+ if device and torch.device(device).type == "cuda":
415
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
416
+ raise ValueError(
417
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
418
+ )
419
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
420
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
421
+ raise ValueError(
422
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
423
+ )
399
424
 
400
425
  is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
401
426
  if is_pipeline_device_mapped:
@@ -416,18 +441,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
416
441
 
417
442
  is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
418
443
  for module in modules:
419
- is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
444
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
420
445
 
421
- if is_loaded_in_8bit and dtype is not None:
446
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
422
447
  logger.warning(
423
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
448
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
424
449
  )
425
450
 
426
- if is_loaded_in_8bit and device is not None:
451
+ if is_loaded_in_8bit_bnb and device is not None:
427
452
  logger.warning(
428
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
453
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
429
454
  )
430
- else:
455
+
456
+ # This can happen for `transformer` models. CPU placement was added in
457
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
458
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
459
+ module.to(device=device)
460
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
431
461
  module.to(device, dtype)
432
462
 
433
463
  if (
@@ -622,6 +652,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
622
652
  >>> pipeline.scheduler = scheduler
623
653
  ```
624
654
  """
655
+ # Copy the kwargs to re-use during loading connected pipeline.
656
+ kwargs_copied = kwargs.copy()
657
+
625
658
  cache_dir = kwargs.pop("cache_dir", None)
626
659
  force_download = kwargs.pop("force_download", False)
627
660
  proxies = kwargs.pop("proxies", None)
@@ -716,39 +749,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
716
749
  else:
717
750
  cached_folder = pretrained_model_name_or_path
718
751
 
752
+ # The variant filenames can have the legacy sharding checkpoint format that we check and throw
753
+ # a warning if detected.
754
+ if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
755
+ warn_msg = (
756
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
757
+ "Please check your files carefully:\n\n"
758
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
759
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
760
+ "If you find any files in the deprecated format:\n"
761
+ "1. Remove all existing checkpoint files for this variant.\n"
762
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
763
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
764
+ )
765
+ logger.warning(warn_msg)
766
+
719
767
  config_dict = cls.load_config(cached_folder)
720
768
 
721
769
  # pop out "_ignore_files" as it is only needed for download
722
770
  config_dict.pop("_ignore_files", None)
723
771
 
724
772
  # 2. Define which model components should load variants
725
- # We retrieve the information by matching whether variant
726
- # model checkpoints exist in the subfolders
727
- model_variants = {}
728
- if variant is not None:
729
- for folder in os.listdir(cached_folder):
730
- folder_path = os.path.join(cached_folder, folder)
731
- is_folder = os.path.isdir(folder_path) and folder in config_dict
732
- variant_exists = is_folder and any(
733
- p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
734
- )
735
- if variant_exists:
736
- model_variants[folder] = variant
773
+ # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
774
+ # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
775
+ # with variant being `"fp16"`.
776
+ model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
777
+ if len(model_variants) == 0 and variant is not None:
778
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
779
+ raise ValueError(error_message)
737
780
 
738
781
  # 3. Load the pipeline class, if using custom module then load it from the hub
739
782
  # if we load from explicit class, let's use it
740
- custom_class_name = None
741
- if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
742
- custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
743
- elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
744
- os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
745
- ):
746
- custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
747
- custom_class_name = config_dict["_class_name"][1]
748
-
783
+ custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
784
+ folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
785
+ )
749
786
  pipeline_class = _get_pipeline_class(
750
787
  cls,
751
- config_dict,
788
+ config=config_dict,
752
789
  load_connected_pipeline=load_connected_pipeline,
753
790
  custom_pipeline=custom_pipeline,
754
791
  class_name=custom_class_name,
@@ -760,23 +797,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
760
797
  raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
761
798
 
762
799
  # DEPRECATED: To be removed in 1.0.0
763
- if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
764
- version.parse(config_dict["_diffusers_version"]).base_version
765
- ) <= version.parse("0.5.1"):
766
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
767
-
768
- pipeline_class = StableDiffusionInpaintPipelineLegacy
769
-
770
- deprecation_message = (
771
- "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
772
- f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
773
- " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
774
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
775
- f" checkpoint {pretrained_model_name_or_path} to the format of"
776
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
777
- " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
778
- )
779
- deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
800
+ # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
801
+ # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
802
+ _maybe_raise_warning_for_inpainting(
803
+ pipeline_class=pipeline_class,
804
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
805
+ config=config_dict,
806
+ )
780
807
 
781
808
  # 4. Define expected modules given pipeline signature
782
809
  # and define non-None initialized modules (=`init_kwargs`)
@@ -785,9 +812,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
785
812
  # in this case they are already instantiated in `kwargs`
786
813
  # extract them here
787
814
  expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815
+ expected_types = pipeline_class._get_signature_types()
788
816
  passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
789
817
  passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
790
-
791
818
  init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
792
819
 
793
820
  # define init kwargs and make sure that optional component modules are filtered out
@@ -808,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
808
835
 
809
836
  init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
810
837
 
838
+ for key in init_dict.keys():
839
+ if key not in passed_class_obj:
840
+ continue
841
+ if "scheduler" in key:
842
+ continue
843
+
844
+ class_obj = passed_class_obj[key]
845
+ _expected_class_types = []
846
+ for expected_type in expected_types[key]:
847
+ if isinstance(expected_type, enum.EnumMeta):
848
+ _expected_class_types.extend(expected_type.__members__.keys())
849
+ else:
850
+ _expected_class_types.append(expected_type.__name__)
851
+
852
+ _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853
+ if not _is_valid_type:
854
+ logger.warning(
855
+ f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856
+ )
857
+
811
858
  # Special case: safety_checker must be loaded separately when using `from_flax`
812
859
  if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
813
860
  raise NotImplementedError(
@@ -847,6 +894,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
847
894
  # 7. Load each module in the pipeline
848
895
  current_device_map = None
849
896
  for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
897
+ # 7.1 device_map shenanigans
850
898
  if final_device_map is not None and len(final_device_map) > 0:
851
899
  component_device = final_device_map.get(name, None)
852
900
  if component_device is not None:
@@ -854,15 +902,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
854
902
  else:
855
903
  current_device_map = None
856
904
 
857
- # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
905
+ # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
858
906
  class_name = class_name[4:] if class_name.startswith("Flax") else class_name
859
907
 
860
- # 7.2 Define all importable classes
908
+ # 7.3 Define all importable classes
861
909
  is_pipeline_module = hasattr(pipelines, library_name)
862
910
  importable_classes = ALL_IMPORTABLE_CLASSES
863
911
  loaded_sub_model = None
864
912
 
865
- # 7.3 Use passed sub model or load class_name from library_name
913
+ # 7.4 Use passed sub model or load class_name from library_name
866
914
  if name in passed_class_obj:
867
915
  # if the model is in a pipeline module, then we load it from the pipeline
868
916
  # check that passed_class_obj has correct parent class
@@ -893,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
893
941
  variant=variant,
894
942
  low_cpu_mem_usage=low_cpu_mem_usage,
895
943
  cached_folder=cached_folder,
944
+ use_safetensors=use_safetensors,
896
945
  )
897
946
  logger.info(
898
947
  f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -900,56 +949,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
900
949
 
901
950
  init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
902
951
 
952
+ # 8. Handle connected pipelines.
903
953
  if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
904
- modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
905
- connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
906
- load_kwargs = {
907
- "cache_dir": cache_dir,
908
- "force_download": force_download,
909
- "proxies": proxies,
910
- "local_files_only": local_files_only,
911
- "token": token,
912
- "revision": revision,
913
- "torch_dtype": torch_dtype,
914
- "custom_pipeline": custom_pipeline,
915
- "custom_revision": custom_revision,
916
- "provider": provider,
917
- "sess_options": sess_options,
918
- "device_map": device_map,
919
- "max_memory": max_memory,
920
- "offload_folder": offload_folder,
921
- "offload_state_dict": offload_state_dict,
922
- "low_cpu_mem_usage": low_cpu_mem_usage,
923
- "variant": variant,
924
- "use_safetensors": use_safetensors,
925
- }
926
-
927
- def get_connected_passed_kwargs(prefix):
928
- connected_passed_class_obj = {
929
- k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
930
- }
931
- connected_passed_pipe_kwargs = {
932
- k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
933
- }
934
-
935
- connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
936
- return connected_passed_kwargs
937
-
938
- connected_pipes = {
939
- prefix: DiffusionPipeline.from_pretrained(
940
- repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
941
- )
942
- for prefix, repo_id in connected_pipes.items()
943
- if repo_id is not None
944
- }
945
-
946
- for prefix, connected_pipe in connected_pipes.items():
947
- # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
948
- init_kwargs.update(
949
- {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
950
- )
954
+ init_kwargs = _update_init_kwargs_with_connected_pipeline(
955
+ init_kwargs=init_kwargs,
956
+ passed_pipe_kwargs=passed_pipe_kwargs,
957
+ passed_class_objs=passed_class_obj,
958
+ folder=cached_folder,
959
+ **kwargs_copied,
960
+ )
951
961
 
952
- # 8. Potentially add passed objects if expected
962
+ # 9. Potentially add passed objects if expected
953
963
  missing_modules = set(expected_modules) - set(init_kwargs.keys())
954
964
  passed_modules = list(passed_class_obj.keys())
955
965
  optional_modules = pipeline_class._optional_components
@@ -1065,9 +1075,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1065
1075
  hook = None
1066
1076
  for model_str in self.model_cpu_offload_seq.split("->"):
1067
1077
  model = all_model_components.pop(model_str, None)
1078
+
1068
1079
  if not isinstance(model, torch.nn.Module):
1069
1080
  continue
1070
1081
 
1082
+ # This is because the model would already be placed on a CUDA device.
1083
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1084
+ if is_loaded_in_8bit_bnb:
1085
+ logger.info(
1086
+ f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
1087
+ )
1088
+ continue
1089
+
1071
1090
  _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
1072
1091
  self._all_hooks.append(hook)
1073
1092
 
@@ -1295,6 +1314,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1295
1314
  model_info_call_error = e # save error to reraise it if model is not cached locally
1296
1315
 
1297
1316
  if not local_files_only:
1317
+ filenames = {sibling.rfilename for sibling in info.siblings}
1318
+ if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
1319
+ warn_msg = (
1320
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
1321
+ "Please check your files carefully:\n\n"
1322
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
1323
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
1324
+ "If you find any files in the deprecated format:\n"
1325
+ "1. Remove all existing checkpoint files for this variant.\n"
1326
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
1327
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
1328
+ )
1329
+ logger.warning(warn_msg)
1330
+
1331
+ model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1332
+
1298
1333
  config_file = hf_hub_download(
1299
1334
  pretrained_model_name,
1300
1335
  cls.config_name,
@@ -1308,52 +1343,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1308
1343
  config_dict = cls._dict_from_json_file(config_file)
1309
1344
  ignore_filenames = config_dict.pop("_ignore_files", [])
1310
1345
 
1311
- # retrieve all folder_names that contain relevant files
1312
- folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
1313
-
1314
- filenames = {sibling.rfilename for sibling in info.siblings}
1315
- model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1316
-
1317
- diffusers_module = importlib.import_module(__name__.split(".")[0])
1318
- pipelines = getattr(diffusers_module, "pipelines")
1319
-
1320
- # optionally create a custom component <> custom file mapping
1321
- custom_components = {}
1322
- for component in folder_names:
1323
- module_candidate = config_dict[component][0]
1324
-
1325
- if module_candidate is None or not isinstance(module_candidate, str):
1326
- continue
1327
-
1328
- # We compute candidate file path on the Hub. Do not use `os.path.join`.
1329
- candidate_file = f"{component}/{module_candidate}.py"
1330
-
1331
- if candidate_file in filenames:
1332
- custom_components[component] = module_candidate
1333
- elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
1334
- raise ValueError(
1335
- f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
1336
- )
1337
-
1338
- if len(variant_filenames) == 0 and variant is not None:
1339
- deprecation_message = (
1340
- f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1341
- f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1342
- "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1343
- "modeling files is deprecated."
1344
- )
1345
- deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
1346
-
1347
1346
  # remove ignored filenames
1348
1347
  model_filenames = set(model_filenames) - set(ignore_filenames)
1349
1348
  variant_filenames = set(variant_filenames) - set(ignore_filenames)
1350
1349
 
1351
- # if the whole pipeline is cached we don't have to ping the Hub
1352
1350
  if revision in DEPRECATED_REVISION_ARGS and version.parse(
1353
1351
  version.parse(__version__).base_version
1354
1352
  ) >= version.parse("0.22.0"):
1355
1353
  warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1356
1354
 
1355
+ custom_components, folder_names = _get_custom_components_and_folders(
1356
+ pretrained_model_name, config_dict, filenames, variant_filenames, variant
1357
+ )
1357
1358
  model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1358
1359
 
1359
1360
  custom_class_name = None
@@ -1413,49 +1414,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1413
1414
  expected_components, _ = cls._get_signature_keys(pipeline_class)
1414
1415
  passed_components = [k for k in expected_components if k in kwargs]
1415
1416
 
1416
- if (
1417
- use_safetensors
1418
- and not allow_pickle
1419
- and not is_safetensors_compatible(
1420
- model_filenames, variant=variant, passed_components=passed_components
1421
- )
1422
- ):
1423
- raise EnvironmentError(
1424
- f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
1425
- )
1426
- if from_flax:
1427
- ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1428
- elif use_safetensors and is_safetensors_compatible(
1429
- model_filenames, variant=variant, passed_components=passed_components
1430
- ):
1431
- ignore_patterns = ["*.bin", "*.msgpack"]
1432
-
1433
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1434
- if not use_onnx:
1435
- ignore_patterns += ["*.onnx", "*.pb"]
1436
-
1437
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
1438
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1439
- if (
1440
- len(safetensors_variant_filenames) > 0
1441
- and safetensors_model_filenames != safetensors_variant_filenames
1442
- ):
1443
- logger.warning(
1444
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1445
- )
1446
- else:
1447
- ignore_patterns = ["*.safetensors", "*.msgpack"]
1448
-
1449
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1450
- if not use_onnx:
1451
- ignore_patterns += ["*.onnx", "*.pb"]
1452
-
1453
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1454
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1455
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1456
- logger.warning(
1457
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1458
- )
1417
+ # retrieve all patterns that should not be downloaded and error out when needed
1418
+ ignore_patterns = _get_ignore_patterns(
1419
+ passed_components,
1420
+ model_folder_names,
1421
+ model_filenames,
1422
+ variant_filenames,
1423
+ use_safetensors,
1424
+ from_flax,
1425
+ allow_pickle,
1426
+ use_onnx,
1427
+ pipeline_class._is_onnx,
1428
+ variant,
1429
+ )
1459
1430
 
1460
1431
  # Don't download any objects that are passed
1461
1432
  allow_patterns = [
@@ -1609,6 +1580,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1609
1580
  """
1610
1581
  return numpy_to_pil(images)
1611
1582
 
1583
+ @torch.compiler.disable
1612
1584
  def progress_bar(self, iterable=None, total=None):
1613
1585
  if not hasattr(self, "_progress_bar_config"):
1614
1586
  self._progress_bar_config = {}
@@ -178,7 +178,7 @@ def retrieve_timesteps(
178
178
  sigmas: Optional[List[float]] = None,
179
179
  **kwargs,
180
180
  ):
181
- """
181
+ r"""
182
182
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
183
183
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
184
184
 
@@ -338,13 +338,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
338
338
  if device is None:
339
339
  device = self._execution_device
340
340
 
341
- if prompt is not None and isinstance(prompt, str):
342
- batch_size = 1
343
- elif prompt is not None and isinstance(prompt, list):
344
- batch_size = len(prompt)
345
- else:
346
- batch_size = prompt_embeds.shape[0]
347
-
348
341
  # See Section 3.1. of the paper.
349
342
  max_length = max_sequence_length
350
343
 
@@ -389,12 +382,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
389
382
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
390
383
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
384
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
393
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
385
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
386
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
394
387
 
395
388
  # get unconditional embeddings for classifier free guidance
396
389
  if do_classifier_free_guidance and negative_prompt_embeds is None:
397
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
390
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
398
391
  uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
399
392
  max_length = prompt_embeds.shape[1]
400
393
  uncond_input = self.tokenizer(
@@ -421,10 +414,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
421
414
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422
415
 
423
416
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
417
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
425
418
 
426
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
419
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
420
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
428
421
  else:
429
422
  negative_prompt_embeds = None
430
423
  negative_prompt_attention_mask = None