diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
CHANGED
@@ -60,14 +60,14 @@ class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
|
|
60
60
|
Output class for Stable Diffusion pipelines.
|
61
61
|
|
62
62
|
Args:
|
63
|
-
latents (`torch.
|
63
|
+
latents (`torch.Tensor`)
|
64
64
|
inverted latents tensor
|
65
65
|
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
66
66
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
67
67
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
68
68
|
"""
|
69
69
|
|
70
|
-
latents: torch.
|
70
|
+
latents: torch.Tensor
|
71
71
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
72
72
|
|
73
73
|
|
@@ -377,8 +377,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
377
377
|
num_images_per_prompt,
|
378
378
|
do_classifier_free_guidance,
|
379
379
|
negative_prompt=None,
|
380
|
-
prompt_embeds: Optional[torch.
|
381
|
-
negative_prompt_embeds: Optional[torch.
|
380
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
381
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
382
382
|
lora_scale: Optional[float] = None,
|
383
383
|
**kwargs,
|
384
384
|
):
|
@@ -410,8 +410,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
410
410
|
num_images_per_prompt,
|
411
411
|
do_classifier_free_guidance,
|
412
412
|
negative_prompt=None,
|
413
|
-
prompt_embeds: Optional[torch.
|
414
|
-
negative_prompt_embeds: Optional[torch.
|
413
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
414
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
415
415
|
lora_scale: Optional[float] = None,
|
416
416
|
clip_skip: Optional[int] = None,
|
417
417
|
):
|
@@ -431,10 +431,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
431
431
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
432
432
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
433
433
|
less than `1`).
|
434
|
-
prompt_embeds (`torch.
|
434
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
435
435
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
436
436
|
provided, text embeddings will be generated from `prompt` input argument.
|
437
|
-
negative_prompt_embeds (`torch.
|
437
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
438
438
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
439
439
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
440
440
|
argument.
|
@@ -661,7 +661,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
661
661
|
|
662
662
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
663
663
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
664
|
-
shape = (
|
664
|
+
shape = (
|
665
|
+
batch_size,
|
666
|
+
num_channels_latents,
|
667
|
+
int(height) // self.vae_scale_factor,
|
668
|
+
int(width) // self.vae_scale_factor,
|
669
|
+
)
|
665
670
|
if isinstance(generator, list) and len(generator) != batch_size:
|
666
671
|
raise ValueError(
|
667
672
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -702,7 +707,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
702
707
|
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
|
703
708
|
|
704
709
|
@torch.no_grad()
|
705
|
-
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.
|
710
|
+
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:
|
706
711
|
num_prompts = len(prompt)
|
707
712
|
embeds = []
|
708
713
|
for i in range(0, num_prompts, batch_size):
|
@@ -822,13 +827,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
822
827
|
num_images_per_prompt: Optional[int] = 1,
|
823
828
|
eta: float = 0.0,
|
824
829
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
825
|
-
latents: Optional[torch.
|
826
|
-
prompt_embeds: Optional[torch.
|
827
|
-
negative_prompt_embeds: Optional[torch.
|
830
|
+
latents: Optional[torch.Tensor] = None,
|
831
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
832
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
828
833
|
cross_attention_guidance_amount: float = 0.1,
|
829
834
|
output_type: Optional[str] = "pil",
|
830
835
|
return_dict: bool = True,
|
831
|
-
callback: Optional[Callable[[int, int, torch.
|
836
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
832
837
|
callback_steps: Optional[int] = 1,
|
833
838
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
834
839
|
clip_skip: Optional[int] = None,
|
@@ -871,14 +876,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
871
876
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
872
877
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
873
878
|
to make generation deterministic.
|
874
|
-
latents (`torch.
|
879
|
+
latents (`torch.Tensor`, *optional*):
|
875
880
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
876
881
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
877
882
|
tensor will ge generated by sampling using the supplied random `generator`.
|
878
|
-
prompt_embeds (`torch.
|
883
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
879
884
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
880
885
|
provided, text embeddings will be generated from `prompt` input argument.
|
881
|
-
negative_prompt_embeds (`torch.
|
886
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
882
887
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
883
888
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
884
889
|
argument.
|
@@ -892,7 +897,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
892
897
|
plain tuple.
|
893
898
|
callback (`Callable`, *optional*):
|
894
899
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
895
|
-
called with the following arguments: `callback(step: int, timestep: int, latents: torch.
|
900
|
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
896
901
|
callback_steps (`int`, *optional*, defaults to 1):
|
897
902
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
898
903
|
called at every step.
|
@@ -1107,12 +1112,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1107
1112
|
num_inference_steps: int = 50,
|
1108
1113
|
guidance_scale: float = 1,
|
1109
1114
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1110
|
-
latents: Optional[torch.
|
1111
|
-
prompt_embeds: Optional[torch.
|
1115
|
+
latents: Optional[torch.Tensor] = None,
|
1116
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
1112
1117
|
cross_attention_guidance_amount: float = 0.1,
|
1113
1118
|
output_type: Optional[str] = "pil",
|
1114
1119
|
return_dict: bool = True,
|
1115
|
-
callback: Optional[Callable[[int, int, torch.
|
1120
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
1116
1121
|
callback_steps: Optional[int] = 1,
|
1117
1122
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1118
1123
|
lambda_auto_corr: float = 20.0,
|
@@ -1127,7 +1132,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1127
1132
|
prompt (`str` or `List[str]`, *optional*):
|
1128
1133
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1129
1134
|
instead.
|
1130
|
-
image (`torch.
|
1135
|
+
image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
1131
1136
|
`Image`, or tensor representing an image batch which will be used for conditioning. Can also accept
|
1132
1137
|
image latents as `image`, if passing latents directly, it will not be encoded again.
|
1133
1138
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
@@ -1142,11 +1147,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1142
1147
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1143
1148
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1144
1149
|
to make generation deterministic.
|
1145
|
-
latents (`torch.
|
1150
|
+
latents (`torch.Tensor`, *optional*):
|
1146
1151
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1147
1152
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1148
1153
|
tensor will ge generated by sampling using the supplied random `generator`.
|
1149
|
-
prompt_embeds (`torch.
|
1154
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
1150
1155
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1151
1156
|
provided, text embeddings will be generated from `prompt` input argument.
|
1152
1157
|
cross_attention_guidance_amount (`float`, defaults to 0.1):
|
@@ -1159,7 +1164,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1159
1164
|
plain tuple.
|
1160
1165
|
callback (`Callable`, *optional*):
|
1161
1166
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
1162
|
-
called with the following arguments: `callback(step: int, timestep: int, latents: torch.
|
1167
|
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
1163
1168
|
callback_steps (`int`, *optional*, defaults to 1):
|
1164
1169
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1165
1170
|
called at every step.
|
@@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
363
363
|
"""
|
364
364
|
|
365
365
|
_supports_gradient_checkpointing = True
|
366
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
|
366
367
|
|
367
368
|
@register_to_config
|
368
369
|
def __init__(
|
@@ -531,7 +532,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
531
532
|
elif encoder_hid_dim_type == "text_image_proj":
|
532
533
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
533
534
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
534
|
-
# case when `addition_embed_type == "text_image_proj"` (
|
535
|
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
535
536
|
self.encoder_hid_proj = TextImageProjection(
|
536
537
|
text_embed_dim=encoder_hid_dim,
|
537
538
|
image_embed_dim=cross_attention_dim,
|
@@ -591,7 +592,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
591
592
|
elif addition_embed_type == "text_image":
|
592
593
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
593
594
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
594
|
-
# case when `addition_embed_type == "text_image"` (
|
595
|
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
595
596
|
self.add_embedding = TextImageTimeEmbedding(
|
596
597
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
597
598
|
)
|
@@ -816,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
816
817
|
positive_len = 768
|
817
818
|
if isinstance(cross_attention_dim, int):
|
818
819
|
positive_len = cross_attention_dim
|
819
|
-
elif isinstance(cross_attention_dim,
|
820
|
+
elif isinstance(cross_attention_dim, (list, tuple)):
|
820
821
|
positive_len = cross_attention_dim[0]
|
821
822
|
|
822
823
|
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
@@ -1000,8 +1001,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1000
1001
|
|
1001
1002
|
def fuse_qkv_projections(self):
|
1002
1003
|
"""
|
1003
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1004
|
-
|
1004
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
1005
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
1005
1006
|
|
1006
1007
|
<Tip warning={true}>
|
1007
1008
|
|
@@ -1047,7 +1048,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1047
1048
|
|
1048
1049
|
def forward(
|
1049
1050
|
self,
|
1050
|
-
sample: torch.
|
1051
|
+
sample: torch.Tensor,
|
1051
1052
|
timestep: Union[torch.Tensor, float, int],
|
1052
1053
|
encoder_hidden_states: torch.Tensor,
|
1053
1054
|
class_labels: Optional[torch.Tensor] = None,
|
@@ -1065,10 +1066,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1065
1066
|
The [`UNetFlatConditionModel`] forward method.
|
1066
1067
|
|
1067
1068
|
Args:
|
1068
|
-
sample (`torch.
|
1069
|
+
sample (`torch.Tensor`):
|
1069
1070
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1070
|
-
timestep (`torch.
|
1071
|
-
encoder_hidden_states (`torch.
|
1071
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1072
|
+
encoder_hidden_states (`torch.Tensor`):
|
1072
1073
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1073
1074
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1074
1075
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
@@ -1112,8 +1113,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1112
1113
|
|
1113
1114
|
Returns:
|
1114
1115
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1115
|
-
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1116
|
-
a `tuple` is returned where the first element is the sample tensor.
|
1116
|
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1117
|
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
1117
1118
|
"""
|
1118
1119
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1119
1120
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
@@ -1257,7 +1258,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1257
1258
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1258
1259
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1259
1260
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1260
|
-
#
|
1261
|
+
# Kandinsky 2.1 - style
|
1261
1262
|
if "image_embeds" not in added_cond_kwargs:
|
1262
1263
|
raise ValueError(
|
1263
1264
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
@@ -1589,8 +1590,8 @@ class DownBlockFlat(nn.Module):
|
|
1589
1590
|
self.gradient_checkpointing = False
|
1590
1591
|
|
1591
1592
|
def forward(
|
1592
|
-
self, hidden_states: torch.
|
1593
|
-
) -> Tuple[torch.
|
1593
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
|
1594
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1594
1595
|
output_states = ()
|
1595
1596
|
|
1596
1597
|
for resnet in self.resnets:
|
@@ -1718,14 +1719,14 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1718
1719
|
|
1719
1720
|
def forward(
|
1720
1721
|
self,
|
1721
|
-
hidden_states: torch.
|
1722
|
-
temb: Optional[torch.
|
1723
|
-
encoder_hidden_states: Optional[torch.
|
1724
|
-
attention_mask: Optional[torch.
|
1722
|
+
hidden_states: torch.Tensor,
|
1723
|
+
temb: Optional[torch.Tensor] = None,
|
1724
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1725
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1725
1726
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1726
|
-
encoder_attention_mask: Optional[torch.
|
1727
|
-
additional_residuals: Optional[torch.
|
1728
|
-
) -> Tuple[torch.
|
1727
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1728
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
1729
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1729
1730
|
output_states = ()
|
1730
1731
|
|
1731
1732
|
blocks = list(zip(self.resnets, self.attentions))
|
@@ -1836,13 +1837,13 @@ class UpBlockFlat(nn.Module):
|
|
1836
1837
|
|
1837
1838
|
def forward(
|
1838
1839
|
self,
|
1839
|
-
hidden_states: torch.
|
1840
|
-
res_hidden_states_tuple: Tuple[torch.
|
1841
|
-
temb: Optional[torch.
|
1840
|
+
hidden_states: torch.Tensor,
|
1841
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1842
|
+
temb: Optional[torch.Tensor] = None,
|
1842
1843
|
upsample_size: Optional[int] = None,
|
1843
1844
|
*args,
|
1844
1845
|
**kwargs,
|
1845
|
-
) -> torch.
|
1846
|
+
) -> torch.Tensor:
|
1846
1847
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
1848
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
1849
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1993,18 +1994,18 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1993
1994
|
|
1994
1995
|
def forward(
|
1995
1996
|
self,
|
1996
|
-
hidden_states: torch.
|
1997
|
-
res_hidden_states_tuple: Tuple[torch.
|
1998
|
-
temb: Optional[torch.
|
1999
|
-
encoder_hidden_states: Optional[torch.
|
1997
|
+
hidden_states: torch.Tensor,
|
1998
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1999
|
+
temb: Optional[torch.Tensor] = None,
|
2000
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2000
2001
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2001
2002
|
upsample_size: Optional[int] = None,
|
2002
|
-
attention_mask: Optional[torch.
|
2003
|
-
encoder_attention_mask: Optional[torch.
|
2004
|
-
) -> torch.
|
2003
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2004
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2005
|
+
) -> torch.Tensor:
|
2005
2006
|
if cross_attention_kwargs is not None:
|
2006
2007
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2007
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2008
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2008
2009
|
|
2009
2010
|
is_freeu_enabled = (
|
2010
2011
|
getattr(self, "s1", None)
|
@@ -2103,8 +2104,8 @@ class UNetMidBlockFlat(nn.Module):
|
|
2103
2104
|
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
2104
2105
|
|
2105
2106
|
Returns:
|
2106
|
-
`torch.
|
2107
|
-
|
2107
|
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
|
2108
|
+
height, width)`.
|
2108
2109
|
|
2109
2110
|
"""
|
2110
2111
|
|
@@ -2222,7 +2223,7 @@ class UNetMidBlockFlat(nn.Module):
|
|
2222
2223
|
self.attentions = nn.ModuleList(attentions)
|
2223
2224
|
self.resnets = nn.ModuleList(resnets)
|
2224
2225
|
|
2225
|
-
def forward(self, hidden_states: torch.
|
2226
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2226
2227
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2227
2228
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2228
2229
|
if attn is not None:
|
@@ -2238,6 +2239,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2238
2239
|
self,
|
2239
2240
|
in_channels: int,
|
2240
2241
|
temb_channels: int,
|
2242
|
+
out_channels: Optional[int] = None,
|
2241
2243
|
dropout: float = 0.0,
|
2242
2244
|
num_layers: int = 1,
|
2243
2245
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
@@ -2245,6 +2247,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2245
2247
|
resnet_time_scale_shift: str = "default",
|
2246
2248
|
resnet_act_fn: str = "swish",
|
2247
2249
|
resnet_groups: int = 32,
|
2250
|
+
resnet_groups_out: Optional[int] = None,
|
2248
2251
|
resnet_pre_norm: bool = True,
|
2249
2252
|
num_attention_heads: int = 1,
|
2250
2253
|
output_scale_factor: float = 1.0,
|
@@ -2256,6 +2259,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2256
2259
|
):
|
2257
2260
|
super().__init__()
|
2258
2261
|
|
2262
|
+
out_channels = out_channels or in_channels
|
2263
|
+
self.in_channels = in_channels
|
2264
|
+
self.out_channels = out_channels
|
2265
|
+
|
2259
2266
|
self.has_cross_attention = True
|
2260
2267
|
self.num_attention_heads = num_attention_heads
|
2261
2268
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
@@ -2264,14 +2271,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2264
2271
|
if isinstance(transformer_layers_per_block, int):
|
2265
2272
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2266
2273
|
|
2274
|
+
resnet_groups_out = resnet_groups_out or resnet_groups
|
2275
|
+
|
2267
2276
|
# there is always at least one resnet
|
2268
2277
|
resnets = [
|
2269
2278
|
ResnetBlockFlat(
|
2270
2279
|
in_channels=in_channels,
|
2271
|
-
out_channels=
|
2280
|
+
out_channels=out_channels,
|
2272
2281
|
temb_channels=temb_channels,
|
2273
2282
|
eps=resnet_eps,
|
2274
2283
|
groups=resnet_groups,
|
2284
|
+
groups_out=resnet_groups_out,
|
2275
2285
|
dropout=dropout,
|
2276
2286
|
time_embedding_norm=resnet_time_scale_shift,
|
2277
2287
|
non_linearity=resnet_act_fn,
|
@@ -2286,11 +2296,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2286
2296
|
attentions.append(
|
2287
2297
|
Transformer2DModel(
|
2288
2298
|
num_attention_heads,
|
2289
|
-
|
2290
|
-
in_channels=
|
2299
|
+
out_channels // num_attention_heads,
|
2300
|
+
in_channels=out_channels,
|
2291
2301
|
num_layers=transformer_layers_per_block[i],
|
2292
2302
|
cross_attention_dim=cross_attention_dim,
|
2293
|
-
norm_num_groups=
|
2303
|
+
norm_num_groups=resnet_groups_out,
|
2294
2304
|
use_linear_projection=use_linear_projection,
|
2295
2305
|
upcast_attention=upcast_attention,
|
2296
2306
|
attention_type=attention_type,
|
@@ -2300,8 +2310,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2300
2310
|
attentions.append(
|
2301
2311
|
DualTransformer2DModel(
|
2302
2312
|
num_attention_heads,
|
2303
|
-
|
2304
|
-
in_channels=
|
2313
|
+
out_channels // num_attention_heads,
|
2314
|
+
in_channels=out_channels,
|
2305
2315
|
num_layers=1,
|
2306
2316
|
cross_attention_dim=cross_attention_dim,
|
2307
2317
|
norm_num_groups=resnet_groups,
|
@@ -2309,11 +2319,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2309
2319
|
)
|
2310
2320
|
resnets.append(
|
2311
2321
|
ResnetBlockFlat(
|
2312
|
-
in_channels=
|
2313
|
-
out_channels=
|
2322
|
+
in_channels=out_channels,
|
2323
|
+
out_channels=out_channels,
|
2314
2324
|
temb_channels=temb_channels,
|
2315
2325
|
eps=resnet_eps,
|
2316
|
-
groups=
|
2326
|
+
groups=resnet_groups_out,
|
2317
2327
|
dropout=dropout,
|
2318
2328
|
time_embedding_norm=resnet_time_scale_shift,
|
2319
2329
|
non_linearity=resnet_act_fn,
|
@@ -2329,16 +2339,16 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2329
2339
|
|
2330
2340
|
def forward(
|
2331
2341
|
self,
|
2332
|
-
hidden_states: torch.
|
2333
|
-
temb: Optional[torch.
|
2334
|
-
encoder_hidden_states: Optional[torch.
|
2335
|
-
attention_mask: Optional[torch.
|
2342
|
+
hidden_states: torch.Tensor,
|
2343
|
+
temb: Optional[torch.Tensor] = None,
|
2344
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2345
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2336
2346
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2337
|
-
encoder_attention_mask: Optional[torch.
|
2338
|
-
) -> torch.
|
2347
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2348
|
+
) -> torch.Tensor:
|
2339
2349
|
if cross_attention_kwargs is not None:
|
2340
2350
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2341
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2351
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2342
2352
|
|
2343
2353
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2344
2354
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
@@ -2470,16 +2480,16 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
2470
2480
|
|
2471
2481
|
def forward(
|
2472
2482
|
self,
|
2473
|
-
hidden_states: torch.
|
2474
|
-
temb: Optional[torch.
|
2475
|
-
encoder_hidden_states: Optional[torch.
|
2476
|
-
attention_mask: Optional[torch.
|
2483
|
+
hidden_states: torch.Tensor,
|
2484
|
+
temb: Optional[torch.Tensor] = None,
|
2485
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2486
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2477
2487
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2478
|
-
encoder_attention_mask: Optional[torch.
|
2479
|
-
) -> torch.
|
2488
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2489
|
+
) -> torch.Tensor:
|
2480
2490
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2481
2491
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2482
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2492
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2483
2493
|
|
2484
2494
|
if attention_mask is None:
|
2485
2495
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -81,7 +81,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
81
81
|
@torch.no_grad()
|
82
82
|
def image_variation(
|
83
83
|
self,
|
84
|
-
image: Union[torch.
|
84
|
+
image: Union[torch.Tensor, PIL.Image.Image],
|
85
85
|
height: Optional[int] = None,
|
86
86
|
width: Optional[int] = None,
|
87
87
|
num_inference_steps: int = 50,
|
@@ -90,10 +90,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
90
90
|
num_images_per_prompt: Optional[int] = 1,
|
91
91
|
eta: float = 0.0,
|
92
92
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
93
|
-
latents: Optional[torch.
|
93
|
+
latents: Optional[torch.Tensor] = None,
|
94
94
|
output_type: Optional[str] = "pil",
|
95
95
|
return_dict: bool = True,
|
96
|
-
callback: Optional[Callable[[int, int, torch.
|
96
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
97
97
|
callback_steps: int = 1,
|
98
98
|
):
|
99
99
|
r"""
|
@@ -123,7 +123,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
123
123
|
generator (`torch.Generator`, *optional*):
|
124
124
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
125
125
|
generation deterministic.
|
126
|
-
latents (`torch.
|
126
|
+
latents (`torch.Tensor`, *optional*):
|
127
127
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
128
128
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
129
129
|
tensor is generated by sampling using the supplied random `generator`.
|
@@ -134,7 +134,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
134
134
|
plain tuple.
|
135
135
|
callback (`Callable`, *optional*):
|
136
136
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
137
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.
|
137
|
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
138
138
|
callback_steps (`int`, *optional*, defaults to 1):
|
139
139
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
140
140
|
every step.
|
@@ -202,10 +202,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
202
202
|
num_images_per_prompt: Optional[int] = 1,
|
203
203
|
eta: float = 0.0,
|
204
204
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
205
|
-
latents: Optional[torch.
|
205
|
+
latents: Optional[torch.Tensor] = None,
|
206
206
|
output_type: Optional[str] = "pil",
|
207
207
|
return_dict: bool = True,
|
208
|
-
callback: Optional[Callable[[int, int, torch.
|
208
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
209
209
|
callback_steps: int = 1,
|
210
210
|
):
|
211
211
|
r"""
|
@@ -235,7 +235,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
235
235
|
generator (`torch.Generator`, *optional*):
|
236
236
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
237
237
|
generation deterministic.
|
238
|
-
latents (`torch.
|
238
|
+
latents (`torch.Tensor`, *optional*):
|
239
239
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
240
240
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
241
241
|
tensor is generated by sampling using the supplied random `generator`.
|
@@ -246,7 +246,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
246
246
|
plain tuple.
|
247
247
|
callback (`Callable`, *optional*):
|
248
248
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
249
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.
|
249
|
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
250
250
|
callback_steps (`int`, *optional*, defaults to 1):
|
251
251
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
252
252
|
every step.
|
@@ -311,10 +311,10 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
311
311
|
num_images_per_prompt: Optional[int] = 1,
|
312
312
|
eta: float = 0.0,
|
313
313
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
314
|
-
latents: Optional[torch.
|
314
|
+
latents: Optional[torch.Tensor] = None,
|
315
315
|
output_type: Optional[str] = "pil",
|
316
316
|
return_dict: bool = True,
|
317
|
-
callback: Optional[Callable[[int, int, torch.
|
317
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
318
318
|
callback_steps: int = 1,
|
319
319
|
):
|
320
320
|
r"""
|
@@ -344,7 +344,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
344
344
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
345
345
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
346
346
|
generation deterministic.
|
347
|
-
latents (`torch.
|
347
|
+
latents (`torch.Tensor`, *optional*):
|
348
348
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
349
349
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
350
350
|
tensor is generated by sampling using the supplied random `generator`.
|
@@ -355,7 +355,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|
355
355
|
plain tuple.
|
356
356
|
callback (`Callable`, *optional*):
|
357
357
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
358
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.
|
358
|
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
359
359
|
callback_steps (`int`, *optional*, defaults to 1):
|
360
360
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
361
361
|
every step.
|
diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
CHANGED
@@ -348,7 +348,12 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
348
348
|
|
349
349
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
350
350
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
351
|
-
shape = (
|
351
|
+
shape = (
|
352
|
+
batch_size,
|
353
|
+
num_channels_latents,
|
354
|
+
int(height) // self.vae_scale_factor,
|
355
|
+
int(width) // self.vae_scale_factor,
|
356
|
+
)
|
352
357
|
if isinstance(generator, list) and len(generator) != batch_size:
|
353
358
|
raise ValueError(
|
354
359
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -390,10 +395,10 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
390
395
|
num_images_per_prompt: Optional[int] = 1,
|
391
396
|
eta: float = 0.0,
|
392
397
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
393
|
-
latents: Optional[torch.
|
398
|
+
latents: Optional[torch.Tensor] = None,
|
394
399
|
output_type: Optional[str] = "pil",
|
395
400
|
return_dict: bool = True,
|
396
|
-
callback: Optional[Callable[[int, int, torch.
|
401
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
397
402
|
callback_steps: int = 1,
|
398
403
|
**kwargs,
|
399
404
|
):
|
@@ -424,7 +429,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
424
429
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
425
430
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
426
431
|
generation deterministic.
|
427
|
-
latents (`torch.
|
432
|
+
latents (`torch.Tensor`, *optional*):
|
428
433
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
429
434
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
430
435
|
tensor is generated by sampling using the supplied random `generator`.
|
@@ -434,7 +439,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|
434
439
|
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
435
440
|
callback (`Callable`, *optional*):
|
436
441
|
A function that calls every `callback_steps` steps during inference. The function is called with the
|
437
|
-
following arguments: `callback(step: int, timestep: int, latents: torch.
|
442
|
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
438
443
|
callback_steps (`int`, *optional*, defaults to 1):
|
439
444
|
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
440
445
|
every step.
|