diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -112,9 +112,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
112
112
|
self.register_to_config(force_upcast=False)
|
113
113
|
|
114
114
|
@apply_forward_hook
|
115
|
-
def encode(
|
116
|
-
self, x: torch.FloatTensor, return_dict: bool = True
|
117
|
-
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
|
115
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
|
118
116
|
h = self.encoder(x)
|
119
117
|
moments = self.quant_conv(h)
|
120
118
|
posterior = DiagonalGaussianDistribution(moments)
|
@@ -126,11 +124,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
126
124
|
|
127
125
|
def _decode(
|
128
126
|
self,
|
129
|
-
z: torch.
|
130
|
-
image: Optional[torch.
|
131
|
-
mask: Optional[torch.
|
127
|
+
z: torch.Tensor,
|
128
|
+
image: Optional[torch.Tensor] = None,
|
129
|
+
mask: Optional[torch.Tensor] = None,
|
132
130
|
return_dict: bool = True,
|
133
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
131
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
134
132
|
z = self.post_quant_conv(z)
|
135
133
|
dec = self.decoder(z, image, mask)
|
136
134
|
|
@@ -142,12 +140,12 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
142
140
|
@apply_forward_hook
|
143
141
|
def decode(
|
144
142
|
self,
|
145
|
-
z: torch.
|
143
|
+
z: torch.Tensor,
|
146
144
|
generator: Optional[torch.Generator] = None,
|
147
|
-
image: Optional[torch.
|
148
|
-
mask: Optional[torch.
|
145
|
+
image: Optional[torch.Tensor] = None,
|
146
|
+
mask: Optional[torch.Tensor] = None,
|
149
147
|
return_dict: bool = True,
|
150
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
148
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
151
149
|
decoded = self._decode(z, image, mask).sample
|
152
150
|
|
153
151
|
if not return_dict:
|
@@ -157,16 +155,16 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
157
155
|
|
158
156
|
def forward(
|
159
157
|
self,
|
160
|
-
sample: torch.
|
161
|
-
mask: Optional[torch.
|
158
|
+
sample: torch.Tensor,
|
159
|
+
mask: Optional[torch.Tensor] = None,
|
162
160
|
sample_posterior: bool = False,
|
163
161
|
return_dict: bool = True,
|
164
162
|
generator: Optional[torch.Generator] = None,
|
165
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
163
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
166
164
|
r"""
|
167
165
|
Args:
|
168
|
-
sample (`torch.
|
169
|
-
mask (`torch.
|
166
|
+
sample (`torch.Tensor`): Input sample.
|
167
|
+
mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
|
170
168
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
171
169
|
Whether to sample from the posterior.
|
172
170
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
-
from ...loaders import
|
20
|
+
from ...loaders.single_file_model import FromOriginalModelMixin
|
21
21
|
from ...utils.accelerate_utils import apply_forward_hook
|
22
22
|
from ..attention_processor import (
|
23
23
|
ADDED_KV_ATTENTION_PROCESSORS,
|
@@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin
|
|
32
32
|
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
33
33
|
|
34
34
|
|
35
|
-
class AutoencoderKL(ModelMixin, ConfigMixin,
|
35
|
+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
36
36
|
r"""
|
37
37
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
38
38
|
|
@@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
65
65
|
"""
|
66
66
|
|
67
67
|
_supports_gradient_checkpointing = True
|
68
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
|
68
69
|
|
69
70
|
@register_to_config
|
70
71
|
def __init__(
|
@@ -236,13 +237,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
236
237
|
|
237
238
|
@apply_forward_hook
|
238
239
|
def encode(
|
239
|
-
self, x: torch.
|
240
|
+
self, x: torch.Tensor, return_dict: bool = True
|
240
241
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
241
242
|
"""
|
242
243
|
Encode a batch of images into latents.
|
243
244
|
|
244
245
|
Args:
|
245
|
-
x (`torch.
|
246
|
+
x (`torch.Tensor`): Input batch of images.
|
246
247
|
return_dict (`bool`, *optional*, defaults to `True`):
|
247
248
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
248
249
|
|
@@ -267,7 +268,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
267
268
|
|
268
269
|
return AutoencoderKLOutput(latent_dist=posterior)
|
269
270
|
|
270
|
-
def _decode(self, z: torch.
|
271
|
+
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
271
272
|
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
272
273
|
return self.tiled_decode(z, return_dict=return_dict)
|
273
274
|
|
@@ -280,14 +281,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
280
281
|
return DecoderOutput(sample=dec)
|
281
282
|
|
282
283
|
@apply_forward_hook
|
283
|
-
def decode(
|
284
|
-
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
285
|
-
) -> Union[DecoderOutput, torch.FloatTensor]:
|
284
|
+
def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
|
286
285
|
"""
|
287
286
|
Decode a batch of images.
|
288
287
|
|
289
288
|
Args:
|
290
|
-
z (`torch.
|
289
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
291
290
|
return_dict (`bool`, *optional*, defaults to `True`):
|
292
291
|
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
293
292
|
|
@@ -301,7 +300,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
301
300
|
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
302
301
|
decoded = torch.cat(decoded_slices)
|
303
302
|
else:
|
304
|
-
decoded = self._decode(z)
|
303
|
+
decoded = self._decode(z, return_dict=False)[0]
|
305
304
|
|
306
305
|
if not return_dict:
|
307
306
|
return (decoded,)
|
@@ -320,7 +319,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
320
319
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
321
320
|
return b
|
322
321
|
|
323
|
-
def tiled_encode(self, x: torch.
|
322
|
+
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
324
323
|
r"""Encode a batch of images using a tiled encoder.
|
325
324
|
|
326
325
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
@@ -330,7 +329,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
330
329
|
output, but they should be much less noticeable.
|
331
330
|
|
332
331
|
Args:
|
333
|
-
x (`torch.
|
332
|
+
x (`torch.Tensor`): Input batch of images.
|
334
333
|
return_dict (`bool`, *optional*, defaults to `True`):
|
335
334
|
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
336
335
|
|
@@ -374,12 +373,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
374
373
|
|
375
374
|
return AutoencoderKLOutput(latent_dist=posterior)
|
376
375
|
|
377
|
-
def tiled_decode(self, z: torch.
|
376
|
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
378
377
|
r"""
|
379
378
|
Decode a batch of images using a tiled decoder.
|
380
379
|
|
381
380
|
Args:
|
382
|
-
z (`torch.
|
381
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
383
382
|
return_dict (`bool`, *optional*, defaults to `True`):
|
384
383
|
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
385
384
|
|
@@ -424,14 +423,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
424
423
|
|
425
424
|
def forward(
|
426
425
|
self,
|
427
|
-
sample: torch.
|
426
|
+
sample: torch.Tensor,
|
428
427
|
sample_posterior: bool = False,
|
429
428
|
return_dict: bool = True,
|
430
429
|
generator: Optional[torch.Generator] = None,
|
431
|
-
) -> Union[DecoderOutput, torch.
|
430
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
432
431
|
r"""
|
433
432
|
Args:
|
434
|
-
sample (`torch.
|
433
|
+
sample (`torch.Tensor`): Input sample.
|
435
434
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
436
435
|
Whether to sample from the posterior.
|
437
436
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -453,8 +452,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
453
452
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
454
453
|
def fuse_qkv_projections(self):
|
455
454
|
"""
|
456
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
457
|
-
|
455
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
456
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
458
457
|
|
459
458
|
<Tip warning={true}>
|
460
459
|
|
@@ -86,10 +86,10 @@ class TemporalDecoder(nn.Module):
|
|
86
86
|
|
87
87
|
def forward(
|
88
88
|
self,
|
89
|
-
sample: torch.
|
90
|
-
image_only_indicator: torch.
|
89
|
+
sample: torch.Tensor,
|
90
|
+
image_only_indicator: torch.Tensor,
|
91
91
|
num_frames: int = 1,
|
92
|
-
) -> torch.
|
92
|
+
) -> torch.Tensor:
|
93
93
|
r"""The forward method of the `Decoder` class."""
|
94
94
|
|
95
95
|
sample = self.conv_in(sample)
|
@@ -315,13 +315,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|
315
315
|
|
316
316
|
@apply_forward_hook
|
317
317
|
def encode(
|
318
|
-
self, x: torch.
|
318
|
+
self, x: torch.Tensor, return_dict: bool = True
|
319
319
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
320
320
|
"""
|
321
321
|
Encode a batch of images into latents.
|
322
322
|
|
323
323
|
Args:
|
324
|
-
x (`torch.
|
324
|
+
x (`torch.Tensor`): Input batch of images.
|
325
325
|
return_dict (`bool`, *optional*, defaults to `True`):
|
326
326
|
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
327
327
|
|
@@ -341,15 +341,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|
341
341
|
@apply_forward_hook
|
342
342
|
def decode(
|
343
343
|
self,
|
344
|
-
z: torch.
|
344
|
+
z: torch.Tensor,
|
345
345
|
num_frames: int,
|
346
346
|
return_dict: bool = True,
|
347
|
-
) -> Union[DecoderOutput, torch.
|
347
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
348
348
|
"""
|
349
349
|
Decode a batch of images.
|
350
350
|
|
351
351
|
Args:
|
352
|
-
z (`torch.
|
352
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
353
353
|
return_dict (`bool`, *optional*, defaults to `True`):
|
354
354
|
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
355
355
|
|
@@ -370,15 +370,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|
370
370
|
|
371
371
|
def forward(
|
372
372
|
self,
|
373
|
-
sample: torch.
|
373
|
+
sample: torch.Tensor,
|
374
374
|
sample_posterior: bool = False,
|
375
375
|
return_dict: bool = True,
|
376
376
|
generator: Optional[torch.Generator] = None,
|
377
377
|
num_frames: int = 1,
|
378
|
-
) -> Union[DecoderOutput, torch.
|
378
|
+
) -> Union[DecoderOutput, torch.Tensor]:
|
379
379
|
r"""
|
380
380
|
Args:
|
381
|
-
sample (`torch.
|
381
|
+
sample (`torch.Tensor`): Input sample.
|
382
382
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
383
383
|
Whether to sample from the posterior.
|
384
384
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -102,6 +102,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
102
102
|
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
103
103
|
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
104
104
|
act_fn: str = "relu",
|
105
|
+
upsample_fn: str = "nearest",
|
105
106
|
latent_channels: int = 4,
|
106
107
|
upsampling_scaling_factor: int = 2,
|
107
108
|
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
@@ -133,6 +134,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
133
134
|
block_out_channels=decoder_block_out_channels,
|
134
135
|
upsampling_scaling_factor=upsampling_scaling_factor,
|
135
136
|
act_fn=act_fn,
|
137
|
+
upsample_fn=upsample_fn,
|
136
138
|
)
|
137
139
|
|
138
140
|
self.latent_magnitude = latent_magnitude
|
@@ -155,11 +157,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
155
157
|
if isinstance(module, (EncoderTiny, DecoderTiny)):
|
156
158
|
module.gradient_checkpointing = value
|
157
159
|
|
158
|
-
def scale_latents(self, x: torch.
|
160
|
+
def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
|
159
161
|
"""raw latents -> [0, 1]"""
|
160
162
|
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
|
161
163
|
|
162
|
-
def unscale_latents(self, x: torch.
|
164
|
+
def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
|
163
165
|
"""[0, 1] -> raw latents"""
|
164
166
|
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
|
165
167
|
|
@@ -192,7 +194,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
192
194
|
"""
|
193
195
|
self.enable_tiling(False)
|
194
196
|
|
195
|
-
def _tiled_encode(self, x: torch.
|
197
|
+
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
196
198
|
r"""Encode a batch of images using a tiled encoder.
|
197
199
|
|
198
200
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
@@ -200,10 +202,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
200
202
|
tiles overlap and are blended together to form a smooth output.
|
201
203
|
|
202
204
|
Args:
|
203
|
-
x (`torch.
|
205
|
+
x (`torch.Tensor`): Input batch of images.
|
204
206
|
|
205
207
|
Returns:
|
206
|
-
`torch.
|
208
|
+
`torch.Tensor`: Encoded batch of images.
|
207
209
|
"""
|
208
210
|
# scale of encoder output relative to input
|
209
211
|
sf = self.spatial_scale_factor
|
@@ -240,7 +242,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
240
242
|
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
|
241
243
|
return out
|
242
244
|
|
243
|
-
def _tiled_decode(self, x: torch.
|
245
|
+
def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
|
244
246
|
r"""Encode a batch of images using a tiled encoder.
|
245
247
|
|
246
248
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
@@ -248,10 +250,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
248
250
|
tiles overlap and are blended together to form a smooth output.
|
249
251
|
|
250
252
|
Args:
|
251
|
-
x (`torch.
|
253
|
+
x (`torch.Tensor`): Input batch of images.
|
252
254
|
|
253
255
|
Returns:
|
254
|
-
`torch.
|
256
|
+
`torch.Tensor`: Encoded batch of images.
|
255
257
|
"""
|
256
258
|
# scale of decoder output relative to input
|
257
259
|
sf = self.spatial_scale_factor
|
@@ -288,9 +290,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
288
290
|
return out
|
289
291
|
|
290
292
|
@apply_forward_hook
|
291
|
-
def encode(
|
292
|
-
self, x: torch.FloatTensor, return_dict: bool = True
|
293
|
-
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
|
293
|
+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
|
294
294
|
if self.use_slicing and x.shape[0] > 1:
|
295
295
|
output = [
|
296
296
|
self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
|
@@ -306,8 +306,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
306
306
|
|
307
307
|
@apply_forward_hook
|
308
308
|
def decode(
|
309
|
-
self, x: torch.
|
310
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
309
|
+
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
310
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
311
311
|
if self.use_slicing and x.shape[0] > 1:
|
312
312
|
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
313
313
|
output = torch.cat(output)
|
@@ -321,12 +321,12 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
321
321
|
|
322
322
|
def forward(
|
323
323
|
self,
|
324
|
-
sample: torch.
|
324
|
+
sample: torch.Tensor,
|
325
325
|
return_dict: bool = True,
|
326
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
326
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
327
327
|
r"""
|
328
328
|
Args:
|
329
|
-
sample (`torch.
|
329
|
+
sample (`torch.Tensor`): Input sample.
|
330
330
|
return_dict (`bool`, *optional*, defaults to `True`):
|
331
331
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
332
332
|
"""
|
@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
63
63
|
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
|
64
64
|
... ).to("cuda")
|
65
65
|
|
66
|
-
>>> pipe("horse", generator=torch.manual_seed(0)).images
|
66
|
+
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
|
67
|
+
>>> image
|
67
68
|
```
|
68
69
|
"""
|
69
70
|
|
@@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
72
73
|
self,
|
73
74
|
scaling_factor: float = 0.18215,
|
74
75
|
latent_channels: int = 4,
|
76
|
+
sample_size: int = 32,
|
75
77
|
encoder_act_fn: str = "silu",
|
76
78
|
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
77
79
|
encoder_double_z: bool = True,
|
@@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
153
155
|
self.use_slicing = False
|
154
156
|
self.use_tiling = False
|
155
157
|
|
158
|
+
# only relevant if vae tiling is enabled
|
159
|
+
self.tile_sample_min_size = self.config.sample_size
|
160
|
+
sample_size = (
|
161
|
+
self.config.sample_size[0]
|
162
|
+
if isinstance(self.config.sample_size, (list, tuple))
|
163
|
+
else self.config.sample_size
|
164
|
+
)
|
165
|
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
166
|
+
self.tile_overlap_factor = 0.25
|
167
|
+
|
156
168
|
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
|
157
169
|
def enable_tiling(self, use_tiling: bool = True):
|
158
170
|
r"""
|
@@ -264,15 +276,15 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
264
276
|
|
265
277
|
@apply_forward_hook
|
266
278
|
def encode(
|
267
|
-
self, x: torch.
|
279
|
+
self, x: torch.Tensor, return_dict: bool = True
|
268
280
|
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
|
269
281
|
"""
|
270
282
|
Encode a batch of images into latents.
|
271
283
|
|
272
284
|
Args:
|
273
|
-
x (`torch.
|
285
|
+
x (`torch.Tensor`): Input batch of images.
|
274
286
|
return_dict (`bool`, *optional*, defaults to `True`):
|
275
|
-
Whether to return a [`~models.
|
287
|
+
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
|
276
288
|
tuple.
|
277
289
|
|
278
290
|
Returns:
|
@@ -300,11 +312,24 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
300
312
|
@apply_forward_hook
|
301
313
|
def decode(
|
302
314
|
self,
|
303
|
-
z: torch.
|
315
|
+
z: torch.Tensor,
|
304
316
|
generator: Optional[torch.Generator] = None,
|
305
317
|
return_dict: bool = True,
|
306
318
|
num_inference_steps: int = 2,
|
307
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
319
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
320
|
+
"""
|
321
|
+
Decodes the input latent vector `z` using the consistency decoder VAE model.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
z (torch.Tensor): The input latent vector.
|
325
|
+
generator (Optional[torch.Generator]): The random number generator. Default is None.
|
326
|
+
return_dict (bool): Whether to return the output as a dictionary. Default is True.
|
327
|
+
num_inference_steps (int): The number of inference steps. Default is 2.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
|
331
|
+
|
332
|
+
"""
|
308
333
|
z = (z * self.config.scaling_factor - self.means) / self.stds
|
309
334
|
|
310
335
|
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
@@ -345,7 +370,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
345
370
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
346
371
|
return b
|
347
372
|
|
348
|
-
def tiled_encode(self, x: torch.
|
373
|
+
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
|
349
374
|
r"""Encode a batch of images using a tiled encoder.
|
350
375
|
|
351
376
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
@@ -355,7 +380,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
355
380
|
output, but they should be much less noticeable.
|
356
381
|
|
357
382
|
Args:
|
358
|
-
x (`torch.
|
383
|
+
x (`torch.Tensor`): Input batch of images.
|
359
384
|
return_dict (`bool`, *optional*, defaults to `True`):
|
360
385
|
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
|
361
386
|
plain tuple.
|
@@ -402,14 +427,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
402
427
|
|
403
428
|
def forward(
|
404
429
|
self,
|
405
|
-
sample: torch.
|
430
|
+
sample: torch.Tensor,
|
406
431
|
sample_posterior: bool = False,
|
407
432
|
return_dict: bool = True,
|
408
433
|
generator: Optional[torch.Generator] = None,
|
409
|
-
) -> Union[DecoderOutput, Tuple[torch.
|
434
|
+
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
410
435
|
r"""
|
411
436
|
Args:
|
412
|
-
sample (`torch.
|
437
|
+
sample (`torch.Tensor`): Input sample.
|
413
438
|
sample_posterior (`bool`, *optional*, defaults to `False`):
|
414
439
|
Whether to sample from the posterior.
|
415
440
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -36,11 +36,12 @@ class DecoderOutput(BaseOutput):
|
|
36
36
|
Output of decoding method.
|
37
37
|
|
38
38
|
Args:
|
39
|
-
sample (`torch.
|
39
|
+
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
40
40
|
The decoded output sample from the last layer of the model.
|
41
41
|
"""
|
42
42
|
|
43
|
-
sample: torch.
|
43
|
+
sample: torch.Tensor
|
44
|
+
commit_loss: Optional[torch.FloatTensor] = None
|
44
45
|
|
45
46
|
|
46
47
|
class Encoder(nn.Module):
|
@@ -90,7 +91,6 @@ class Encoder(nn.Module):
|
|
90
91
|
padding=1,
|
91
92
|
)
|
92
93
|
|
93
|
-
self.mid_block = None
|
94
94
|
self.down_blocks = nn.ModuleList([])
|
95
95
|
|
96
96
|
# down
|
@@ -137,7 +137,7 @@ class Encoder(nn.Module):
|
|
137
137
|
|
138
138
|
self.gradient_checkpointing = False
|
139
139
|
|
140
|
-
def forward(self, sample: torch.
|
140
|
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
141
141
|
r"""The forward method of the `Encoder` class."""
|
142
142
|
|
143
143
|
sample = self.conv_in(sample)
|
@@ -228,7 +228,6 @@ class Decoder(nn.Module):
|
|
228
228
|
padding=1,
|
229
229
|
)
|
230
230
|
|
231
|
-
self.mid_block = None
|
232
231
|
self.up_blocks = nn.ModuleList([])
|
233
232
|
|
234
233
|
temb_channels = in_channels if norm_type == "spatial" else None
|
@@ -284,9 +283,9 @@ class Decoder(nn.Module):
|
|
284
283
|
|
285
284
|
def forward(
|
286
285
|
self,
|
287
|
-
sample: torch.
|
288
|
-
latent_embeds: Optional[torch.
|
289
|
-
) -> torch.
|
286
|
+
sample: torch.Tensor,
|
287
|
+
latent_embeds: Optional[torch.Tensor] = None,
|
288
|
+
) -> torch.Tensor:
|
290
289
|
r"""The forward method of the `Decoder` class."""
|
291
290
|
|
292
291
|
sample = self.conv_in(sample)
|
@@ -369,7 +368,7 @@ class UpSample(nn.Module):
|
|
369
368
|
self.out_channels = out_channels
|
370
369
|
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
371
370
|
|
372
|
-
def forward(self, x: torch.
|
371
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
373
372
|
r"""The forward method of the `UpSample` class."""
|
374
373
|
x = torch.relu(x)
|
375
374
|
x = self.deconv(x)
|
@@ -418,7 +417,7 @@ class MaskConditionEncoder(nn.Module):
|
|
418
417
|
|
419
418
|
self.layers = nn.Sequential(*layers)
|
420
419
|
|
421
|
-
def forward(self, x: torch.
|
420
|
+
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
|
422
421
|
r"""The forward method of the `MaskConditionEncoder` class."""
|
423
422
|
out = {}
|
424
423
|
for l in range(len(self.layers)):
|
@@ -474,7 +473,6 @@ class MaskConditionDecoder(nn.Module):
|
|
474
473
|
padding=1,
|
475
474
|
)
|
476
475
|
|
477
|
-
self.mid_block = None
|
478
476
|
self.up_blocks = nn.ModuleList([])
|
479
477
|
|
480
478
|
temb_channels = in_channels if norm_type == "spatial" else None
|
@@ -536,11 +534,11 @@ class MaskConditionDecoder(nn.Module):
|
|
536
534
|
|
537
535
|
def forward(
|
538
536
|
self,
|
539
|
-
z: torch.
|
540
|
-
image: Optional[torch.
|
541
|
-
mask: Optional[torch.
|
542
|
-
latent_embeds: Optional[torch.
|
543
|
-
) -> torch.
|
537
|
+
z: torch.Tensor,
|
538
|
+
image: Optional[torch.Tensor] = None,
|
539
|
+
mask: Optional[torch.Tensor] = None,
|
540
|
+
latent_embeds: Optional[torch.Tensor] = None,
|
541
|
+
) -> torch.Tensor:
|
544
542
|
r"""The forward method of the `MaskConditionDecoder` class."""
|
545
543
|
sample = z
|
546
544
|
sample = self.conv_in(sample)
|
@@ -714,7 +712,7 @@ class VectorQuantizer(nn.Module):
|
|
714
712
|
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
715
713
|
return back.reshape(ishape)
|
716
714
|
|
717
|
-
def forward(self, z: torch.
|
715
|
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
|
718
716
|
# reshape z -> (batch, height, width, channel) and flatten
|
719
717
|
z = z.permute(0, 2, 3, 1).contiguous()
|
720
718
|
z_flattened = z.view(-1, self.vq_embed_dim)
|
@@ -733,7 +731,7 @@ class VectorQuantizer(nn.Module):
|
|
733
731
|
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
734
732
|
|
735
733
|
# preserve gradients
|
736
|
-
z_q: torch.
|
734
|
+
z_q: torch.Tensor = z + (z_q - z).detach()
|
737
735
|
|
738
736
|
# reshape back to match original input shape
|
739
737
|
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
@@ -748,7 +746,7 @@ class VectorQuantizer(nn.Module):
|
|
748
746
|
|
749
747
|
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
750
748
|
|
751
|
-
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.
|
749
|
+
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
|
752
750
|
# shape specifying (batch, height, width, channel)
|
753
751
|
if self.remap is not None:
|
754
752
|
indices = indices.reshape(shape[0], -1) # add batch axis
|
@@ -756,7 +754,7 @@ class VectorQuantizer(nn.Module):
|
|
756
754
|
indices = indices.reshape(-1) # flatten again
|
757
755
|
|
758
756
|
# get quantized latent vectors
|
759
|
-
z_q: torch.
|
757
|
+
z_q: torch.Tensor = self.embedding(indices)
|
760
758
|
|
761
759
|
if shape is not None:
|
762
760
|
z_q = z_q.view(shape)
|
@@ -779,7 +777,7 @@ class DiagonalGaussianDistribution(object):
|
|
779
777
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
780
778
|
)
|
781
779
|
|
782
|
-
def sample(self, generator: Optional[torch.Generator] = None) -> torch.
|
780
|
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
783
781
|
# make sure sample is on the same device as the parameters and has same dtype
|
784
782
|
sample = randn_tensor(
|
785
783
|
self.mean.shape,
|
@@ -876,7 +874,7 @@ class EncoderTiny(nn.Module):
|
|
876
874
|
self.layers = nn.Sequential(*layers)
|
877
875
|
self.gradient_checkpointing = False
|
878
876
|
|
879
|
-
def forward(self, x: torch.
|
877
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
880
878
|
r"""The forward method of the `EncoderTiny` class."""
|
881
879
|
if self.training and self.gradient_checkpointing:
|
882
880
|
|
@@ -926,6 +924,7 @@ class DecoderTiny(nn.Module):
|
|
926
924
|
block_out_channels: Tuple[int, ...],
|
927
925
|
upsampling_scaling_factor: int,
|
928
926
|
act_fn: str,
|
927
|
+
upsample_fn: str,
|
929
928
|
):
|
930
929
|
super().__init__()
|
931
930
|
|
@@ -942,7 +941,7 @@ class DecoderTiny(nn.Module):
|
|
942
941
|
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
943
942
|
|
944
943
|
if not is_final_block:
|
945
|
-
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
944
|
+
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
|
946
945
|
|
947
946
|
conv_out_channel = num_channels if not is_final_block else out_channels
|
948
947
|
layers.append(
|
@@ -958,7 +957,7 @@ class DecoderTiny(nn.Module):
|
|
958
957
|
self.layers = nn.Sequential(*layers)
|
959
958
|
self.gradient_checkpointing = False
|
960
959
|
|
961
|
-
def forward(self, x: torch.
|
960
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
962
961
|
r"""The forward method of the `DecoderTiny` class."""
|
963
962
|
# Clamp.
|
964
963
|
x = torch.tanh(x / 3) * 3
|