diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,9 @@ from torch import nn
|
|
20
20
|
|
21
21
|
from ..configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ..models.embeddings import ImagePositionalEmbeddings
|
23
|
-
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
|
23
|
+
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
24
24
|
from .attention import BasicTransformerBlock
|
25
|
-
from .embeddings import
|
25
|
+
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
|
26
26
|
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
27
27
|
from .modeling_utils import ModelMixin
|
28
28
|
from .normalization import AdaLayerNormSingle
|
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
70
70
|
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
71
71
|
"""
|
72
72
|
|
73
|
+
_supports_gradient_checkpointing = True
|
74
|
+
|
73
75
|
@register_to_config
|
74
76
|
def __init__(
|
75
77
|
self,
|
@@ -233,10 +235,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
233
235
|
|
234
236
|
self.caption_projection = None
|
235
237
|
if caption_channels is not None:
|
236
|
-
self.caption_projection =
|
238
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
237
239
|
|
238
240
|
self.gradient_checkpointing = False
|
239
241
|
|
242
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
243
|
+
if hasattr(module, "gradient_checkpointing"):
|
244
|
+
module.gradient_checkpointing = value
|
245
|
+
|
240
246
|
def forward(
|
241
247
|
self,
|
242
248
|
hidden_states: torch.Tensor,
|
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
360
366
|
|
361
367
|
for block in self.transformer_blocks:
|
362
368
|
if self.training and self.gradient_checkpointing:
|
369
|
+
|
370
|
+
def create_custom_forward(module, return_dict=None):
|
371
|
+
def custom_forward(*inputs):
|
372
|
+
if return_dict is not None:
|
373
|
+
return module(*inputs, return_dict=return_dict)
|
374
|
+
else:
|
375
|
+
return module(*inputs)
|
376
|
+
|
377
|
+
return custom_forward
|
378
|
+
|
379
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
363
380
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
364
|
-
block,
|
381
|
+
create_custom_forward(block),
|
365
382
|
hidden_states,
|
366
383
|
attention_mask,
|
367
384
|
encoder_hidden_states,
|
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
369
386
|
timestep,
|
370
387
|
cross_attention_kwargs,
|
371
388
|
class_labels,
|
372
|
-
|
389
|
+
**ckpt_kwargs,
|
373
390
|
)
|
374
391
|
else:
|
375
392
|
hidden_states = block(
|
@@ -19,8 +19,10 @@ from torch import nn
|
|
19
19
|
|
20
20
|
from ..configuration_utils import ConfigMixin, register_to_config
|
21
21
|
from ..utils import BaseOutput
|
22
|
-
from .attention import BasicTransformerBlock
|
22
|
+
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
23
|
+
from .embeddings import TimestepEmbedding, Timesteps
|
23
24
|
from .modeling_utils import ModelMixin
|
25
|
+
from .resnet import AlphaBlender
|
24
26
|
|
25
27
|
|
26
28
|
@dataclass
|
@@ -195,3 +197,183 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
195
197
|
return (output,)
|
196
198
|
|
197
199
|
return TransformerTemporalModelOutput(sample=output)
|
200
|
+
|
201
|
+
|
202
|
+
class TransformerSpatioTemporalModel(nn.Module):
|
203
|
+
"""
|
204
|
+
A Transformer model for video-like data.
|
205
|
+
|
206
|
+
Parameters:
|
207
|
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
208
|
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
209
|
+
in_channels (`int`, *optional*):
|
210
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
211
|
+
out_channels (`int`, *optional*):
|
212
|
+
The number of channels in the output (specify if the input is **continuous**).
|
213
|
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
214
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
215
|
+
"""
|
216
|
+
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
num_attention_heads: int = 16,
|
220
|
+
attention_head_dim: int = 88,
|
221
|
+
in_channels: int = 320,
|
222
|
+
out_channels: Optional[int] = None,
|
223
|
+
num_layers: int = 1,
|
224
|
+
cross_attention_dim: Optional[int] = None,
|
225
|
+
):
|
226
|
+
super().__init__()
|
227
|
+
self.num_attention_heads = num_attention_heads
|
228
|
+
self.attention_head_dim = attention_head_dim
|
229
|
+
|
230
|
+
inner_dim = num_attention_heads * attention_head_dim
|
231
|
+
self.inner_dim = inner_dim
|
232
|
+
|
233
|
+
# 2. Define input layers
|
234
|
+
self.in_channels = in_channels
|
235
|
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
236
|
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
237
|
+
|
238
|
+
# 3. Define transformers blocks
|
239
|
+
self.transformer_blocks = nn.ModuleList(
|
240
|
+
[
|
241
|
+
BasicTransformerBlock(
|
242
|
+
inner_dim,
|
243
|
+
num_attention_heads,
|
244
|
+
attention_head_dim,
|
245
|
+
cross_attention_dim=cross_attention_dim,
|
246
|
+
)
|
247
|
+
for d in range(num_layers)
|
248
|
+
]
|
249
|
+
)
|
250
|
+
|
251
|
+
time_mix_inner_dim = inner_dim
|
252
|
+
self.temporal_transformer_blocks = nn.ModuleList(
|
253
|
+
[
|
254
|
+
TemporalBasicTransformerBlock(
|
255
|
+
inner_dim,
|
256
|
+
time_mix_inner_dim,
|
257
|
+
num_attention_heads,
|
258
|
+
attention_head_dim,
|
259
|
+
cross_attention_dim=cross_attention_dim,
|
260
|
+
)
|
261
|
+
for _ in range(num_layers)
|
262
|
+
]
|
263
|
+
)
|
264
|
+
|
265
|
+
time_embed_dim = in_channels * 4
|
266
|
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
267
|
+
self.time_proj = Timesteps(in_channels, True, 0)
|
268
|
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
269
|
+
|
270
|
+
# 4. Define output layers
|
271
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
272
|
+
# TODO: should use out_channels for continuous projections
|
273
|
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
274
|
+
|
275
|
+
self.gradient_checkpointing = False
|
276
|
+
|
277
|
+
def forward(
|
278
|
+
self,
|
279
|
+
hidden_states: torch.Tensor,
|
280
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
281
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
282
|
+
return_dict: bool = True,
|
283
|
+
):
|
284
|
+
"""
|
285
|
+
Args:
|
286
|
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
287
|
+
Input hidden_states.
|
288
|
+
num_frames (`int`):
|
289
|
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
290
|
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
291
|
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
292
|
+
self-attention.
|
293
|
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
294
|
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
295
|
+
images, 0 indicates that the input contains video frames.
|
296
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
297
|
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
298
|
+
tuple.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
302
|
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
303
|
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
304
|
+
"""
|
305
|
+
# 1. Input
|
306
|
+
batch_frames, _, height, width = hidden_states.shape
|
307
|
+
num_frames = image_only_indicator.shape[-1]
|
308
|
+
batch_size = batch_frames // num_frames
|
309
|
+
|
310
|
+
time_context = encoder_hidden_states
|
311
|
+
time_context_first_timestep = time_context[None, :].reshape(
|
312
|
+
batch_size, num_frames, -1, time_context.shape[-1]
|
313
|
+
)[:, 0]
|
314
|
+
time_context = time_context_first_timestep[None, :].broadcast_to(
|
315
|
+
height * width, batch_size, 1, time_context.shape[-1]
|
316
|
+
)
|
317
|
+
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
318
|
+
|
319
|
+
residual = hidden_states
|
320
|
+
|
321
|
+
hidden_states = self.norm(hidden_states)
|
322
|
+
inner_dim = hidden_states.shape[1]
|
323
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
324
|
+
hidden_states = self.proj_in(hidden_states)
|
325
|
+
|
326
|
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
327
|
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
328
|
+
num_frames_emb = num_frames_emb.reshape(-1)
|
329
|
+
t_emb = self.time_proj(num_frames_emb)
|
330
|
+
|
331
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
332
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
333
|
+
# there might be better ways to encapsulate this.
|
334
|
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
335
|
+
|
336
|
+
emb = self.time_pos_embed(t_emb)
|
337
|
+
emb = emb[:, None, :]
|
338
|
+
|
339
|
+
# 2. Blocks
|
340
|
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
341
|
+
if self.training and self.gradient_checkpointing:
|
342
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
343
|
+
block,
|
344
|
+
hidden_states,
|
345
|
+
None,
|
346
|
+
encoder_hidden_states,
|
347
|
+
None,
|
348
|
+
use_reentrant=False,
|
349
|
+
)
|
350
|
+
else:
|
351
|
+
hidden_states = block(
|
352
|
+
hidden_states,
|
353
|
+
encoder_hidden_states=encoder_hidden_states,
|
354
|
+
)
|
355
|
+
|
356
|
+
hidden_states_mix = hidden_states
|
357
|
+
hidden_states_mix = hidden_states_mix + emb
|
358
|
+
|
359
|
+
hidden_states_mix = temporal_block(
|
360
|
+
hidden_states_mix,
|
361
|
+
num_frames=num_frames,
|
362
|
+
encoder_hidden_states=time_context,
|
363
|
+
)
|
364
|
+
hidden_states = self.time_mixer(
|
365
|
+
x_spatial=hidden_states,
|
366
|
+
x_temporal=hidden_states_mix,
|
367
|
+
image_only_indicator=image_only_indicator,
|
368
|
+
)
|
369
|
+
|
370
|
+
# 3. Output
|
371
|
+
hidden_states = self.proj_out(hidden_states)
|
372
|
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
373
|
+
|
374
|
+
output = hidden_states + residual
|
375
|
+
|
376
|
+
if not return_dict:
|
377
|
+
return (output,)
|
378
|
+
|
379
|
+
return TransformerTemporalModelOutput(sample=output)
|
@@ -45,6 +45,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
45
45
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
46
46
|
Parameters `dtype`
|
47
47
|
"""
|
48
|
+
|
48
49
|
in_channels: int
|
49
50
|
out_channels: int
|
50
51
|
dropout: float = 0.0
|
@@ -125,6 +126,7 @@ class FlaxDownBlock2D(nn.Module):
|
|
125
126
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
126
127
|
Parameters `dtype`
|
127
128
|
"""
|
129
|
+
|
128
130
|
in_channels: int
|
129
131
|
out_channels: int
|
130
132
|
dropout: float = 0.0
|
@@ -190,6 +192,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
190
192
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
191
193
|
Parameters `dtype`
|
192
194
|
"""
|
195
|
+
|
193
196
|
in_channels: int
|
194
197
|
out_channels: int
|
195
198
|
prev_output_channel: int
|
@@ -275,6 +278,7 @@ class FlaxUpBlock2D(nn.Module):
|
|
275
278
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
276
279
|
Parameters `dtype`
|
277
280
|
"""
|
281
|
+
|
278
282
|
in_channels: int
|
279
283
|
out_channels: int
|
280
284
|
prev_output_channel: int
|
@@ -339,6 +343,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
339
343
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
340
344
|
Parameters `dtype`
|
341
345
|
"""
|
346
|
+
|
342
347
|
in_channels: int
|
343
348
|
dropout: float = 0.0
|
344
349
|
num_layers: int = 1
|
@@ -25,6 +25,7 @@ from .activations import get_activation
|
|
25
25
|
from .attention_processor import (
|
26
26
|
ADDED_KV_ATTENTION_PROCESSORS,
|
27
27
|
CROSS_ATTENTION_PROCESSORS,
|
28
|
+
Attention,
|
28
29
|
AttentionProcessor,
|
29
30
|
AttnAddedKVProcessor,
|
30
31
|
AttnProcessor,
|
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
794
795
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
795
796
|
setattr(upsample_block, k, None)
|
796
797
|
|
798
|
+
def fuse_qkv_projections(self):
|
799
|
+
"""
|
800
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
801
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
802
|
+
|
803
|
+
<Tip warning={true}>
|
804
|
+
|
805
|
+
This API is 🧪 experimental.
|
806
|
+
|
807
|
+
</Tip>
|
808
|
+
"""
|
809
|
+
self.original_attn_processors = None
|
810
|
+
|
811
|
+
for _, attn_processor in self.attn_processors.items():
|
812
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
813
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
814
|
+
|
815
|
+
self.original_attn_processors = self.attn_processors
|
816
|
+
|
817
|
+
for module in self.modules():
|
818
|
+
if isinstance(module, Attention):
|
819
|
+
module.fuse_projections(fuse=True)
|
820
|
+
|
821
|
+
def unfuse_qkv_projections(self):
|
822
|
+
"""Disables the fused QKV projection if enabled.
|
823
|
+
|
824
|
+
<Tip warning={true}>
|
825
|
+
|
826
|
+
This API is 🧪 experimental.
|
827
|
+
|
828
|
+
</Tip>
|
829
|
+
|
830
|
+
"""
|
831
|
+
if self.original_attn_processors is not None:
|
832
|
+
self.set_attn_processor(self.original_attn_processors)
|
833
|
+
|
797
834
|
def forward(
|
798
835
|
self,
|
799
836
|
sample: torch.FloatTensor,
|
@@ -1022,6 +1059,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
1022
1059
|
)
|
1023
1060
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1024
1061
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1062
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1063
|
+
if "image_embeds" not in added_cond_kwargs:
|
1064
|
+
raise ValueError(
|
1065
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1066
|
+
)
|
1067
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1068
|
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
1069
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
1070
|
+
|
1025
1071
|
# 2. pre-process
|
1026
1072
|
sample = self.conv_in(sample)
|
1027
1073
|
|
@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
100
100
|
sample_size: int = 32
|
101
101
|
in_channels: int = 4
|
102
102
|
out_channels: int = 4
|
103
|
-
down_block_types: Tuple[str] = (
|
103
|
+
down_block_types: Tuple[str, ...] = (
|
104
104
|
"CrossAttnDownBlock2D",
|
105
105
|
"CrossAttnDownBlock2D",
|
106
106
|
"CrossAttnDownBlock2D",
|
107
107
|
"DownBlock2D",
|
108
108
|
)
|
109
|
-
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
109
|
+
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
110
110
|
only_cross_attention: Union[bool, Tuple[bool]] = False
|
111
|
-
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
111
|
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
112
112
|
layers_per_block: int = 2
|
113
|
-
attention_head_dim: Union[int, Tuple[int]] = 8
|
114
|
-
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
113
|
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
114
|
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
115
115
|
cross_attention_dim: int = 1280
|
116
116
|
dropout: float = 0.0
|
117
117
|
use_linear_projection: bool = False
|
@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
120
120
|
freq_shift: int = 0
|
121
121
|
use_memory_efficient_attention: bool = False
|
122
122
|
split_head_dim: bool = False
|
123
|
-
transformer_layers_per_block: Union[int, Tuple[int]] = 1
|
123
|
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
|
124
124
|
addition_embed_type: Optional[str] = None
|
125
125
|
addition_time_embed_dim: Optional[int] = None
|
126
126
|
addition_embed_type_num_heads: int = 64
|
@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
158
158
|
}
|
159
159
|
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
160
160
|
|
161
|
-
def setup(self):
|
161
|
+
def setup(self) -> None:
|
162
162
|
block_out_channels = self.block_out_channels
|
163
163
|
time_embed_dim = block_out_channels[0] * 4
|
164
164
|
|
@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
320
320
|
|
321
321
|
def __call__(
|
322
322
|
self,
|
323
|
-
sample,
|
324
|
-
timesteps,
|
325
|
-
encoder_hidden_states,
|
323
|
+
sample: jnp.ndarray,
|
324
|
+
timesteps: Union[jnp.ndarray, float, int],
|
325
|
+
encoder_hidden_states: jnp.ndarray,
|
326
326
|
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
|
327
|
-
down_block_additional_residuals=None,
|
328
|
-
mid_block_additional_residual=None,
|
327
|
+
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
|
328
|
+
mid_block_additional_residual: Optional[jnp.ndarray] = None,
|
329
329
|
return_dict: bool = True,
|
330
330
|
train: bool = False,
|
331
|
-
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
331
|
+
) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
|
332
332
|
r"""
|
333
333
|
Args:
|
334
334
|
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|