diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py
RENAMED
@@ -7,20 +7,20 @@ import torch.nn.functional as F
|
|
7
7
|
|
8
8
|
from diffusers.utils import deprecate
|
9
9
|
|
10
|
-
from
|
11
|
-
from
|
12
|
-
from
|
13
|
-
from
|
14
|
-
from ...models.attention_processor import (
|
10
|
+
from ....configuration_utils import ConfigMixin, register_to_config
|
11
|
+
from ....models import ModelMixin
|
12
|
+
from ....models.activations import get_activation
|
13
|
+
from ....models.attention_processor import (
|
15
14
|
ADDED_KV_ATTENTION_PROCESSORS,
|
16
15
|
CROSS_ATTENTION_PROCESSORS,
|
16
|
+
Attention,
|
17
17
|
AttentionProcessor,
|
18
18
|
AttnAddedKVProcessor,
|
19
19
|
AttnAddedKVProcessor2_0,
|
20
20
|
AttnProcessor,
|
21
21
|
)
|
22
|
-
from
|
23
|
-
from
|
22
|
+
from ....models.dual_transformer_2d import DualTransformer2DModel
|
23
|
+
from ....models.embeddings import (
|
24
24
|
GaussianFourierProjection,
|
25
25
|
ImageHintTimeEmbedding,
|
26
26
|
ImageProjection,
|
@@ -31,10 +31,10 @@ from ...models.embeddings import (
|
|
31
31
|
TimestepEmbedding,
|
32
32
|
Timesteps,
|
33
33
|
)
|
34
|
-
from
|
35
|
-
from
|
36
|
-
from
|
37
|
-
from
|
34
|
+
from ....models.transformer_2d import Transformer2DModel
|
35
|
+
from ....models.unet_2d_condition import UNet2DConditionOutput
|
36
|
+
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
37
|
+
from ....utils.torch_utils import apply_freeu
|
38
38
|
|
39
39
|
|
40
40
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -50,6 +50,9 @@ def get_down_block(
|
|
50
50
|
resnet_eps,
|
51
51
|
resnet_act_fn,
|
52
52
|
num_attention_heads,
|
53
|
+
transformer_layers_per_block,
|
54
|
+
attention_type,
|
55
|
+
attention_head_dim,
|
53
56
|
resnet_groups=None,
|
54
57
|
cross_attention_dim=None,
|
55
58
|
downsample_padding=None,
|
@@ -113,6 +116,10 @@ def get_up_block(
|
|
113
116
|
resnet_eps,
|
114
117
|
resnet_act_fn,
|
115
118
|
num_attention_heads,
|
119
|
+
transformer_layers_per_block,
|
120
|
+
resolution_idx,
|
121
|
+
attention_type,
|
122
|
+
attention_head_dim,
|
116
123
|
resnet_groups=None,
|
117
124
|
cross_attention_dim=None,
|
118
125
|
dual_cross_attention=False,
|
@@ -425,10 +432,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
425
432
|
|
426
433
|
if num_attention_heads is not None:
|
427
434
|
raise ValueError(
|
428
|
-
"At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
|
429
|
-
" because of a naming issue as described in"
|
430
|
-
" https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
|
431
|
-
" `num_attention_heads` will only be supported in diffusers v0.19."
|
435
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
432
436
|
)
|
433
437
|
|
434
438
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
@@ -442,44 +446,37 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
442
446
|
# Check inputs
|
443
447
|
if len(down_block_types) != len(up_block_types):
|
444
448
|
raise ValueError(
|
445
|
-
"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
|
446
|
-
f" {down_block_types}. `up_block_types`: {up_block_types}."
|
449
|
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
447
450
|
)
|
448
451
|
|
449
452
|
if len(block_out_channels) != len(down_block_types):
|
450
453
|
raise ValueError(
|
451
|
-
"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
|
452
|
-
f" {block_out_channels}. `down_block_types`: {down_block_types}."
|
454
|
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
453
455
|
)
|
454
456
|
|
455
457
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
456
458
|
raise ValueError(
|
457
|
-
"Must provide the same number of `only_cross_attention` as `down_block_types`."
|
458
|
-
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
459
|
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
459
460
|
)
|
460
461
|
|
461
462
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
462
463
|
raise ValueError(
|
463
|
-
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
|
464
|
-
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
|
464
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
465
465
|
)
|
466
466
|
|
467
467
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
468
468
|
raise ValueError(
|
469
|
-
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
|
470
|
-
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
|
469
|
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
471
470
|
)
|
472
471
|
|
473
472
|
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
474
473
|
raise ValueError(
|
475
|
-
"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
|
476
|
-
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
474
|
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
477
475
|
)
|
478
476
|
|
479
477
|
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
480
478
|
raise ValueError(
|
481
|
-
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
|
482
|
-
f" {layers_per_block}. `down_block_types`: {down_block_types}."
|
479
|
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
483
480
|
)
|
484
481
|
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
485
482
|
for layer_number_per_block in transformer_layers_per_block:
|
@@ -897,8 +894,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
897
894
|
processor = AttnProcessor()
|
898
895
|
else:
|
899
896
|
raise ValueError(
|
900
|
-
"Cannot call `set_default_attn_processor` when attention processors are of type"
|
901
|
-
f" {next(iter(self.attn_processors.values()))}"
|
897
|
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
902
898
|
)
|
903
899
|
|
904
900
|
self.set_attn_processor(processor, _remove_lora=True)
|
@@ -1004,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1004
1000
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
1005
1001
|
setattr(upsample_block, k, None)
|
1006
1002
|
|
1003
|
+
def fuse_qkv_projections(self):
|
1004
|
+
"""
|
1005
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1006
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
1007
|
+
|
1008
|
+
<Tip warning={true}>
|
1009
|
+
|
1010
|
+
This API is 🧪 experimental.
|
1011
|
+
|
1012
|
+
</Tip>
|
1013
|
+
"""
|
1014
|
+
self.original_attn_processors = None
|
1015
|
+
|
1016
|
+
for _, attn_processor in self.attn_processors.items():
|
1017
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
1018
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
1019
|
+
|
1020
|
+
self.original_attn_processors = self.attn_processors
|
1021
|
+
|
1022
|
+
for module in self.modules():
|
1023
|
+
if isinstance(module, Attention):
|
1024
|
+
module.fuse_projections(fuse=True)
|
1025
|
+
|
1026
|
+
def unfuse_qkv_projections(self):
|
1027
|
+
"""Disables the fused QKV projection if enabled.
|
1028
|
+
|
1029
|
+
<Tip warning={true}>
|
1030
|
+
|
1031
|
+
This API is 🧪 experimental.
|
1032
|
+
|
1033
|
+
</Tip>
|
1034
|
+
|
1035
|
+
"""
|
1036
|
+
if self.original_attn_processors is not None:
|
1037
|
+
self.set_attn_processor(self.original_attn_processors)
|
1038
|
+
|
1007
1039
|
def forward(
|
1008
1040
|
self,
|
1009
1041
|
sample: torch.FloatTensor,
|
@@ -1166,8 +1198,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1166
1198
|
# Kandinsky 2.1 - style
|
1167
1199
|
if "image_embeds" not in added_cond_kwargs:
|
1168
1200
|
raise ValueError(
|
1169
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
|
1170
|
-
" the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1201
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1171
1202
|
)
|
1172
1203
|
|
1173
1204
|
image_embs = added_cond_kwargs.get("image_embeds")
|
@@ -1177,14 +1208,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1177
1208
|
# SDXL - style
|
1178
1209
|
if "text_embeds" not in added_cond_kwargs:
|
1179
1210
|
raise ValueError(
|
1180
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
1181
|
-
" the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1211
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1182
1212
|
)
|
1183
1213
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
1184
1214
|
if "time_ids" not in added_cond_kwargs:
|
1185
1215
|
raise ValueError(
|
1186
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
1187
|
-
" the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1216
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1188
1217
|
)
|
1189
1218
|
time_ids = added_cond_kwargs.get("time_ids")
|
1190
1219
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
@@ -1196,8 +1225,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1196
1225
|
# Kandinsky 2.2 - style
|
1197
1226
|
if "image_embeds" not in added_cond_kwargs:
|
1198
1227
|
raise ValueError(
|
1199
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
|
1200
|
-
" keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1228
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1201
1229
|
)
|
1202
1230
|
image_embs = added_cond_kwargs.get("image_embeds")
|
1203
1231
|
aug_emb = self.add_embedding(image_embs)
|
@@ -1205,8 +1233,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1205
1233
|
# Kandinsky 2.2 - style
|
1206
1234
|
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
1207
1235
|
raise ValueError(
|
1208
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
|
1209
|
-
" the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1236
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1210
1237
|
)
|
1211
1238
|
image_embs = added_cond_kwargs.get("image_embeds")
|
1212
1239
|
hint = added_cond_kwargs.get("hint")
|
@@ -1224,8 +1251,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1224
1251
|
# Kadinsky 2.1 - style
|
1225
1252
|
if "image_embeds" not in added_cond_kwargs:
|
1226
1253
|
raise ValueError(
|
1227
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
|
1228
|
-
" requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1254
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1229
1255
|
)
|
1230
1256
|
|
1231
1257
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
@@ -1234,11 +1260,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1234
1260
|
# Kandinsky 2.2 - style
|
1235
1261
|
if "image_embeds" not in added_cond_kwargs:
|
1236
1262
|
raise ValueError(
|
1237
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
|
1238
|
-
" the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1263
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1239
1264
|
)
|
1240
1265
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1241
1266
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1267
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1268
|
+
if "image_embeds" not in added_cond_kwargs:
|
1269
|
+
raise ValueError(
|
1270
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1271
|
+
)
|
1272
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1273
|
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
1274
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
1275
|
+
|
1242
1276
|
# 2. pre-process
|
1243
1277
|
sample = self.conv_in(sample)
|
1244
1278
|
|
@@ -1264,10 +1298,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1264
1298
|
deprecate(
|
1265
1299
|
"T2I should not use down_block_additional_residuals",
|
1266
1300
|
"1.3.0",
|
1267
|
-
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated
|
1268
|
-
|
1269
|
-
|
1270
|
-
" `down_intrablock_additional_residuals` instead. ",
|
1301
|
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1302
|
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1303
|
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1271
1304
|
standard_warn=False,
|
1272
1305
|
)
|
1273
1306
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
@@ -2102,8 +2135,7 @@ class UNetMidBlockFlat(nn.Module):
|
|
2102
2135
|
|
2103
2136
|
if attention_head_dim is None:
|
2104
2137
|
logger.warn(
|
2105
|
-
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
|
2106
|
-
f" `in_channels`: {in_channels}."
|
2138
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
2107
2139
|
)
|
2108
2140
|
attention_head_dim = in_channels
|
2109
2141
|
|
@@ -5,10 +5,10 @@ import PIL.Image
|
|
5
5
|
import torch
|
6
6
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
7
7
|
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
8
|
+
from ....models import AutoencoderKL, UNet2DConditionModel
|
9
|
+
from ....schedulers import KarrasDiffusionSchedulers
|
10
|
+
from ....utils import logging
|
11
|
+
from ...pipeline_utils import DiffusionPipeline
|
12
12
|
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
13
13
|
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
14
14
|
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
@@ -26,12 +26,12 @@ from transformers import (
|
|
26
26
|
CLIPVisionModelWithProjection,
|
27
27
|
)
|
28
28
|
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
32
|
-
from
|
33
|
-
from
|
34
|
-
from
|
29
|
+
from ....image_processor import VaeImageProcessor
|
30
|
+
from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
|
31
|
+
from ....schedulers import KarrasDiffusionSchedulers
|
32
|
+
from ....utils import deprecate, logging
|
33
|
+
from ....utils.torch_utils import randn_tensor
|
34
|
+
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
35
35
|
from .modeling_text_unet import UNetFlatConditionModel
|
36
36
|
|
37
37
|
|
@@ -58,6 +58,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
58
58
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
59
59
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
60
60
|
"""
|
61
|
+
|
61
62
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
62
63
|
|
63
64
|
tokenizer: CLIPTokenizer
|
@@ -21,12 +21,12 @@ import torch
|
|
21
21
|
import torch.utils.checkpoint
|
22
22
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
23
23
|
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
24
|
+
from ....image_processor import VaeImageProcessor
|
25
|
+
from ....models import AutoencoderKL, UNet2DConditionModel
|
26
|
+
from ....schedulers import KarrasDiffusionSchedulers
|
27
|
+
from ....utils import deprecate, logging
|
28
|
+
from ....utils.torch_utils import randn_tensor
|
29
|
+
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
30
30
|
|
31
31
|
|
32
32
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -52,6 +52,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|
52
52
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
53
53
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
54
54
|
"""
|
55
|
+
|
55
56
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
56
57
|
|
57
58
|
image_feature_extractor: CLIPImageProcessor
|
@@ -19,12 +19,12 @@ import torch
|
|
19
19
|
import torch.utils.checkpoint
|
20
20
|
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
|
21
21
|
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
22
|
+
from ....image_processor import VaeImageProcessor
|
23
|
+
from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
|
24
|
+
from ....schedulers import KarrasDiffusionSchedulers
|
25
|
+
from ....utils import deprecate, logging
|
26
|
+
from ....utils.torch_utils import randn_tensor
|
27
|
+
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
28
28
|
from .modeling_text_unet import UNetFlatConditionModel
|
29
29
|
|
30
30
|
|
@@ -51,6 +51,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
|
51
51
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
52
52
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
53
53
|
"""
|
54
|
+
|
54
55
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
55
56
|
|
56
57
|
tokenizer: CLIPTokenizer
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import TYPE_CHECKING
|
2
2
|
|
3
|
-
from
|
3
|
+
from ....utils import (
|
4
4
|
DIFFUSERS_SLOW_IMPORT,
|
5
5
|
OptionalDependencyNotAvailable,
|
6
6
|
_LazyModule,
|
@@ -16,7 +16,7 @@ try:
|
|
16
16
|
if not (is_transformers_available() and is_torch_available()):
|
17
17
|
raise OptionalDependencyNotAvailable()
|
18
18
|
except OptionalDependencyNotAvailable:
|
19
|
-
from
|
19
|
+
from ....utils.dummy_torch_and_transformers_objects import (
|
20
20
|
LearnedClassifierFreeSamplingEmbeddings,
|
21
21
|
VQDiffusionPipeline,
|
22
22
|
)
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
36
36
|
if not (is_transformers_available() and is_torch_available()):
|
37
37
|
raise OptionalDependencyNotAvailable()
|
38
38
|
except OptionalDependencyNotAvailable:
|
39
|
-
from
|
39
|
+
from ....utils.dummy_torch_and_transformers_objects import (
|
40
40
|
LearnedClassifierFreeSamplingEmbeddings,
|
41
41
|
VQDiffusionPipeline,
|
42
42
|
)
|
@@ -17,11 +17,11 @@ from typing import Callable, List, Optional, Tuple, Union
|
|
17
17
|
import torch
|
18
18
|
from transformers import CLIPTextModel, CLIPTokenizer
|
19
19
|
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
20
|
+
from ....configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ....models import ModelMixin, Transformer2DModel, VQModel
|
22
|
+
from ....schedulers import VQDiffusionScheduler
|
23
|
+
from ....utils import logging
|
24
|
+
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
25
25
|
|
26
26
|
|
27
27
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -181,7 +181,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
|
|
181
181
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
182
182
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
183
183
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
184
|
-
`._callback_tensor_inputs` attribute of your
|
184
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
185
185
|
|
186
186
|
Examples:
|
187
187
|
|
@@ -283,7 +283,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
|
|
283
283
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
284
284
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
285
285
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
286
|
-
`._callback_tensor_inputs` attribute of your
|
286
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
287
287
|
|
288
288
|
Examples:
|
289
289
|
|
@@ -759,7 +759,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
|
|
759
759
|
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
760
760
|
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
761
761
|
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
762
|
-
the `._callback_tensor_inputs` attribute of your
|
762
|
+
the `._callback_tensor_inputs` attribute of your pipeline class.
|
763
763
|
callback_on_step_end (`Callable`, *optional*):
|
764
764
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
765
765
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
@@ -768,7 +768,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
|
|
768
768
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
769
769
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
770
770
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
771
|
-
`._callback_tensor_inputs` attribute of your
|
771
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
772
772
|
|
773
773
|
|
774
774
|
Examples:
|
@@ -255,7 +255,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
|
255
255
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
256
256
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
257
257
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
258
|
-
`._callback_tensor_inputs` attribute of your
|
258
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
259
259
|
|
260
260
|
Examples:
|
261
261
|
|
@@ -362,7 +362,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
|
|
362
362
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
363
363
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
364
364
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
365
|
-
`._callback_tensor_inputs` attribute of your
|
365
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
366
366
|
|
367
367
|
Examples:
|
368
368
|
|
@@ -423,7 +423,7 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
|
|
423
423
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
424
424
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
425
425
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
426
|
-
`._callback_tensor_inputs` attribute of your
|
426
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
427
427
|
|
428
428
|
Examples:
|
429
429
|
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_torch_available,
|
9
|
+
is_transformers_available,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
_dummy_objects = {}
|
14
|
+
_import_structure = {}
|
15
|
+
|
16
|
+
try:
|
17
|
+
if not (is_transformers_available() and is_torch_available()):
|
18
|
+
raise OptionalDependencyNotAvailable()
|
19
|
+
except OptionalDependencyNotAvailable:
|
20
|
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
21
|
+
|
22
|
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
23
|
+
else:
|
24
|
+
_import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]
|
25
|
+
_import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]
|
26
|
+
|
27
|
+
|
28
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
29
|
+
try:
|
30
|
+
if not (is_transformers_available() and is_torch_available()):
|
31
|
+
raise OptionalDependencyNotAvailable()
|
32
|
+
|
33
|
+
except OptionalDependencyNotAvailable:
|
34
|
+
from ...utils.dummy_torch_and_transformers_objects import *
|
35
|
+
else:
|
36
|
+
from .pipeline_kandinsky3 import Kandinsky3Pipeline
|
37
|
+
from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
|
38
|
+
else:
|
39
|
+
import sys
|
40
|
+
|
41
|
+
sys.modules[__name__] = _LazyModule(
|
42
|
+
__name__,
|
43
|
+
globals()["__file__"],
|
44
|
+
_import_structure,
|
45
|
+
module_spec=__spec__,
|
46
|
+
)
|
47
|
+
|
48
|
+
for name, value in _dummy_objects.items():
|
49
|
+
setattr(sys.modules[__name__], name, value)
|
@@ -0,0 +1,98 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
import argparse
|
3
|
+
import fnmatch
|
4
|
+
|
5
|
+
from safetensors.torch import load_file
|
6
|
+
|
7
|
+
from diffusers import Kandinsky3UNet
|
8
|
+
|
9
|
+
|
10
|
+
MAPPING = {
|
11
|
+
"to_time_embed.1": "time_embedding.linear_1",
|
12
|
+
"to_time_embed.3": "time_embedding.linear_2",
|
13
|
+
"in_layer": "conv_in",
|
14
|
+
"out_layer.0": "conv_norm_out",
|
15
|
+
"out_layer.2": "conv_out",
|
16
|
+
"down_samples": "down_blocks",
|
17
|
+
"up_samples": "up_blocks",
|
18
|
+
"projection_lin": "encoder_hid_proj.projection_linear",
|
19
|
+
"projection_ln": "encoder_hid_proj.projection_norm",
|
20
|
+
"feature_pooling": "add_time_condition",
|
21
|
+
"to_query": "to_q",
|
22
|
+
"to_key": "to_k",
|
23
|
+
"to_value": "to_v",
|
24
|
+
"output_layer": "to_out.0",
|
25
|
+
"self_attention_block": "attentions.0",
|
26
|
+
}
|
27
|
+
|
28
|
+
DYNAMIC_MAP = {
|
29
|
+
"resnet_attn_blocks.*.0": "resnets_in.*",
|
30
|
+
"resnet_attn_blocks.*.1": ("attentions.*", 1),
|
31
|
+
"resnet_attn_blocks.*.2": "resnets_out.*",
|
32
|
+
}
|
33
|
+
# MAPPING = {}
|
34
|
+
|
35
|
+
|
36
|
+
def convert_state_dict(unet_state_dict):
|
37
|
+
"""
|
38
|
+
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
|
39
|
+
Args:
|
40
|
+
unet_model (torch.nn.Module): The original U-Net model.
|
41
|
+
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
OrderedDict: The converted state dictionary.
|
45
|
+
"""
|
46
|
+
# Example of renaming logic (this will vary based on your model's architecture)
|
47
|
+
converted_state_dict = {}
|
48
|
+
for key in unet_state_dict:
|
49
|
+
new_key = key
|
50
|
+
for pattern, new_pattern in MAPPING.items():
|
51
|
+
new_key = new_key.replace(pattern, new_pattern)
|
52
|
+
|
53
|
+
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
|
54
|
+
has_matched = False
|
55
|
+
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
|
56
|
+
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
|
57
|
+
|
58
|
+
if isinstance(dyn_new_pattern, tuple):
|
59
|
+
new_star = star + dyn_new_pattern[-1]
|
60
|
+
dyn_new_pattern = dyn_new_pattern[0]
|
61
|
+
else:
|
62
|
+
new_star = star
|
63
|
+
|
64
|
+
pattern = dyn_pattern.replace("*", str(star))
|
65
|
+
new_pattern = dyn_new_pattern.replace("*", str(new_star))
|
66
|
+
|
67
|
+
new_key = new_key.replace(pattern, new_pattern)
|
68
|
+
has_matched = True
|
69
|
+
|
70
|
+
converted_state_dict[new_key] = unet_state_dict[key]
|
71
|
+
|
72
|
+
return converted_state_dict
|
73
|
+
|
74
|
+
|
75
|
+
def main(model_path, output_path):
|
76
|
+
# Load your original U-Net model
|
77
|
+
unet_state_dict = load_file(model_path)
|
78
|
+
|
79
|
+
# Initialize your Kandinsky3UNet model
|
80
|
+
config = {}
|
81
|
+
|
82
|
+
# Convert the state dict
|
83
|
+
converted_state_dict = convert_state_dict(unet_state_dict)
|
84
|
+
|
85
|
+
unet = Kandinsky3UNet(config)
|
86
|
+
unet.load_state_dict(converted_state_dict)
|
87
|
+
|
88
|
+
unet.save_pretrained(output_path)
|
89
|
+
print(f"Converted model saved to {output_path}")
|
90
|
+
|
91
|
+
|
92
|
+
if __name__ == "__main__":
|
93
|
+
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
|
94
|
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
|
95
|
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
96
|
+
|
97
|
+
args = parser.parse_args()
|
98
|
+
main(args.model_path, args.output_path)
|