diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from .attention_processor import (
|
|
29
29
|
AttentionProcessor,
|
30
30
|
AttnAddedKVProcessor,
|
31
31
|
AttnProcessor,
|
32
|
+
FusedAttnProcessor2_0,
|
32
33
|
)
|
33
34
|
from .controlnet import ControlNetConditioningEmbedding
|
34
35
|
from .embeddings import TimestepEmbedding, Timesteps
|
@@ -114,6 +115,7 @@ def get_down_block_adapter(
|
|
114
115
|
cross_attention_dim: Optional[int] = 1024,
|
115
116
|
add_downsample: bool = True,
|
116
117
|
upcast_attention: Optional[bool] = False,
|
118
|
+
use_linear_projection: Optional[bool] = True,
|
117
119
|
):
|
118
120
|
num_layers = 2 # only support sd + sdxl
|
119
121
|
|
@@ -152,7 +154,7 @@ def get_down_block_adapter(
|
|
152
154
|
in_channels=ctrl_out_channels,
|
153
155
|
num_layers=transformer_layers_per_block[i],
|
154
156
|
cross_attention_dim=cross_attention_dim,
|
155
|
-
use_linear_projection=
|
157
|
+
use_linear_projection=use_linear_projection,
|
156
158
|
upcast_attention=upcast_attention,
|
157
159
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
|
158
160
|
)
|
@@ -200,6 +202,7 @@ def get_mid_block_adapter(
|
|
200
202
|
num_attention_heads: Optional[int] = 1,
|
201
203
|
cross_attention_dim: Optional[int] = 1024,
|
202
204
|
upcast_attention: bool = False,
|
205
|
+
use_linear_projection: bool = True,
|
203
206
|
):
|
204
207
|
# Before the midblock application, information is concatted from base to control.
|
205
208
|
# Concat doesn't require change in number of channels
|
@@ -214,7 +217,7 @@ def get_mid_block_adapter(
|
|
214
217
|
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
|
215
218
|
cross_attention_dim=cross_attention_dim,
|
216
219
|
num_attention_heads=num_attention_heads,
|
217
|
-
use_linear_projection=
|
220
|
+
use_linear_projection=use_linear_projection,
|
218
221
|
upcast_attention=upcast_attention,
|
219
222
|
)
|
220
223
|
|
@@ -282,7 +285,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|
282
285
|
upcast_attention (`bool`, defaults to `True`):
|
283
286
|
Whether the attention computation should always be upcasted.
|
284
287
|
max_norm_num_groups (`int`, defaults to 32):
|
285
|
-
Maximum number of groups in group normal. The actual number will
|
288
|
+
Maximum number of groups in group normal. The actual number will be the largest divisor of the respective
|
286
289
|
channels, that is <= max_norm_num_groups.
|
287
290
|
"""
|
288
291
|
|
@@ -308,6 +311,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|
308
311
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
309
312
|
upcast_attention: bool = True,
|
310
313
|
max_norm_num_groups: int = 32,
|
314
|
+
use_linear_projection: bool = True,
|
311
315
|
):
|
312
316
|
super().__init__()
|
313
317
|
|
@@ -381,6 +385,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|
381
385
|
cross_attention_dim=cross_attention_dim[i],
|
382
386
|
add_downsample=not is_final_block,
|
383
387
|
upcast_attention=upcast_attention,
|
388
|
+
use_linear_projection=use_linear_projection,
|
384
389
|
)
|
385
390
|
)
|
386
391
|
|
@@ -393,6 +398,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|
393
398
|
num_attention_heads=num_attention_heads[-1],
|
394
399
|
cross_attention_dim=cross_attention_dim[-1],
|
395
400
|
upcast_attention=upcast_attention,
|
401
|
+
use_linear_projection=use_linear_projection,
|
396
402
|
)
|
397
403
|
|
398
404
|
# up
|
@@ -489,6 +495,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
|
489
495
|
transformer_layers_per_block=unet.config.transformer_layers_per_block,
|
490
496
|
upcast_attention=unet.config.upcast_attention,
|
491
497
|
max_norm_num_groups=unet.config.norm_num_groups,
|
498
|
+
use_linear_projection=unet.config.use_linear_projection,
|
492
499
|
)
|
493
500
|
|
494
501
|
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
|
@@ -538,6 +545,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
538
545
|
addition_embed_type: Optional[str] = None,
|
539
546
|
addition_time_embed_dim: Optional[int] = None,
|
540
547
|
upcast_attention: bool = True,
|
548
|
+
use_linear_projection: bool = True,
|
541
549
|
time_cond_proj_dim: Optional[int] = None,
|
542
550
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
543
551
|
# additional controlnet configs
|
@@ -595,7 +603,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
595
603
|
time_embed_dim,
|
596
604
|
cond_proj_dim=time_cond_proj_dim,
|
597
605
|
)
|
598
|
-
|
606
|
+
if ctrl_learn_time_embedding:
|
607
|
+
self.ctrl_time_embedding = TimestepEmbedding(
|
608
|
+
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
|
609
|
+
)
|
610
|
+
else:
|
611
|
+
self.ctrl_time_embedding = None
|
599
612
|
|
600
613
|
if addition_embed_type is None:
|
601
614
|
self.base_add_time_proj = None
|
@@ -632,6 +645,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
632
645
|
cross_attention_dim=cross_attention_dim[i],
|
633
646
|
add_downsample=not is_final_block,
|
634
647
|
upcast_attention=upcast_attention,
|
648
|
+
use_linear_projection=use_linear_projection,
|
635
649
|
)
|
636
650
|
)
|
637
651
|
|
@@ -647,6 +661,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
647
661
|
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
|
648
662
|
cross_attention_dim=cross_attention_dim[-1],
|
649
663
|
upcast_attention=upcast_attention,
|
664
|
+
use_linear_projection=use_linear_projection,
|
650
665
|
)
|
651
666
|
|
652
667
|
# # Create up blocks
|
@@ -690,6 +705,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
690
705
|
add_upsample=not is_final_block,
|
691
706
|
upcast_attention=upcast_attention,
|
692
707
|
norm_num_groups=norm_num_groups,
|
708
|
+
use_linear_projection=use_linear_projection,
|
693
709
|
)
|
694
710
|
)
|
695
711
|
|
@@ -754,6 +770,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
754
770
|
"addition_embed_type",
|
755
771
|
"addition_time_embed_dim",
|
756
772
|
"upcast_attention",
|
773
|
+
"use_linear_projection",
|
757
774
|
"time_cond_proj_dim",
|
758
775
|
"projection_class_embeddings_input_dim",
|
759
776
|
]
|
@@ -864,7 +881,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
864
881
|
|
865
882
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
866
883
|
if hasattr(module, "get_processor"):
|
867
|
-
processors[f"{name}.processor"] = module.get_processor(
|
884
|
+
processors[f"{name}.processor"] = module.get_processor()
|
868
885
|
|
869
886
|
for sub_name, child in module.named_children():
|
870
887
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -985,6 +1002,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
985
1002
|
if isinstance(module, Attention):
|
986
1003
|
module.fuse_projections(fuse=True)
|
987
1004
|
|
1005
|
+
self.set_attn_processor(FusedAttnProcessor2_0())
|
1006
|
+
|
988
1007
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
989
1008
|
def unfuse_qkv_projections(self):
|
990
1009
|
"""Disables the fused QKV projection if enabled.
|
@@ -1219,6 +1238,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1219
1238
|
cross_attention_dim: Optional[int] = 1024,
|
1220
1239
|
add_downsample: bool = True,
|
1221
1240
|
upcast_attention: Optional[bool] = False,
|
1241
|
+
use_linear_projection: Optional[bool] = True,
|
1222
1242
|
):
|
1223
1243
|
super().__init__()
|
1224
1244
|
base_resnets = []
|
@@ -1270,7 +1290,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1270
1290
|
in_channels=base_out_channels,
|
1271
1291
|
num_layers=transformer_layers_per_block[i],
|
1272
1292
|
cross_attention_dim=cross_attention_dim,
|
1273
|
-
use_linear_projection=
|
1293
|
+
use_linear_projection=use_linear_projection,
|
1274
1294
|
upcast_attention=upcast_attention,
|
1275
1295
|
norm_num_groups=norm_num_groups,
|
1276
1296
|
)
|
@@ -1282,7 +1302,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1282
1302
|
in_channels=ctrl_out_channels,
|
1283
1303
|
num_layers=transformer_layers_per_block[i],
|
1284
1304
|
cross_attention_dim=cross_attention_dim,
|
1285
|
-
use_linear_projection=
|
1305
|
+
use_linear_projection=use_linear_projection,
|
1286
1306
|
upcast_attention=upcast_attention,
|
1287
1307
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
|
1288
1308
|
)
|
@@ -1342,6 +1362,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1342
1362
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
|
1343
1363
|
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
|
1344
1364
|
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
|
1365
|
+
use_linear_projection = base_downblock.attentions[0].use_linear_projection
|
1345
1366
|
else:
|
1346
1367
|
has_crossattn = False
|
1347
1368
|
transformer_layers_per_block = None
|
@@ -1349,6 +1370,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1349
1370
|
ctrl_num_attention_heads = None
|
1350
1371
|
cross_attention_dim = None
|
1351
1372
|
upcast_attention = None
|
1373
|
+
use_linear_projection = None
|
1352
1374
|
add_downsample = base_downblock.downsamplers is not None
|
1353
1375
|
|
1354
1376
|
# create model
|
@@ -1367,6 +1389,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1367
1389
|
cross_attention_dim=cross_attention_dim,
|
1368
1390
|
add_downsample=add_downsample,
|
1369
1391
|
upcast_attention=upcast_attention,
|
1392
|
+
use_linear_projection=use_linear_projection,
|
1370
1393
|
)
|
1371
1394
|
|
1372
1395
|
# # load weights
|
@@ -1527,6 +1550,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|
1527
1550
|
ctrl_num_attention_heads: Optional[int] = 1,
|
1528
1551
|
cross_attention_dim: Optional[int] = 1024,
|
1529
1552
|
upcast_attention: bool = False,
|
1553
|
+
use_linear_projection: Optional[bool] = True,
|
1530
1554
|
):
|
1531
1555
|
super().__init__()
|
1532
1556
|
|
@@ -1541,7 +1565,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|
1541
1565
|
resnet_groups=norm_num_groups,
|
1542
1566
|
cross_attention_dim=cross_attention_dim,
|
1543
1567
|
num_attention_heads=base_num_attention_heads,
|
1544
|
-
use_linear_projection=
|
1568
|
+
use_linear_projection=use_linear_projection,
|
1545
1569
|
upcast_attention=upcast_attention,
|
1546
1570
|
)
|
1547
1571
|
|
@@ -1556,7 +1580,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|
1556
1580
|
),
|
1557
1581
|
cross_attention_dim=cross_attention_dim,
|
1558
1582
|
num_attention_heads=ctrl_num_attention_heads,
|
1559
|
-
use_linear_projection=
|
1583
|
+
use_linear_projection=use_linear_projection,
|
1560
1584
|
upcast_attention=upcast_attention,
|
1561
1585
|
)
|
1562
1586
|
|
@@ -1590,6 +1614,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|
1590
1614
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
|
1591
1615
|
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
|
1592
1616
|
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
|
1617
|
+
use_linear_projection = base_midblock.attentions[0].use_linear_projection
|
1593
1618
|
|
1594
1619
|
# create model
|
1595
1620
|
model = cls(
|
@@ -1603,6 +1628,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
|
|
1603
1628
|
ctrl_num_attention_heads=ctrl_num_attention_heads,
|
1604
1629
|
cross_attention_dim=cross_attention_dim,
|
1605
1630
|
upcast_attention=upcast_attention,
|
1631
|
+
use_linear_projection=use_linear_projection,
|
1606
1632
|
)
|
1607
1633
|
|
1608
1634
|
# load weights
|
@@ -1677,6 +1703,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1677
1703
|
cross_attention_dim: int = 1024,
|
1678
1704
|
add_upsample: bool = True,
|
1679
1705
|
upcast_attention: bool = False,
|
1706
|
+
use_linear_projection: Optional[bool] = True,
|
1680
1707
|
):
|
1681
1708
|
super().__init__()
|
1682
1709
|
resnets = []
|
@@ -1714,7 +1741,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1714
1741
|
in_channels=out_channels,
|
1715
1742
|
num_layers=transformer_layers_per_block[i],
|
1716
1743
|
cross_attention_dim=cross_attention_dim,
|
1717
|
-
use_linear_projection=
|
1744
|
+
use_linear_projection=use_linear_projection,
|
1718
1745
|
upcast_attention=upcast_attention,
|
1719
1746
|
norm_num_groups=norm_num_groups,
|
1720
1747
|
)
|
@@ -1753,12 +1780,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1753
1780
|
num_attention_heads = get_first_cross_attention(base_upblock).heads
|
1754
1781
|
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
|
1755
1782
|
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
|
1783
|
+
use_linear_projection = base_upblock.attentions[0].use_linear_projection
|
1756
1784
|
else:
|
1757
1785
|
has_crossattn = False
|
1758
1786
|
transformer_layers_per_block = None
|
1759
1787
|
num_attention_heads = None
|
1760
1788
|
cross_attention_dim = None
|
1761
1789
|
upcast_attention = None
|
1790
|
+
use_linear_projection = None
|
1762
1791
|
add_upsample = base_upblock.upsamplers is not None
|
1763
1792
|
|
1764
1793
|
# create model
|
@@ -1776,6 +1805,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1776
1805
|
cross_attention_dim=cross_attention_dim,
|
1777
1806
|
add_upsample=add_upsample,
|
1778
1807
|
upcast_attention=upcast_attention,
|
1808
|
+
use_linear_projection=use_linear_projection,
|
1779
1809
|
)
|
1780
1810
|
|
1781
1811
|
# load weights
|
diffusers/models/downsampling.py
CHANGED
@@ -285,6 +285,74 @@ class KDownsample2D(nn.Module):
|
|
285
285
|
return F.conv2d(inputs, weight, stride=2)
|
286
286
|
|
287
287
|
|
288
|
+
class CogVideoXDownsample3D(nn.Module):
|
289
|
+
# Todo: Wait for paper relase.
|
290
|
+
r"""
|
291
|
+
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
292
|
+
|
293
|
+
Args:
|
294
|
+
in_channels (`int`):
|
295
|
+
Number of channels in the input image.
|
296
|
+
out_channels (`int`):
|
297
|
+
Number of channels produced by the convolution.
|
298
|
+
kernel_size (`int`, defaults to `3`):
|
299
|
+
Size of the convolving kernel.
|
300
|
+
stride (`int`, defaults to `2`):
|
301
|
+
Stride of the convolution.
|
302
|
+
padding (`int`, defaults to `0`):
|
303
|
+
Padding added to all four sides of the input.
|
304
|
+
compress_time (`bool`, defaults to `False`):
|
305
|
+
Whether or not to compress the time dimension.
|
306
|
+
"""
|
307
|
+
|
308
|
+
def __init__(
|
309
|
+
self,
|
310
|
+
in_channels: int,
|
311
|
+
out_channels: int,
|
312
|
+
kernel_size: int = 3,
|
313
|
+
stride: int = 2,
|
314
|
+
padding: int = 0,
|
315
|
+
compress_time: bool = False,
|
316
|
+
):
|
317
|
+
super().__init__()
|
318
|
+
|
319
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
320
|
+
self.compress_time = compress_time
|
321
|
+
|
322
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
323
|
+
if self.compress_time:
|
324
|
+
batch_size, channels, frames, height, width = x.shape
|
325
|
+
|
326
|
+
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
327
|
+
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
328
|
+
|
329
|
+
if x.shape[-1] % 2 == 1:
|
330
|
+
x_first, x_rest = x[..., 0], x[..., 1:]
|
331
|
+
if x_rest.shape[-1] > 0:
|
332
|
+
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
333
|
+
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
334
|
+
|
335
|
+
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
336
|
+
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
337
|
+
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
338
|
+
else:
|
339
|
+
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
340
|
+
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
341
|
+
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
342
|
+
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
343
|
+
|
344
|
+
# Pad the tensor
|
345
|
+
pad = (0, 1, 0, 1)
|
346
|
+
x = F.pad(x, pad, mode="constant", value=0)
|
347
|
+
batch_size, channels, frames, height, width = x.shape
|
348
|
+
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
349
|
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
350
|
+
x = self.conv(x)
|
351
|
+
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
352
|
+
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
353
|
+
return x
|
354
|
+
|
355
|
+
|
288
356
|
def downsample_2d(
|
289
357
|
hidden_states: torch.Tensor,
|
290
358
|
kernel: Optional[torch.Tensor] = None,
|