diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -110,7 +110,10 @@ def jax_memory_efficient_attention(
|
|
110
110
|
)
|
111
111
|
|
112
112
|
_, res = jax.lax.scan(
|
113
|
-
f=chunk_scanner,
|
113
|
+
f=chunk_scanner,
|
114
|
+
init=0,
|
115
|
+
xs=None,
|
116
|
+
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
|
114
117
|
)
|
115
118
|
|
116
119
|
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
@@ -138,6 +141,7 @@ class FlaxAttention(nn.Module):
|
|
138
141
|
Parameters `dtype`
|
139
142
|
|
140
143
|
"""
|
144
|
+
|
141
145
|
query_dim: int
|
142
146
|
heads: int = 8
|
143
147
|
dim_head: int = 64
|
@@ -262,6 +266,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
|
262
266
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
263
267
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
264
268
|
"""
|
269
|
+
|
265
270
|
dim: int
|
266
271
|
n_heads: int
|
267
272
|
d_head: int
|
@@ -347,6 +352,7 @@ class FlaxTransformer2DModel(nn.Module):
|
|
347
352
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
348
353
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
349
354
|
"""
|
355
|
+
|
350
356
|
in_channels: int
|
351
357
|
n_heads: int
|
352
358
|
d_head: int
|
@@ -442,6 +448,7 @@ class FlaxFeedForward(nn.Module):
|
|
442
448
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
443
449
|
Parameters `dtype`
|
444
450
|
"""
|
451
|
+
|
445
452
|
dim: int
|
446
453
|
dropout: float = 0.0
|
447
454
|
dtype: jnp.dtype = jnp.float32
|
@@ -471,6 +478,7 @@ class FlaxGEGLU(nn.Module):
|
|
471
478
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
472
479
|
Parameters `dtype`
|
473
480
|
"""
|
481
|
+
|
474
482
|
dim: int
|
475
483
|
dropout: float = 0.0
|
476
484
|
dtype: jnp.dtype = jnp.float32
|
@@ -109,15 +109,19 @@ class Attention(nn.Module):
|
|
109
109
|
residual_connection: bool = False,
|
110
110
|
_from_deprecated_attn_block: bool = False,
|
111
111
|
processor: Optional["AttnProcessor"] = None,
|
112
|
+
out_dim: int = None,
|
112
113
|
):
|
113
114
|
super().__init__()
|
114
|
-
self.inner_dim = dim_head * heads
|
115
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
116
|
+
self.query_dim = query_dim
|
115
117
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
116
118
|
self.upcast_attention = upcast_attention
|
117
119
|
self.upcast_softmax = upcast_softmax
|
118
120
|
self.rescale_output_factor = rescale_output_factor
|
119
121
|
self.residual_connection = residual_connection
|
120
122
|
self.dropout = dropout
|
123
|
+
self.fused_projections = False
|
124
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
121
125
|
|
122
126
|
# we make use of this private variable to know whether this class is loaded
|
123
127
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -126,7 +130,7 @@ class Attention(nn.Module):
|
|
126
130
|
self.scale_qk = scale_qk
|
127
131
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
128
132
|
|
129
|
-
self.heads = heads
|
133
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
130
134
|
# for slice_size > 0 the attention score computation
|
131
135
|
# is split across the batch axis to save memory
|
132
136
|
# You can set slice_size with `set_attention_slice`
|
@@ -178,6 +182,7 @@ class Attention(nn.Module):
|
|
178
182
|
else:
|
179
183
|
linear_cls = LoRACompatibleLinear
|
180
184
|
|
185
|
+
self.linear_cls = linear_cls
|
181
186
|
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
182
187
|
|
183
188
|
if not self.only_cross_attention:
|
@@ -193,7 +198,7 @@ class Attention(nn.Module):
|
|
193
198
|
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
194
199
|
|
195
200
|
self.to_out = nn.ModuleList([])
|
196
|
-
self.to_out.append(linear_cls(self.inner_dim,
|
201
|
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
197
202
|
self.to_out.append(nn.Dropout(dropout))
|
198
203
|
|
199
204
|
# set attention processor
|
@@ -690,6 +695,32 @@ class Attention(nn.Module):
|
|
690
695
|
|
691
696
|
return encoder_hidden_states
|
692
697
|
|
698
|
+
@torch.no_grad()
|
699
|
+
def fuse_projections(self, fuse=True):
|
700
|
+
is_cross_attention = self.cross_attention_dim != self.query_dim
|
701
|
+
device = self.to_q.weight.data.device
|
702
|
+
dtype = self.to_q.weight.data.dtype
|
703
|
+
|
704
|
+
if not is_cross_attention:
|
705
|
+
# fetch weight matrices.
|
706
|
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
707
|
+
in_features = concatenated_weights.shape[1]
|
708
|
+
out_features = concatenated_weights.shape[0]
|
709
|
+
|
710
|
+
# create a new single projection layer and copy over the weights.
|
711
|
+
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
712
|
+
self.to_qkv.weight.copy_(concatenated_weights)
|
713
|
+
|
714
|
+
else:
|
715
|
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
716
|
+
in_features = concatenated_weights.shape[1]
|
717
|
+
out_features = concatenated_weights.shape[0]
|
718
|
+
|
719
|
+
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
720
|
+
self.to_kv.weight.copy_(concatenated_weights)
|
721
|
+
|
722
|
+
self.fused_projections = fuse
|
723
|
+
|
693
724
|
|
694
725
|
class AttnProcessor:
|
695
726
|
r"""
|
@@ -1182,9 +1213,6 @@ class AttnProcessor2_0:
|
|
1182
1213
|
scale: float = 1.0,
|
1183
1214
|
) -> torch.FloatTensor:
|
1184
1215
|
residual = hidden_states
|
1185
|
-
|
1186
|
-
args = () if USE_PEFT_BACKEND else (scale,)
|
1187
|
-
|
1188
1216
|
if attn.spatial_norm is not None:
|
1189
1217
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1190
1218
|
|
@@ -1251,6 +1279,103 @@ class AttnProcessor2_0:
|
|
1251
1279
|
return hidden_states
|
1252
1280
|
|
1253
1281
|
|
1282
|
+
class FusedAttnProcessor2_0:
|
1283
|
+
r"""
|
1284
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1285
|
+
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
|
1286
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
1287
|
+
|
1288
|
+
<Tip warning={true}>
|
1289
|
+
|
1290
|
+
This API is currently 🧪 experimental in nature and can change in future.
|
1291
|
+
|
1292
|
+
</Tip>
|
1293
|
+
"""
|
1294
|
+
|
1295
|
+
def __init__(self):
|
1296
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1297
|
+
raise ImportError(
|
1298
|
+
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
|
1299
|
+
)
|
1300
|
+
|
1301
|
+
def __call__(
|
1302
|
+
self,
|
1303
|
+
attn: Attention,
|
1304
|
+
hidden_states: torch.FloatTensor,
|
1305
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1306
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1307
|
+
temb: Optional[torch.FloatTensor] = None,
|
1308
|
+
scale: float = 1.0,
|
1309
|
+
) -> torch.FloatTensor:
|
1310
|
+
residual = hidden_states
|
1311
|
+
if attn.spatial_norm is not None:
|
1312
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1313
|
+
|
1314
|
+
input_ndim = hidden_states.ndim
|
1315
|
+
|
1316
|
+
if input_ndim == 4:
|
1317
|
+
batch_size, channel, height, width = hidden_states.shape
|
1318
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1319
|
+
|
1320
|
+
batch_size, sequence_length, _ = (
|
1321
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1322
|
+
)
|
1323
|
+
|
1324
|
+
if attention_mask is not None:
|
1325
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1326
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1327
|
+
# (batch, heads, source_length, target_length)
|
1328
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1329
|
+
|
1330
|
+
if attn.group_norm is not None:
|
1331
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1332
|
+
|
1333
|
+
args = () if USE_PEFT_BACKEND else (scale,)
|
1334
|
+
if encoder_hidden_states is None:
|
1335
|
+
qkv = attn.to_qkv(hidden_states, *args)
|
1336
|
+
split_size = qkv.shape[-1] // 3
|
1337
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1338
|
+
else:
|
1339
|
+
if attn.norm_cross:
|
1340
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1341
|
+
query = attn.to_q(hidden_states, *args)
|
1342
|
+
|
1343
|
+
kv = attn.to_kv(encoder_hidden_states, *args)
|
1344
|
+
split_size = kv.shape[-1] // 2
|
1345
|
+
key, value = torch.split(kv, split_size, dim=-1)
|
1346
|
+
|
1347
|
+
inner_dim = key.shape[-1]
|
1348
|
+
head_dim = inner_dim // attn.heads
|
1349
|
+
|
1350
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1351
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1352
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1353
|
+
|
1354
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1355
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1356
|
+
hidden_states = F.scaled_dot_product_attention(
|
1357
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1358
|
+
)
|
1359
|
+
|
1360
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1361
|
+
hidden_states = hidden_states.to(query.dtype)
|
1362
|
+
|
1363
|
+
# linear proj
|
1364
|
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
1365
|
+
# dropout
|
1366
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1367
|
+
|
1368
|
+
if input_ndim == 4:
|
1369
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1370
|
+
|
1371
|
+
if attn.residual_connection:
|
1372
|
+
hidden_states = hidden_states + residual
|
1373
|
+
|
1374
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1375
|
+
|
1376
|
+
return hidden_states
|
1377
|
+
|
1378
|
+
|
1254
1379
|
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1255
1380
|
r"""
|
1256
1381
|
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
@@ -1975,6 +2100,250 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|
1975
2100
|
return attn.processor(attn, hidden_states, *args, **kwargs)
|
1976
2101
|
|
1977
2102
|
|
2103
|
+
class IPAdapterAttnProcessor(nn.Module):
|
2104
|
+
r"""
|
2105
|
+
Attention processor for IP-Adapater.
|
2106
|
+
|
2107
|
+
Args:
|
2108
|
+
hidden_size (`int`):
|
2109
|
+
The hidden size of the attention layer.
|
2110
|
+
cross_attention_dim (`int`):
|
2111
|
+
The number of channels in the `encoder_hidden_states`.
|
2112
|
+
num_tokens (`int`, defaults to 4):
|
2113
|
+
The context length of the image features.
|
2114
|
+
scale (`float`, defaults to 1.0):
|
2115
|
+
the weight scale of image prompt.
|
2116
|
+
"""
|
2117
|
+
|
2118
|
+
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
2119
|
+
super().__init__()
|
2120
|
+
|
2121
|
+
self.hidden_size = hidden_size
|
2122
|
+
self.cross_attention_dim = cross_attention_dim
|
2123
|
+
self.num_tokens = num_tokens
|
2124
|
+
self.scale = scale
|
2125
|
+
|
2126
|
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
2127
|
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
2128
|
+
|
2129
|
+
def __call__(
|
2130
|
+
self,
|
2131
|
+
attn,
|
2132
|
+
hidden_states,
|
2133
|
+
encoder_hidden_states=None,
|
2134
|
+
attention_mask=None,
|
2135
|
+
temb=None,
|
2136
|
+
scale=1.0,
|
2137
|
+
):
|
2138
|
+
if scale != 1.0:
|
2139
|
+
logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
|
2140
|
+
residual = hidden_states
|
2141
|
+
|
2142
|
+
if attn.spatial_norm is not None:
|
2143
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2144
|
+
|
2145
|
+
input_ndim = hidden_states.ndim
|
2146
|
+
|
2147
|
+
if input_ndim == 4:
|
2148
|
+
batch_size, channel, height, width = hidden_states.shape
|
2149
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2150
|
+
|
2151
|
+
batch_size, sequence_length, _ = (
|
2152
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2153
|
+
)
|
2154
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2155
|
+
|
2156
|
+
if attn.group_norm is not None:
|
2157
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2158
|
+
|
2159
|
+
query = attn.to_q(hidden_states)
|
2160
|
+
|
2161
|
+
if encoder_hidden_states is None:
|
2162
|
+
encoder_hidden_states = hidden_states
|
2163
|
+
elif attn.norm_cross:
|
2164
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2165
|
+
|
2166
|
+
# split hidden states
|
2167
|
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
2168
|
+
encoder_hidden_states, ip_hidden_states = (
|
2169
|
+
encoder_hidden_states[:, :end_pos, :],
|
2170
|
+
encoder_hidden_states[:, end_pos:, :],
|
2171
|
+
)
|
2172
|
+
|
2173
|
+
key = attn.to_k(encoder_hidden_states)
|
2174
|
+
value = attn.to_v(encoder_hidden_states)
|
2175
|
+
|
2176
|
+
query = attn.head_to_batch_dim(query)
|
2177
|
+
key = attn.head_to_batch_dim(key)
|
2178
|
+
value = attn.head_to_batch_dim(value)
|
2179
|
+
|
2180
|
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
2181
|
+
hidden_states = torch.bmm(attention_probs, value)
|
2182
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
2183
|
+
|
2184
|
+
# for ip-adapter
|
2185
|
+
ip_key = self.to_k_ip(ip_hidden_states)
|
2186
|
+
ip_value = self.to_v_ip(ip_hidden_states)
|
2187
|
+
|
2188
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
2189
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
2190
|
+
|
2191
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
2192
|
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
2193
|
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
2194
|
+
|
2195
|
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
2196
|
+
|
2197
|
+
# linear proj
|
2198
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2199
|
+
# dropout
|
2200
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2201
|
+
|
2202
|
+
if input_ndim == 4:
|
2203
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2204
|
+
|
2205
|
+
if attn.residual_connection:
|
2206
|
+
hidden_states = hidden_states + residual
|
2207
|
+
|
2208
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2209
|
+
|
2210
|
+
return hidden_states
|
2211
|
+
|
2212
|
+
|
2213
|
+
class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
2214
|
+
r"""
|
2215
|
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
2216
|
+
|
2217
|
+
Args:
|
2218
|
+
hidden_size (`int`):
|
2219
|
+
The hidden size of the attention layer.
|
2220
|
+
cross_attention_dim (`int`):
|
2221
|
+
The number of channels in the `encoder_hidden_states`.
|
2222
|
+
num_tokens (`int`, defaults to 4):
|
2223
|
+
The context length of the image features.
|
2224
|
+
scale (`float`, defaults to 1.0):
|
2225
|
+
the weight scale of image prompt.
|
2226
|
+
"""
|
2227
|
+
|
2228
|
+
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
2229
|
+
super().__init__()
|
2230
|
+
|
2231
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2232
|
+
raise ImportError(
|
2233
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2234
|
+
)
|
2235
|
+
|
2236
|
+
self.hidden_size = hidden_size
|
2237
|
+
self.cross_attention_dim = cross_attention_dim
|
2238
|
+
self.num_tokens = num_tokens
|
2239
|
+
self.scale = scale
|
2240
|
+
|
2241
|
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
2242
|
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
2243
|
+
|
2244
|
+
def __call__(
|
2245
|
+
self,
|
2246
|
+
attn,
|
2247
|
+
hidden_states,
|
2248
|
+
encoder_hidden_states=None,
|
2249
|
+
attention_mask=None,
|
2250
|
+
temb=None,
|
2251
|
+
scale=1.0,
|
2252
|
+
):
|
2253
|
+
if scale != 1.0:
|
2254
|
+
logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
|
2255
|
+
residual = hidden_states
|
2256
|
+
|
2257
|
+
if attn.spatial_norm is not None:
|
2258
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2259
|
+
|
2260
|
+
input_ndim = hidden_states.ndim
|
2261
|
+
|
2262
|
+
if input_ndim == 4:
|
2263
|
+
batch_size, channel, height, width = hidden_states.shape
|
2264
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2265
|
+
|
2266
|
+
batch_size, sequence_length, _ = (
|
2267
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2268
|
+
)
|
2269
|
+
|
2270
|
+
if attention_mask is not None:
|
2271
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2272
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2273
|
+
# (batch, heads, source_length, target_length)
|
2274
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2275
|
+
|
2276
|
+
if attn.group_norm is not None:
|
2277
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2278
|
+
|
2279
|
+
query = attn.to_q(hidden_states)
|
2280
|
+
|
2281
|
+
if encoder_hidden_states is None:
|
2282
|
+
encoder_hidden_states = hidden_states
|
2283
|
+
elif attn.norm_cross:
|
2284
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2285
|
+
|
2286
|
+
# split hidden states
|
2287
|
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
2288
|
+
encoder_hidden_states, ip_hidden_states = (
|
2289
|
+
encoder_hidden_states[:, :end_pos, :],
|
2290
|
+
encoder_hidden_states[:, end_pos:, :],
|
2291
|
+
)
|
2292
|
+
|
2293
|
+
key = attn.to_k(encoder_hidden_states)
|
2294
|
+
value = attn.to_v(encoder_hidden_states)
|
2295
|
+
|
2296
|
+
inner_dim = key.shape[-1]
|
2297
|
+
head_dim = inner_dim // attn.heads
|
2298
|
+
|
2299
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2300
|
+
|
2301
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2302
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2303
|
+
|
2304
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2305
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2306
|
+
hidden_states = F.scaled_dot_product_attention(
|
2307
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2308
|
+
)
|
2309
|
+
|
2310
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2311
|
+
hidden_states = hidden_states.to(query.dtype)
|
2312
|
+
|
2313
|
+
# for ip-adapter
|
2314
|
+
ip_key = self.to_k_ip(ip_hidden_states)
|
2315
|
+
ip_value = self.to_v_ip(ip_hidden_states)
|
2316
|
+
|
2317
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2318
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2319
|
+
|
2320
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2321
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2322
|
+
ip_hidden_states = F.scaled_dot_product_attention(
|
2323
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2324
|
+
)
|
2325
|
+
|
2326
|
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2327
|
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
2328
|
+
|
2329
|
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
2330
|
+
|
2331
|
+
# linear proj
|
2332
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2333
|
+
# dropout
|
2334
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2335
|
+
|
2336
|
+
if input_ndim == 4:
|
2337
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2338
|
+
|
2339
|
+
if attn.residual_connection:
|
2340
|
+
hidden_states = hidden_states + residual
|
2341
|
+
|
2342
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2343
|
+
|
2344
|
+
return hidden_states
|
2345
|
+
|
2346
|
+
|
1978
2347
|
LORA_ATTENTION_PROCESSORS = (
|
1979
2348
|
LoRAAttnProcessor,
|
1980
2349
|
LoRAAttnProcessor2_0,
|
@@ -1998,11 +2367,14 @@ CROSS_ATTENTION_PROCESSORS = (
|
|
1998
2367
|
LoRAAttnProcessor,
|
1999
2368
|
LoRAAttnProcessor2_0,
|
2000
2369
|
LoRAXFormersAttnProcessor,
|
2370
|
+
IPAdapterAttnProcessor,
|
2371
|
+
IPAdapterAttnProcessor2_0,
|
2001
2372
|
)
|
2002
2373
|
|
2003
2374
|
AttentionProcessor = Union[
|
2004
2375
|
AttnProcessor,
|
2005
2376
|
AttnProcessor2_0,
|
2377
|
+
FusedAttnProcessor2_0,
|
2006
2378
|
XFormersAttnProcessor,
|
2007
2379
|
SlicedAttnProcessor,
|
2008
2380
|
AttnAddedKVProcessor,
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
2
|
+
from .autoencoder_kl import AutoencoderKL
|
3
|
+
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
4
|
+
from .autoencoder_tiny import AutoencoderTiny
|
5
|
+
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
@@ -16,10 +16,10 @@ from typing import Optional, Tuple, Union
|
|
16
16
|
import torch
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
|
-
from
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
19
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
+
from ...utils.accelerate_utils import apply_forward_hook
|
21
|
+
from ..modeling_outputs import AutoencoderKLOutput
|
22
|
+
from ..modeling_utils import ModelMixin
|
23
23
|
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
24
24
|
|
25
25
|
|
@@ -65,11 +65,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
65
65
|
self,
|
66
66
|
in_channels: int = 3,
|
67
67
|
out_channels: int = 3,
|
68
|
-
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
69
|
-
down_block_out_channels: Tuple[int] = (64,),
|
68
|
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
69
|
+
down_block_out_channels: Tuple[int, ...] = (64,),
|
70
70
|
layers_per_down_block: int = 1,
|
71
|
-
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
72
|
-
up_block_out_channels: Tuple[int] = (64,),
|
71
|
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
72
|
+
up_block_out_channels: Tuple[int, ...] = (64,),
|
73
73
|
layers_per_up_block: int = 1,
|
74
74
|
act_fn: str = "silu",
|
75
75
|
latent_channels: int = 4,
|
@@ -108,8 +108,13 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
108
108
|
self.use_slicing = False
|
109
109
|
self.use_tiling = False
|
110
110
|
|
111
|
+
self.register_to_config(block_out_channels=up_block_out_channels)
|
112
|
+
self.register_to_config(force_upcast=False)
|
113
|
+
|
111
114
|
@apply_forward_hook
|
112
|
-
def encode(
|
115
|
+
def encode(
|
116
|
+
self, x: torch.FloatTensor, return_dict: bool = True
|
117
|
+
) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
|
113
118
|
h = self.encoder(x)
|
114
119
|
moments = self.quant_conv(h)
|
115
120
|
posterior = DiagonalGaussianDistribution(moments)
|
@@ -125,7 +130,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
125
130
|
image: Optional[torch.FloatTensor] = None,
|
126
131
|
mask: Optional[torch.FloatTensor] = None,
|
127
132
|
return_dict: bool = True,
|
128
|
-
) -> Union[DecoderOutput, torch.FloatTensor]:
|
133
|
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
129
134
|
z = self.post_quant_conv(z)
|
130
135
|
dec = self.decoder(z, image, mask)
|
131
136
|
|
@@ -142,7 +147,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
142
147
|
image: Optional[torch.FloatTensor] = None,
|
143
148
|
mask: Optional[torch.FloatTensor] = None,
|
144
149
|
return_dict: bool = True,
|
145
|
-
) -> Union[DecoderOutput, torch.FloatTensor]:
|
150
|
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
146
151
|
decoded = self._decode(z, image, mask).sample
|
147
152
|
|
148
153
|
if not return_dict:
|
@@ -157,7 +162,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
157
162
|
sample_posterior: bool = False,
|
158
163
|
return_dict: bool = True,
|
159
164
|
generator: Optional[torch.Generator] = None,
|
160
|
-
) -> Union[DecoderOutput, torch.FloatTensor]:
|
165
|
+
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
161
166
|
r"""
|
162
167
|
Args:
|
163
168
|
sample (`torch.FloatTensor`): Input sample.
|
@@ -11,41 +11,27 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from dataclasses import dataclass
|
15
14
|
from typing import Dict, Optional, Tuple, Union
|
16
15
|
|
17
16
|
import torch
|
18
17
|
import torch.nn as nn
|
19
18
|
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from ..
|
24
|
-
from .attention_processor import (
|
19
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
+
from ...loaders import FromOriginalVAEMixin
|
21
|
+
from ...utils.accelerate_utils import apply_forward_hook
|
22
|
+
from ..attention_processor import (
|
25
23
|
ADDED_KV_ATTENTION_PROCESSORS,
|
26
24
|
CROSS_ATTENTION_PROCESSORS,
|
25
|
+
Attention,
|
27
26
|
AttentionProcessor,
|
28
27
|
AttnAddedKVProcessor,
|
29
28
|
AttnProcessor,
|
30
29
|
)
|
31
|
-
from
|
30
|
+
from ..modeling_outputs import AutoencoderKLOutput
|
31
|
+
from ..modeling_utils import ModelMixin
|
32
32
|
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
33
33
|
|
34
34
|
|
35
|
-
@dataclass
|
36
|
-
class AutoencoderKLOutput(BaseOutput):
|
37
|
-
"""
|
38
|
-
Output of AutoencoderKL encoding method.
|
39
|
-
|
40
|
-
Args:
|
41
|
-
latent_dist (`DiagonalGaussianDistribution`):
|
42
|
-
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
43
|
-
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
44
|
-
"""
|
45
|
-
|
46
|
-
latent_dist: "DiagonalGaussianDistribution"
|
47
|
-
|
48
|
-
|
49
35
|
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
50
36
|
r"""
|
51
37
|
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
@@ -322,13 +308,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
322
308
|
|
323
309
|
return DecoderOutput(sample=decoded)
|
324
310
|
|
325
|
-
def blend_v(self, a, b, blend_extent):
|
311
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
326
312
|
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
327
313
|
for y in range(blend_extent):
|
328
314
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
329
315
|
return b
|
330
316
|
|
331
|
-
def blend_h(self, a, b, blend_extent):
|
317
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
332
318
|
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
333
319
|
for x in range(blend_extent):
|
334
320
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
@@ -463,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
463
449
|
return (dec,)
|
464
450
|
|
465
451
|
return DecoderOutput(sample=dec)
|
452
|
+
|
453
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
454
|
+
def fuse_qkv_projections(self):
|
455
|
+
"""
|
456
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
457
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
458
|
+
|
459
|
+
<Tip warning={true}>
|
460
|
+
|
461
|
+
This API is 🧪 experimental.
|
462
|
+
|
463
|
+
</Tip>
|
464
|
+
"""
|
465
|
+
self.original_attn_processors = None
|
466
|
+
|
467
|
+
for _, attn_processor in self.attn_processors.items():
|
468
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
469
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
470
|
+
|
471
|
+
self.original_attn_processors = self.attn_processors
|
472
|
+
|
473
|
+
for module in self.modules():
|
474
|
+
if isinstance(module, Attention):
|
475
|
+
module.fuse_projections(fuse=True)
|
476
|
+
|
477
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
478
|
+
def unfuse_qkv_projections(self):
|
479
|
+
"""Disables the fused QKV projection if enabled.
|
480
|
+
|
481
|
+
<Tip warning={true}>
|
482
|
+
|
483
|
+
This API is 🧪 experimental.
|
484
|
+
|
485
|
+
</Tip>
|
486
|
+
|
487
|
+
"""
|
488
|
+
if self.original_attn_processors is not None:
|
489
|
+
self.set_attn_processor(self.original_attn_processors)
|