diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +4 -5
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +8 -7
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -425,10 +425,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
425
425
|
|
426
426
|
if num_attention_heads is not None:
|
427
427
|
raise ValueError(
|
428
|
-
"At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
|
429
|
-
" because of a naming issue as described in"
|
430
|
-
" https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
|
431
|
-
" `num_attention_heads` will only be supported in diffusers v0.19."
|
428
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
432
429
|
)
|
433
430
|
|
434
431
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
@@ -442,44 +439,37 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
442
439
|
# Check inputs
|
443
440
|
if len(down_block_types) != len(up_block_types):
|
444
441
|
raise ValueError(
|
445
|
-
"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
|
446
|
-
f" {down_block_types}. `up_block_types`: {up_block_types}."
|
442
|
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
447
443
|
)
|
448
444
|
|
449
445
|
if len(block_out_channels) != len(down_block_types):
|
450
446
|
raise ValueError(
|
451
|
-
"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
|
452
|
-
f" {block_out_channels}. `down_block_types`: {down_block_types}."
|
447
|
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
453
448
|
)
|
454
449
|
|
455
450
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
456
451
|
raise ValueError(
|
457
|
-
"Must provide the same number of `only_cross_attention` as `down_block_types`."
|
458
|
-
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
452
|
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
459
453
|
)
|
460
454
|
|
461
455
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
462
456
|
raise ValueError(
|
463
|
-
"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
|
464
|
-
f" {num_attention_heads}. `down_block_types`: {down_block_types}."
|
457
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
465
458
|
)
|
466
459
|
|
467
460
|
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
468
461
|
raise ValueError(
|
469
|
-
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
|
470
|
-
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
|
462
|
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
471
463
|
)
|
472
464
|
|
473
465
|
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
474
466
|
raise ValueError(
|
475
|
-
"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
|
476
|
-
f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
467
|
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
477
468
|
)
|
478
469
|
|
479
470
|
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
480
471
|
raise ValueError(
|
481
|
-
"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
|
482
|
-
f" {layers_per_block}. `down_block_types`: {down_block_types}."
|
472
|
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
483
473
|
)
|
484
474
|
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
485
475
|
for layer_number_per_block in transformer_layers_per_block:
|
@@ -897,8 +887,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
897
887
|
processor = AttnProcessor()
|
898
888
|
else:
|
899
889
|
raise ValueError(
|
900
|
-
"Cannot call `set_default_attn_processor` when attention processors are of type"
|
901
|
-
f" {next(iter(self.attn_processors.values()))}"
|
890
|
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
902
891
|
)
|
903
892
|
|
904
893
|
self.set_attn_processor(processor, _remove_lora=True)
|
@@ -1166,8 +1155,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1166
1155
|
# Kandinsky 2.1 - style
|
1167
1156
|
if "image_embeds" not in added_cond_kwargs:
|
1168
1157
|
raise ValueError(
|
1169
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
|
1170
|
-
" the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1158
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1171
1159
|
)
|
1172
1160
|
|
1173
1161
|
image_embs = added_cond_kwargs.get("image_embeds")
|
@@ -1177,14 +1165,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1177
1165
|
# SDXL - style
|
1178
1166
|
if "text_embeds" not in added_cond_kwargs:
|
1179
1167
|
raise ValueError(
|
1180
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
1181
|
-
" the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1168
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
1182
1169
|
)
|
1183
1170
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
1184
1171
|
if "time_ids" not in added_cond_kwargs:
|
1185
1172
|
raise ValueError(
|
1186
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
|
1187
|
-
" the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1173
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
1188
1174
|
)
|
1189
1175
|
time_ids = added_cond_kwargs.get("time_ids")
|
1190
1176
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
@@ -1196,8 +1182,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1196
1182
|
# Kandinsky 2.2 - style
|
1197
1183
|
if "image_embeds" not in added_cond_kwargs:
|
1198
1184
|
raise ValueError(
|
1199
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
|
1200
|
-
" keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1185
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1201
1186
|
)
|
1202
1187
|
image_embs = added_cond_kwargs.get("image_embeds")
|
1203
1188
|
aug_emb = self.add_embedding(image_embs)
|
@@ -1205,8 +1190,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1205
1190
|
# Kandinsky 2.2 - style
|
1206
1191
|
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
1207
1192
|
raise ValueError(
|
1208
|
-
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
|
1209
|
-
" the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1193
|
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
1210
1194
|
)
|
1211
1195
|
image_embs = added_cond_kwargs.get("image_embeds")
|
1212
1196
|
hint = added_cond_kwargs.get("hint")
|
@@ -1224,8 +1208,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1224
1208
|
# Kadinsky 2.1 - style
|
1225
1209
|
if "image_embeds" not in added_cond_kwargs:
|
1226
1210
|
raise ValueError(
|
1227
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
|
1228
|
-
" requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1211
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1229
1212
|
)
|
1230
1213
|
|
1231
1214
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
@@ -1234,11 +1217,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1234
1217
|
# Kandinsky 2.2 - style
|
1235
1218
|
if "image_embeds" not in added_cond_kwargs:
|
1236
1219
|
raise ValueError(
|
1237
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
|
1238
|
-
" the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1220
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1239
1221
|
)
|
1240
1222
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1241
1223
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1224
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1225
|
+
if "image_embeds" not in added_cond_kwargs:
|
1226
|
+
raise ValueError(
|
1227
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1228
|
+
)
|
1229
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1230
|
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
1231
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
1232
|
+
|
1242
1233
|
# 2. pre-process
|
1243
1234
|
sample = self.conv_in(sample)
|
1244
1235
|
|
@@ -1264,10 +1255,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1264
1255
|
deprecate(
|
1265
1256
|
"T2I should not use down_block_additional_residuals",
|
1266
1257
|
"1.3.0",
|
1267
|
-
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated
|
1268
|
-
|
1269
|
-
|
1270
|
-
" `down_intrablock_additional_residuals` instead. ",
|
1258
|
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
1259
|
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
1260
|
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
1271
1261
|
standard_warn=False,
|
1272
1262
|
)
|
1273
1263
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
@@ -2102,8 +2092,7 @@ class UNetMidBlockFlat(nn.Module):
|
|
2102
2092
|
|
2103
2093
|
if attention_head_dim is None:
|
2104
2094
|
logger.warn(
|
2105
|
-
"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
|
2106
|
-
f" `in_channels`: {in_channels}."
|
2095
|
+
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
2107
2096
|
)
|
2108
2097
|
attention_head_dim = in_channels
|
2109
2098
|
|
@@ -58,6 +58,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
58
58
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
59
59
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
60
60
|
"""
|
61
|
+
|
61
62
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
62
63
|
|
63
64
|
tokenizer: CLIPTokenizer
|
@@ -52,6 +52,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|
52
52
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
53
53
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
54
54
|
"""
|
55
|
+
|
55
56
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
56
57
|
|
57
58
|
image_feature_extractor: CLIPImageProcessor
|
@@ -51,6 +51,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
|
51
51
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
52
52
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
53
53
|
"""
|
54
|
+
|
54
55
|
model_cpu_offload_seq = "bert->unet->vqvae"
|
55
56
|
|
56
57
|
tokenizer: CLIPTokenizer
|
@@ -17,6 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
|
19
19
|
from ...models.attention_processor import Attention
|
20
|
+
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
21
|
+
from ...utils import USE_PEFT_BACKEND
|
20
22
|
|
21
23
|
|
22
24
|
class WuerstchenLayerNorm(nn.LayerNorm):
|
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
|
|
32
34
|
class TimestepBlock(nn.Module):
|
33
35
|
def __init__(self, c, c_timestep):
|
34
36
|
super().__init__()
|
35
|
-
|
37
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
38
|
+
self.mapper = linear_cls(c_timestep, c * 2)
|
36
39
|
|
37
40
|
def forward(self, x, t):
|
38
41
|
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
|
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
|
|
42
45
|
class ResBlock(nn.Module):
|
43
46
|
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
|
44
47
|
super().__init__()
|
45
|
-
|
48
|
+
|
49
|
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
50
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
51
|
+
|
52
|
+
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
|
46
53
|
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
47
54
|
self.channelwise = nn.Sequential(
|
48
|
-
|
55
|
+
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
|
49
56
|
)
|
50
57
|
|
51
58
|
def forward(self, x, x_skip=None):
|
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
|
|
73
80
|
class AttnBlock(nn.Module):
|
74
81
|
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
75
82
|
super().__init__()
|
83
|
+
|
84
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
85
|
+
|
76
86
|
self.self_attn = self_attn
|
77
87
|
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
|
78
88
|
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
|
79
|
-
self.kv_mapper = nn.Sequential(nn.SiLU(),
|
89
|
+
self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
|
80
90
|
|
81
91
|
def forward(self, x, kv):
|
82
92
|
kv = self.kv_mapper(kv)
|
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
|
|
28
28
|
AttnAddedKVProcessor,
|
29
29
|
AttnProcessor,
|
30
30
|
)
|
31
|
+
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
31
32
|
from ...models.modeling_utils import ModelMixin
|
32
|
-
from ...utils import is_torch_version
|
33
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version
|
33
34
|
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
34
35
|
|
35
36
|
|
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
40
41
|
@register_to_config
|
41
42
|
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
42
43
|
super().__init__()
|
44
|
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
45
|
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
46
|
+
|
43
47
|
self.c_r = c_r
|
44
|
-
self.projection =
|
48
|
+
self.projection = conv_cls(c_in, c, kernel_size=1)
|
45
49
|
self.cond_mapper = nn.Sequential(
|
46
|
-
|
50
|
+
linear_cls(c_cond, c),
|
47
51
|
nn.LeakyReLU(0.2),
|
48
|
-
|
52
|
+
linear_cls(c, c),
|
49
53
|
)
|
50
54
|
|
51
55
|
self.blocks = nn.ModuleList()
|
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
55
59
|
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
|
56
60
|
self.out = nn.Sequential(
|
57
61
|
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
|
58
|
-
|
62
|
+
conv_cls(c, c_in * 2, kernel_size=1),
|
59
63
|
)
|
60
64
|
|
61
65
|
self.gradient_checkpointing = False
|
@@ -269,7 +269,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
269
269
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
270
270
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
271
271
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
272
|
-
`._callback_tensor_inputs` attribute of your
|
272
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
273
273
|
|
274
274
|
Examples:
|
275
275
|
|
@@ -234,7 +234,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
234
234
|
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
|
235
235
|
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
|
236
236
|
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
|
237
|
-
the `._callback_tensor_inputs` attribute of your
|
237
|
+
the `._callback_tensor_inputs` attribute of your pipeline class.
|
238
238
|
callback_on_step_end (`Callable`, *optional*):
|
239
239
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
240
240
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
@@ -243,7 +243,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
243
243
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
244
244
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
245
245
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
246
|
-
`._callback_tensor_inputs` attribute of your
|
246
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
247
247
|
|
248
248
|
Examples:
|
249
249
|
|
@@ -349,7 +349,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
349
349
|
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
350
350
|
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
351
351
|
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
352
|
-
`._callback_tensor_inputs` attribute of your
|
352
|
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
353
353
|
|
354
354
|
Examples:
|
355
355
|
|
diffusers/schedulers/__init__.py
CHANGED
@@ -38,6 +38,7 @@ except OptionalDependencyNotAvailable:
|
|
38
38
|
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
|
39
39
|
|
40
40
|
else:
|
41
|
+
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
41
42
|
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
42
43
|
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
43
44
|
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
@@ -56,12 +57,10 @@ else:
|
|
56
57
|
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
|
57
58
|
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
58
59
|
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
|
59
|
-
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
|
60
60
|
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
61
61
|
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
62
62
|
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
63
63
|
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
64
|
-
_import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
|
65
64
|
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
66
65
|
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
|
67
66
|
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
|
@@ -129,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
129
128
|
except OptionalDependencyNotAvailable:
|
130
129
|
from ..utils.dummy_pt_objects import * # noqa F403
|
131
130
|
else:
|
131
|
+
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
132
132
|
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
133
133
|
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
134
134
|
from .scheduling_ddim import DDIMScheduler
|
@@ -147,12 +147,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
147
147
|
from .scheduling_ipndm import IPNDMScheduler
|
148
148
|
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
149
149
|
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
|
150
|
-
from .scheduling_karras_ve import KarrasVeScheduler
|
151
150
|
from .scheduling_lcm import LCMScheduler
|
152
151
|
from .scheduling_pndm import PNDMScheduler
|
153
152
|
from .scheduling_repaint import RePaintScheduler
|
154
153
|
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
155
|
-
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
156
154
|
from .scheduling_unclip import UnCLIPScheduler
|
157
155
|
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
158
156
|
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_torch_available,
|
9
|
+
is_transformers_available,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
_dummy_objects = {}
|
14
|
+
_import_structure = {}
|
15
|
+
|
16
|
+
try:
|
17
|
+
if not (is_transformers_available() and is_torch_available()):
|
18
|
+
raise OptionalDependencyNotAvailable()
|
19
|
+
except OptionalDependencyNotAvailable:
|
20
|
+
from ...utils import dummy_pt_objects # noqa F403
|
21
|
+
|
22
|
+
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
23
|
+
else:
|
24
|
+
_import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
|
25
|
+
_import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
|
26
|
+
|
27
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
28
|
+
try:
|
29
|
+
if not is_torch_available():
|
30
|
+
raise OptionalDependencyNotAvailable()
|
31
|
+
|
32
|
+
except OptionalDependencyNotAvailable:
|
33
|
+
from ..utils.dummy_pt_objects import * # noqa F403
|
34
|
+
else:
|
35
|
+
from .scheduling_karras_ve import KarrasVeScheduler
|
36
|
+
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
37
|
+
|
38
|
+
|
39
|
+
else:
|
40
|
+
import sys
|
41
|
+
|
42
|
+
sys.modules[__name__] = _LazyModule(
|
43
|
+
__name__,
|
44
|
+
globals()["__file__"],
|
45
|
+
_import_structure,
|
46
|
+
module_spec=__spec__,
|
47
|
+
)
|
48
|
+
|
49
|
+
for name, value in _dummy_objects.items():
|
50
|
+
setattr(sys.modules[__name__], name, value)
|
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
21
21
|
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
25
|
-
from
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils import BaseOutput
|
24
|
+
from ...utils.torch_utils import randn_tensor
|
25
|
+
from ..scheduling_utils import SchedulerMixin
|
26
26
|
|
27
27
|
|
28
28
|
@dataclass
|
@@ -19,9 +19,9 @@ from typing import Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
|
-
from
|
23
|
-
from
|
24
|
-
from
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...utils.torch_utils import randn_tensor
|
24
|
+
from ..scheduling_utils import SchedulerMixin
|
25
25
|
|
26
26
|
|
27
27
|
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
@@ -79,9 +79,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
|
79
79
|
|
80
80
|
# TODO(Patrick) better comments + non-PyTorch
|
81
81
|
# postprocess model score
|
82
|
-
log_mean_coeff = (
|
83
|
-
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
84
|
-
)
|
82
|
+
log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
85
83
|
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
|
86
84
|
std = std.flatten()
|
87
85
|
while len(std.shape) < len(score.shape):
|
@@ -208,9 +208,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
208
208
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
209
209
|
elif beta_schedule == "scaled_linear":
|
210
210
|
# this schedule is very specific to the latent diffusion model.
|
211
|
-
self.betas = (
|
212
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
213
|
-
)
|
211
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
214
212
|
elif beta_schedule == "squaredcos_cap_v2":
|
215
213
|
# Glide cosine schedule
|
216
214
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -204,9 +204,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
204
204
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
205
205
|
elif beta_schedule == "scaled_linear":
|
206
206
|
# this schedule is very specific to the latent diffusion model.
|
207
|
-
self.betas = (
|
208
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
209
|
-
)
|
207
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
210
208
|
elif beta_schedule == "squaredcos_cap_v2":
|
211
209
|
# Glide cosine schedule
|
212
210
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -215,9 +215,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
215
215
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
216
216
|
elif beta_schedule == "scaled_linear":
|
217
217
|
# this schedule is very specific to the latent diffusion model.
|
218
|
-
self.betas = (
|
219
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
220
|
-
)
|
218
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
221
219
|
elif beta_schedule == "squaredcos_cap_v2":
|
222
220
|
# Glide cosine schedule
|
223
221
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -160,9 +160,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
160
160
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
161
161
|
elif beta_schedule == "scaled_linear":
|
162
162
|
# this schedule is very specific to the latent diffusion model.
|
163
|
-
self.betas = (
|
164
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
165
|
-
)
|
163
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
166
164
|
elif beta_schedule == "squaredcos_cap_v2":
|
167
165
|
# Glide cosine schedule
|
168
166
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -170,9 +170,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
170
170
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
171
171
|
elif beta_schedule == "scaled_linear":
|
172
172
|
# this schedule is very specific to the latent diffusion model.
|
173
|
-
self.betas = (
|
174
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
175
|
-
)
|
173
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
176
174
|
elif beta_schedule == "squaredcos_cap_v2":
|
177
175
|
# Glide cosine schedule
|
178
176
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -149,9 +149,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
149
149
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
150
150
|
elif beta_schedule == "scaled_linear":
|
151
151
|
# this schedule is very specific to the latent diffusion model.
|
152
|
-
self.betas = (
|
153
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
154
|
-
)
|
152
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
155
153
|
elif beta_schedule == "squaredcos_cap_v2":
|
156
154
|
# Glide cosine schedule
|
157
155
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -325,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
325
323
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
326
324
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
327
325
|
|
328
|
-
|
329
|
-
|
326
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
327
|
+
# TODO: Add this logic to the other schedulers
|
328
|
+
if hasattr(self.config, "sigma_min"):
|
329
|
+
sigma_min = self.config.sigma_min
|
330
|
+
else:
|
331
|
+
sigma_min = None
|
332
|
+
|
333
|
+
if hasattr(self.config, "sigma_max"):
|
334
|
+
sigma_max = self.config.sigma_max
|
335
|
+
else:
|
336
|
+
sigma_max = None
|
337
|
+
|
338
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
339
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
330
340
|
|
331
341
|
rho = 7.0 # 7.0 is the value used in the paper
|
332
342
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -176,9 +176,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
176
176
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
177
177
|
elif beta_schedule == "scaled_linear":
|
178
178
|
# this schedule is very specific to the latent diffusion model.
|
179
|
-
self.betas = (
|
180
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
181
|
-
)
|
179
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
182
180
|
elif beta_schedule == "squaredcos_cap_v2":
|
183
181
|
# Glide cosine schedule
|
184
182
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -360,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
360
358
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
361
359
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
362
360
|
|
363
|
-
|
364
|
-
|
361
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
362
|
+
# TODO: Add this logic to the other schedulers
|
363
|
+
if hasattr(self.config, "sigma_min"):
|
364
|
+
sigma_min = self.config.sigma_min
|
365
|
+
else:
|
366
|
+
sigma_min = None
|
367
|
+
|
368
|
+
if hasattr(self.config, "sigma_max"):
|
369
|
+
sigma_max = self.config.sigma_max
|
370
|
+
else:
|
371
|
+
sigma_max = None
|
372
|
+
|
373
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
374
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
365
375
|
|
366
376
|
rho = 7.0 # 7.0 is the value used in the paper
|
367
377
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -171,9 +171,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
171
171
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
172
172
|
elif beta_schedule == "scaled_linear":
|
173
173
|
# this schedule is very specific to the latent diffusion model.
|
174
|
-
self.betas = (
|
175
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
176
|
-
)
|
174
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
177
175
|
elif beta_schedule == "squaredcos_cap_v2":
|
178
176
|
# Glide cosine schedule
|
179
177
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -360,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
360
358
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
361
359
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
362
360
|
|
363
|
-
|
364
|
-
|
361
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
362
|
+
# TODO: Add this logic to the other schedulers
|
363
|
+
if hasattr(self.config, "sigma_min"):
|
364
|
+
sigma_min = self.config.sigma_min
|
365
|
+
else:
|
366
|
+
sigma_min = None
|
367
|
+
|
368
|
+
if hasattr(self.config, "sigma_max"):
|
369
|
+
sigma_max = self.config.sigma_max
|
370
|
+
else:
|
371
|
+
sigma_max = None
|
372
|
+
|
373
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
374
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
365
375
|
|
366
376
|
rho = 7.0 # 7.0 is the value used in the paper
|
367
377
|
ramp = np.linspace(0, 1, num_inference_steps)
|