diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -88,9 +88,21 @@ EXAMPLE_DOC_STRING = """
88
88
 
89
89
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
90
90
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
91
- """
92
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
93
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
91
+ r"""
92
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
93
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
94
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
95
+
96
+ Args:
97
+ noise_cfg (`torch.Tensor`):
98
+ The predicted noise tensor for the guided diffusion process.
99
+ noise_pred_text (`torch.Tensor`):
100
+ The predicted noise tensor for the text-guided diffusion process.
101
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
102
+ A rescale factor applied to the noise predictions.
103
+
104
+ Returns:
105
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
94
106
  """
95
107
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
96
108
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -110,7 +122,7 @@ def retrieve_timesteps(
110
122
  sigmas: Optional[List[float]] = None,
111
123
  **kwargs,
112
124
  ):
113
- """
125
+ r"""
114
126
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
115
127
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
116
128
 
@@ -1237,8 +1249,8 @@ class StableDiffusionXLPAGPipeline(
1237
1249
 
1238
1250
  # perform guidance
1239
1251
  if self.do_perturbed_attention_guidance:
1240
- noise_pred = self._apply_perturbed_attention_guidance(
1241
- noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1252
+ noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1253
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
1242
1254
  )
1243
1255
 
1244
1256
  elif self.do_classifier_free_guidance:
@@ -92,9 +92,21 @@ EXAMPLE_DOC_STRING = """
92
92
 
93
93
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
94
94
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
95
- """
96
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
97
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
95
+ r"""
96
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
97
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
98
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
99
+
100
+ Args:
101
+ noise_cfg (`torch.Tensor`):
102
+ The predicted noise tensor for the guided diffusion process.
103
+ noise_pred_text (`torch.Tensor`):
104
+ The predicted noise tensor for the text-guided diffusion process.
105
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
106
+ A rescale factor applied to the noise predictions.
107
+
108
+ Returns:
109
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
98
110
  """
99
111
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
100
112
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -128,7 +140,7 @@ def retrieve_timesteps(
128
140
  sigmas: Optional[List[float]] = None,
129
141
  **kwargs,
130
142
  ):
131
- """
143
+ r"""
132
144
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
133
145
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
134
146
 
@@ -648,14 +660,16 @@ class StableDiffusionXLPAGImg2ImgPipeline(
648
660
  if denoising_start is None:
649
661
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
650
662
  t_start = max(num_inference_steps - init_timestep, 0)
651
- else:
652
- t_start = 0
653
663
 
654
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
664
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
665
+ if hasattr(self.scheduler, "set_begin_index"):
666
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
655
667
 
656
- # Strength is irrelevant if we directly request a timestep to start at;
657
- # that is, strength is determined by the denoising_start instead.
658
- if denoising_start is not None:
668
+ return timesteps, num_inference_steps - t_start
669
+
670
+ else:
671
+ # Strength is irrelevant if we directly request a timestep to start at;
672
+ # that is, strength is determined by the denoising_start instead.
659
673
  discrete_timestep_cutoff = int(
660
674
  round(
661
675
  self.scheduler.config.num_train_timesteps
@@ -663,7 +677,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
663
677
  )
664
678
  )
665
679
 
666
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
680
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
667
681
  if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
668
682
  # if the scheduler is a 2nd order scheduler we might have to do +1
669
683
  # because `num_inference_steps` might be even given that every timestep
@@ -674,11 +688,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
674
688
  num_inference_steps = num_inference_steps + 1
675
689
 
676
690
  # because t_n+1 >= t_n, we slice the timesteps starting from the end
677
- timesteps = timesteps[-num_inference_steps:]
691
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
692
+ timesteps = self.scheduler.timesteps[t_start:]
693
+ if hasattr(self.scheduler, "set_begin_index"):
694
+ self.scheduler.set_begin_index(t_start)
678
695
  return timesteps, num_inference_steps
679
696
 
680
- return timesteps, num_inference_steps - t_start
681
-
682
697
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
683
698
  def prepare_latents(
684
699
  self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
@@ -1434,8 +1449,8 @@ class StableDiffusionXLPAGImg2ImgPipeline(
1434
1449
 
1435
1450
  # perform guidance
1436
1451
  if self.do_perturbed_attention_guidance:
1437
- noise_pred = self._apply_perturbed_attention_guidance(
1438
- noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1452
+ noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1453
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
1439
1454
  )
1440
1455
  elif self.do_classifier_free_guidance:
1441
1456
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -105,9 +105,21 @@ EXAMPLE_DOC_STRING = """
105
105
 
106
106
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
107
107
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
108
- """
109
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
110
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
108
+ r"""
109
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
110
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
111
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
112
+
113
+ Args:
114
+ noise_cfg (`torch.Tensor`):
115
+ The predicted noise tensor for the guided diffusion process.
116
+ noise_pred_text (`torch.Tensor`):
117
+ The predicted noise tensor for the text-guided diffusion process.
118
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
119
+ A rescale factor applied to the noise predictions.
120
+
121
+ Returns:
122
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
111
123
  """
112
124
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
113
125
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -141,7 +153,7 @@ def retrieve_timesteps(
141
153
  sigmas: Optional[List[float]] = None,
142
154
  **kwargs,
143
155
  ):
144
- """
156
+ r"""
145
157
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
146
158
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
147
159
 
@@ -897,14 +909,16 @@ class StableDiffusionXLPAGInpaintPipeline(
897
909
  if denoising_start is None:
898
910
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
899
911
  t_start = max(num_inference_steps - init_timestep, 0)
900
- else:
901
- t_start = 0
902
912
 
903
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
913
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
914
+ if hasattr(self.scheduler, "set_begin_index"):
915
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
904
916
 
905
- # Strength is irrelevant if we directly request a timestep to start at;
906
- # that is, strength is determined by the denoising_start instead.
907
- if denoising_start is not None:
917
+ return timesteps, num_inference_steps - t_start
918
+
919
+ else:
920
+ # Strength is irrelevant if we directly request a timestep to start at;
921
+ # that is, strength is determined by the denoising_start instead.
908
922
  discrete_timestep_cutoff = int(
909
923
  round(
910
924
  self.scheduler.config.num_train_timesteps
@@ -912,7 +926,7 @@ class StableDiffusionXLPAGInpaintPipeline(
912
926
  )
913
927
  )
914
928
 
915
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
929
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
916
930
  if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
917
931
  # if the scheduler is a 2nd order scheduler we might have to do +1
918
932
  # because `num_inference_steps` might be even given that every timestep
@@ -923,11 +937,12 @@ class StableDiffusionXLPAGInpaintPipeline(
923
937
  num_inference_steps = num_inference_steps + 1
924
938
 
925
939
  # because t_n+1 >= t_n, we slice the timesteps starting from the end
926
- timesteps = timesteps[-num_inference_steps:]
940
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
941
+ timesteps = self.scheduler.timesteps[t_start:]
942
+ if hasattr(self.scheduler, "set_begin_index"):
943
+ self.scheduler.set_begin_index(t_start)
927
944
  return timesteps, num_inference_steps
928
945
 
929
- return timesteps, num_inference_steps - t_start
930
-
931
946
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
932
947
  def _get_add_time_ids(
933
948
  self,
@@ -1471,6 +1486,14 @@ class StableDiffusionXLPAGInpaintPipeline(
1471
1486
  generator,
1472
1487
  self.do_classifier_free_guidance,
1473
1488
  )
1489
+ if self.do_perturbed_attention_guidance:
1490
+ if self.do_classifier_free_guidance:
1491
+ mask, _ = mask.chunk(2)
1492
+ masked_image_latents, _ = masked_image_latents.chunk(2)
1493
+ mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
1494
+ masked_image_latents = self._prepare_perturbed_attention_guidance(
1495
+ masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
1496
+ )
1474
1497
 
1475
1498
  # 8. Check that sizes of mask, masked image and latents match
1476
1499
  if num_channels_unet == 9:
@@ -1638,8 +1661,8 @@ class StableDiffusionXLPAGInpaintPipeline(
1638
1661
 
1639
1662
  # perform guidance
1640
1663
  if self.do_perturbed_attention_guidance:
1641
- noise_pred = self._apply_perturbed_attention_guidance(
1642
- noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
1664
+ noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
1665
+ noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
1643
1666
  )
1644
1667
  elif self.do_classifier_free_guidance:
1645
1668
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -1659,10 +1682,10 @@ class StableDiffusionXLPAGInpaintPipeline(
1659
1682
 
1660
1683
  if num_channels_unet == 4:
1661
1684
  init_latents_proper = image_latents
1662
- if self.do_classifier_free_guidance:
1663
- init_mask, _ = mask.chunk(2)
1685
+ if self.do_perturbed_attention_guidance:
1686
+ init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
1664
1687
  else:
1665
- init_mask = mask
1688
+ init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
1666
1689
 
1667
1690
  if i < len(timesteps) - 1:
1668
1691
  noise_timestep = timesteps[i + 1]
@@ -824,6 +824,8 @@ class PIAPipeline(
824
824
  if self.do_classifier_free_guidance:
825
825
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
826
826
 
827
+ prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
828
+
827
829
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
828
830
  image_embeds = self.prepare_ip_adapter_image_embeds(
829
831
  ip_adapter_image,
@@ -180,7 +180,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
180
180
 
181
181
  if push_to_hub:
182
182
  commit_message = kwargs.pop("commit_message", None)
183
- private = kwargs.pop("private", False)
183
+ private = kwargs.pop("private", None)
184
184
  create_pr = kwargs.pop("create_pr", False)
185
185
  token = kwargs.pop("token", None)
186
186
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -22,7 +22,7 @@ from pathlib import Path
22
22
  from typing import Any, Dict, List, Optional, Union
23
23
 
24
24
  import torch
25
- from huggingface_hub import model_info
25
+ from huggingface_hub import ModelCard, model_info
26
26
  from huggingface_hub.utils import validate_hf_hub_args
27
27
  from packaging import version
28
28
 
@@ -33,6 +33,7 @@ from ..utils import (
33
33
  ONNX_WEIGHTS_NAME,
34
34
  SAFETENSORS_WEIGHTS_NAME,
35
35
  WEIGHTS_NAME,
36
+ deprecate,
36
37
  get_class_from_dynamic_module,
37
38
  is_accelerate_available,
38
39
  is_peft_available,
@@ -89,49 +90,50 @@ for library in LOADABLE_CLASSES:
89
90
  ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
90
91
 
91
92
 
92
- def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
93
+ def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
93
94
  """
94
95
  Checking for safetensors compatibility:
95
- - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
96
- files to know which safetensors files are needed.
97
- - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
96
+ - The model is safetensors compatible only if there is a safetensors file for each model component present in
97
+ filenames.
98
98
 
99
99
  Converting default pytorch serialized filenames to safetensors serialized filenames:
100
100
  - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
101
101
  - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
102
102
  extension is replaced with ".safetensors"
103
103
  """
104
- pt_filenames = []
105
-
106
- sf_filenames = set()
107
-
108
104
  passed_components = passed_components or []
105
+ if folder_names is not None:
106
+ filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
109
107
 
108
+ # extract all components of the pipeline and their associated files
109
+ components = {}
110
110
  for filename in filenames:
111
- _, extension = os.path.splitext(filename)
111
+ if not len(filename.split("/")) == 2:
112
+ continue
112
113
 
113
- if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
114
+ component, component_filename = filename.split("/")
115
+ if component in passed_components:
114
116
  continue
115
117
 
116
- if extension == ".bin":
117
- pt_filenames.append(os.path.normpath(filename))
118
- elif extension == ".safetensors":
119
- sf_filenames.add(os.path.normpath(filename))
118
+ components.setdefault(component, [])
119
+ components[component].append(component_filename)
120
120
 
121
- for filename in pt_filenames:
122
- # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
123
- path, filename = os.path.split(filename)
124
- filename, extension = os.path.splitext(filename)
121
+ # If there are no component folders check the main directory for safetensors files
122
+ if not components:
123
+ return any(".safetensors" in filename for filename in filenames)
125
124
 
126
- if filename.startswith("pytorch_model"):
127
- filename = filename.replace("pytorch_model", "model")
128
- else:
129
- filename = filename
125
+ # iterate over all files of a component
126
+ # check if safetensor files exist for that component
127
+ # if variant is provided check if the variant of the safetensors exists
128
+ for component, component_filenames in components.items():
129
+ matches = []
130
+ for component_filename in component_filenames:
131
+ filename, extension = os.path.splitext(component_filename)
130
132
 
131
- expected_sf_filename = os.path.normpath(os.path.join(path, filename))
132
- expected_sf_filename = f"{expected_sf_filename}.safetensors"
133
- if expected_sf_filename not in sf_filenames:
134
- logger.warning(f"{expected_sf_filename} not found")
133
+ match_exists = extension == ".safetensors"
134
+ matches.append(match_exists)
135
+
136
+ if not any(matches):
135
137
  return False
136
138
 
137
139
  return True
@@ -196,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
196
198
  variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
197
199
  return variant_filename
198
200
 
199
- for f in non_variant_filenames:
200
- variant_filename = convert_to_variant(f)
201
- if variant_filename not in usable_filenames:
202
- usable_filenames.add(f)
201
+ def find_component(filename):
202
+ if not len(filename.split("/")) == 2:
203
+ return
204
+ component = filename.split("/")[0]
205
+ return component
206
+
207
+ def has_sharded_variant(component, variant, variant_filenames):
208
+ # If component exists check for sharded variant index filename
209
+ # If component doesn't exist check main dir for sharded variant index filename
210
+ component = component + "/" if component else ""
211
+ variant_index_re = re.compile(
212
+ rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
+ )
214
+ return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215
+
216
+ for filename in non_variant_filenames:
217
+ if convert_to_variant(filename) in variant_filenames:
218
+ continue
219
+
220
+ component = find_component(filename)
221
+ # If a sharded variant exists skip adding to allowed patterns
222
+ if has_sharded_variant(component, variant, variant_filenames):
223
+ continue
224
+
225
+ usable_filenames.add(filename)
203
226
 
204
227
  return usable_filenames, variant_filenames
205
228
 
@@ -603,6 +626,7 @@ def load_sub_model(
603
626
  variant: str,
604
627
  low_cpu_mem_usage: bool,
605
628
  cached_folder: Union[str, os.PathLike],
629
+ use_safetensors: bool,
606
630
  ):
607
631
  """Helper method to load the module `name` from `library_name` and `class_name`"""
608
632
 
@@ -672,6 +696,7 @@ def load_sub_model(
672
696
  loading_kwargs["offload_folder"] = offload_folder
673
697
  loading_kwargs["offload_state_dict"] = offload_state_dict
674
698
  loading_kwargs["variant"] = model_variants.pop(name, None)
699
+ loading_kwargs["use_safetensors"] = use_safetensors
675
700
 
676
701
  if from_flax:
677
702
  loading_kwargs["from_flax"] = True
@@ -749,3 +774,197 @@ def _fetch_class_library_tuple(module):
749
774
  class_name = not_compiled_module.__class__.__name__
750
775
 
751
776
  return (library, class_name)
777
+
778
+
779
+ def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
780
+ model_variants = {}
781
+ if variant is not None:
782
+ for sub_folder in os.listdir(folder):
783
+ folder_path = os.path.join(folder, sub_folder)
784
+ is_folder = os.path.isdir(folder_path) and sub_folder in config
785
+ variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
786
+ if variant_exists:
787
+ model_variants[sub_folder] = variant
788
+ return model_variants
789
+
790
+
791
+ def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
792
+ custom_class_name = None
793
+ if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
794
+ custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
795
+ elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
796
+ os.path.join(folder, f"{config['_class_name'][0]}.py")
797
+ ):
798
+ custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
799
+ custom_class_name = config["_class_name"][1]
800
+
801
+ return custom_pipeline, custom_class_name
802
+
803
+
804
+ def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
805
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
806
+ version.parse(config["_diffusers_version"]).base_version
807
+ ) <= version.parse("0.5.1"):
808
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
809
+
810
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
811
+
812
+ deprecation_message = (
813
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
814
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
815
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
816
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
817
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
818
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
819
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
820
+ )
821
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
822
+
823
+
824
+ def _update_init_kwargs_with_connected_pipeline(
825
+ init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
826
+ ) -> dict:
827
+ from .pipeline_utils import DiffusionPipeline
828
+
829
+ modelcard = ModelCard.load(os.path.join(folder, "README.md"))
830
+ connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
831
+
832
+ # We don't scheduler argument to match the existing logic:
833
+ # https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
834
+ pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
835
+ if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
836
+ for k in pipeline_loading_kwargs:
837
+ if "scheduler" in k:
838
+ _ = pipeline_loading_kwargs_cp.pop(k)
839
+
840
+ def get_connected_passed_kwargs(prefix):
841
+ connected_passed_class_obj = {
842
+ k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
843
+ }
844
+ connected_passed_pipe_kwargs = {
845
+ k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
846
+ }
847
+
848
+ connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
849
+ return connected_passed_kwargs
850
+
851
+ connected_pipes = {
852
+ prefix: DiffusionPipeline.from_pretrained(
853
+ repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
854
+ )
855
+ for prefix, repo_id in connected_pipes.items()
856
+ if repo_id is not None
857
+ }
858
+
859
+ for prefix, connected_pipe in connected_pipes.items():
860
+ # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
861
+ init_kwargs.update(
862
+ {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
863
+ )
864
+
865
+ return init_kwargs
866
+
867
+
868
+ def _get_custom_components_and_folders(
869
+ pretrained_model_name: str,
870
+ config_dict: Dict[str, Any],
871
+ filenames: Optional[List[str]] = None,
872
+ variant_filenames: Optional[List[str]] = None,
873
+ variant: Optional[str] = None,
874
+ ):
875
+ config_dict = config_dict.copy()
876
+
877
+ # retrieve all folder_names that contain relevant files
878
+ folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
879
+
880
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
881
+ pipelines = getattr(diffusers_module, "pipelines")
882
+
883
+ # optionally create a custom component <> custom file mapping
884
+ custom_components = {}
885
+ for component in folder_names:
886
+ module_candidate = config_dict[component][0]
887
+
888
+ if module_candidate is None or not isinstance(module_candidate, str):
889
+ continue
890
+
891
+ # We compute candidate file path on the Hub. Do not use `os.path.join`.
892
+ candidate_file = f"{component}/{module_candidate}.py"
893
+
894
+ if candidate_file in filenames:
895
+ custom_components[component] = module_candidate
896
+ elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
897
+ raise ValueError(
898
+ f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
899
+ )
900
+
901
+ if len(variant_filenames) == 0 and variant is not None:
902
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
903
+ raise ValueError(error_message)
904
+
905
+ return custom_components, folder_names
906
+
907
+
908
+ def _get_ignore_patterns(
909
+ passed_components,
910
+ model_folder_names: List[str],
911
+ model_filenames: List[str],
912
+ variant_filenames: List[str],
913
+ use_safetensors: bool,
914
+ from_flax: bool,
915
+ allow_pickle: bool,
916
+ use_onnx: bool,
917
+ is_onnx: bool,
918
+ variant: Optional[str] = None,
919
+ ) -> List[str]:
920
+ if (
921
+ use_safetensors
922
+ and not allow_pickle
923
+ and not is_safetensors_compatible(
924
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
925
+ )
926
+ ):
927
+ raise EnvironmentError(
928
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
929
+ )
930
+
931
+ if from_flax:
932
+ ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
933
+
934
+ elif use_safetensors and is_safetensors_compatible(
935
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
936
+ ):
937
+ ignore_patterns = ["*.bin", "*.msgpack"]
938
+
939
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
940
+ if not use_onnx:
941
+ ignore_patterns += ["*.onnx", "*.pb"]
942
+
943
+ safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
944
+ safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
945
+ if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
946
+ logger.warning(
947
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
948
+ f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
949
+ f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
950
+ f"expected, please check your folder structure."
951
+ )
952
+
953
+ else:
954
+ ignore_patterns = ["*.safetensors", "*.msgpack"]
955
+
956
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
957
+ if not use_onnx:
958
+ ignore_patterns += ["*.onnx", "*.pb"]
959
+
960
+ bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
961
+ bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
962
+ if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
963
+ logger.warning(
964
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
965
+ f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
966
+ f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
967
+ f"your folder structure."
968
+ )
969
+
970
+ return ignore_patterns