diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,12 @@ class Transformer2DModelOutput(BaseOutput):
|
|
35
35
|
The output of [`Transformer2DModel`].
|
36
36
|
|
37
37
|
Args:
|
38
|
-
sample (`torch.
|
38
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
39
39
|
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
40
40
|
distributions for the unnoised latent pixels.
|
41
41
|
"""
|
42
42
|
|
43
|
-
sample: torch.
|
43
|
+
sample: torch.Tensor
|
44
44
|
|
45
45
|
|
46
46
|
class Transformer2DModel(ModelMixin, ConfigMixin):
|
@@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
72
72
|
"""
|
73
73
|
|
74
74
|
_supports_gradient_checkpointing = True
|
75
|
+
_no_split_modules = ["BasicTransformerBlock"]
|
75
76
|
|
76
77
|
@register_to_config
|
77
78
|
def __init__(
|
@@ -100,8 +101,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
100
101
|
attention_type: str = "default",
|
101
102
|
caption_channels: int = None,
|
102
103
|
interpolation_scale: float = None,
|
104
|
+
use_additional_conditions: Optional[bool] = None,
|
103
105
|
):
|
104
106
|
super().__init__()
|
107
|
+
|
108
|
+
# Validate inputs.
|
105
109
|
if patch_size is not None:
|
106
110
|
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
|
107
111
|
raise NotImplementedError(
|
@@ -112,13 +116,22 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
112
116
|
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
113
117
|
)
|
114
118
|
|
119
|
+
# Set some common variables used across the board.
|
115
120
|
self.use_linear_projection = use_linear_projection
|
121
|
+
self.interpolation_scale = interpolation_scale
|
122
|
+
self.caption_channels = caption_channels
|
116
123
|
self.num_attention_heads = num_attention_heads
|
117
124
|
self.attention_head_dim = attention_head_dim
|
118
|
-
inner_dim = num_attention_heads * attention_head_dim
|
119
|
-
|
120
|
-
|
121
|
-
|
125
|
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
126
|
+
self.in_channels = in_channels
|
127
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
128
|
+
self.gradient_checkpointing = False
|
129
|
+
if use_additional_conditions is None:
|
130
|
+
if norm_type == "ada_norm_single" and sample_size == 128:
|
131
|
+
use_additional_conditions = True
|
132
|
+
else:
|
133
|
+
use_additional_conditions = False
|
134
|
+
self.use_additional_conditions = use_additional_conditions
|
122
135
|
|
123
136
|
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
124
137
|
# Define whether input is continuous or discrete depending on configuration
|
@@ -129,7 +142,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
129
142
|
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
130
143
|
deprecation_message = (
|
131
144
|
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
132
|
-
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
145
|
+
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
|
133
146
|
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
134
147
|
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
135
148
|
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
@@ -153,104 +166,165 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
153
166
|
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
154
167
|
)
|
155
168
|
|
156
|
-
# 2.
|
169
|
+
# 2. Initialize the right blocks.
|
170
|
+
# These functions follow a common structure:
|
171
|
+
# a. Initialize the input blocks. b. Initialize the transformer blocks.
|
172
|
+
# c. Initialize the output blocks and other projection blocks when necessary.
|
157
173
|
if self.is_input_continuous:
|
158
|
-
self.
|
159
|
-
|
160
|
-
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
161
|
-
if use_linear_projection:
|
162
|
-
self.proj_in = linear_cls(in_channels, inner_dim)
|
163
|
-
else:
|
164
|
-
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
174
|
+
self._init_continuous_input(norm_type=norm_type)
|
165
175
|
elif self.is_input_vectorized:
|
166
|
-
|
167
|
-
|
176
|
+
self._init_vectorized_inputs(norm_type=norm_type)
|
177
|
+
elif self.is_input_patches:
|
178
|
+
self._init_patched_inputs(norm_type=norm_type)
|
168
179
|
|
169
|
-
|
170
|
-
|
171
|
-
self.
|
172
|
-
|
180
|
+
def _init_continuous_input(self, norm_type):
|
181
|
+
self.norm = torch.nn.GroupNorm(
|
182
|
+
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
|
183
|
+
)
|
184
|
+
if self.use_linear_projection:
|
185
|
+
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
|
186
|
+
else:
|
187
|
+
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
|
173
188
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
189
|
+
self.transformer_blocks = nn.ModuleList(
|
190
|
+
[
|
191
|
+
BasicTransformerBlock(
|
192
|
+
self.inner_dim,
|
193
|
+
self.config.num_attention_heads,
|
194
|
+
self.config.attention_head_dim,
|
195
|
+
dropout=self.config.dropout,
|
196
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
197
|
+
activation_fn=self.config.activation_fn,
|
198
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
199
|
+
attention_bias=self.config.attention_bias,
|
200
|
+
only_cross_attention=self.config.only_cross_attention,
|
201
|
+
double_self_attention=self.config.double_self_attention,
|
202
|
+
upcast_attention=self.config.upcast_attention,
|
203
|
+
norm_type=norm_type,
|
204
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
205
|
+
norm_eps=self.config.norm_eps,
|
206
|
+
attention_type=self.config.attention_type,
|
207
|
+
)
|
208
|
+
for _ in range(self.config.num_layers)
|
209
|
+
]
|
210
|
+
)
|
179
211
|
|
180
|
-
|
181
|
-
self.
|
212
|
+
if self.use_linear_projection:
|
213
|
+
self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
|
214
|
+
else:
|
215
|
+
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
|
182
216
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
217
|
+
def _init_vectorized_inputs(self, norm_type):
|
218
|
+
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
219
|
+
assert (
|
220
|
+
self.config.num_vector_embeds is not None
|
221
|
+
), "Transformer2DModel over discrete input must provide num_embed"
|
222
|
+
|
223
|
+
self.height = self.config.sample_size
|
224
|
+
self.width = self.config.sample_size
|
225
|
+
self.num_latent_pixels = self.height * self.width
|
226
|
+
|
227
|
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
228
|
+
num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
|
229
|
+
)
|
195
230
|
|
196
|
-
# 3. Define transformers blocks
|
197
231
|
self.transformer_blocks = nn.ModuleList(
|
198
232
|
[
|
199
233
|
BasicTransformerBlock(
|
200
|
-
inner_dim,
|
201
|
-
num_attention_heads,
|
202
|
-
attention_head_dim,
|
203
|
-
dropout=dropout,
|
204
|
-
cross_attention_dim=cross_attention_dim,
|
205
|
-
activation_fn=activation_fn,
|
206
|
-
num_embeds_ada_norm=num_embeds_ada_norm,
|
207
|
-
attention_bias=attention_bias,
|
208
|
-
only_cross_attention=only_cross_attention,
|
209
|
-
double_self_attention=double_self_attention,
|
210
|
-
upcast_attention=upcast_attention,
|
234
|
+
self.inner_dim,
|
235
|
+
self.config.num_attention_heads,
|
236
|
+
self.config.attention_head_dim,
|
237
|
+
dropout=self.config.dropout,
|
238
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
239
|
+
activation_fn=self.config.activation_fn,
|
240
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
241
|
+
attention_bias=self.config.attention_bias,
|
242
|
+
only_cross_attention=self.config.only_cross_attention,
|
243
|
+
double_self_attention=self.config.double_self_attention,
|
244
|
+
upcast_attention=self.config.upcast_attention,
|
211
245
|
norm_type=norm_type,
|
212
|
-
norm_elementwise_affine=norm_elementwise_affine,
|
213
|
-
norm_eps=norm_eps,
|
214
|
-
attention_type=attention_type,
|
246
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
247
|
+
norm_eps=self.config.norm_eps,
|
248
|
+
attention_type=self.config.attention_type,
|
215
249
|
)
|
216
|
-
for
|
250
|
+
for _ in range(self.config.num_layers)
|
217
251
|
]
|
218
252
|
)
|
219
253
|
|
220
|
-
|
221
|
-
self.
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
self.
|
233
|
-
self.
|
234
|
-
|
235
|
-
|
236
|
-
self.
|
237
|
-
self.
|
238
|
-
self.
|
239
|
-
|
240
|
-
|
254
|
+
self.norm_out = nn.LayerNorm(self.inner_dim)
|
255
|
+
self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
|
256
|
+
|
257
|
+
def _init_patched_inputs(self, norm_type):
|
258
|
+
assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
259
|
+
|
260
|
+
self.height = self.config.sample_size
|
261
|
+
self.width = self.config.sample_size
|
262
|
+
|
263
|
+
self.patch_size = self.config.patch_size
|
264
|
+
interpolation_scale = (
|
265
|
+
self.config.interpolation_scale
|
266
|
+
if self.config.interpolation_scale is not None
|
267
|
+
else max(self.config.sample_size // 64, 1)
|
268
|
+
)
|
269
|
+
self.pos_embed = PatchEmbed(
|
270
|
+
height=self.config.sample_size,
|
271
|
+
width=self.config.sample_size,
|
272
|
+
patch_size=self.config.patch_size,
|
273
|
+
in_channels=self.in_channels,
|
274
|
+
embed_dim=self.inner_dim,
|
275
|
+
interpolation_scale=interpolation_scale,
|
276
|
+
)
|
277
|
+
|
278
|
+
self.transformer_blocks = nn.ModuleList(
|
279
|
+
[
|
280
|
+
BasicTransformerBlock(
|
281
|
+
self.inner_dim,
|
282
|
+
self.config.num_attention_heads,
|
283
|
+
self.config.attention_head_dim,
|
284
|
+
dropout=self.config.dropout,
|
285
|
+
cross_attention_dim=self.config.cross_attention_dim,
|
286
|
+
activation_fn=self.config.activation_fn,
|
287
|
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
288
|
+
attention_bias=self.config.attention_bias,
|
289
|
+
only_cross_attention=self.config.only_cross_attention,
|
290
|
+
double_self_attention=self.config.double_self_attention,
|
291
|
+
upcast_attention=self.config.upcast_attention,
|
292
|
+
norm_type=norm_type,
|
293
|
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
294
|
+
norm_eps=self.config.norm_eps,
|
295
|
+
attention_type=self.config.attention_type,
|
296
|
+
)
|
297
|
+
for _ in range(self.config.num_layers)
|
298
|
+
]
|
299
|
+
)
|
300
|
+
|
301
|
+
if self.config.norm_type != "ada_norm_single":
|
302
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
303
|
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
304
|
+
self.proj_out_2 = nn.Linear(
|
305
|
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
306
|
+
)
|
307
|
+
elif self.config.norm_type == "ada_norm_single":
|
308
|
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
309
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
310
|
+
self.proj_out = nn.Linear(
|
311
|
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
312
|
+
)
|
313
|
+
|
314
|
+
# PixArt-Alpha blocks.
|
241
315
|
self.adaln_single = None
|
242
|
-
self.
|
243
|
-
if norm_type == "ada_norm_single":
|
244
|
-
self.use_additional_conditions = self.config.sample_size == 128
|
316
|
+
if self.config.norm_type == "ada_norm_single":
|
245
317
|
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
246
318
|
# additional conditions until we find better name
|
247
|
-
self.adaln_single = AdaLayerNormSingle(
|
319
|
+
self.adaln_single = AdaLayerNormSingle(
|
320
|
+
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
321
|
+
)
|
248
322
|
|
249
323
|
self.caption_projection = None
|
250
|
-
if caption_channels is not None:
|
251
|
-
self.caption_projection = PixArtAlphaTextProjection(
|
252
|
-
|
253
|
-
|
324
|
+
if self.caption_channels is not None:
|
325
|
+
self.caption_projection = PixArtAlphaTextProjection(
|
326
|
+
in_features=self.caption_channels, hidden_size=self.inner_dim
|
327
|
+
)
|
254
328
|
|
255
329
|
def _set_gradient_checkpointing(self, module, value=False):
|
256
330
|
if hasattr(module, "gradient_checkpointing"):
|
@@ -272,9 +346,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
272
346
|
The [`Transformer2DModel`] forward method.
|
273
347
|
|
274
348
|
Args:
|
275
|
-
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.
|
349
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
276
350
|
Input `hidden_states`.
|
277
|
-
encoder_hidden_states ( `torch.
|
351
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
278
352
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
279
353
|
self-attention.
|
280
354
|
timestep ( `torch.LongTensor`, *optional*):
|
@@ -308,7 +382,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
308
382
|
"""
|
309
383
|
if cross_attention_kwargs is not None:
|
310
384
|
if cross_attention_kwargs.get("scale", None) is not None:
|
311
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
385
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
312
386
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
313
387
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
314
388
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
@@ -334,41 +408,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
334
408
|
|
335
409
|
# 1. Input
|
336
410
|
if self.is_input_continuous:
|
337
|
-
|
411
|
+
batch_size, _, height, width = hidden_states.shape
|
338
412
|
residual = hidden_states
|
339
|
-
|
340
|
-
hidden_states = self.norm(hidden_states)
|
341
|
-
if not self.use_linear_projection:
|
342
|
-
hidden_states = self.proj_in(hidden_states)
|
343
|
-
inner_dim = hidden_states.shape[1]
|
344
|
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
345
|
-
else:
|
346
|
-
inner_dim = hidden_states.shape[1]
|
347
|
-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
348
|
-
hidden_states = self.proj_in(hidden_states)
|
349
|
-
|
413
|
+
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
|
350
414
|
elif self.is_input_vectorized:
|
351
415
|
hidden_states = self.latent_image_embedding(hidden_states)
|
352
416
|
elif self.is_input_patches:
|
353
417
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
354
|
-
hidden_states = self.
|
355
|
-
|
356
|
-
|
357
|
-
if self.use_additional_conditions and added_cond_kwargs is None:
|
358
|
-
raise ValueError(
|
359
|
-
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
360
|
-
)
|
361
|
-
batch_size = hidden_states.shape[0]
|
362
|
-
timestep, embedded_timestep = self.adaln_single(
|
363
|
-
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
364
|
-
)
|
418
|
+
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
|
419
|
+
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
|
420
|
+
)
|
365
421
|
|
366
422
|
# 2. Blocks
|
367
|
-
if self.caption_projection is not None:
|
368
|
-
batch_size = hidden_states.shape[0]
|
369
|
-
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
370
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
371
|
-
|
372
423
|
for block in self.transformer_blocks:
|
373
424
|
if self.training and self.gradient_checkpointing:
|
374
425
|
|
@@ -406,51 +457,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
406
457
|
|
407
458
|
# 3. Output
|
408
459
|
if self.is_input_continuous:
|
409
|
-
|
410
|
-
hidden_states
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
460
|
+
output = self._get_output_for_continuous_inputs(
|
461
|
+
hidden_states=hidden_states,
|
462
|
+
residual=residual,
|
463
|
+
batch_size=batch_size,
|
464
|
+
height=height,
|
465
|
+
width=width,
|
466
|
+
inner_dim=inner_dim,
|
467
|
+
)
|
417
468
|
elif self.is_input_vectorized:
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
469
|
+
output = self._get_output_for_vectorized_inputs(hidden_states)
|
470
|
+
elif self.is_input_patches:
|
471
|
+
output = self._get_output_for_patched_inputs(
|
472
|
+
hidden_states=hidden_states,
|
473
|
+
timestep=timestep,
|
474
|
+
class_labels=class_labels,
|
475
|
+
embedded_timestep=embedded_timestep,
|
476
|
+
height=height,
|
477
|
+
width=width,
|
478
|
+
)
|
479
|
+
|
480
|
+
if not return_dict:
|
481
|
+
return (output,)
|
422
482
|
|
423
|
-
|
424
|
-
|
483
|
+
return Transformer2DModelOutput(sample=output)
|
484
|
+
|
485
|
+
def _operate_on_continuous_inputs(self, hidden_states):
|
486
|
+
batch, _, height, width = hidden_states.shape
|
487
|
+
hidden_states = self.norm(hidden_states)
|
488
|
+
|
489
|
+
if not self.use_linear_projection:
|
490
|
+
hidden_states = self.proj_in(hidden_states)
|
491
|
+
inner_dim = hidden_states.shape[1]
|
492
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
493
|
+
else:
|
494
|
+
inner_dim = hidden_states.shape[1]
|
495
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
496
|
+
hidden_states = self.proj_in(hidden_states)
|
425
497
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
498
|
+
return hidden_states, inner_dim
|
499
|
+
|
500
|
+
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
|
501
|
+
batch_size = hidden_states.shape[0]
|
502
|
+
hidden_states = self.pos_embed(hidden_states)
|
503
|
+
embedded_timestep = None
|
504
|
+
|
505
|
+
if self.adaln_single is not None:
|
506
|
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
507
|
+
raise ValueError(
|
508
|
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
430
509
|
)
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
height
|
445
|
-
hidden_states = hidden_states.reshape(
|
446
|
-
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
510
|
+
timestep, embedded_timestep = self.adaln_single(
|
511
|
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
512
|
+
)
|
513
|
+
|
514
|
+
if self.caption_projection is not None:
|
515
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
516
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
517
|
+
|
518
|
+
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
|
519
|
+
|
520
|
+
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
|
521
|
+
if not self.use_linear_projection:
|
522
|
+
hidden_states = (
|
523
|
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
447
524
|
)
|
448
|
-
hidden_states =
|
449
|
-
|
450
|
-
|
525
|
+
hidden_states = self.proj_out(hidden_states)
|
526
|
+
else:
|
527
|
+
hidden_states = self.proj_out(hidden_states)
|
528
|
+
hidden_states = (
|
529
|
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
451
530
|
)
|
452
531
|
|
453
|
-
|
454
|
-
|
532
|
+
output = hidden_states + residual
|
533
|
+
return output
|
455
534
|
|
456
|
-
|
535
|
+
def _get_output_for_vectorized_inputs(self, hidden_states):
|
536
|
+
hidden_states = self.norm_out(hidden_states)
|
537
|
+
logits = self.out(hidden_states)
|
538
|
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
539
|
+
logits = logits.permute(0, 2, 1)
|
540
|
+
# log(p(x_0))
|
541
|
+
output = F.log_softmax(logits.double(), dim=1).float()
|
542
|
+
return output
|
543
|
+
|
544
|
+
def _get_output_for_patched_inputs(
|
545
|
+
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
|
546
|
+
):
|
547
|
+
if self.config.norm_type != "ada_norm_single":
|
548
|
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
549
|
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
550
|
+
)
|
551
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
552
|
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
553
|
+
hidden_states = self.proj_out_2(hidden_states)
|
554
|
+
elif self.config.norm_type == "ada_norm_single":
|
555
|
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
556
|
+
hidden_states = self.norm_out(hidden_states)
|
557
|
+
# Modulation
|
558
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
559
|
+
hidden_states = self.proj_out(hidden_states)
|
560
|
+
hidden_states = hidden_states.squeeze(1)
|
561
|
+
|
562
|
+
# unpatchify
|
563
|
+
if self.adaln_single is None:
|
564
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
565
|
+
hidden_states = hidden_states.reshape(
|
566
|
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
567
|
+
)
|
568
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
569
|
+
output = hidden_states.reshape(
|
570
|
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
571
|
+
)
|
572
|
+
return output
|
@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
|
|
31
31
|
The output of [`TransformerTemporalModel`].
|
32
32
|
|
33
33
|
Args:
|
34
|
-
sample (`torch.
|
34
|
+
sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
35
35
|
The hidden states output conditioned on `encoder_hidden_states` input.
|
36
36
|
"""
|
37
37
|
|
38
|
-
sample: torch.
|
38
|
+
sample: torch.Tensor
|
39
39
|
|
40
40
|
|
41
41
|
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
120
120
|
|
121
121
|
def forward(
|
122
122
|
self,
|
123
|
-
hidden_states: torch.
|
123
|
+
hidden_states: torch.Tensor,
|
124
124
|
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
125
125
|
timestep: Optional[torch.LongTensor] = None,
|
126
126
|
class_labels: torch.LongTensor = None,
|
@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
132
132
|
The [`TransformerTemporal`] forward method.
|
133
133
|
|
134
134
|
Args:
|
135
|
-
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.
|
135
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
136
136
|
Input hidden_states.
|
137
137
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
138
138
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
283
283
|
):
|
284
284
|
"""
|
285
285
|
Args:
|
286
|
-
hidden_states (`torch.
|
286
|
+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
287
287
|
Input hidden_states.
|
288
288
|
num_frames (`int`):
|
289
289
|
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
294
294
|
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
295
295
|
images, 0 indicates that the input contains video frames.
|
296
296
|
return_dict (`bool`, *optional*, defaults to `True`):
|
297
|
-
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
|
298
|
-
tuple.
|
297
|
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
|
298
|
+
plain tuple.
|
299
299
|
|
300
300
|
Returns:
|
301
301
|
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
311
311
|
time_context_first_timestep = time_context[None, :].reshape(
|
312
312
|
batch_size, num_frames, -1, time_context.shape[-1]
|
313
313
|
)[:, 0]
|
314
|
-
time_context = time_context_first_timestep[None
|
315
|
-
height * width,
|
314
|
+
time_context = time_context_first_timestep[:, None].broadcast_to(
|
315
|
+
batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
|
316
316
|
)
|
317
|
-
time_context = time_context.reshape(
|
317
|
+
time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
|
318
318
|
|
319
319
|
residual = hidden_states
|
320
320
|
|
@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
|
|
31
31
|
The output of [`UNet1DModel`].
|
32
32
|
|
33
33
|
Args:
|
34
|
-
sample (`torch.
|
34
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
|
35
35
|
The hidden states output from the last layer of the model.
|
36
36
|
"""
|
37
37
|
|
38
|
-
sample: torch.
|
38
|
+
sample: torch.Tensor
|
39
39
|
|
40
40
|
|
41
41
|
class UNet1DModel(ModelMixin, ConfigMixin):
|
@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
194
194
|
|
195
195
|
def forward(
|
196
196
|
self,
|
197
|
-
sample: torch.
|
197
|
+
sample: torch.Tensor,
|
198
198
|
timestep: Union[torch.Tensor, float, int],
|
199
199
|
return_dict: bool = True,
|
200
200
|
) -> Union[UNet1DOutput, Tuple]:
|
@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
202
202
|
The [`UNet1DModel`] forward method.
|
203
203
|
|
204
204
|
Args:
|
205
|
-
sample (`torch.
|
205
|
+
sample (`torch.Tensor`):
|
206
206
|
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
207
|
-
timestep (`torch.
|
207
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
208
208
|
return_dict (`bool`, *optional*, defaults to `True`):
|
209
209
|
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
210
210
|
|