diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin):
30
30
  MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
31
  user-assigned weighting.
32
32
 
33
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
34
- implements for all the model (such as downloading or saving, etc.)
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
34
+ or saving.
35
35
 
36
- Parameters:
36
+ Args:
37
37
  adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
38
  A list of `T2IAdapter` model instances.
39
39
  """
@@ -77,11 +77,13 @@ class MultiAdapter(ModelMixin):
77
77
  r"""
78
78
  Args:
79
79
  xs (`torch.Tensor`):
80
- (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
81
- `channel` should equal to `num_adapter` * "number of channel of image".
80
+ A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
81
+ models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
82
+ `num_adapter` * number of channel per image.
83
+
82
84
  adapter_weights (`List[float]`, *optional*, defaults to None):
83
- List of floats representing the weight which will be multiply to each adapter's output before adding
84
- them together.
85
+ A list of floats representing the weights which will be multiplied by each adapter's output before
86
+ summing them together. If `None`, equal weights will be used for all adapters.
85
87
  """
86
88
  if adapter_weights is None:
87
89
  adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
@@ -109,24 +111,24 @@ class MultiAdapter(ModelMixin):
109
111
  variant: Optional[str] = None,
110
112
  ):
111
113
  """
112
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
114
+ Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
113
115
  `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
114
116
 
115
- Arguments:
117
+ Args:
116
118
  save_directory (`str` or `os.PathLike`):
117
- Directory to which to save. Will be created if it doesn't exist.
118
- is_main_process (`bool`, *optional*, defaults to `True`):
119
- Whether the process calling this is the main process or not. Useful when in distributed training like
120
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
121
- the main process to avoid race conditions.
119
+ The directory where the model will be saved. If the directory does not exist, it will be created.
120
+ is_main_process (`bool`, optional, defaults=True):
121
+ Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
122
+ TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
123
+ for the main process to avoid race conditions.
122
124
  save_function (`Callable`):
123
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
124
- need to replace `torch.save` by another method. Can be configured with the environment variable
125
- `DIFFUSERS_SAVE_MODE`.
126
- safe_serialization (`bool`, *optional*, defaults to `True`):
127
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
125
+ Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
126
+ `torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
127
+ variable.
128
+ safe_serialization (`bool`, optional, defaults=True):
129
+ If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
128
130
  variant (`str`, *optional*):
129
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
131
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
130
132
  """
131
133
  idx = 0
132
134
  model_path_to_save = save_directory
@@ -145,19 +147,17 @@ class MultiAdapter(ModelMixin):
145
147
  @classmethod
146
148
  def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
147
149
  r"""
148
- Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
150
+ Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
149
151
 
150
152
  The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
151
- the model, you should first set it back in training mode with `model.train()`.
153
+ the model, set it back to training mode using `model.train()`.
152
154
 
153
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
154
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
155
- task.
155
+ Warnings:
156
+ *Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
157
+ with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
158
+ from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
156
159
 
157
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
158
- weights are discarded.
159
-
160
- Parameters:
160
+ Args:
161
161
  pretrained_model_path (`os.PathLike`):
162
162
  A path to a *directory* containing model weights saved using
163
163
  [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
@@ -175,20 +175,20 @@ class MultiAdapter(ModelMixin):
175
175
  more information about each option see [designing a device
176
176
  map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
177
  max_memory (`Dict`, *optional*):
178
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
179
- GPU and the available CPU RAM if unset.
178
+ A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
179
+ available for each GPU and the available CPU RAM if unset.
180
180
  low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
181
  Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
182
  also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
183
  model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
184
  setting this argument to `True` will raise an error.
185
185
  variant (`str`, *optional*):
186
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
187
- ignored when using `from_flax`.
186
+ If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
187
+ be ignored when using `from_flax`.
188
188
  use_safetensors (`bool`, *optional*, defaults to `None`):
189
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
190
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
191
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
189
+ If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
190
+ installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
191
+ `safetensors` is not used.
192
192
  """
193
193
  idx = 0
194
194
  adapters = []
@@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin):
223
223
  and
224
224
  [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
225
 
226
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
227
- implements for all the model (such as downloading or saving, etc.)
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
227
+ downloading or saving.
228
228
 
229
- Parameters:
230
- in_channels (`int`, *optional*, defaults to 3):
231
- Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
232
- image as *control image*.
229
+ Args:
230
+ in_channels (`int`, *optional*, defaults to `3`):
231
+ The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
232
+ image.
233
233
  channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
- The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
235
- also determine the number of downsample blocks in the Adapter.
236
- num_res_blocks (`int`, *optional*, defaults to 2):
234
+ The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
235
+ determines the number of downsample blocks in the adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to `2`):
237
237
  Number of ResNet blocks in each downsample block.
238
- downscale_factor (`int`, *optional*, defaults to 8):
238
+ downscale_factor (`int`, *optional*, defaults to `8`):
239
239
  A factor that determines the total downscale factor of the Adapter.
240
240
  adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
- The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
241
+ Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
242
242
  """
243
243
 
244
244
  @register_to_config
@@ -393,7 +393,7 @@ class AdapterBlock(nn.Module):
393
393
  An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
394
  `FullAdapterXL` models.
395
395
 
396
- Parameters:
396
+ Args:
397
397
  in_channels (`int`):
398
398
  Number of channels of AdapterBlock's input.
399
399
  out_channels (`int`):
@@ -401,7 +401,7 @@ class AdapterBlock(nn.Module):
401
401
  num_res_blocks (`int`):
402
402
  Number of ResNet blocks in the AdapterBlock.
403
403
  down (`bool`, *optional*, defaults to `False`):
404
- Whether to perform downsampling on AdapterBlock's input.
404
+ If `True`, perform downsampling on AdapterBlock's input.
405
405
  """
406
406
 
407
407
  def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module):
440
440
  r"""
441
441
  An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
442
 
443
- Parameters:
443
+ Args:
444
444
  channels (`int`):
445
445
  Number of channels of AdapterResnetBlock's input and output.
446
446
  """
@@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module):
518
518
  A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
519
  `LightAdapter` model.
520
520
 
521
- Parameters:
521
+ Args:
522
522
  in_channels (`int`):
523
523
  Number of channels of LightAdapterBlock's input.
524
524
  out_channels (`int`):
@@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module):
526
526
  num_res_blocks (`int`):
527
527
  Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
528
  down (`bool`, *optional*, defaults to `False`):
529
- Whether to perform downsampling on LightAdapterBlock's input.
529
+ If `True`, perform downsampling on LightAdapterBlock's input.
530
530
  """
531
531
 
532
532
  def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
@@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module):
561
561
  A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
562
  architecture than `AdapterResnetBlock`.
563
563
 
564
- Parameters:
564
+ Args:
565
565
  channels (`int`):
566
566
  Number of channels of LightAdapterResnetBlock's input and output.
567
567
  """
@@ -19,10 +19,10 @@ from torch import nn
19
19
 
20
20
  from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
- from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
23
23
  from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
- from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
26
 
27
27
 
28
28
  logger = logging.get_logger(__name__)
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
100
100
  processing of `context` conditions.
101
101
  """
102
102
 
103
- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
104
112
  super().__init__()
105
113
 
114
+ self.use_dual_attention = use_dual_attention
106
115
  self.context_pre_only = context_pre_only
107
116
  context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
117
 
109
- self.norm1 = AdaLayerNormZero(dim)
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
110
122
 
111
123
  if context_norm_type == "ada_norm_continous":
112
124
  self.norm1_context = AdaLayerNormContinuous(
@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
118
130
  raise ValueError(
119
131
  f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
132
  )
133
+
121
134
  if hasattr(F, "scaled_dot_product_attention"):
122
135
  processor = JointAttnProcessor2_0()
123
136
  else:
124
137
  raise ValueError(
125
138
  "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
139
  )
140
+
127
141
  self.attn = Attention(
128
142
  query_dim=dim,
129
143
  cross_attention_dim=None,
@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
134
148
  context_pre_only=context_pre_only,
135
149
  bias=True,
136
150
  processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
137
153
  )
138
154
 
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
139
170
  self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
171
  self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
172
 
@@ -157,9 +188,19 @@ class JointTransformerBlock(nn.Module):
157
188
  self._chunk_dim = dim
158
189
 
159
190
  def forward(
160
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
191
+ self,
192
+ hidden_states: torch.FloatTensor,
193
+ encoder_hidden_states: torch.FloatTensor,
194
+ temb: torch.FloatTensor,
195
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
161
196
  ):
162
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
197
+ joint_attention_kwargs = joint_attention_kwargs or {}
198
+ if self.use_dual_attention:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
200
+ hidden_states, emb=temb
201
+ )
202
+ else:
203
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
204
 
164
205
  if self.context_pre_only:
165
206
  norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
@@ -170,13 +211,20 @@ class JointTransformerBlock(nn.Module):
170
211
 
171
212
  # Attention.
172
213
  attn_output, context_attn_output = self.attn(
173
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ **joint_attention_kwargs,
174
217
  )
175
218
 
176
219
  # Process attention outputs for the `hidden_states`.
177
220
  attn_output = gate_msa.unsqueeze(1) * attn_output
178
221
  hidden_states = hidden_states + attn_output
179
222
 
223
+ if self.use_dual_attention:
224
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
225
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
226
+ hidden_states = hidden_states + attn_output2
227
+
180
228
  norm_hidden_states = self.norm2(hidden_states)
181
229
  norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
230
  if self._chunk_size is not None:
@@ -972,15 +1020,32 @@ class FreeNoiseTransformerBlock(nn.Module):
972
1020
  return frame_indices
973
1021
 
974
1022
  def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975
- if weighting_scheme == "pyramid":
1023
+ if weighting_scheme == "flat":
1024
+ weights = [1.0] * num_frames
1025
+
1026
+ elif weighting_scheme == "pyramid":
976
1027
  if num_frames % 2 == 0:
977
1028
  # num_frames = 4 => [1, 2, 2, 1]
978
- weights = list(range(1, num_frames // 2 + 1))
1029
+ mid = num_frames // 2
1030
+ weights = list(range(1, mid + 1))
979
1031
  weights = weights + weights[::-1]
980
1032
  else:
981
1033
  # num_frames = 5 => [1, 2, 3, 2, 1]
982
- weights = list(range(1, num_frames // 2 + 1))
983
- weights = weights + [num_frames // 2 + 1] + weights[::-1]
1034
+ mid = (num_frames + 1) // 2
1035
+ weights = list(range(1, mid))
1036
+ weights = weights + [mid] + weights[::-1]
1037
+
1038
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1039
+ if num_frames % 2 == 0:
1040
+ # num_frames = 4 => [0.01, 2, 2, 1]
1041
+ mid = num_frames // 2
1042
+ weights = [0.01] * (mid - 1) + [mid]
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1046
+ mid = (num_frames + 1) // 2
1047
+ weights = [0.01] * mid
1048
+ weights = weights + list(range(mid, 0, -1))
984
1049
  else:
985
1050
  raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
986
1051
 
@@ -1087,8 +1152,26 @@ class FreeNoiseTransformerBlock(nn.Module):
1087
1152
  accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1088
1153
  num_times_accumulated[:, frame_start:frame_end] += weights
1089
1154
 
1090
- hidden_states = torch.where(
1091
- num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
+ # TODO(aryan): Maybe this could be done in a better way.
1156
+ #
1157
+ # Previously, this was:
1158
+ # hidden_states = torch.where(
1159
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1160
+ # )
1161
+ #
1162
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1163
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1164
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1165
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1166
+ hidden_states = torch.cat(
1167
+ [
1168
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1169
+ for accumulated_split, num_times_split in zip(
1170
+ accumulated_values.split(self.context_length, dim=1),
1171
+ num_times_accumulated.split(self.context_length, dim=1),
1172
+ )
1173
+ ],
1174
+ dim=1,
1092
1175
  ).to(dtype)
1093
1176
 
1094
1177
  # 3. Feed-forward
@@ -1146,6 +1229,8 @@ class FeedForward(nn.Module):
1146
1229
  act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1147
1230
  elif activation_fn == "swiglu":
1148
1231
  act_fn = SwiGLU(dim, inner_dim, bias=bias)
1232
+ elif activation_fn == "linear-silu":
1233
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
1149
1234
 
1150
1235
  self.net = nn.ModuleList([])
1151
1236
  # project in
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
216
216
  hidden_states = jax_memory_efficient_attention(
217
217
  query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218
218
  )
219
-
220
219
  hidden_states = hidden_states.transpose(1, 0, 2)
220
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
221
221
  else:
222
222
  # compute attentions
223
223
  if self.split_head_dim: