diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from .attention_processor import (
29
29
  AttentionProcessor,
30
30
  AttnAddedKVProcessor,
31
31
  AttnProcessor,
32
+ FusedAttnProcessor2_0,
32
33
  )
33
34
  from .controlnet import ControlNetConditioningEmbedding
34
35
  from .embeddings import TimestepEmbedding, Timesteps
@@ -114,6 +115,7 @@ def get_down_block_adapter(
114
115
  cross_attention_dim: Optional[int] = 1024,
115
116
  add_downsample: bool = True,
116
117
  upcast_attention: Optional[bool] = False,
118
+ use_linear_projection: Optional[bool] = True,
117
119
  ):
118
120
  num_layers = 2 # only support sd + sdxl
119
121
 
@@ -152,7 +154,7 @@ def get_down_block_adapter(
152
154
  in_channels=ctrl_out_channels,
153
155
  num_layers=transformer_layers_per_block[i],
154
156
  cross_attention_dim=cross_attention_dim,
155
- use_linear_projection=True,
157
+ use_linear_projection=use_linear_projection,
156
158
  upcast_attention=upcast_attention,
157
159
  norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
158
160
  )
@@ -200,6 +202,7 @@ def get_mid_block_adapter(
200
202
  num_attention_heads: Optional[int] = 1,
201
203
  cross_attention_dim: Optional[int] = 1024,
202
204
  upcast_attention: bool = False,
205
+ use_linear_projection: bool = True,
203
206
  ):
204
207
  # Before the midblock application, information is concatted from base to control.
205
208
  # Concat doesn't require change in number of channels
@@ -214,7 +217,7 @@ def get_mid_block_adapter(
214
217
  resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
215
218
  cross_attention_dim=cross_attention_dim,
216
219
  num_attention_heads=num_attention_heads,
217
- use_linear_projection=True,
220
+ use_linear_projection=use_linear_projection,
218
221
  upcast_attention=upcast_attention,
219
222
  )
220
223
 
@@ -282,7 +285,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
282
285
  upcast_attention (`bool`, defaults to `True`):
283
286
  Whether the attention computation should always be upcasted.
284
287
  max_norm_num_groups (`int`, defaults to 32):
285
- Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
288
+ Maximum number of groups in group normal. The actual number will be the largest divisor of the respective
286
289
  channels, that is <= max_norm_num_groups.
287
290
  """
288
291
 
@@ -308,6 +311,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
308
311
  transformer_layers_per_block: Union[int, Tuple[int]] = 1,
309
312
  upcast_attention: bool = True,
310
313
  max_norm_num_groups: int = 32,
314
+ use_linear_projection: bool = True,
311
315
  ):
312
316
  super().__init__()
313
317
 
@@ -381,6 +385,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
381
385
  cross_attention_dim=cross_attention_dim[i],
382
386
  add_downsample=not is_final_block,
383
387
  upcast_attention=upcast_attention,
388
+ use_linear_projection=use_linear_projection,
384
389
  )
385
390
  )
386
391
 
@@ -393,6 +398,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
393
398
  num_attention_heads=num_attention_heads[-1],
394
399
  cross_attention_dim=cross_attention_dim[-1],
395
400
  upcast_attention=upcast_attention,
401
+ use_linear_projection=use_linear_projection,
396
402
  )
397
403
 
398
404
  # up
@@ -489,6 +495,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
489
495
  transformer_layers_per_block=unet.config.transformer_layers_per_block,
490
496
  upcast_attention=unet.config.upcast_attention,
491
497
  max_norm_num_groups=unet.config.norm_num_groups,
498
+ use_linear_projection=unet.config.use_linear_projection,
492
499
  )
493
500
 
494
501
  # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +545,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
538
545
  addition_embed_type: Optional[str] = None,
539
546
  addition_time_embed_dim: Optional[int] = None,
540
547
  upcast_attention: bool = True,
548
+ use_linear_projection: bool = True,
541
549
  time_cond_proj_dim: Optional[int] = None,
542
550
  projection_class_embeddings_input_dim: Optional[int] = None,
543
551
  # additional controlnet configs
@@ -595,7 +603,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
595
603
  time_embed_dim,
596
604
  cond_proj_dim=time_cond_proj_dim,
597
605
  )
598
- self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
606
+ if ctrl_learn_time_embedding:
607
+ self.ctrl_time_embedding = TimestepEmbedding(
608
+ in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
609
+ )
610
+ else:
611
+ self.ctrl_time_embedding = None
599
612
 
600
613
  if addition_embed_type is None:
601
614
  self.base_add_time_proj = None
@@ -632,6 +645,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
632
645
  cross_attention_dim=cross_attention_dim[i],
633
646
  add_downsample=not is_final_block,
634
647
  upcast_attention=upcast_attention,
648
+ use_linear_projection=use_linear_projection,
635
649
  )
636
650
  )
637
651
 
@@ -647,6 +661,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
647
661
  ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
648
662
  cross_attention_dim=cross_attention_dim[-1],
649
663
  upcast_attention=upcast_attention,
664
+ use_linear_projection=use_linear_projection,
650
665
  )
651
666
 
652
667
  # # Create up blocks
@@ -690,6 +705,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
690
705
  add_upsample=not is_final_block,
691
706
  upcast_attention=upcast_attention,
692
707
  norm_num_groups=norm_num_groups,
708
+ use_linear_projection=use_linear_projection,
693
709
  )
694
710
  )
695
711
 
@@ -754,6 +770,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
754
770
  "addition_embed_type",
755
771
  "addition_time_embed_dim",
756
772
  "upcast_attention",
773
+ "use_linear_projection",
757
774
  "time_cond_proj_dim",
758
775
  "projection_class_embeddings_input_dim",
759
776
  ]
@@ -864,7 +881,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
864
881
 
865
882
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
866
883
  if hasattr(module, "get_processor"):
867
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
884
+ processors[f"{name}.processor"] = module.get_processor()
868
885
 
869
886
  for sub_name, child in module.named_children():
870
887
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -985,6 +1002,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
985
1002
  if isinstance(module, Attention):
986
1003
  module.fuse_projections(fuse=True)
987
1004
 
1005
+ self.set_attn_processor(FusedAttnProcessor2_0())
1006
+
988
1007
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
989
1008
  def unfuse_qkv_projections(self):
990
1009
  """Disables the fused QKV projection if enabled.
@@ -1219,6 +1238,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1219
1238
  cross_attention_dim: Optional[int] = 1024,
1220
1239
  add_downsample: bool = True,
1221
1240
  upcast_attention: Optional[bool] = False,
1241
+ use_linear_projection: Optional[bool] = True,
1222
1242
  ):
1223
1243
  super().__init__()
1224
1244
  base_resnets = []
@@ -1270,7 +1290,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1270
1290
  in_channels=base_out_channels,
1271
1291
  num_layers=transformer_layers_per_block[i],
1272
1292
  cross_attention_dim=cross_attention_dim,
1273
- use_linear_projection=True,
1293
+ use_linear_projection=use_linear_projection,
1274
1294
  upcast_attention=upcast_attention,
1275
1295
  norm_num_groups=norm_num_groups,
1276
1296
  )
@@ -1282,7 +1302,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1282
1302
  in_channels=ctrl_out_channels,
1283
1303
  num_layers=transformer_layers_per_block[i],
1284
1304
  cross_attention_dim=cross_attention_dim,
1285
- use_linear_projection=True,
1305
+ use_linear_projection=use_linear_projection,
1286
1306
  upcast_attention=upcast_attention,
1287
1307
  norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
1288
1308
  )
@@ -1342,6 +1362,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1342
1362
  ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
1343
1363
  cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
1344
1364
  upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
1365
+ use_linear_projection = base_downblock.attentions[0].use_linear_projection
1345
1366
  else:
1346
1367
  has_crossattn = False
1347
1368
  transformer_layers_per_block = None
@@ -1349,6 +1370,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1349
1370
  ctrl_num_attention_heads = None
1350
1371
  cross_attention_dim = None
1351
1372
  upcast_attention = None
1373
+ use_linear_projection = None
1352
1374
  add_downsample = base_downblock.downsamplers is not None
1353
1375
 
1354
1376
  # create model
@@ -1367,6 +1389,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1367
1389
  cross_attention_dim=cross_attention_dim,
1368
1390
  add_downsample=add_downsample,
1369
1391
  upcast_attention=upcast_attention,
1392
+ use_linear_projection=use_linear_projection,
1370
1393
  )
1371
1394
 
1372
1395
  # # load weights
@@ -1527,6 +1550,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
1527
1550
  ctrl_num_attention_heads: Optional[int] = 1,
1528
1551
  cross_attention_dim: Optional[int] = 1024,
1529
1552
  upcast_attention: bool = False,
1553
+ use_linear_projection: Optional[bool] = True,
1530
1554
  ):
1531
1555
  super().__init__()
1532
1556
 
@@ -1541,7 +1565,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
1541
1565
  resnet_groups=norm_num_groups,
1542
1566
  cross_attention_dim=cross_attention_dim,
1543
1567
  num_attention_heads=base_num_attention_heads,
1544
- use_linear_projection=True,
1568
+ use_linear_projection=use_linear_projection,
1545
1569
  upcast_attention=upcast_attention,
1546
1570
  )
1547
1571
 
@@ -1556,7 +1580,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
1556
1580
  ),
1557
1581
  cross_attention_dim=cross_attention_dim,
1558
1582
  num_attention_heads=ctrl_num_attention_heads,
1559
- use_linear_projection=True,
1583
+ use_linear_projection=use_linear_projection,
1560
1584
  upcast_attention=upcast_attention,
1561
1585
  )
1562
1586
 
@@ -1590,6 +1614,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
1590
1614
  ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
1591
1615
  cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
1592
1616
  upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
1617
+ use_linear_projection = base_midblock.attentions[0].use_linear_projection
1593
1618
 
1594
1619
  # create model
1595
1620
  model = cls(
@@ -1603,6 +1628,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
1603
1628
  ctrl_num_attention_heads=ctrl_num_attention_heads,
1604
1629
  cross_attention_dim=cross_attention_dim,
1605
1630
  upcast_attention=upcast_attention,
1631
+ use_linear_projection=use_linear_projection,
1606
1632
  )
1607
1633
 
1608
1634
  # load weights
@@ -1677,6 +1703,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1677
1703
  cross_attention_dim: int = 1024,
1678
1704
  add_upsample: bool = True,
1679
1705
  upcast_attention: bool = False,
1706
+ use_linear_projection: Optional[bool] = True,
1680
1707
  ):
1681
1708
  super().__init__()
1682
1709
  resnets = []
@@ -1714,7 +1741,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1714
1741
  in_channels=out_channels,
1715
1742
  num_layers=transformer_layers_per_block[i],
1716
1743
  cross_attention_dim=cross_attention_dim,
1717
- use_linear_projection=True,
1744
+ use_linear_projection=use_linear_projection,
1718
1745
  upcast_attention=upcast_attention,
1719
1746
  norm_num_groups=norm_num_groups,
1720
1747
  )
@@ -1753,12 +1780,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1753
1780
  num_attention_heads = get_first_cross_attention(base_upblock).heads
1754
1781
  cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
1755
1782
  upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
1783
+ use_linear_projection = base_upblock.attentions[0].use_linear_projection
1756
1784
  else:
1757
1785
  has_crossattn = False
1758
1786
  transformer_layers_per_block = None
1759
1787
  num_attention_heads = None
1760
1788
  cross_attention_dim = None
1761
1789
  upcast_attention = None
1790
+ use_linear_projection = None
1762
1791
  add_upsample = base_upblock.upsamplers is not None
1763
1792
 
1764
1793
  # create model
@@ -1776,6 +1805,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1776
1805
  cross_attention_dim=cross_attention_dim,
1777
1806
  add_upsample=add_upsample,
1778
1807
  upcast_attention=upcast_attention,
1808
+ use_linear_projection=use_linear_projection,
1779
1809
  )
1780
1810
 
1781
1811
  # load weights
@@ -285,6 +285,74 @@ class KDownsample2D(nn.Module):
285
285
  return F.conv2d(inputs, weight, stride=2)
286
286
 
287
287
 
288
+ class CogVideoXDownsample3D(nn.Module):
289
+ # Todo: Wait for paper relase.
290
+ r"""
291
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
292
+
293
+ Args:
294
+ in_channels (`int`):
295
+ Number of channels in the input image.
296
+ out_channels (`int`):
297
+ Number of channels produced by the convolution.
298
+ kernel_size (`int`, defaults to `3`):
299
+ Size of the convolving kernel.
300
+ stride (`int`, defaults to `2`):
301
+ Stride of the convolution.
302
+ padding (`int`, defaults to `0`):
303
+ Padding added to all four sides of the input.
304
+ compress_time (`bool`, defaults to `False`):
305
+ Whether or not to compress the time dimension.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ in_channels: int,
311
+ out_channels: int,
312
+ kernel_size: int = 3,
313
+ stride: int = 2,
314
+ padding: int = 0,
315
+ compress_time: bool = False,
316
+ ):
317
+ super().__init__()
318
+
319
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
320
+ self.compress_time = compress_time
321
+
322
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
323
+ if self.compress_time:
324
+ batch_size, channels, frames, height, width = x.shape
325
+
326
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
327
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
328
+
329
+ if x.shape[-1] % 2 == 1:
330
+ x_first, x_rest = x[..., 0], x[..., 1:]
331
+ if x_rest.shape[-1] > 0:
332
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
333
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
334
+
335
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
336
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
337
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
338
+ else:
339
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
340
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
341
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
342
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
343
+
344
+ # Pad the tensor
345
+ pad = (0, 1, 0, 1)
346
+ x = F.pad(x, pad, mode="constant", value=0)
347
+ batch_size, channels, frames, height, width = x.shape
348
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
349
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
350
+ x = self.conv(x)
351
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
352
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
353
+ return x
354
+
355
+
288
356
  def downsample_2d(
289
357
  hidden_states: torch.Tensor,
290
358
  kernel: Optional[torch.Tensor] = None,