diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -31,11 +31,11 @@ class TransformerTemporalModelOutput(BaseOutput):
|
|
31
31
|
The output of [`TransformerTemporalModel`].
|
32
32
|
|
33
33
|
Args:
|
34
|
-
sample (`torch.
|
34
|
+
sample (`torch.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
35
35
|
The hidden states output conditioned on `encoder_hidden_states` input.
|
36
36
|
"""
|
37
37
|
|
38
|
-
sample: torch.
|
38
|
+
sample: torch.Tensor
|
39
39
|
|
40
40
|
|
41
41
|
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
@@ -120,7 +120,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
120
120
|
|
121
121
|
def forward(
|
122
122
|
self,
|
123
|
-
hidden_states: torch.
|
123
|
+
hidden_states: torch.Tensor,
|
124
124
|
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
125
125
|
timestep: Optional[torch.LongTensor] = None,
|
126
126
|
class_labels: torch.LongTensor = None,
|
@@ -132,7 +132,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
132
132
|
The [`TransformerTemporal`] forward method.
|
133
133
|
|
134
134
|
Args:
|
135
|
-
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.
|
135
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
136
136
|
Input hidden_states.
|
137
137
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
138
138
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
@@ -283,7 +283,7 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
283
283
|
):
|
284
284
|
"""
|
285
285
|
Args:
|
286
|
-
hidden_states (`torch.
|
286
|
+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
287
287
|
Input hidden_states.
|
288
288
|
num_frames (`int`):
|
289
289
|
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
@@ -294,8 +294,8 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
294
294
|
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
295
295
|
images, 0 indicates that the input contains video frames.
|
296
296
|
return_dict (`bool`, *optional*, defaults to `True`):
|
297
|
-
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
|
298
|
-
tuple.
|
297
|
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
|
298
|
+
plain tuple.
|
299
299
|
|
300
300
|
Returns:
|
301
301
|
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
|
|
311
311
|
time_context_first_timestep = time_context[None, :].reshape(
|
312
312
|
batch_size, num_frames, -1, time_context.shape[-1]
|
313
313
|
)[:, 0]
|
314
|
-
time_context = time_context_first_timestep[None
|
315
|
-
height * width,
|
314
|
+
time_context = time_context_first_timestep[:, None].broadcast_to(
|
315
|
+
batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
|
316
316
|
)
|
317
|
-
time_context = time_context.reshape(
|
317
|
+
time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
|
318
318
|
|
319
319
|
residual = hidden_states
|
320
320
|
|
@@ -31,11 +31,11 @@ class UNet1DOutput(BaseOutput):
|
|
31
31
|
The output of [`UNet1DModel`].
|
32
32
|
|
33
33
|
Args:
|
34
|
-
sample (`torch.
|
34
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`):
|
35
35
|
The hidden states output from the last layer of the model.
|
36
36
|
"""
|
37
37
|
|
38
|
-
sample: torch.
|
38
|
+
sample: torch.Tensor
|
39
39
|
|
40
40
|
|
41
41
|
class UNet1DModel(ModelMixin, ConfigMixin):
|
@@ -194,7 +194,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
194
194
|
|
195
195
|
def forward(
|
196
196
|
self,
|
197
|
-
sample: torch.
|
197
|
+
sample: torch.Tensor,
|
198
198
|
timestep: Union[torch.Tensor, float, int],
|
199
199
|
return_dict: bool = True,
|
200
200
|
) -> Union[UNet1DOutput, Tuple]:
|
@@ -202,9 +202,9 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
|
202
202
|
The [`UNet1DModel`] forward method.
|
203
203
|
|
204
204
|
Args:
|
205
|
-
sample (`torch.
|
205
|
+
sample (`torch.Tensor`):
|
206
206
|
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
207
|
-
timestep (`torch.
|
207
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
208
208
|
return_dict (`bool`, *optional*, defaults to `True`):
|
209
209
|
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
210
210
|
|
@@ -66,7 +66,7 @@ class DownResnetBlock1D(nn.Module):
|
|
66
66
|
if add_downsample:
|
67
67
|
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
68
68
|
|
69
|
-
def forward(self, hidden_states: torch.
|
69
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
70
70
|
output_states = ()
|
71
71
|
|
72
72
|
hidden_states = self.resnets[0](hidden_states, temb)
|
@@ -128,10 +128,10 @@ class UpResnetBlock1D(nn.Module):
|
|
128
128
|
|
129
129
|
def forward(
|
130
130
|
self,
|
131
|
-
hidden_states: torch.
|
132
|
-
res_hidden_states_tuple: Optional[Tuple[torch.
|
133
|
-
temb: Optional[torch.
|
134
|
-
) -> torch.
|
131
|
+
hidden_states: torch.Tensor,
|
132
|
+
res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None,
|
133
|
+
temb: Optional[torch.Tensor] = None,
|
134
|
+
) -> torch.Tensor:
|
135
135
|
if res_hidden_states_tuple is not None:
|
136
136
|
res_hidden_states = res_hidden_states_tuple[-1]
|
137
137
|
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
@@ -161,7 +161,7 @@ class ValueFunctionMidBlock1D(nn.Module):
|
|
161
161
|
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
162
162
|
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
163
163
|
|
164
|
-
def forward(self, x: torch.
|
164
|
+
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
165
165
|
x = self.res1(x, temb)
|
166
166
|
x = self.down1(x)
|
167
167
|
x = self.res2(x, temb)
|
@@ -209,7 +209,7 @@ class MidResTemporalBlock1D(nn.Module):
|
|
209
209
|
if self.upsample and self.downsample:
|
210
210
|
raise ValueError("Block cannot downsample and upsample")
|
211
211
|
|
212
|
-
def forward(self, hidden_states: torch.
|
212
|
+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
213
213
|
hidden_states = self.resnets[0](hidden_states, temb)
|
214
214
|
for resnet in self.resnets[1:]:
|
215
215
|
hidden_states = resnet(hidden_states, temb)
|
@@ -230,7 +230,7 @@ class OutConv1DBlock(nn.Module):
|
|
230
230
|
self.final_conv1d_act = get_activation(act_fn)
|
231
231
|
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
232
232
|
|
233
|
-
def forward(self, hidden_states: torch.
|
233
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
234
234
|
hidden_states = self.final_conv1d_1(hidden_states)
|
235
235
|
hidden_states = rearrange_dims(hidden_states)
|
236
236
|
hidden_states = self.final_conv1d_gn(hidden_states)
|
@@ -251,7 +251,7 @@ class OutValueFunctionBlock(nn.Module):
|
|
251
251
|
]
|
252
252
|
)
|
253
253
|
|
254
|
-
def forward(self, hidden_states: torch.
|
254
|
+
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
255
255
|
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
256
256
|
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
257
257
|
for layer in self.final_block:
|
@@ -288,7 +288,7 @@ class Downsample1d(nn.Module):
|
|
288
288
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
289
289
|
self.register_buffer("kernel", kernel_1d)
|
290
290
|
|
291
|
-
def forward(self, hidden_states: torch.
|
291
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
292
292
|
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
293
293
|
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
294
294
|
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
@@ -305,7 +305,7 @@ class Upsample1d(nn.Module):
|
|
305
305
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
306
306
|
self.register_buffer("kernel", kernel_1d)
|
307
307
|
|
308
|
-
def forward(self, hidden_states: torch.
|
308
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
309
309
|
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
310
310
|
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
311
311
|
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
@@ -335,7 +335,7 @@ class SelfAttention1d(nn.Module):
|
|
335
335
|
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
336
336
|
return new_projection
|
337
337
|
|
338
|
-
def forward(self, hidden_states: torch.
|
338
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
339
339
|
residual = hidden_states
|
340
340
|
batch, channel_dim, seq = hidden_states.shape
|
341
341
|
|
@@ -390,7 +390,7 @@ class ResConvBlock(nn.Module):
|
|
390
390
|
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
391
391
|
self.gelu_2 = nn.GELU()
|
392
392
|
|
393
|
-
def forward(self, hidden_states: torch.
|
393
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
394
394
|
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
395
395
|
|
396
396
|
hidden_states = self.conv_1(hidden_states)
|
@@ -435,7 +435,7 @@ class UNetMidBlock1D(nn.Module):
|
|
435
435
|
self.attentions = nn.ModuleList(attentions)
|
436
436
|
self.resnets = nn.ModuleList(resnets)
|
437
437
|
|
438
|
-
def forward(self, hidden_states: torch.
|
438
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
439
439
|
hidden_states = self.down(hidden_states)
|
440
440
|
for attn, resnet in zip(self.attentions, self.resnets):
|
441
441
|
hidden_states = resnet(hidden_states)
|
@@ -466,7 +466,7 @@ class AttnDownBlock1D(nn.Module):
|
|
466
466
|
self.attentions = nn.ModuleList(attentions)
|
467
467
|
self.resnets = nn.ModuleList(resnets)
|
468
468
|
|
469
|
-
def forward(self, hidden_states: torch.
|
469
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
470
470
|
hidden_states = self.down(hidden_states)
|
471
471
|
|
472
472
|
for resnet, attn in zip(self.resnets, self.attentions):
|
@@ -490,7 +490,7 @@ class DownBlock1D(nn.Module):
|
|
490
490
|
|
491
491
|
self.resnets = nn.ModuleList(resnets)
|
492
492
|
|
493
|
-
def forward(self, hidden_states: torch.
|
493
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
494
494
|
hidden_states = self.down(hidden_states)
|
495
495
|
|
496
496
|
for resnet in self.resnets:
|
@@ -512,7 +512,7 @@ class DownBlock1DNoSkip(nn.Module):
|
|
512
512
|
|
513
513
|
self.resnets = nn.ModuleList(resnets)
|
514
514
|
|
515
|
-
def forward(self, hidden_states: torch.
|
515
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
516
516
|
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
517
517
|
for resnet in self.resnets:
|
518
518
|
hidden_states = resnet(hidden_states)
|
@@ -542,10 +542,10 @@ class AttnUpBlock1D(nn.Module):
|
|
542
542
|
|
543
543
|
def forward(
|
544
544
|
self,
|
545
|
-
hidden_states: torch.
|
546
|
-
res_hidden_states_tuple: Tuple[torch.
|
547
|
-
temb: Optional[torch.
|
548
|
-
) -> torch.
|
545
|
+
hidden_states: torch.Tensor,
|
546
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
547
|
+
temb: Optional[torch.Tensor] = None,
|
548
|
+
) -> torch.Tensor:
|
549
549
|
res_hidden_states = res_hidden_states_tuple[-1]
|
550
550
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
551
551
|
|
@@ -574,10 +574,10 @@ class UpBlock1D(nn.Module):
|
|
574
574
|
|
575
575
|
def forward(
|
576
576
|
self,
|
577
|
-
hidden_states: torch.
|
578
|
-
res_hidden_states_tuple: Tuple[torch.
|
579
|
-
temb: Optional[torch.
|
580
|
-
) -> torch.
|
577
|
+
hidden_states: torch.Tensor,
|
578
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
579
|
+
temb: Optional[torch.Tensor] = None,
|
580
|
+
) -> torch.Tensor:
|
581
581
|
res_hidden_states = res_hidden_states_tuple[-1]
|
582
582
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
583
583
|
|
@@ -604,10 +604,10 @@ class UpBlock1DNoSkip(nn.Module):
|
|
604
604
|
|
605
605
|
def forward(
|
606
606
|
self,
|
607
|
-
hidden_states: torch.
|
608
|
-
res_hidden_states_tuple: Tuple[torch.
|
609
|
-
temb: Optional[torch.
|
610
|
-
) -> torch.
|
607
|
+
hidden_states: torch.Tensor,
|
608
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
609
|
+
temb: Optional[torch.Tensor] = None,
|
610
|
+
) -> torch.Tensor:
|
611
611
|
res_hidden_states = res_hidden_states_tuple[-1]
|
612
612
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
613
613
|
|
@@ -30,11 +30,11 @@ class UNet2DOutput(BaseOutput):
|
|
30
30
|
The output of [`UNet2DModel`].
|
31
31
|
|
32
32
|
Args:
|
33
|
-
sample (`torch.
|
33
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
34
34
|
The hidden states output from the last layer of the model.
|
35
35
|
"""
|
36
36
|
|
37
|
-
sample: torch.
|
37
|
+
sample: torch.Tensor
|
38
38
|
|
39
39
|
|
40
40
|
class UNet2DModel(ModelMixin, ConfigMixin):
|
@@ -242,7 +242,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
242
242
|
|
243
243
|
def forward(
|
244
244
|
self,
|
245
|
-
sample: torch.
|
245
|
+
sample: torch.Tensor,
|
246
246
|
timestep: Union[torch.Tensor, float, int],
|
247
247
|
class_labels: Optional[torch.Tensor] = None,
|
248
248
|
return_dict: bool = True,
|
@@ -251,10 +251,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
251
251
|
The [`UNet2DModel`] forward method.
|
252
252
|
|
253
253
|
Args:
|
254
|
-
sample (`torch.
|
254
|
+
sample (`torch.Tensor`):
|
255
255
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
256
|
-
timestep (`torch.
|
257
|
-
class_labels (`torch.
|
256
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
257
|
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
258
258
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
259
259
|
return_dict (`bool`, *optional*, defaults to `True`):
|
260
260
|
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|