diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_flax_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
26
+ _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
27
+ _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
28
+ _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
29
+ _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
30
+ _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
31
+ _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
32
+ _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
33
+ _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
34
+ _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
35
+ _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
36
+
37
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
38
+ try:
39
+ if not (is_transformers_available() and is_torch_available()):
40
+ raise OptionalDependencyNotAvailable()
41
+
42
+ except OptionalDependencyNotAvailable:
43
+ from ...utils.dummy_torch_and_transformers_objects import *
44
+ else:
45
+ from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
46
+ from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
47
+ from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
48
+ from .pipeline_pag_kolors import KolorsPAGPipeline
49
+ from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
50
+ from .pipeline_pag_sd import StableDiffusionPAGPipeline
51
+ from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
52
+ from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
53
+ from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
54
+ from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
55
+ from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
56
+
57
+ else:
58
+ import sys
59
+
60
+ sys.modules[__name__] = _LazyModule(
61
+ __name__,
62
+ globals()["__file__"],
63
+ _import_structure,
64
+ module_spec=__spec__,
65
+ )
66
+ for name, value in _dummy_objects.items():
67
+ setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...models.attention_processor import (
22
+ Attention,
23
+ AttentionProcessor,
24
+ PAGCFGIdentitySelfAttnProcessor2_0,
25
+ PAGIdentitySelfAttnProcessor2_0,
26
+ )
27
+ from ...utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ class PAGMixin:
34
+ r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1)."""
35
+
36
+ def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
37
+ r"""
38
+ Set the attention processor for the PAG layers.
39
+ """
40
+ pag_attn_processors = self._pag_attn_processors
41
+ if pag_attn_processors is None:
42
+ raise ValueError(
43
+ "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters."
44
+ )
45
+
46
+ pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
47
+
48
+ if hasattr(self, "unet"):
49
+ model: nn.Module = self.unet
50
+ else:
51
+ model: nn.Module = self.transformer
52
+
53
+ def is_self_attn(module: nn.Module) -> bool:
54
+ r"""
55
+ Check if the module is self-attention module based on its name.
56
+ """
57
+ return isinstance(module, Attention) and not module.is_cross_attention
58
+
59
+ def is_fake_integral_match(layer_id, name):
60
+ layer_id = layer_id.split(".")[-1]
61
+ name = name.split(".")[-1]
62
+ return layer_id.isnumeric() and name.isnumeric() and layer_id == name
63
+
64
+ for layer_id in pag_applied_layers:
65
+ # for each PAG layer input, we find corresponding self-attention layers in the unet model
66
+ target_modules = []
67
+
68
+ for name, module in model.named_modules():
69
+ # Identify the following simple cases:
70
+ # (1) Self Attention layer existing
71
+ # (2) Whether the module name matches pag layer id even partially
72
+ # (3) Make sure it's not a fake integral match if the layer_id ends with a number
73
+ # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
74
+ if (
75
+ is_self_attn(module)
76
+ and re.search(layer_id, name) is not None
77
+ and not is_fake_integral_match(layer_id, name)
78
+ ):
79
+ logger.debug(f"Applying PAG to layer: {name}")
80
+ target_modules.append(module)
81
+
82
+ if len(target_modules) == 0:
83
+ raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
84
+
85
+ for module in target_modules:
86
+ module.processor = pag_attn_proc
87
+
88
+ def _get_pag_scale(self, t):
89
+ r"""
90
+ Get the scale factor for the perturbed attention guidance at timestep `t`.
91
+ """
92
+
93
+ if self.do_pag_adaptive_scaling:
94
+ signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
95
+ if signal_scale < 0:
96
+ signal_scale = 0
97
+ return signal_scale
98
+ else:
99
+ return self.pag_scale
100
+
101
+ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
102
+ r"""
103
+ Apply perturbed attention guidance to the noise prediction.
104
+
105
+ Args:
106
+ noise_pred (torch.Tensor): The noise prediction tensor.
107
+ do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
108
+ guidance_scale (float): The scale factor for the guidance term.
109
+ t (int): The current time step.
110
+
111
+ Returns:
112
+ torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
113
+ """
114
+ pag_scale = self._get_pag_scale(t)
115
+ if do_classifier_free_guidance:
116
+ noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
117
+ noise_pred = (
118
+ noise_pred_uncond
119
+ + guidance_scale * (noise_pred_text - noise_pred_uncond)
120
+ + pag_scale * (noise_pred_text - noise_pred_perturb)
121
+ )
122
+ else:
123
+ noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
124
+ noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
125
+ return noise_pred
126
+
127
+ def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
128
+ """
129
+ Prepares the perturbed attention guidance for the PAG model.
130
+
131
+ Args:
132
+ cond (torch.Tensor): The conditional input tensor.
133
+ uncond (torch.Tensor): The unconditional input tensor.
134
+ do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
135
+
136
+ Returns:
137
+ torch.Tensor: The prepared perturbed attention guidance tensor.
138
+ """
139
+
140
+ cond = torch.cat([cond] * 2, dim=0)
141
+
142
+ if do_classifier_free_guidance:
143
+ cond = torch.cat([uncond, cond], dim=0)
144
+ return cond
145
+
146
+ def set_pag_applied_layers(
147
+ self,
148
+ pag_applied_layers: Union[str, List[str]],
149
+ pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
150
+ PAGCFGIdentitySelfAttnProcessor2_0(),
151
+ PAGIdentitySelfAttnProcessor2_0(),
152
+ ),
153
+ ):
154
+ r"""
155
+ Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
156
+
157
+ Args:
158
+ pag_applied_layers (`str` or `List[str]`):
159
+ One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
160
+ PAG is to be applied. A few ways of expected usage are as follows:
161
+ - Single layers specified as - "blocks.{layer_index}"
162
+ - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
163
+ - Multiple layers as a block name - "mid"
164
+ - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
165
+ pag_attn_processors:
166
+ (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
167
+ PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
168
+ processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
169
+ attention processor is for PAG with CFG disabled (unconditional only).
170
+ """
171
+
172
+ if not hasattr(self, "_pag_attn_processors"):
173
+ self._pag_attn_processors = None
174
+
175
+ if not isinstance(pag_applied_layers, list):
176
+ pag_applied_layers = [pag_applied_layers]
177
+ if pag_attn_processors is not None:
178
+ if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
179
+ raise ValueError("Expected a tuple of two attention processors")
180
+
181
+ for i in range(len(pag_applied_layers)):
182
+ if not isinstance(pag_applied_layers[i], str):
183
+ raise ValueError(
184
+ f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
185
+ )
186
+
187
+ self.pag_applied_layers = pag_applied_layers
188
+ self._pag_attn_processors = pag_attn_processors
189
+
190
+ @property
191
+ def pag_scale(self) -> float:
192
+ r"""Get the scale factor for the perturbed attention guidance."""
193
+ return self._pag_scale
194
+
195
+ @property
196
+ def pag_adaptive_scale(self) -> float:
197
+ r"""Get the adaptive scale factor for the perturbed attention guidance."""
198
+ return self._pag_adaptive_scale
199
+
200
+ @property
201
+ def do_pag_adaptive_scaling(self) -> bool:
202
+ r"""Check if the adaptive scaling is enabled for the perturbed attention guidance."""
203
+ return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
204
+
205
+ @property
206
+ def do_perturbed_attention_guidance(self) -> bool:
207
+ r"""Check if the perturbed attention guidance is enabled."""
208
+ return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
209
+
210
+ @property
211
+ def pag_attn_processors(self) -> Dict[str, AttentionProcessor]:
212
+ r"""
213
+ Returns:
214
+ `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
215
+ with the key as the name of the layer.
216
+ """
217
+
218
+ if self._pag_attn_processors is None:
219
+ return {}
220
+
221
+ valid_attn_processors = {x.__class__ for x in self._pag_attn_processors}
222
+
223
+ processors = {}
224
+ # We could have iterated through the self.components.items() and checked if a component is
225
+ # `ModelMixin` subclassed but that can include a VAE too.
226
+ if hasattr(self, "unet"):
227
+ denoiser_module = self.unet
228
+ elif hasattr(self, "transformer"):
229
+ denoiser_module = self.transformer
230
+ else:
231
+ raise ValueError("No denoiser module found.")
232
+
233
+ for name, proc in denoiser_module.attn_processors.items():
234
+ if proc.__class__ in valid_attn_processors:
235
+ processors[name] = proc
236
+
237
+ return processors