diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/models/adapter.py
CHANGED
@@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin):
|
|
30
30
|
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
31
31
|
user-assigned weighting.
|
32
32
|
|
33
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
34
|
-
|
33
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
|
34
|
+
or saving.
|
35
35
|
|
36
|
-
|
36
|
+
Args:
|
37
37
|
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
38
38
|
A list of `T2IAdapter` model instances.
|
39
39
|
"""
|
@@ -77,11 +77,13 @@ class MultiAdapter(ModelMixin):
|
|
77
77
|
r"""
|
78
78
|
Args:
|
79
79
|
xs (`torch.Tensor`):
|
80
|
-
(batch, channel, height, width) input images for multiple adapter
|
81
|
-
|
80
|
+
A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
|
81
|
+
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
|
82
|
+
`num_adapter` * number of channel per image.
|
83
|
+
|
82
84
|
adapter_weights (`List[float]`, *optional*, defaults to None):
|
83
|
-
|
84
|
-
them together.
|
85
|
+
A list of floats representing the weights which will be multiplied by each adapter's output before
|
86
|
+
summing them together. If `None`, equal weights will be used for all adapters.
|
85
87
|
"""
|
86
88
|
if adapter_weights is None:
|
87
89
|
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
|
@@ -109,24 +111,24 @@ class MultiAdapter(ModelMixin):
|
|
109
111
|
variant: Optional[str] = None,
|
110
112
|
):
|
111
113
|
"""
|
112
|
-
Save a model and its configuration file to a directory,
|
114
|
+
Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
|
113
115
|
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
|
114
116
|
|
115
|
-
|
117
|
+
Args:
|
116
118
|
save_directory (`str` or `os.PathLike`):
|
117
|
-
|
118
|
-
is_main_process (`bool`,
|
119
|
-
|
120
|
-
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only
|
121
|
-
the main process to avoid race conditions.
|
119
|
+
The directory where the model will be saved. If the directory does not exist, it will be created.
|
120
|
+
is_main_process (`bool`, optional, defaults=True):
|
121
|
+
Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
|
122
|
+
TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
|
123
|
+
for the main process to avoid race conditions.
|
122
124
|
save_function (`Callable`):
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
safe_serialization (`bool`,
|
127
|
-
|
125
|
+
Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
|
126
|
+
`torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
|
127
|
+
variable.
|
128
|
+
safe_serialization (`bool`, optional, defaults=True):
|
129
|
+
If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
|
128
130
|
variant (`str`, *optional*):
|
129
|
-
If specified, weights are saved in the format pytorch_model.<variant>.bin
|
131
|
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
130
132
|
"""
|
131
133
|
idx = 0
|
132
134
|
model_path_to_save = save_directory
|
@@ -145,19 +147,17 @@ class MultiAdapter(ModelMixin):
|
|
145
147
|
@classmethod
|
146
148
|
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
147
149
|
r"""
|
148
|
-
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
|
150
|
+
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
|
149
151
|
|
150
152
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
151
|
-
the model,
|
153
|
+
the model, set it back to training mode using `model.train()`.
|
152
154
|
|
153
|
-
|
154
|
-
|
155
|
-
|
155
|
+
Warnings:
|
156
|
+
*Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
|
157
|
+
with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
|
158
|
+
from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
|
156
159
|
|
157
|
-
|
158
|
-
weights are discarded.
|
159
|
-
|
160
|
-
Parameters:
|
160
|
+
Args:
|
161
161
|
pretrained_model_path (`os.PathLike`):
|
162
162
|
A path to a *directory* containing model weights saved using
|
163
163
|
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
|
@@ -175,20 +175,20 @@ class MultiAdapter(ModelMixin):
|
|
175
175
|
more information about each option see [designing a device
|
176
176
|
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
177
177
|
max_memory (`Dict`, *optional*):
|
178
|
-
A dictionary device
|
179
|
-
GPU and the available CPU RAM if unset.
|
178
|
+
A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
|
179
|
+
available for each GPU and the available CPU RAM if unset.
|
180
180
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
181
181
|
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
182
182
|
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
183
183
|
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
184
184
|
setting this argument to `True` will raise an error.
|
185
185
|
variant (`str`, *optional*):
|
186
|
-
If specified load weights from `variant`
|
187
|
-
ignored when using `from_flax`.
|
186
|
+
If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
|
187
|
+
be ignored when using `from_flax`.
|
188
188
|
use_safetensors (`bool`, *optional*, defaults to `None`):
|
189
|
-
If
|
190
|
-
|
191
|
-
`safetensors`
|
189
|
+
If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
|
190
|
+
installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
|
191
|
+
`safetensors` is not used.
|
192
192
|
"""
|
193
193
|
idx = 0
|
194
194
|
adapters = []
|
@@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
|
223
223
|
and
|
224
224
|
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
|
225
225
|
|
226
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for the
|
227
|
-
|
226
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
|
227
|
+
downloading or saving.
|
228
228
|
|
229
|
-
|
230
|
-
in_channels (`int`, *optional*, defaults to 3):
|
231
|
-
|
232
|
-
image
|
229
|
+
Args:
|
230
|
+
in_channels (`int`, *optional*, defaults to `3`):
|
231
|
+
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
|
232
|
+
image.
|
233
233
|
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
234
|
-
The number of
|
235
|
-
|
236
|
-
num_res_blocks (`int`, *optional*, defaults to 2):
|
234
|
+
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
|
235
|
+
determines the number of downsample blocks in the adapter.
|
236
|
+
num_res_blocks (`int`, *optional*, defaults to `2`):
|
237
237
|
Number of ResNet blocks in each downsample block.
|
238
|
-
downscale_factor (`int`, *optional*, defaults to 8):
|
238
|
+
downscale_factor (`int`, *optional*, defaults to `8`):
|
239
239
|
A factor that determines the total downscale factor of the Adapter.
|
240
240
|
adapter_type (`str`, *optional*, defaults to `full_adapter`):
|
241
|
-
|
241
|
+
Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
|
242
242
|
"""
|
243
243
|
|
244
244
|
@register_to_config
|
@@ -393,7 +393,7 @@ class AdapterBlock(nn.Module):
|
|
393
393
|
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
|
394
394
|
`FullAdapterXL` models.
|
395
395
|
|
396
|
-
|
396
|
+
Args:
|
397
397
|
in_channels (`int`):
|
398
398
|
Number of channels of AdapterBlock's input.
|
399
399
|
out_channels (`int`):
|
@@ -401,7 +401,7 @@ class AdapterBlock(nn.Module):
|
|
401
401
|
num_res_blocks (`int`):
|
402
402
|
Number of ResNet blocks in the AdapterBlock.
|
403
403
|
down (`bool`, *optional*, defaults to `False`):
|
404
|
-
|
404
|
+
If `True`, perform downsampling on AdapterBlock's input.
|
405
405
|
"""
|
406
406
|
|
407
407
|
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
@@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module):
|
|
440
440
|
r"""
|
441
441
|
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
|
442
442
|
|
443
|
-
|
443
|
+
Args:
|
444
444
|
channels (`int`):
|
445
445
|
Number of channels of AdapterResnetBlock's input and output.
|
446
446
|
"""
|
@@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module):
|
|
518
518
|
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
|
519
519
|
`LightAdapter` model.
|
520
520
|
|
521
|
-
|
521
|
+
Args:
|
522
522
|
in_channels (`int`):
|
523
523
|
Number of channels of LightAdapterBlock's input.
|
524
524
|
out_channels (`int`):
|
@@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module):
|
|
526
526
|
num_res_blocks (`int`):
|
527
527
|
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
|
528
528
|
down (`bool`, *optional*, defaults to `False`):
|
529
|
-
|
529
|
+
If `True`, perform downsampling on LightAdapterBlock's input.
|
530
530
|
"""
|
531
531
|
|
532
532
|
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
@@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module):
|
|
561
561
|
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
|
562
562
|
architecture than `AdapterResnetBlock`.
|
563
563
|
|
564
|
-
|
564
|
+
Args:
|
565
565
|
channels (`int`):
|
566
566
|
Number of channels of LightAdapterResnetBlock's input and output.
|
567
567
|
"""
|
diffusers/models/attention.py
CHANGED
@@ -19,10 +19,10 @@ from torch import nn
|
|
19
19
|
|
20
20
|
from ..utils import deprecate, logging
|
21
21
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
|
-
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
22
|
+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
23
23
|
from .attention_processor import Attention, JointAttnProcessor2_0
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
|
-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
25
|
+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
26
26
|
|
27
27
|
|
28
28
|
logger = logging.get_logger(__name__)
|
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
|
|
100
100
|
processing of `context` conditions.
|
101
101
|
"""
|
102
102
|
|
103
|
-
def __init__(
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
dim: int,
|
106
|
+
num_attention_heads: int,
|
107
|
+
attention_head_dim: int,
|
108
|
+
context_pre_only: bool = False,
|
109
|
+
qk_norm: Optional[str] = None,
|
110
|
+
use_dual_attention: bool = False,
|
111
|
+
):
|
104
112
|
super().__init__()
|
105
113
|
|
114
|
+
self.use_dual_attention = use_dual_attention
|
106
115
|
self.context_pre_only = context_pre_only
|
107
116
|
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
108
117
|
|
109
|
-
|
118
|
+
if use_dual_attention:
|
119
|
+
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
120
|
+
else:
|
121
|
+
self.norm1 = AdaLayerNormZero(dim)
|
110
122
|
|
111
123
|
if context_norm_type == "ada_norm_continous":
|
112
124
|
self.norm1_context = AdaLayerNormContinuous(
|
@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
|
|
118
130
|
raise ValueError(
|
119
131
|
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
120
132
|
)
|
133
|
+
|
121
134
|
if hasattr(F, "scaled_dot_product_attention"):
|
122
135
|
processor = JointAttnProcessor2_0()
|
123
136
|
else:
|
124
137
|
raise ValueError(
|
125
138
|
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
126
139
|
)
|
140
|
+
|
127
141
|
self.attn = Attention(
|
128
142
|
query_dim=dim,
|
129
143
|
cross_attention_dim=None,
|
@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
|
|
134
148
|
context_pre_only=context_pre_only,
|
135
149
|
bias=True,
|
136
150
|
processor=processor,
|
151
|
+
qk_norm=qk_norm,
|
152
|
+
eps=1e-6,
|
137
153
|
)
|
138
154
|
|
155
|
+
if use_dual_attention:
|
156
|
+
self.attn2 = Attention(
|
157
|
+
query_dim=dim,
|
158
|
+
cross_attention_dim=None,
|
159
|
+
dim_head=attention_head_dim,
|
160
|
+
heads=num_attention_heads,
|
161
|
+
out_dim=dim,
|
162
|
+
bias=True,
|
163
|
+
processor=processor,
|
164
|
+
qk_norm=qk_norm,
|
165
|
+
eps=1e-6,
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
self.attn2 = None
|
169
|
+
|
139
170
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
140
171
|
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
141
172
|
|
@@ -157,9 +188,19 @@ class JointTransformerBlock(nn.Module):
|
|
157
188
|
self._chunk_dim = dim
|
158
189
|
|
159
190
|
def forward(
|
160
|
-
self,
|
191
|
+
self,
|
192
|
+
hidden_states: torch.FloatTensor,
|
193
|
+
encoder_hidden_states: torch.FloatTensor,
|
194
|
+
temb: torch.FloatTensor,
|
195
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
161
196
|
):
|
162
|
-
|
197
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
198
|
+
if self.use_dual_attention:
|
199
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
200
|
+
hidden_states, emb=temb
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
163
204
|
|
164
205
|
if self.context_pre_only:
|
165
206
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
@@ -170,13 +211,20 @@ class JointTransformerBlock(nn.Module):
|
|
170
211
|
|
171
212
|
# Attention.
|
172
213
|
attn_output, context_attn_output = self.attn(
|
173
|
-
hidden_states=norm_hidden_states,
|
214
|
+
hidden_states=norm_hidden_states,
|
215
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
216
|
+
**joint_attention_kwargs,
|
174
217
|
)
|
175
218
|
|
176
219
|
# Process attention outputs for the `hidden_states`.
|
177
220
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
178
221
|
hidden_states = hidden_states + attn_output
|
179
222
|
|
223
|
+
if self.use_dual_attention:
|
224
|
+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
225
|
+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
226
|
+
hidden_states = hidden_states + attn_output2
|
227
|
+
|
180
228
|
norm_hidden_states = self.norm2(hidden_states)
|
181
229
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
182
230
|
if self._chunk_size is not None:
|
@@ -972,15 +1020,32 @@ class FreeNoiseTransformerBlock(nn.Module):
|
|
972
1020
|
return frame_indices
|
973
1021
|
|
974
1022
|
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
975
|
-
if weighting_scheme == "
|
1023
|
+
if weighting_scheme == "flat":
|
1024
|
+
weights = [1.0] * num_frames
|
1025
|
+
|
1026
|
+
elif weighting_scheme == "pyramid":
|
976
1027
|
if num_frames % 2 == 0:
|
977
1028
|
# num_frames = 4 => [1, 2, 2, 1]
|
978
|
-
|
1029
|
+
mid = num_frames // 2
|
1030
|
+
weights = list(range(1, mid + 1))
|
979
1031
|
weights = weights + weights[::-1]
|
980
1032
|
else:
|
981
1033
|
# num_frames = 5 => [1, 2, 3, 2, 1]
|
982
|
-
|
983
|
-
weights =
|
1034
|
+
mid = (num_frames + 1) // 2
|
1035
|
+
weights = list(range(1, mid))
|
1036
|
+
weights = weights + [mid] + weights[::-1]
|
1037
|
+
|
1038
|
+
elif weighting_scheme == "delayed_reverse_sawtooth":
|
1039
|
+
if num_frames % 2 == 0:
|
1040
|
+
# num_frames = 4 => [0.01, 2, 2, 1]
|
1041
|
+
mid = num_frames // 2
|
1042
|
+
weights = [0.01] * (mid - 1) + [mid]
|
1043
|
+
weights = weights + list(range(mid, 0, -1))
|
1044
|
+
else:
|
1045
|
+
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
1046
|
+
mid = (num_frames + 1) // 2
|
1047
|
+
weights = [0.01] * mid
|
1048
|
+
weights = weights + list(range(mid, 0, -1))
|
984
1049
|
else:
|
985
1050
|
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
986
1051
|
|
@@ -1087,8 +1152,26 @@ class FreeNoiseTransformerBlock(nn.Module):
|
|
1087
1152
|
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
1088
1153
|
num_times_accumulated[:, frame_start:frame_end] += weights
|
1089
1154
|
|
1090
|
-
|
1091
|
-
|
1155
|
+
# TODO(aryan): Maybe this could be done in a better way.
|
1156
|
+
#
|
1157
|
+
# Previously, this was:
|
1158
|
+
# hidden_states = torch.where(
|
1159
|
+
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
1160
|
+
# )
|
1161
|
+
#
|
1162
|
+
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
|
1163
|
+
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
|
1164
|
+
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
|
1165
|
+
# looked into this deeply because other memory optimizations led to more pronounced reductions.
|
1166
|
+
hidden_states = torch.cat(
|
1167
|
+
[
|
1168
|
+
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
1169
|
+
for accumulated_split, num_times_split in zip(
|
1170
|
+
accumulated_values.split(self.context_length, dim=1),
|
1171
|
+
num_times_accumulated.split(self.context_length, dim=1),
|
1172
|
+
)
|
1173
|
+
],
|
1174
|
+
dim=1,
|
1092
1175
|
).to(dtype)
|
1093
1176
|
|
1094
1177
|
# 3. Feed-forward
|
@@ -1146,6 +1229,8 @@ class FeedForward(nn.Module):
|
|
1146
1229
|
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
1147
1230
|
elif activation_fn == "swiglu":
|
1148
1231
|
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
1232
|
+
elif activation_fn == "linear-silu":
|
1233
|
+
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
1149
1234
|
|
1150
1235
|
self.net = nn.ModuleList([])
|
1151
1236
|
# project in
|
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
|
|
216
216
|
hidden_states = jax_memory_efficient_attention(
|
217
217
|
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
218
218
|
)
|
219
|
-
|
220
219
|
hidden_states = hidden_states.transpose(1, 0, 2)
|
220
|
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
221
221
|
else:
|
222
222
|
# compute attentions
|
223
223
|
if self.split_head_dim:
|