diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
14
+ from typing import Any, Dict, Optional, Union
15
15
 
16
16
  import torch
17
17
  from torch import nn
@@ -19,6 +19,7 @@ from torch import nn
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...utils import is_torch_version, logging
21
21
  from ..attention import BasicTransformerBlock
22
+ from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22
23
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
23
24
  from ..modeling_outputs import Transformer2DModelOutput
24
25
  from ..modeling_utils import ModelMixin
@@ -186,6 +187,106 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
186
187
  if hasattr(module, "gradient_checkpointing"):
187
188
  module.gradient_checkpointing = value
188
189
 
190
+ @property
191
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
192
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
193
+ r"""
194
+ Returns:
195
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
196
+ indexed by its weight name.
197
+ """
198
+ # set recursively
199
+ processors = {}
200
+
201
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
202
+ if hasattr(module, "get_processor"):
203
+ processors[f"{name}.processor"] = module.get_processor()
204
+
205
+ for sub_name, child in module.named_children():
206
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
207
+
208
+ return processors
209
+
210
+ for name, module in self.named_children():
211
+ fn_recursive_add_processors(name, module, processors)
212
+
213
+ return processors
214
+
215
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
216
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
217
+ r"""
218
+ Sets the attention processor to use to compute attention.
219
+
220
+ Parameters:
221
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
222
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
223
+ for **all** `Attention` layers.
224
+
225
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
226
+ processor. This is strongly recommended when setting trainable attention processors.
227
+
228
+ """
229
+ count = len(self.attn_processors.keys())
230
+
231
+ if isinstance(processor, dict) and len(processor) != count:
232
+ raise ValueError(
233
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
234
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
235
+ )
236
+
237
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
238
+ if hasattr(module, "set_processor"):
239
+ if not isinstance(processor, dict):
240
+ module.set_processor(processor)
241
+ else:
242
+ module.set_processor(processor.pop(f"{name}.processor"))
243
+
244
+ for sub_name, child in module.named_children():
245
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
246
+
247
+ for name, module in self.named_children():
248
+ fn_recursive_attn_processor(name, module, processor)
249
+
250
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251
+ def fuse_qkv_projections(self):
252
+ """
253
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
254
+ are fused. For cross-attention modules, key and value projection matrices are fused.
255
+
256
+ <Tip warning={true}>
257
+
258
+ This API is 🧪 experimental.
259
+
260
+ </Tip>
261
+ """
262
+ self.original_attn_processors = None
263
+
264
+ for _, attn_processor in self.attn_processors.items():
265
+ if "Added" in str(attn_processor.__class__.__name__):
266
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
267
+
268
+ self.original_attn_processors = self.attn_processors
269
+
270
+ for module in self.modules():
271
+ if isinstance(module, Attention):
272
+ module.fuse_projections(fuse=True)
273
+
274
+ self.set_attn_processor(FusedAttnProcessor2_0())
275
+
276
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
277
+ def unfuse_qkv_projections(self):
278
+ """Disables the fused QKV projection if enabled.
279
+
280
+ <Tip warning={true}>
281
+
282
+ This API is 🧪 experimental.
283
+
284
+ </Tip>
285
+
286
+ """
287
+ if self.original_attn_processors is not None:
288
+ self.set_attn_processor(self.original_attn_processors)
289
+
189
290
  def forward(
190
291
  self,
191
292
  hidden_states: torch.Tensor,
@@ -179,7 +179,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
179
179
 
180
180
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
181
181
  if hasattr(module, "get_processor"):
182
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
182
+ processors[f"{name}.processor"] = module.get_processor()
183
183
 
184
184
  for sub_name, child in module.named_children():
185
185
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -0,0 +1,458 @@
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...models.attention import FeedForward
25
+ from ...models.attention_processor import (
26
+ Attention,
27
+ AttentionProcessor,
28
+ StableAudioAttnProcessor2_0,
29
+ )
30
+ from ...models.modeling_utils import ModelMixin
31
+ from ...models.transformers.transformer_2d import Transformer2DModelOutput
32
+ from ...utils import is_torch_version, logging
33
+ from ...utils.torch_utils import maybe_allow_in_graph
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class StableAudioGaussianFourierProjection(nn.Module):
40
+ """Gaussian Fourier embeddings for noise levels."""
41
+
42
+ # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
43
+ def __init__(
44
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
45
+ ):
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
48
+ self.log = log
49
+ self.flip_sin_to_cos = flip_sin_to_cos
50
+
51
+ if set_W_to_weight:
52
+ # to delete later
53
+ del self.weight
54
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
55
+ self.weight = self.W
56
+ del self.W
57
+
58
+ def forward(self, x):
59
+ if self.log:
60
+ x = torch.log(x)
61
+
62
+ x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
63
+
64
+ if self.flip_sin_to_cos:
65
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
66
+ else:
67
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
68
+ return out
69
+
70
+
71
+ @maybe_allow_in_graph
72
+ class StableAudioDiTBlock(nn.Module):
73
+ r"""
74
+ Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
75
+ connection and QKNorm
76
+
77
+ Parameters:
78
+ dim (`int`): The number of channels in the input and output.
79
+ num_attention_heads (`int`): The number of heads to use for the query states.
80
+ num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
81
+ attention_head_dim (`int`): The number of channels in each head.
82
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
83
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
84
+ upcast_attention (`bool`, *optional*):
85
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ dim: int,
91
+ num_attention_heads: int,
92
+ num_key_value_attention_heads: int,
93
+ attention_head_dim: int,
94
+ dropout=0.0,
95
+ cross_attention_dim: Optional[int] = None,
96
+ upcast_attention: bool = False,
97
+ norm_eps: float = 1e-5,
98
+ ff_inner_dim: Optional[int] = None,
99
+ ):
100
+ super().__init__()
101
+ # Define 3 blocks. Each block has its own normalization layer.
102
+ # 1. Self-Attn
103
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
104
+ self.attn1 = Attention(
105
+ query_dim=dim,
106
+ heads=num_attention_heads,
107
+ dim_head=attention_head_dim,
108
+ dropout=dropout,
109
+ bias=False,
110
+ upcast_attention=upcast_attention,
111
+ out_bias=False,
112
+ processor=StableAudioAttnProcessor2_0(),
113
+ )
114
+
115
+ # 2. Cross-Attn
116
+ self.norm2 = nn.LayerNorm(dim, norm_eps, True)
117
+
118
+ self.attn2 = Attention(
119
+ query_dim=dim,
120
+ cross_attention_dim=cross_attention_dim,
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ kv_heads=num_key_value_attention_heads,
124
+ dropout=dropout,
125
+ bias=False,
126
+ upcast_attention=upcast_attention,
127
+ out_bias=False,
128
+ processor=StableAudioAttnProcessor2_0(),
129
+ ) # is self-attn if encoder_hidden_states is none
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, norm_eps, True)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn="swiglu",
137
+ final_dropout=False,
138
+ inner_dim=ff_inner_dim,
139
+ bias=True,
140
+ )
141
+
142
+ # let chunk size default to None
143
+ self._chunk_size = None
144
+ self._chunk_dim = 0
145
+
146
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
147
+ # Sets chunk feed-forward
148
+ self._chunk_size = chunk_size
149
+ self._chunk_dim = dim
150
+
151
+ def forward(
152
+ self,
153
+ hidden_states: torch.Tensor,
154
+ attention_mask: Optional[torch.Tensor] = None,
155
+ encoder_hidden_states: Optional[torch.Tensor] = None,
156
+ encoder_attention_mask: Optional[torch.Tensor] = None,
157
+ rotary_embedding: Optional[torch.FloatTensor] = None,
158
+ ) -> torch.Tensor:
159
+ # Notice that normalization is always applied before the real computation in the following blocks.
160
+ # 0. Self-Attention
161
+ norm_hidden_states = self.norm1(hidden_states)
162
+
163
+ attn_output = self.attn1(
164
+ norm_hidden_states,
165
+ attention_mask=attention_mask,
166
+ rotary_emb=rotary_embedding,
167
+ )
168
+
169
+ hidden_states = attn_output + hidden_states
170
+
171
+ # 2. Cross-Attention
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 3. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+ ff_output = self.ff(norm_hidden_states)
184
+
185
+ hidden_states = ff_output + hidden_states
186
+
187
+ return hidden_states
188
+
189
+
190
+ class StableAudioDiTModel(ModelMixin, ConfigMixin):
191
+ """
192
+ The Diffusion Transformer model introduced in Stable Audio.
193
+
194
+ Reference: https://github.com/Stability-AI/stable-audio-tools
195
+
196
+ Parameters:
197
+ sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
198
+ in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
199
+ num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
200
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
201
+ num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
202
+ num_key_value_attention_heads (`int`, *optional*, defaults to 12):
203
+ The number of heads to use for the key and value states.
204
+ out_channels (`int`, defaults to 64): Number of output channels.
205
+ cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
206
+ time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
207
+ global_states_input_dim ( `int`, *optional*, defaults to 1536):
208
+ Input dimension of the global hidden states projection.
209
+ cross_attention_input_dim ( `int`, *optional*, defaults to 768):
210
+ Input dimension of the cross-attention projection
211
+ """
212
+
213
+ _supports_gradient_checkpointing = True
214
+
215
+ @register_to_config
216
+ def __init__(
217
+ self,
218
+ sample_size: int = 1024,
219
+ in_channels: int = 64,
220
+ num_layers: int = 24,
221
+ attention_head_dim: int = 64,
222
+ num_attention_heads: int = 24,
223
+ num_key_value_attention_heads: int = 12,
224
+ out_channels: int = 64,
225
+ cross_attention_dim: int = 768,
226
+ time_proj_dim: int = 256,
227
+ global_states_input_dim: int = 1536,
228
+ cross_attention_input_dim: int = 768,
229
+ ):
230
+ super().__init__()
231
+ self.sample_size = sample_size
232
+ self.out_channels = out_channels
233
+ self.inner_dim = num_attention_heads * attention_head_dim
234
+
235
+ self.time_proj = StableAudioGaussianFourierProjection(
236
+ embedding_size=time_proj_dim // 2,
237
+ flip_sin_to_cos=True,
238
+ log=False,
239
+ set_W_to_weight=False,
240
+ )
241
+
242
+ self.timestep_proj = nn.Sequential(
243
+ nn.Linear(time_proj_dim, self.inner_dim, bias=True),
244
+ nn.SiLU(),
245
+ nn.Linear(self.inner_dim, self.inner_dim, bias=True),
246
+ )
247
+
248
+ self.global_proj = nn.Sequential(
249
+ nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
250
+ nn.SiLU(),
251
+ nn.Linear(self.inner_dim, self.inner_dim, bias=False),
252
+ )
253
+
254
+ self.cross_attention_proj = nn.Sequential(
255
+ nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
256
+ nn.SiLU(),
257
+ nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
258
+ )
259
+
260
+ self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
261
+ self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
262
+
263
+ self.transformer_blocks = nn.ModuleList(
264
+ [
265
+ StableAudioDiTBlock(
266
+ dim=self.inner_dim,
267
+ num_attention_heads=num_attention_heads,
268
+ num_key_value_attention_heads=num_key_value_attention_heads,
269
+ attention_head_dim=attention_head_dim,
270
+ cross_attention_dim=cross_attention_dim,
271
+ )
272
+ for i in range(num_layers)
273
+ ]
274
+ )
275
+
276
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
277
+ self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
278
+
279
+ self.gradient_checkpointing = False
280
+
281
+ @property
282
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
283
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
284
+ r"""
285
+ Returns:
286
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
287
+ indexed by its weight name.
288
+ """
289
+ # set recursively
290
+ processors = {}
291
+
292
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
293
+ if hasattr(module, "get_processor"):
294
+ processors[f"{name}.processor"] = module.get_processor()
295
+
296
+ for sub_name, child in module.named_children():
297
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
298
+
299
+ return processors
300
+
301
+ for name, module in self.named_children():
302
+ fn_recursive_add_processors(name, module, processors)
303
+
304
+ return processors
305
+
306
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
307
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
308
+ r"""
309
+ Sets the attention processor to use to compute attention.
310
+
311
+ Parameters:
312
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
313
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
314
+ for **all** `Attention` layers.
315
+
316
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
317
+ processor. This is strongly recommended when setting trainable attention processors.
318
+
319
+ """
320
+ count = len(self.attn_processors.keys())
321
+
322
+ if isinstance(processor, dict) and len(processor) != count:
323
+ raise ValueError(
324
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
325
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
326
+ )
327
+
328
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
329
+ if hasattr(module, "set_processor"):
330
+ if not isinstance(processor, dict):
331
+ module.set_processor(processor)
332
+ else:
333
+ module.set_processor(processor.pop(f"{name}.processor"))
334
+
335
+ for sub_name, child in module.named_children():
336
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
337
+
338
+ for name, module in self.named_children():
339
+ fn_recursive_attn_processor(name, module, processor)
340
+
341
+ # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
342
+ def set_default_attn_processor(self):
343
+ """
344
+ Disables custom attention processors and sets the default attention implementation.
345
+ """
346
+ self.set_attn_processor(StableAudioAttnProcessor2_0())
347
+
348
+ def _set_gradient_checkpointing(self, module, value=False):
349
+ if hasattr(module, "gradient_checkpointing"):
350
+ module.gradient_checkpointing = value
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.FloatTensor,
355
+ timestep: torch.LongTensor = None,
356
+ encoder_hidden_states: torch.FloatTensor = None,
357
+ global_hidden_states: torch.FloatTensor = None,
358
+ rotary_embedding: torch.FloatTensor = None,
359
+ return_dict: bool = True,
360
+ attention_mask: Optional[torch.LongTensor] = None,
361
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
362
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
363
+ """
364
+ The [`StableAudioDiTModel`] forward method.
365
+
366
+ Args:
367
+ hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
368
+ Input `hidden_states`.
369
+ timestep ( `torch.LongTensor`):
370
+ Used to indicate denoising step.
371
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
372
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
373
+ global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
374
+ Global embeddings that will be prepended to the hidden states.
375
+ rotary_embedding (`torch.Tensor`):
376
+ The rotary embeddings to apply on query and key tensors during attention calculation.
377
+ return_dict (`bool`, *optional*, defaults to `True`):
378
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
379
+ tuple.
380
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
381
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention
382
+ masks
383
+ for the two text encoders together. Mask values selected in `[0, 1]`:
384
+
385
+ - 1 for tokens that are **not masked**,
386
+ - 0 for tokens that are **masked**.
387
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
388
+ Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
389
+ the attention masks
390
+ for the two text encoders together. Mask values selected in `[0, 1]`:
391
+
392
+ - 1 for tokens that are **not masked**,
393
+ - 0 for tokens that are **masked**.
394
+ Returns:
395
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
396
+ `tuple` where the first element is the sample tensor.
397
+ """
398
+ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
399
+ global_hidden_states = self.global_proj(global_hidden_states)
400
+ time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype)))
401
+
402
+ global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
403
+
404
+ hidden_states = self.preprocess_conv(hidden_states) + hidden_states
405
+ # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
406
+ hidden_states = hidden_states.transpose(1, 2)
407
+
408
+ hidden_states = self.proj_in(hidden_states)
409
+
410
+ # prepend global states to hidden states
411
+ hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
412
+ if attention_mask is not None:
413
+ prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
414
+ attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
415
+
416
+ for block in self.transformer_blocks:
417
+ if self.training and self.gradient_checkpointing:
418
+
419
+ def create_custom_forward(module, return_dict=None):
420
+ def custom_forward(*inputs):
421
+ if return_dict is not None:
422
+ return module(*inputs, return_dict=return_dict)
423
+ else:
424
+ return module(*inputs)
425
+
426
+ return custom_forward
427
+
428
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
429
+ hidden_states = torch.utils.checkpoint.checkpoint(
430
+ create_custom_forward(block),
431
+ hidden_states,
432
+ attention_mask,
433
+ cross_attention_hidden_states,
434
+ encoder_attention_mask,
435
+ rotary_embedding,
436
+ **ckpt_kwargs,
437
+ )
438
+
439
+ else:
440
+ hidden_states = block(
441
+ hidden_states=hidden_states,
442
+ attention_mask=attention_mask,
443
+ encoder_hidden_states=cross_attention_hidden_states,
444
+ encoder_attention_mask=encoder_attention_mask,
445
+ rotary_embedding=rotary_embedding,
446
+ )
447
+
448
+ hidden_states = self.proj_out(hidden_states)
449
+
450
+ # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
451
+ # remove prepend length that has been added by global hidden states
452
+ hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
453
+ hidden_states = self.postprocess_conv(hidden_states) + hidden_states
454
+
455
+ if not return_dict:
456
+ return (hidden_states,)
457
+
458
+ return Transformer2DModelOutput(sample=hidden_states)