diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +17 -12
- diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
- diffusers-0.34.0.dist-info/RECORD +639 -0
- diffusers-0.33.0.dist-info/RECORD +0 -608
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -829,7 +829,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
|
|
829
829
|
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
830
830
|
batch_size, num_channels, num_frames, height, width = z.shape
|
831
831
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
832
|
-
tile_latent_min_width = self.
|
832
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
833
833
|
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
834
834
|
|
835
835
|
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
2
2
|
# All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -1067,7 +1067,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1067
1067
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
1068
1068
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
1069
1069
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
1070
|
-
Synthesis with Latent Diffusion Models](https://
|
1070
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
|
1071
1071
|
encoder_causal (`bool`, defaults to `True`):
|
1072
1072
|
Whether the encoder should behave causally (future frames depend only on past frames) or not.
|
1073
1073
|
decoder_causal (`bool`, defaults to `False`):
|
@@ -1285,7 +1285,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1285
1285
|
) -> Union[DecoderOutput, torch.Tensor]:
|
1286
1286
|
batch_size, num_channels, num_frames, height, width = z.shape
|
1287
1287
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1288
|
-
tile_latent_min_width = self.
|
1288
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
1289
1289
|
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
|
1290
1290
|
|
1291
1291
|
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
|
@@ -428,7 +428,7 @@ class EasyAnimateMidBlock3d(nn.Module):
|
|
428
428
|
|
429
429
|
class EasyAnimateEncoder(nn.Module):
|
430
430
|
r"""
|
431
|
-
Causal encoder for 3D video-like data used in [EasyAnimate](https://
|
431
|
+
Causal encoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
|
432
432
|
"""
|
433
433
|
|
434
434
|
_supports_gradient_checkpointing = True
|
@@ -544,7 +544,7 @@ class EasyAnimateEncoder(nn.Module):
|
|
544
544
|
|
545
545
|
class EasyAnimateDecoder(nn.Module):
|
546
546
|
r"""
|
547
|
-
Causal decoder for 3D video-like data used in [EasyAnimate](https://
|
547
|
+
Causal decoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
|
548
548
|
"""
|
549
549
|
|
550
550
|
_supports_gradient_checkpointing = True
|
@@ -666,7 +666,7 @@ class EasyAnimateDecoder(nn.Module):
|
|
666
666
|
class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
|
667
667
|
r"""
|
668
668
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
|
669
|
-
model is used in [EasyAnimate](https://
|
669
|
+
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
|
670
670
|
|
671
671
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
672
672
|
for all models (such as downloading or saving).
|
@@ -887,7 +887,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
|
|
887
887
|
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
888
888
|
batch_size, num_channels, num_frames, height, width = z.shape
|
889
889
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
890
|
-
tile_latent_min_width = self.
|
890
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
891
891
|
|
892
892
|
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
|
893
893
|
return self.tiled_decode(z, return_dict=return_dict)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The Mochi team and The HuggingFace Team.
|
2
2
|
# All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -677,7 +677,7 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
|
|
677
677
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
678
678
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
679
679
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
680
|
-
Synthesis with Latent Diffusion Models](https://
|
680
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
|
681
681
|
"""
|
682
682
|
|
683
683
|
_supports_gradient_checkpointing = True
|
@@ -909,7 +909,7 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
|
|
909
909
|
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
910
910
|
batch_size, num_channels, num_frames, height, width = z.shape
|
911
911
|
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
912
|
-
tile_latent_min_width = self.
|
912
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
913
913
|
|
914
914
|
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
915
915
|
return self.tiled_decode(z, return_dict=return_dict)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -158,11 +158,11 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|
158
158
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
159
159
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
160
160
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
161
|
-
Synthesis with Latent Diffusion Models](https://
|
161
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
|
162
162
|
force_upcast (`bool`, *optional*, default to `True`):
|
163
163
|
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
164
|
-
can be fine-tuned / trained to a lower range without
|
165
|
-
|
164
|
+
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
|
165
|
+
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
166
166
|
"""
|
167
167
|
|
168
168
|
_supports_gradient_checkpointing = True
|
@@ -730,27 +730,104 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
730
730
|
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
731
731
|
)
|
732
732
|
|
733
|
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
734
|
+
|
735
|
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
736
|
+
# to perform decoding of a single video latent at a time.
|
737
|
+
self.use_slicing = False
|
738
|
+
|
739
|
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
740
|
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
741
|
+
# intermediate tiles together, the memory requirement can be lowered.
|
742
|
+
self.use_tiling = False
|
743
|
+
|
744
|
+
# The minimal tile height and width for spatial tiling to be used
|
745
|
+
self.tile_sample_min_height = 256
|
746
|
+
self.tile_sample_min_width = 256
|
747
|
+
|
748
|
+
# The minimal distance between two spatial tiles
|
749
|
+
self.tile_sample_stride_height = 192
|
750
|
+
self.tile_sample_stride_width = 192
|
751
|
+
|
752
|
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
753
|
+
self._cached_conv_counts = {
|
754
|
+
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
755
|
+
if self.decoder is not None
|
756
|
+
else 0,
|
757
|
+
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
758
|
+
if self.encoder is not None
|
759
|
+
else 0,
|
760
|
+
}
|
761
|
+
|
762
|
+
def enable_tiling(
|
763
|
+
self,
|
764
|
+
tile_sample_min_height: Optional[int] = None,
|
765
|
+
tile_sample_min_width: Optional[int] = None,
|
766
|
+
tile_sample_stride_height: Optional[float] = None,
|
767
|
+
tile_sample_stride_width: Optional[float] = None,
|
768
|
+
) -> None:
|
769
|
+
r"""
|
770
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
771
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
772
|
+
processing larger images.
|
773
|
+
|
774
|
+
Args:
|
775
|
+
tile_sample_min_height (`int`, *optional*):
|
776
|
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
777
|
+
tile_sample_min_width (`int`, *optional*):
|
778
|
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
779
|
+
tile_sample_stride_height (`int`, *optional*):
|
780
|
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
781
|
+
no tiling artifacts produced across the height dimension.
|
782
|
+
tile_sample_stride_width (`int`, *optional*):
|
783
|
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
784
|
+
artifacts produced across the width dimension.
|
785
|
+
"""
|
786
|
+
self.use_tiling = True
|
787
|
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
788
|
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
789
|
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
790
|
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
791
|
+
|
792
|
+
def disable_tiling(self) -> None:
|
793
|
+
r"""
|
794
|
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
795
|
+
decoding in one step.
|
796
|
+
"""
|
797
|
+
self.use_tiling = False
|
798
|
+
|
799
|
+
def enable_slicing(self) -> None:
|
800
|
+
r"""
|
801
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
802
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
803
|
+
"""
|
804
|
+
self.use_slicing = True
|
805
|
+
|
806
|
+
def disable_slicing(self) -> None:
|
807
|
+
r"""
|
808
|
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
809
|
+
decoding in one step.
|
810
|
+
"""
|
811
|
+
self.use_slicing = False
|
812
|
+
|
733
813
|
def clear_cache(self):
|
734
|
-
|
735
|
-
|
736
|
-
for m in model.modules():
|
737
|
-
if isinstance(m, WanCausalConv3d):
|
738
|
-
count += 1
|
739
|
-
return count
|
740
|
-
|
741
|
-
self._conv_num = _count_conv3d(self.decoder)
|
814
|
+
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
815
|
+
self._conv_num = self._cached_conv_counts["decoder"]
|
742
816
|
self._conv_idx = [0]
|
743
817
|
self._feat_map = [None] * self._conv_num
|
744
818
|
# cache encode
|
745
|
-
self._enc_conv_num =
|
819
|
+
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
746
820
|
self._enc_conv_idx = [0]
|
747
821
|
self._enc_feat_map = [None] * self._enc_conv_num
|
748
822
|
|
749
|
-
def _encode(self, x: torch.Tensor)
|
823
|
+
def _encode(self, x: torch.Tensor):
|
824
|
+
_, _, num_frame, height, width = x.shape
|
825
|
+
|
826
|
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
827
|
+
return self.tiled_encode(x)
|
828
|
+
|
750
829
|
self.clear_cache()
|
751
|
-
|
752
|
-
t = x.shape[2]
|
753
|
-
iter_ = 1 + (t - 1) // 4
|
830
|
+
iter_ = 1 + (num_frame - 1) // 4
|
754
831
|
for i in range(iter_):
|
755
832
|
self._enc_conv_idx = [0]
|
756
833
|
if i == 0:
|
@@ -764,8 +841,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
764
841
|
out = torch.cat([out, out_], 2)
|
765
842
|
|
766
843
|
enc = self.quant_conv(out)
|
767
|
-
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
768
|
-
enc = torch.cat([mu, logvar], dim=1)
|
769
844
|
self.clear_cache()
|
770
845
|
return enc
|
771
846
|
|
@@ -785,18 +860,28 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
785
860
|
The latent representations of the encoded videos. If `return_dict` is True, a
|
786
861
|
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
787
862
|
"""
|
788
|
-
|
863
|
+
if self.use_slicing and x.shape[0] > 1:
|
864
|
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
865
|
+
h = torch.cat(encoded_slices)
|
866
|
+
else:
|
867
|
+
h = self._encode(x)
|
789
868
|
posterior = DiagonalGaussianDistribution(h)
|
869
|
+
|
790
870
|
if not return_dict:
|
791
871
|
return (posterior,)
|
792
872
|
return AutoencoderKLOutput(latent_dist=posterior)
|
793
873
|
|
794
|
-
def _decode(self, z: torch.Tensor, return_dict: bool = True)
|
795
|
-
|
874
|
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
875
|
+
_, _, num_frame, height, width = z.shape
|
876
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
877
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
796
878
|
|
797
|
-
|
879
|
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
880
|
+
return self.tiled_decode(z, return_dict=return_dict)
|
881
|
+
|
882
|
+
self.clear_cache()
|
798
883
|
x = self.post_quant_conv(z)
|
799
|
-
for i in range(
|
884
|
+
for i in range(num_frame):
|
800
885
|
self._conv_idx = [0]
|
801
886
|
if i == 0:
|
802
887
|
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
@@ -826,12 +911,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
826
911
|
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
827
912
|
returned.
|
828
913
|
"""
|
829
|
-
|
914
|
+
if self.use_slicing and z.shape[0] > 1:
|
915
|
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
916
|
+
decoded = torch.cat(decoded_slices)
|
917
|
+
else:
|
918
|
+
decoded = self._decode(z).sample
|
919
|
+
|
830
920
|
if not return_dict:
|
831
921
|
return (decoded,)
|
832
|
-
|
833
922
|
return DecoderOutput(sample=decoded)
|
834
923
|
|
924
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
925
|
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
926
|
+
for y in range(blend_extent):
|
927
|
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
928
|
+
y / blend_extent
|
929
|
+
)
|
930
|
+
return b
|
931
|
+
|
932
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
933
|
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
934
|
+
for x in range(blend_extent):
|
935
|
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
936
|
+
x / blend_extent
|
937
|
+
)
|
938
|
+
return b
|
939
|
+
|
940
|
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
941
|
+
r"""Encode a batch of images using a tiled encoder.
|
942
|
+
|
943
|
+
Args:
|
944
|
+
x (`torch.Tensor`): Input batch of videos.
|
945
|
+
|
946
|
+
Returns:
|
947
|
+
`torch.Tensor`:
|
948
|
+
The latent representation of the encoded videos.
|
949
|
+
"""
|
950
|
+
_, _, num_frames, height, width = x.shape
|
951
|
+
latent_height = height // self.spatial_compression_ratio
|
952
|
+
latent_width = width // self.spatial_compression_ratio
|
953
|
+
|
954
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
955
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
956
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
957
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
958
|
+
|
959
|
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
960
|
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
961
|
+
|
962
|
+
# Split x into overlapping tiles and encode them separately.
|
963
|
+
# The tiles have an overlap to avoid seams between tiles.
|
964
|
+
rows = []
|
965
|
+
for i in range(0, height, self.tile_sample_stride_height):
|
966
|
+
row = []
|
967
|
+
for j in range(0, width, self.tile_sample_stride_width):
|
968
|
+
self.clear_cache()
|
969
|
+
time = []
|
970
|
+
frame_range = 1 + (num_frames - 1) // 4
|
971
|
+
for k in range(frame_range):
|
972
|
+
self._enc_conv_idx = [0]
|
973
|
+
if k == 0:
|
974
|
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
975
|
+
else:
|
976
|
+
tile = x[
|
977
|
+
:,
|
978
|
+
:,
|
979
|
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
980
|
+
i : i + self.tile_sample_min_height,
|
981
|
+
j : j + self.tile_sample_min_width,
|
982
|
+
]
|
983
|
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
984
|
+
tile = self.quant_conv(tile)
|
985
|
+
time.append(tile)
|
986
|
+
row.append(torch.cat(time, dim=2))
|
987
|
+
rows.append(row)
|
988
|
+
self.clear_cache()
|
989
|
+
|
990
|
+
result_rows = []
|
991
|
+
for i, row in enumerate(rows):
|
992
|
+
result_row = []
|
993
|
+
for j, tile in enumerate(row):
|
994
|
+
# blend the above tile and the left tile
|
995
|
+
# to the current tile and add the current tile to the result row
|
996
|
+
if i > 0:
|
997
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
998
|
+
if j > 0:
|
999
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
1000
|
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
1001
|
+
result_rows.append(torch.cat(result_row, dim=-1))
|
1002
|
+
|
1003
|
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
1004
|
+
return enc
|
1005
|
+
|
1006
|
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
1007
|
+
r"""
|
1008
|
+
Decode a batch of images using a tiled decoder.
|
1009
|
+
|
1010
|
+
Args:
|
1011
|
+
z (`torch.Tensor`): Input batch of latent vectors.
|
1012
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1013
|
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
1014
|
+
|
1015
|
+
Returns:
|
1016
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
1017
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
1018
|
+
returned.
|
1019
|
+
"""
|
1020
|
+
_, _, num_frames, height, width = z.shape
|
1021
|
+
sample_height = height * self.spatial_compression_ratio
|
1022
|
+
sample_width = width * self.spatial_compression_ratio
|
1023
|
+
|
1024
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
1025
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
1026
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
1027
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
1028
|
+
|
1029
|
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
1030
|
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
1031
|
+
|
1032
|
+
# Split z into overlapping tiles and decode them separately.
|
1033
|
+
# The tiles have an overlap to avoid seams between tiles.
|
1034
|
+
rows = []
|
1035
|
+
for i in range(0, height, tile_latent_stride_height):
|
1036
|
+
row = []
|
1037
|
+
for j in range(0, width, tile_latent_stride_width):
|
1038
|
+
self.clear_cache()
|
1039
|
+
time = []
|
1040
|
+
for k in range(num_frames):
|
1041
|
+
self._conv_idx = [0]
|
1042
|
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
1043
|
+
tile = self.post_quant_conv(tile)
|
1044
|
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
1045
|
+
time.append(decoded)
|
1046
|
+
row.append(torch.cat(time, dim=2))
|
1047
|
+
rows.append(row)
|
1048
|
+
self.clear_cache()
|
1049
|
+
|
1050
|
+
result_rows = []
|
1051
|
+
for i, row in enumerate(rows):
|
1052
|
+
result_row = []
|
1053
|
+
for j, tile in enumerate(row):
|
1054
|
+
# blend the above tile and the left tile
|
1055
|
+
# to the current tile and add the current tile to the result row
|
1056
|
+
if i > 0:
|
1057
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
1058
|
+
if j > 0:
|
1059
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
1060
|
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
1061
|
+
result_rows.append(torch.cat(result_row, dim=-1))
|
1062
|
+
|
1063
|
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
1064
|
+
|
1065
|
+
if not return_dict:
|
1066
|
+
return (dec,)
|
1067
|
+
return DecoderOutput(sample=dec)
|
1068
|
+
|
835
1069
|
def forward(
|
836
1070
|
self,
|
837
1071
|
sample: torch.Tensor,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -83,8 +83,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
|
83
83
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
84
84
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
85
85
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
86
|
-
Synthesis with Latent Diffusion Models](https://
|
87
|
-
however, no such scaling factor was used, hence the value of 1.0 as the default.
|
86
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. For this
|
87
|
+
Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default.
|
88
88
|
force_upcast (`bool`, *optional*, default to `False`):
|
89
89
|
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
90
90
|
can be fine-tuned / trained to a lower range without losing too much precision, in which case
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -255,7 +255,7 @@ class Decoder(nn.Module):
|
|
255
255
|
num_layers=self.layers_per_block + 1,
|
256
256
|
in_channels=prev_output_channel,
|
257
257
|
out_channels=output_channel,
|
258
|
-
prev_output_channel=
|
258
|
+
prev_output_channel=prev_output_channel,
|
259
259
|
add_upsample=not is_final_block,
|
260
260
|
resnet_eps=1e-6,
|
261
261
|
resnet_act_fn=act_fn,
|
@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
|
|
744
744
|
return self.mean
|
745
745
|
|
746
746
|
|
747
|
+
class IdentityDistribution(object):
|
748
|
+
def __init__(self, parameters: torch.Tensor):
|
749
|
+
self.parameters = parameters
|
750
|
+
|
751
|
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
752
|
+
return self.parameters
|
753
|
+
|
754
|
+
def mode(self) -> torch.Tensor:
|
755
|
+
return self.parameters
|
756
|
+
|
757
|
+
|
747
758
|
class EncoderTiny(nn.Module):
|
748
759
|
r"""
|
749
760
|
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -66,7 +66,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
|
66
66
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
67
67
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
68
68
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
69
|
-
Synthesis with Latent Diffusion Models](https://
|
69
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
|
70
70
|
norm_type (`str`, *optional*, defaults to `"group"`):
|
71
71
|
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
72
72
|
"""
|
diffusers/models/cache_utils.py
CHANGED
diffusers/models/controlnet.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -9,6 +9,7 @@ if is_torch_available():
|
|
9
9
|
HunyuanDiT2DControlNetModel,
|
10
10
|
HunyuanDiT2DMultiControlNetModel,
|
11
11
|
)
|
12
|
+
from .controlnet_sana import SanaControlNetModel
|
12
13
|
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
13
14
|
from .controlnet_sparsectrl import (
|
14
15
|
SparseControlNetConditioningEmbedding,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -63,8 +63,8 @@ class ControlNetOutput(BaseOutput):
|
|
63
63
|
|
64
64
|
class ControlNetConditioningEmbedding(nn.Module):
|
65
65
|
"""
|
66
|
-
Quoting from https://
|
67
|
-
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
66
|
+
Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to
|
67
|
+
VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
68
68
|
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
69
69
|
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
70
70
|
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|