diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
diffusers/models/downsampling.py
CHANGED
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
|
|
102
102
|
self.padding = padding
|
103
103
|
stride = 2
|
104
104
|
self.name = name
|
105
|
-
conv_cls = nn.Conv2d
|
106
105
|
|
107
106
|
if norm_type == "ln_norm":
|
108
107
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
|
|
114
113
|
raise ValueError(f"unknown norm_type: {norm_type}")
|
115
114
|
|
116
115
|
if use_conv:
|
117
|
-
conv =
|
116
|
+
conv = nn.Conv2d(
|
118
117
|
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
119
118
|
)
|
120
119
|
else:
|
@@ -130,7 +129,7 @@ class Downsample2D(nn.Module):
|
|
130
129
|
else:
|
131
130
|
self.conv = conv
|
132
131
|
|
133
|
-
def forward(self, hidden_states: torch.
|
132
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
134
133
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
135
134
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
136
135
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -181,24 +180,24 @@ class FirDownsample2D(nn.Module):
|
|
181
180
|
|
182
181
|
def _downsample_2d(
|
183
182
|
self,
|
184
|
-
hidden_states: torch.
|
185
|
-
weight: Optional[torch.
|
186
|
-
kernel: Optional[torch.
|
183
|
+
hidden_states: torch.Tensor,
|
184
|
+
weight: Optional[torch.Tensor] = None,
|
185
|
+
kernel: Optional[torch.Tensor] = None,
|
187
186
|
factor: int = 2,
|
188
187
|
gain: float = 1,
|
189
|
-
) -> torch.
|
188
|
+
) -> torch.Tensor:
|
190
189
|
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
191
190
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
192
191
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
193
192
|
arbitrary order.
|
194
193
|
|
195
194
|
Args:
|
196
|
-
hidden_states (`torch.
|
195
|
+
hidden_states (`torch.Tensor`):
|
197
196
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
198
|
-
weight (`torch.
|
197
|
+
weight (`torch.Tensor`, *optional*):
|
199
198
|
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
200
199
|
performed by `inChannels = x.shape[0] // numGroups`.
|
201
|
-
kernel (`torch.
|
200
|
+
kernel (`torch.Tensor`, *optional*):
|
202
201
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
203
202
|
corresponds to average pooling.
|
204
203
|
factor (`int`, *optional*, default to `2`):
|
@@ -207,7 +206,7 @@ class FirDownsample2D(nn.Module):
|
|
207
206
|
Scaling factor for signal magnitude.
|
208
207
|
|
209
208
|
Returns:
|
210
|
-
output (`torch.
|
209
|
+
output (`torch.Tensor`):
|
211
210
|
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
212
211
|
datatype as `x`.
|
213
212
|
"""
|
@@ -245,7 +244,7 @@ class FirDownsample2D(nn.Module):
|
|
245
244
|
|
246
245
|
return output
|
247
246
|
|
248
|
-
def forward(self, hidden_states: torch.
|
247
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
249
248
|
if self.use_conv:
|
250
249
|
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
251
250
|
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
@@ -287,11 +286,11 @@ class KDownsample2D(nn.Module):
|
|
287
286
|
|
288
287
|
|
289
288
|
def downsample_2d(
|
290
|
-
hidden_states: torch.
|
291
|
-
kernel: Optional[torch.
|
289
|
+
hidden_states: torch.Tensor,
|
290
|
+
kernel: Optional[torch.Tensor] = None,
|
292
291
|
factor: int = 2,
|
293
292
|
gain: float = 1,
|
294
|
-
) -> torch.
|
293
|
+
) -> torch.Tensor:
|
295
294
|
r"""Downsample2D a batch of 2D images with the given filter.
|
296
295
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
297
296
|
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
@@ -299,9 +298,9 @@ def downsample_2d(
|
|
299
298
|
shape is a multiple of the downsampling factor.
|
300
299
|
|
301
300
|
Args:
|
302
|
-
hidden_states (`torch.
|
301
|
+
hidden_states (`torch.Tensor`)
|
303
302
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
304
|
-
kernel (`torch.
|
303
|
+
kernel (`torch.Tensor`, *optional*):
|
305
304
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
306
305
|
corresponds to average pooling.
|
307
306
|
factor (`int`, *optional*, default to `2`):
|
@@ -310,7 +309,7 @@ def downsample_2d(
|
|
310
309
|
Scaling factor for signal magnitude.
|
311
310
|
|
312
311
|
Returns:
|
313
|
-
output (`torch.
|
312
|
+
output (`torch.Tensor`):
|
314
313
|
Tensor of the shape `[N, C, H // factor, W // factor]`
|
315
314
|
"""
|
316
315
|
|
diffusers/models/embeddings.py
CHANGED
@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
|
|
16
16
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
|
+
import torch.nn.functional as F
|
19
20
|
from torch import nn
|
20
21
|
|
21
22
|
from ..utils import deprecate
|
22
|
-
from .activations import get_activation
|
23
|
+
from .activations import FP32SiLU, get_activation
|
23
24
|
from .attention_processor import Attention
|
24
25
|
|
25
26
|
|
@@ -135,6 +136,7 @@ class PatchEmbed(nn.Module):
|
|
135
136
|
flatten=True,
|
136
137
|
bias=True,
|
137
138
|
interpolation_scale=1,
|
139
|
+
pos_embed_type="sincos",
|
138
140
|
):
|
139
141
|
super().__init__()
|
140
142
|
|
@@ -156,10 +158,18 @@ class PatchEmbed(nn.Module):
|
|
156
158
|
self.height, self.width = height // patch_size, width // patch_size
|
157
159
|
self.base_size = height // patch_size
|
158
160
|
self.interpolation_scale = interpolation_scale
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
161
|
+
if pos_embed_type is None:
|
162
|
+
self.pos_embed = None
|
163
|
+
elif pos_embed_type == "sincos":
|
164
|
+
pos_embed = get_2d_sincos_pos_embed(
|
165
|
+
embed_dim,
|
166
|
+
int(num_patches**0.5),
|
167
|
+
base_size=self.base_size,
|
168
|
+
interpolation_scale=self.interpolation_scale,
|
169
|
+
)
|
170
|
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
171
|
+
else:
|
172
|
+
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
163
173
|
|
164
174
|
def forward(self, latent):
|
165
175
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
@@ -169,6 +179,8 @@ class PatchEmbed(nn.Module):
|
|
169
179
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
170
180
|
if self.layer_norm:
|
171
181
|
latent = self.norm(latent)
|
182
|
+
if self.pos_embed is None:
|
183
|
+
return latent.to(latent.dtype)
|
172
184
|
|
173
185
|
# Interpolate positional embeddings if needed.
|
174
186
|
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
@@ -187,6 +199,113 @@ class PatchEmbed(nn.Module):
|
|
187
199
|
return (latent + pos_embed).to(latent.dtype)
|
188
200
|
|
189
201
|
|
202
|
+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
203
|
+
"""
|
204
|
+
RoPE for image tokens with 2d structure.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
embed_dim: (`int`):
|
208
|
+
The embedding dimension size
|
209
|
+
crops_coords (`Tuple[int]`)
|
210
|
+
The top-left and bottom-right coordinates of the crop.
|
211
|
+
grid_size (`Tuple[int]`):
|
212
|
+
The grid size of the positional embedding.
|
213
|
+
use_real (`bool`):
|
214
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
215
|
+
|
216
|
+
Returns:
|
217
|
+
`torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
|
218
|
+
"""
|
219
|
+
start, stop = crops_coords
|
220
|
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
221
|
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
222
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
223
|
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
224
|
+
|
225
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
226
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
227
|
+
return pos_embed
|
228
|
+
|
229
|
+
|
230
|
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
231
|
+
assert embed_dim % 4 == 0
|
232
|
+
|
233
|
+
# use half of dimensions to encode grid_h
|
234
|
+
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
235
|
+
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
236
|
+
|
237
|
+
if use_real:
|
238
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
239
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
240
|
+
return cos, sin
|
241
|
+
else:
|
242
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
243
|
+
return emb
|
244
|
+
|
245
|
+
|
246
|
+
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
247
|
+
"""
|
248
|
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
249
|
+
|
250
|
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
251
|
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
252
|
+
data type.
|
253
|
+
|
254
|
+
Args:
|
255
|
+
dim (`int`): Dimension of the frequency tensor.
|
256
|
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
257
|
+
theta (`float`, *optional*, defaults to 10000.0):
|
258
|
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
259
|
+
use_real (`bool`, *optional*):
|
260
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
264
|
+
"""
|
265
|
+
if isinstance(pos, int):
|
266
|
+
pos = np.arange(pos)
|
267
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
268
|
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
269
|
+
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
270
|
+
if use_real:
|
271
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
272
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
273
|
+
return freqs_cos, freqs_sin
|
274
|
+
else:
|
275
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
276
|
+
return freqs_cis
|
277
|
+
|
278
|
+
|
279
|
+
def apply_rotary_emb(
|
280
|
+
x: torch.Tensor,
|
281
|
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
282
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
283
|
+
"""
|
284
|
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
285
|
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
286
|
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
287
|
+
tensors contain rotary embeddings and are returned as real tensors.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
x (`torch.Tensor`):
|
291
|
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
292
|
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
293
|
+
|
294
|
+
Returns:
|
295
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
296
|
+
"""
|
297
|
+
cos, sin = freqs_cis # [S, D]
|
298
|
+
cos = cos[None, None]
|
299
|
+
sin = sin[None, None]
|
300
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
301
|
+
|
302
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
303
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
304
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
305
|
+
|
306
|
+
return out
|
307
|
+
|
308
|
+
|
190
309
|
class TimestepEmbedding(nn.Module):
|
191
310
|
def __init__(
|
192
311
|
self,
|
@@ -199,9 +318,8 @@ class TimestepEmbedding(nn.Module):
|
|
199
318
|
sample_proj_bias=True,
|
200
319
|
):
|
201
320
|
super().__init__()
|
202
|
-
linear_cls = nn.Linear
|
203
321
|
|
204
|
-
self.linear_1 =
|
322
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
205
323
|
|
206
324
|
if cond_proj_dim is not None:
|
207
325
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
@@ -214,7 +332,7 @@ class TimestepEmbedding(nn.Module):
|
|
214
332
|
time_embed_dim_out = out_dim
|
215
333
|
else:
|
216
334
|
time_embed_dim_out = time_embed_dim
|
217
|
-
self.linear_2 =
|
335
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
218
336
|
|
219
337
|
if post_act_fn is None:
|
220
338
|
self.post_act = None
|
@@ -425,7 +543,7 @@ class TextImageProjection(nn.Module):
|
|
425
543
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
426
544
|
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
427
545
|
|
428
|
-
def forward(self, text_embeds: torch.
|
546
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
429
547
|
batch_size = text_embeds.shape[0]
|
430
548
|
|
431
549
|
# image
|
@@ -451,7 +569,7 @@ class ImageProjection(nn.Module):
|
|
451
569
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
452
570
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
453
571
|
|
454
|
-
def forward(self, image_embeds: torch.
|
572
|
+
def forward(self, image_embeds: torch.Tensor):
|
455
573
|
batch_size = image_embeds.shape[0]
|
456
574
|
|
457
575
|
# image
|
@@ -469,10 +587,26 @@ class IPAdapterFullImageProjection(nn.Module):
|
|
469
587
|
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
470
588
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
471
589
|
|
472
|
-
def forward(self, image_embeds: torch.
|
590
|
+
def forward(self, image_embeds: torch.Tensor):
|
473
591
|
return self.norm(self.ff(image_embeds))
|
474
592
|
|
475
593
|
|
594
|
+
class IPAdapterFaceIDImageProjection(nn.Module):
|
595
|
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
596
|
+
super().__init__()
|
597
|
+
from .attention import FeedForward
|
598
|
+
|
599
|
+
self.num_tokens = num_tokens
|
600
|
+
self.cross_attention_dim = cross_attention_dim
|
601
|
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
602
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
603
|
+
|
604
|
+
def forward(self, image_embeds: torch.Tensor):
|
605
|
+
x = self.ff(image_embeds)
|
606
|
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
607
|
+
return self.norm(x)
|
608
|
+
|
609
|
+
|
476
610
|
class CombinedTimestepLabelEmbeddings(nn.Module):
|
477
611
|
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
478
612
|
super().__init__()
|
@@ -492,6 +626,88 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
492
626
|
return conditioning
|
493
627
|
|
494
628
|
|
629
|
+
class HunyuanDiTAttentionPool(nn.Module):
|
630
|
+
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
631
|
+
|
632
|
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
633
|
+
super().__init__()
|
634
|
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
|
635
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
636
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
637
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
638
|
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
639
|
+
self.num_heads = num_heads
|
640
|
+
|
641
|
+
def forward(self, x):
|
642
|
+
x = x.permute(1, 0, 2) # NLC -> LNC
|
643
|
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
644
|
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
645
|
+
x, _ = F.multi_head_attention_forward(
|
646
|
+
query=x[:1],
|
647
|
+
key=x,
|
648
|
+
value=x,
|
649
|
+
embed_dim_to_check=x.shape[-1],
|
650
|
+
num_heads=self.num_heads,
|
651
|
+
q_proj_weight=self.q_proj.weight,
|
652
|
+
k_proj_weight=self.k_proj.weight,
|
653
|
+
v_proj_weight=self.v_proj.weight,
|
654
|
+
in_proj_weight=None,
|
655
|
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
656
|
+
bias_k=None,
|
657
|
+
bias_v=None,
|
658
|
+
add_zero_attn=False,
|
659
|
+
dropout_p=0,
|
660
|
+
out_proj_weight=self.c_proj.weight,
|
661
|
+
out_proj_bias=self.c_proj.bias,
|
662
|
+
use_separate_proj_weight=True,
|
663
|
+
training=self.training,
|
664
|
+
need_weights=False,
|
665
|
+
)
|
666
|
+
return x.squeeze(0)
|
667
|
+
|
668
|
+
|
669
|
+
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
670
|
+
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
|
671
|
+
super().__init__()
|
672
|
+
|
673
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
674
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
675
|
+
|
676
|
+
self.pooler = HunyuanDiTAttentionPool(
|
677
|
+
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
678
|
+
)
|
679
|
+
# Here we use a default learned embedder layer for future extension.
|
680
|
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
681
|
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
682
|
+
self.extra_embedder = PixArtAlphaTextProjection(
|
683
|
+
in_features=extra_in_dim,
|
684
|
+
hidden_size=embedding_dim * 4,
|
685
|
+
out_features=embedding_dim,
|
686
|
+
act_fn="silu_fp32",
|
687
|
+
)
|
688
|
+
|
689
|
+
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
|
690
|
+
timesteps_proj = self.time_proj(timestep)
|
691
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
|
692
|
+
|
693
|
+
# extra condition1: text
|
694
|
+
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
695
|
+
|
696
|
+
# extra condition2: image meta size embdding
|
697
|
+
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
|
698
|
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
699
|
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
700
|
+
|
701
|
+
# extra condition3: style embedding
|
702
|
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
703
|
+
|
704
|
+
# Concatenate all extra vectors
|
705
|
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
706
|
+
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
707
|
+
|
708
|
+
return conditioning
|
709
|
+
|
710
|
+
|
495
711
|
class TextTimeEmbedding(nn.Module):
|
496
712
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
497
713
|
super().__init__()
|
@@ -515,7 +731,7 @@ class TextImageTimeEmbedding(nn.Module):
|
|
515
731
|
self.text_norm = nn.LayerNorm(time_embed_dim)
|
516
732
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
517
733
|
|
518
|
-
def forward(self, text_embeds: torch.
|
734
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
519
735
|
# text
|
520
736
|
time_text_embeds = self.text_proj(text_embeds)
|
521
737
|
time_text_embeds = self.text_norm(time_text_embeds)
|
@@ -532,7 +748,7 @@ class ImageTimeEmbedding(nn.Module):
|
|
532
748
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
533
749
|
self.image_norm = nn.LayerNorm(time_embed_dim)
|
534
750
|
|
535
|
-
def forward(self, image_embeds: torch.
|
751
|
+
def forward(self, image_embeds: torch.Tensor):
|
536
752
|
# image
|
537
753
|
time_image_embeds = self.image_proj(image_embeds)
|
538
754
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -562,7 +778,7 @@ class ImageHintTimeEmbedding(nn.Module):
|
|
562
778
|
nn.Conv2d(256, 4, 3, padding=1),
|
563
779
|
)
|
564
780
|
|
565
|
-
def forward(self, image_embeds: torch.
|
781
|
+
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
|
566
782
|
# image
|
567
783
|
time_image_embeds = self.image_proj(image_embeds)
|
568
784
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -778,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module):
|
|
778
994
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
779
995
|
"""
|
780
996
|
|
781
|
-
def __init__(self, in_features, hidden_size,
|
997
|
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
782
998
|
super().__init__()
|
999
|
+
if out_features is None:
|
1000
|
+
out_features = hidden_size
|
783
1001
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
784
|
-
|
785
|
-
|
1002
|
+
if act_fn == "gelu_tanh":
|
1003
|
+
self.act_1 = nn.GELU(approximate="tanh")
|
1004
|
+
elif act_fn == "silu_fp32":
|
1005
|
+
self.act_1 = FP32SiLU()
|
1006
|
+
else:
|
1007
|
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
1008
|
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
786
1009
|
|
787
1010
|
def forward(self, caption):
|
788
1011
|
hidden_states = self.linear_1(caption)
|
@@ -795,17 +1018,15 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
795
1018
|
"""Resampler of IP-Adapter Plus.
|
796
1019
|
|
797
1020
|
Args:
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
num_queries (int): The number of queries. Defaults to 8.
|
808
|
-
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
1021
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
1022
|
+
that is the same
|
1023
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
1024
|
+
hidden_dims (int):
|
1025
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
1026
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
1027
|
+
Defaults to 16. num_queries (int):
|
1028
|
+
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
1029
|
+
of feedforward network hidden
|
809
1030
|
layer channels. Defaults to 4.
|
810
1031
|
"""
|
811
1032
|
|
@@ -855,11 +1076,8 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
855
1076
|
"""Forward pass.
|
856
1077
|
|
857
1078
|
Args:
|
858
|
-
----
|
859
1079
|
x (torch.Tensor): Input Tensor.
|
860
|
-
|
861
1080
|
Returns:
|
862
|
-
-------
|
863
1081
|
torch.Tensor: Output Tensor.
|
864
1082
|
"""
|
865
1083
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
@@ -879,12 +1097,125 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
879
1097
|
return self.norm_out(latents)
|
880
1098
|
|
881
1099
|
|
1100
|
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
1101
|
+
def __init__(
|
1102
|
+
self,
|
1103
|
+
embed_dims: int = 768,
|
1104
|
+
dim_head: int = 64,
|
1105
|
+
heads: int = 16,
|
1106
|
+
ffn_ratio: float = 4,
|
1107
|
+
) -> None:
|
1108
|
+
super().__init__()
|
1109
|
+
from .attention import FeedForward
|
1110
|
+
|
1111
|
+
self.ln0 = nn.LayerNorm(embed_dims)
|
1112
|
+
self.ln1 = nn.LayerNorm(embed_dims)
|
1113
|
+
self.attn = Attention(
|
1114
|
+
query_dim=embed_dims,
|
1115
|
+
dim_head=dim_head,
|
1116
|
+
heads=heads,
|
1117
|
+
out_bias=False,
|
1118
|
+
)
|
1119
|
+
self.ff = nn.Sequential(
|
1120
|
+
nn.LayerNorm(embed_dims),
|
1121
|
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
1122
|
+
)
|
1123
|
+
|
1124
|
+
def forward(self, x, latents, residual):
|
1125
|
+
encoder_hidden_states = self.ln0(x)
|
1126
|
+
latents = self.ln1(latents)
|
1127
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
1128
|
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
1129
|
+
latents = self.ff(latents) + latents
|
1130
|
+
return latents
|
1131
|
+
|
1132
|
+
|
1133
|
+
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
1134
|
+
"""FacePerceiverResampler of IP-Adapter Plus.
|
1135
|
+
|
1136
|
+
Args:
|
1137
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
1138
|
+
that is the same
|
1139
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
1140
|
+
hidden_dims (int):
|
1141
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
1142
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
1143
|
+
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
1144
|
+
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
1145
|
+
layer channels. Defaults to 4.
|
1146
|
+
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
1147
|
+
layer channels (for ID embeddings). Defaults to 4.
|
1148
|
+
"""
|
1149
|
+
|
1150
|
+
def __init__(
|
1151
|
+
self,
|
1152
|
+
embed_dims: int = 768,
|
1153
|
+
output_dims: int = 768,
|
1154
|
+
hidden_dims: int = 1280,
|
1155
|
+
id_embeddings_dim: int = 512,
|
1156
|
+
depth: int = 4,
|
1157
|
+
dim_head: int = 64,
|
1158
|
+
heads: int = 16,
|
1159
|
+
num_tokens: int = 4,
|
1160
|
+
num_queries: int = 8,
|
1161
|
+
ffn_ratio: float = 4,
|
1162
|
+
ffproj_ratio: int = 2,
|
1163
|
+
) -> None:
|
1164
|
+
super().__init__()
|
1165
|
+
from .attention import FeedForward
|
1166
|
+
|
1167
|
+
self.num_tokens = num_tokens
|
1168
|
+
self.embed_dim = embed_dims
|
1169
|
+
self.clip_embeds = None
|
1170
|
+
self.shortcut = False
|
1171
|
+
self.shortcut_scale = 1.0
|
1172
|
+
|
1173
|
+
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
1174
|
+
self.norm = nn.LayerNorm(embed_dims)
|
1175
|
+
|
1176
|
+
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
1177
|
+
|
1178
|
+
self.proj_out = nn.Linear(embed_dims, output_dims)
|
1179
|
+
self.norm_out = nn.LayerNorm(output_dims)
|
1180
|
+
|
1181
|
+
self.layers = nn.ModuleList(
|
1182
|
+
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
1183
|
+
)
|
1184
|
+
|
1185
|
+
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
1186
|
+
"""Forward pass.
|
1187
|
+
|
1188
|
+
Args:
|
1189
|
+
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
1190
|
+
Returns:
|
1191
|
+
torch.Tensor: Output Tensor.
|
1192
|
+
"""
|
1193
|
+
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
1194
|
+
id_embeds = self.proj(id_embeds)
|
1195
|
+
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
1196
|
+
id_embeds = self.norm(id_embeds)
|
1197
|
+
latents = id_embeds
|
1198
|
+
|
1199
|
+
clip_embeds = self.proj_in(self.clip_embeds)
|
1200
|
+
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
1201
|
+
|
1202
|
+
for block in self.layers:
|
1203
|
+
residual = latents
|
1204
|
+
latents = block(x, latents, residual)
|
1205
|
+
|
1206
|
+
latents = self.proj_out(latents)
|
1207
|
+
out = self.norm_out(latents)
|
1208
|
+
if self.shortcut:
|
1209
|
+
out = id_embeds + self.shortcut_scale * out
|
1210
|
+
return out
|
1211
|
+
|
1212
|
+
|
882
1213
|
class MultiIPAdapterImageProjection(nn.Module):
|
883
1214
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
884
1215
|
super().__init__()
|
885
1216
|
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
886
1217
|
|
887
|
-
def forward(self, image_embeds: List[torch.
|
1218
|
+
def forward(self, image_embeds: List[torch.Tensor]):
|
888
1219
|
projected_image_embeds = []
|
889
1220
|
|
890
1221
|
# currently, we accept `image_embeds` as
|