diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
|
|
561
561
|
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
|
562
562
|
|
563
563
|
Returns:
|
564
|
-
`torch.
|
564
|
+
`torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
|
565
565
|
`out_channels`.
|
566
566
|
"""
|
567
567
|
|
@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
|
|
582
582
|
)
|
583
583
|
self.fuse = nn.ReLU()
|
584
584
|
|
585
|
-
def forward(self, x: torch.
|
585
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
586
586
|
return self.fuse(self.conv(x) + self.skip(x))
|
587
587
|
|
588
588
|
|
@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
|
|
612
612
|
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
613
613
|
|
614
614
|
Returns:
|
615
|
-
`torch.
|
616
|
-
|
615
|
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
|
616
|
+
height, width)`.
|
617
617
|
|
618
618
|
"""
|
619
619
|
|
@@ -731,7 +731,7 @@ class UNetMidBlock2D(nn.Module):
|
|
731
731
|
self.attentions = nn.ModuleList(attentions)
|
732
732
|
self.resnets = nn.ModuleList(resnets)
|
733
733
|
|
734
|
-
def forward(self, hidden_states: torch.
|
734
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
735
735
|
hidden_states = self.resnets[0](hidden_states, temb)
|
736
736
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
737
737
|
if attn is not None:
|
@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
746
746
|
self,
|
747
747
|
in_channels: int,
|
748
748
|
temb_channels: int,
|
749
|
+
out_channels: Optional[int] = None,
|
749
750
|
dropout: float = 0.0,
|
750
751
|
num_layers: int = 1,
|
751
752
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
753
754
|
resnet_time_scale_shift: str = "default",
|
754
755
|
resnet_act_fn: str = "swish",
|
755
756
|
resnet_groups: int = 32,
|
757
|
+
resnet_groups_out: Optional[int] = None,
|
756
758
|
resnet_pre_norm: bool = True,
|
757
759
|
num_attention_heads: int = 1,
|
758
760
|
output_scale_factor: float = 1.0,
|
@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
764
766
|
):
|
765
767
|
super().__init__()
|
766
768
|
|
769
|
+
out_channels = out_channels or in_channels
|
770
|
+
self.in_channels = in_channels
|
771
|
+
self.out_channels = out_channels
|
772
|
+
|
767
773
|
self.has_cross_attention = True
|
768
774
|
self.num_attention_heads = num_attention_heads
|
769
775
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
772
778
|
if isinstance(transformer_layers_per_block, int):
|
773
779
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
774
780
|
|
781
|
+
resnet_groups_out = resnet_groups_out or resnet_groups
|
782
|
+
|
775
783
|
# there is always at least one resnet
|
776
784
|
resnets = [
|
777
785
|
ResnetBlock2D(
|
778
786
|
in_channels=in_channels,
|
779
|
-
out_channels=
|
787
|
+
out_channels=out_channels,
|
780
788
|
temb_channels=temb_channels,
|
781
789
|
eps=resnet_eps,
|
782
790
|
groups=resnet_groups,
|
791
|
+
groups_out=resnet_groups_out,
|
783
792
|
dropout=dropout,
|
784
793
|
time_embedding_norm=resnet_time_scale_shift,
|
785
794
|
non_linearity=resnet_act_fn,
|
@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
794
803
|
attentions.append(
|
795
804
|
Transformer2DModel(
|
796
805
|
num_attention_heads,
|
797
|
-
|
798
|
-
in_channels=
|
806
|
+
out_channels // num_attention_heads,
|
807
|
+
in_channels=out_channels,
|
799
808
|
num_layers=transformer_layers_per_block[i],
|
800
809
|
cross_attention_dim=cross_attention_dim,
|
801
|
-
norm_num_groups=
|
810
|
+
norm_num_groups=resnet_groups_out,
|
802
811
|
use_linear_projection=use_linear_projection,
|
803
812
|
upcast_attention=upcast_attention,
|
804
813
|
attention_type=attention_type,
|
@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
808
817
|
attentions.append(
|
809
818
|
DualTransformer2DModel(
|
810
819
|
num_attention_heads,
|
811
|
-
|
812
|
-
in_channels=
|
820
|
+
out_channels // num_attention_heads,
|
821
|
+
in_channels=out_channels,
|
813
822
|
num_layers=1,
|
814
823
|
cross_attention_dim=cross_attention_dim,
|
815
824
|
norm_num_groups=resnet_groups,
|
@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
817
826
|
)
|
818
827
|
resnets.append(
|
819
828
|
ResnetBlock2D(
|
820
|
-
in_channels=
|
821
|
-
out_channels=
|
829
|
+
in_channels=out_channels,
|
830
|
+
out_channels=out_channels,
|
822
831
|
temb_channels=temb_channels,
|
823
832
|
eps=resnet_eps,
|
824
|
-
groups=
|
833
|
+
groups=resnet_groups_out,
|
825
834
|
dropout=dropout,
|
826
835
|
time_embedding_norm=resnet_time_scale_shift,
|
827
836
|
non_linearity=resnet_act_fn,
|
@@ -837,16 +846,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
837
846
|
|
838
847
|
def forward(
|
839
848
|
self,
|
840
|
-
hidden_states: torch.
|
841
|
-
temb: Optional[torch.
|
842
|
-
encoder_hidden_states: Optional[torch.
|
843
|
-
attention_mask: Optional[torch.
|
849
|
+
hidden_states: torch.Tensor,
|
850
|
+
temb: Optional[torch.Tensor] = None,
|
851
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
852
|
+
attention_mask: Optional[torch.Tensor] = None,
|
844
853
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
845
|
-
encoder_attention_mask: Optional[torch.
|
846
|
-
) -> torch.
|
854
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
855
|
+
) -> torch.Tensor:
|
847
856
|
if cross_attention_kwargs is not None:
|
848
857
|
if cross_attention_kwargs.get("scale", None) is not None:
|
849
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
858
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
850
859
|
|
851
860
|
hidden_states = self.resnets[0](hidden_states, temb)
|
852
861
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
@@ -977,16 +986,16 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
977
986
|
|
978
987
|
def forward(
|
979
988
|
self,
|
980
|
-
hidden_states: torch.
|
981
|
-
temb: Optional[torch.
|
982
|
-
encoder_hidden_states: Optional[torch.
|
983
|
-
attention_mask: Optional[torch.
|
989
|
+
hidden_states: torch.Tensor,
|
990
|
+
temb: Optional[torch.Tensor] = None,
|
991
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
992
|
+
attention_mask: Optional[torch.Tensor] = None,
|
984
993
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
985
|
-
encoder_attention_mask: Optional[torch.
|
986
|
-
) -> torch.
|
994
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
995
|
+
) -> torch.Tensor:
|
987
996
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
988
997
|
if cross_attention_kwargs.get("scale", None) is not None:
|
989
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
998
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
990
999
|
|
991
1000
|
if attention_mask is None:
|
992
1001
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -1109,14 +1118,14 @@ class AttnDownBlock2D(nn.Module):
|
|
1109
1118
|
|
1110
1119
|
def forward(
|
1111
1120
|
self,
|
1112
|
-
hidden_states: torch.
|
1113
|
-
temb: Optional[torch.
|
1121
|
+
hidden_states: torch.Tensor,
|
1122
|
+
temb: Optional[torch.Tensor] = None,
|
1114
1123
|
upsample_size: Optional[int] = None,
|
1115
1124
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1116
|
-
) -> Tuple[torch.
|
1125
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1117
1126
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1118
1127
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1119
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1128
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1120
1129
|
|
1121
1130
|
output_states = ()
|
1122
1131
|
|
@@ -1231,17 +1240,17 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1231
1240
|
|
1232
1241
|
def forward(
|
1233
1242
|
self,
|
1234
|
-
hidden_states: torch.
|
1235
|
-
temb: Optional[torch.
|
1236
|
-
encoder_hidden_states: Optional[torch.
|
1237
|
-
attention_mask: Optional[torch.
|
1243
|
+
hidden_states: torch.Tensor,
|
1244
|
+
temb: Optional[torch.Tensor] = None,
|
1245
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1246
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1238
1247
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1239
|
-
encoder_attention_mask: Optional[torch.
|
1240
|
-
additional_residuals: Optional[torch.
|
1241
|
-
) -> Tuple[torch.
|
1248
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1249
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
1250
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1242
1251
|
if cross_attention_kwargs is not None:
|
1243
1252
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1244
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1253
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1245
1254
|
|
1246
1255
|
output_states = ()
|
1247
1256
|
|
@@ -1353,8 +1362,8 @@ class DownBlock2D(nn.Module):
|
|
1353
1362
|
self.gradient_checkpointing = False
|
1354
1363
|
|
1355
1364
|
def forward(
|
1356
|
-
self, hidden_states: torch.
|
1357
|
-
) -> Tuple[torch.
|
1365
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
1366
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1358
1367
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1359
1368
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1360
1369
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1456,7 +1465,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
1456
1465
|
else:
|
1457
1466
|
self.downsamplers = None
|
1458
1467
|
|
1459
|
-
def forward(self, hidden_states: torch.
|
1468
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1460
1469
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1461
1470
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1462
1471
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1558,7 +1567,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1558
1567
|
else:
|
1559
1568
|
self.downsamplers = None
|
1560
1569
|
|
1561
|
-
def forward(self, hidden_states: torch.
|
1570
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1562
1571
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1563
1572
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1564
1573
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1657,12 +1666,12 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1657
1666
|
|
1658
1667
|
def forward(
|
1659
1668
|
self,
|
1660
|
-
hidden_states: torch.
|
1661
|
-
temb: Optional[torch.
|
1662
|
-
skip_sample: Optional[torch.
|
1669
|
+
hidden_states: torch.Tensor,
|
1670
|
+
temb: Optional[torch.Tensor] = None,
|
1671
|
+
skip_sample: Optional[torch.Tensor] = None,
|
1663
1672
|
*args,
|
1664
1673
|
**kwargs,
|
1665
|
-
) -> Tuple[torch.
|
1674
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
|
1666
1675
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1667
1676
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1668
1677
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1748,12 +1757,12 @@ class SkipDownBlock2D(nn.Module):
|
|
1748
1757
|
|
1749
1758
|
def forward(
|
1750
1759
|
self,
|
1751
|
-
hidden_states: torch.
|
1752
|
-
temb: Optional[torch.
|
1753
|
-
skip_sample: Optional[torch.
|
1760
|
+
hidden_states: torch.Tensor,
|
1761
|
+
temb: Optional[torch.Tensor] = None,
|
1762
|
+
skip_sample: Optional[torch.Tensor] = None,
|
1754
1763
|
*args,
|
1755
1764
|
**kwargs,
|
1756
|
-
) -> Tuple[torch.
|
1765
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
|
1757
1766
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1758
1767
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1759
1768
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1841,8 +1850,8 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1841
1850
|
self.gradient_checkpointing = False
|
1842
1851
|
|
1843
1852
|
def forward(
|
1844
|
-
self, hidden_states: torch.
|
1845
|
-
) -> Tuple[torch.
|
1853
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
1854
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1846
1855
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
1856
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
1857
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1977,16 +1986,16 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1977
1986
|
|
1978
1987
|
def forward(
|
1979
1988
|
self,
|
1980
|
-
hidden_states: torch.
|
1981
|
-
temb: Optional[torch.
|
1982
|
-
encoder_hidden_states: Optional[torch.
|
1983
|
-
attention_mask: Optional[torch.
|
1989
|
+
hidden_states: torch.Tensor,
|
1990
|
+
temb: Optional[torch.Tensor] = None,
|
1991
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1992
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1984
1993
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1985
|
-
encoder_attention_mask: Optional[torch.
|
1986
|
-
) -> Tuple[torch.
|
1994
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1995
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1987
1996
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1988
1997
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1989
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1998
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1990
1999
|
|
1991
2000
|
output_states = ()
|
1992
2001
|
|
@@ -2088,8 +2097,8 @@ class KDownBlock2D(nn.Module):
|
|
2088
2097
|
self.gradient_checkpointing = False
|
2089
2098
|
|
2090
2099
|
def forward(
|
2091
|
-
self, hidden_states: torch.
|
2092
|
-
) -> Tuple[torch.
|
2100
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
2101
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2093
2102
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2094
2103
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2095
2104
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2192,16 +2201,16 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2192
2201
|
|
2193
2202
|
def forward(
|
2194
2203
|
self,
|
2195
|
-
hidden_states: torch.
|
2196
|
-
temb: Optional[torch.
|
2197
|
-
encoder_hidden_states: Optional[torch.
|
2198
|
-
attention_mask: Optional[torch.
|
2204
|
+
hidden_states: torch.Tensor,
|
2205
|
+
temb: Optional[torch.Tensor] = None,
|
2206
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2207
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2199
2208
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2200
|
-
encoder_attention_mask: Optional[torch.
|
2201
|
-
) -> Tuple[torch.
|
2209
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2210
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2202
2211
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2203
2212
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2204
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2213
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2205
2214
|
|
2206
2215
|
output_states = ()
|
2207
2216
|
|
@@ -2349,13 +2358,13 @@ class AttnUpBlock2D(nn.Module):
|
|
2349
2358
|
|
2350
2359
|
def forward(
|
2351
2360
|
self,
|
2352
|
-
hidden_states: torch.
|
2353
|
-
res_hidden_states_tuple: Tuple[torch.
|
2354
|
-
temb: Optional[torch.
|
2361
|
+
hidden_states: torch.Tensor,
|
2362
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2363
|
+
temb: Optional[torch.Tensor] = None,
|
2355
2364
|
upsample_size: Optional[int] = None,
|
2356
2365
|
*args,
|
2357
2366
|
**kwargs,
|
2358
|
-
) -> torch.
|
2367
|
+
) -> torch.Tensor:
|
2359
2368
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2360
2369
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2361
2370
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2472,18 +2481,18 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2472
2481
|
|
2473
2482
|
def forward(
|
2474
2483
|
self,
|
2475
|
-
hidden_states: torch.
|
2476
|
-
res_hidden_states_tuple: Tuple[torch.
|
2477
|
-
temb: Optional[torch.
|
2478
|
-
encoder_hidden_states: Optional[torch.
|
2484
|
+
hidden_states: torch.Tensor,
|
2485
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2486
|
+
temb: Optional[torch.Tensor] = None,
|
2487
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2479
2488
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2480
2489
|
upsample_size: Optional[int] = None,
|
2481
|
-
attention_mask: Optional[torch.
|
2482
|
-
encoder_attention_mask: Optional[torch.
|
2483
|
-
) -> torch.
|
2490
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2491
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2492
|
+
) -> torch.Tensor:
|
2484
2493
|
if cross_attention_kwargs is not None:
|
2485
2494
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2486
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2495
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2487
2496
|
|
2488
2497
|
is_freeu_enabled = (
|
2489
2498
|
getattr(self, "s1", None)
|
@@ -2607,13 +2616,13 @@ class UpBlock2D(nn.Module):
|
|
2607
2616
|
|
2608
2617
|
def forward(
|
2609
2618
|
self,
|
2610
|
-
hidden_states: torch.
|
2611
|
-
res_hidden_states_tuple: Tuple[torch.
|
2612
|
-
temb: Optional[torch.
|
2619
|
+
hidden_states: torch.Tensor,
|
2620
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2621
|
+
temb: Optional[torch.Tensor] = None,
|
2613
2622
|
upsample_size: Optional[int] = None,
|
2614
2623
|
*args,
|
2615
2624
|
**kwargs,
|
2616
|
-
) -> torch.
|
2625
|
+
) -> torch.Tensor:
|
2617
2626
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2618
2627
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2619
2628
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2732,7 +2741,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
2732
2741
|
|
2733
2742
|
self.resolution_idx = resolution_idx
|
2734
2743
|
|
2735
|
-
def forward(self, hidden_states: torch.
|
2744
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2736
2745
|
for resnet in self.resnets:
|
2737
2746
|
hidden_states = resnet(hidden_states, temb=temb)
|
2738
2747
|
|
@@ -2830,7 +2839,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2830
2839
|
|
2831
2840
|
self.resolution_idx = resolution_idx
|
2832
2841
|
|
2833
|
-
def forward(self, hidden_states: torch.
|
2842
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2834
2843
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2835
2844
|
hidden_states = resnet(hidden_states, temb=temb)
|
2836
2845
|
hidden_states = attn(hidden_states, temb=temb)
|
@@ -2938,13 +2947,13 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2938
2947
|
|
2939
2948
|
def forward(
|
2940
2949
|
self,
|
2941
|
-
hidden_states: torch.
|
2942
|
-
res_hidden_states_tuple: Tuple[torch.
|
2943
|
-
temb: Optional[torch.
|
2950
|
+
hidden_states: torch.Tensor,
|
2951
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2952
|
+
temb: Optional[torch.Tensor] = None,
|
2944
2953
|
skip_sample=None,
|
2945
2954
|
*args,
|
2946
2955
|
**kwargs,
|
2947
|
-
) -> Tuple[torch.
|
2956
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2948
2957
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2949
2958
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2950
2959
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3050,13 +3059,13 @@ class SkipUpBlock2D(nn.Module):
|
|
3050
3059
|
|
3051
3060
|
def forward(
|
3052
3061
|
self,
|
3053
|
-
hidden_states: torch.
|
3054
|
-
res_hidden_states_tuple: Tuple[torch.
|
3055
|
-
temb: Optional[torch.
|
3062
|
+
hidden_states: torch.Tensor,
|
3063
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3064
|
+
temb: Optional[torch.Tensor] = None,
|
3056
3065
|
skip_sample=None,
|
3057
3066
|
*args,
|
3058
3067
|
**kwargs,
|
3059
|
-
) -> Tuple[torch.
|
3068
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
3060
3069
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3061
3070
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3062
3071
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3157,13 +3166,13 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3157
3166
|
|
3158
3167
|
def forward(
|
3159
3168
|
self,
|
3160
|
-
hidden_states: torch.
|
3161
|
-
res_hidden_states_tuple: Tuple[torch.
|
3162
|
-
temb: Optional[torch.
|
3169
|
+
hidden_states: torch.Tensor,
|
3170
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3171
|
+
temb: Optional[torch.Tensor] = None,
|
3163
3172
|
upsample_size: Optional[int] = None,
|
3164
3173
|
*args,
|
3165
3174
|
**kwargs,
|
3166
|
-
) -> torch.
|
3175
|
+
) -> torch.Tensor:
|
3167
3176
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3168
3177
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3169
3178
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3301,18 +3310,18 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3301
3310
|
|
3302
3311
|
def forward(
|
3303
3312
|
self,
|
3304
|
-
hidden_states: torch.
|
3305
|
-
res_hidden_states_tuple: Tuple[torch.
|
3306
|
-
temb: Optional[torch.
|
3307
|
-
encoder_hidden_states: Optional[torch.
|
3313
|
+
hidden_states: torch.Tensor,
|
3314
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3315
|
+
temb: Optional[torch.Tensor] = None,
|
3316
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3308
3317
|
upsample_size: Optional[int] = None,
|
3309
|
-
attention_mask: Optional[torch.
|
3318
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3310
3319
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3311
|
-
encoder_attention_mask: Optional[torch.
|
3312
|
-
) -> torch.
|
3320
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3321
|
+
) -> torch.Tensor:
|
3313
3322
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3314
3323
|
if cross_attention_kwargs.get("scale", None) is not None:
|
3315
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
3324
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
3316
3325
|
|
3317
3326
|
if attention_mask is None:
|
3318
3327
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -3419,13 +3428,13 @@ class KUpBlock2D(nn.Module):
|
|
3419
3428
|
|
3420
3429
|
def forward(
|
3421
3430
|
self,
|
3422
|
-
hidden_states: torch.
|
3423
|
-
res_hidden_states_tuple: Tuple[torch.
|
3424
|
-
temb: Optional[torch.
|
3431
|
+
hidden_states: torch.Tensor,
|
3432
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3433
|
+
temb: Optional[torch.Tensor] = None,
|
3425
3434
|
upsample_size: Optional[int] = None,
|
3426
3435
|
*args,
|
3427
3436
|
**kwargs,
|
3428
|
-
) -> torch.
|
3437
|
+
) -> torch.Tensor:
|
3429
3438
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3430
3439
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3431
3440
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3549,15 +3558,15 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3549
3558
|
|
3550
3559
|
def forward(
|
3551
3560
|
self,
|
3552
|
-
hidden_states: torch.
|
3553
|
-
res_hidden_states_tuple: Tuple[torch.
|
3554
|
-
temb: Optional[torch.
|
3555
|
-
encoder_hidden_states: Optional[torch.
|
3561
|
+
hidden_states: torch.Tensor,
|
3562
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3563
|
+
temb: Optional[torch.Tensor] = None,
|
3564
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3556
3565
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3557
3566
|
upsample_size: Optional[int] = None,
|
3558
|
-
attention_mask: Optional[torch.
|
3559
|
-
encoder_attention_mask: Optional[torch.
|
3560
|
-
) -> torch.
|
3567
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3568
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3569
|
+
) -> torch.Tensor:
|
3561
3570
|
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
3562
3571
|
if res_hidden_states_tuple is not None:
|
3563
3572
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
@@ -3675,26 +3684,26 @@ class KAttentionBlock(nn.Module):
|
|
3675
3684
|
cross_attention_norm=cross_attention_norm,
|
3676
3685
|
)
|
3677
3686
|
|
3678
|
-
def _to_3d(self, hidden_states: torch.
|
3687
|
+
def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
|
3679
3688
|
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
|
3680
3689
|
|
3681
|
-
def _to_4d(self, hidden_states: torch.
|
3690
|
+
def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
|
3682
3691
|
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
|
3683
3692
|
|
3684
3693
|
def forward(
|
3685
3694
|
self,
|
3686
|
-
hidden_states: torch.
|
3687
|
-
encoder_hidden_states: Optional[torch.
|
3695
|
+
hidden_states: torch.Tensor,
|
3696
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3688
3697
|
# TODO: mark emb as non-optional (self.norm2 requires it).
|
3689
3698
|
# requires assessing impact of change to positional param interface.
|
3690
|
-
emb: Optional[torch.
|
3691
|
-
attention_mask: Optional[torch.
|
3699
|
+
emb: Optional[torch.Tensor] = None,
|
3700
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3692
3701
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3693
|
-
encoder_attention_mask: Optional[torch.
|
3694
|
-
) -> torch.
|
3702
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3703
|
+
) -> torch.Tensor:
|
3695
3704
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3696
3705
|
if cross_attention_kwargs.get("scale", None) is not None:
|
3697
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
3706
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
3698
3707
|
|
3699
3708
|
# 1. Self-Attention
|
3700
3709
|
if self.add_self_attention:
|