diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/hooks/hooks.py CHANGED
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
18
18
  import torch
19
19
 
20
20
  from ..utils.logging import get_logger
21
+ from ..utils.torch_utils import unwrap_module
21
22
 
22
23
 
23
24
  logger = get_logger(__name__) # pylint: disable=invalid-name
24
25
 
25
26
 
27
+ class BaseState:
28
+ def reset(self, *args, **kwargs) -> None:
29
+ raise NotImplementedError(
30
+ "BaseState::reset is not implemented. Please implement this method in the derived class."
31
+ )
32
+
33
+
34
+ class StateManager:
35
+ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
36
+ self._state_cls = state_cls
37
+ self._init_args = init_args if init_args is not None else ()
38
+ self._init_kwargs = init_kwargs if init_kwargs is not None else {}
39
+ self._state_cache = {}
40
+ self._current_context = None
41
+
42
+ def get_state(self):
43
+ if self._current_context is None:
44
+ raise ValueError("No context is set. Please set a context before retrieving the state.")
45
+ if self._current_context not in self._state_cache.keys():
46
+ self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
47
+ return self._state_cache[self._current_context]
48
+
49
+ def set_context(self, name: str) -> None:
50
+ self._current_context = name
51
+
52
+ def reset(self, *args, **kwargs) -> None:
53
+ for name, state in list(self._state_cache.items()):
54
+ state.reset(*args, **kwargs)
55
+ self._state_cache.pop(name)
56
+ self._current_context = None
57
+
58
+
26
59
  class ModelHook:
27
60
  r"""
28
61
  A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
99
132
  raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
100
133
  return module
101
134
 
135
+ def _set_context(self, module: torch.nn.Module, name: str) -> None:
136
+ # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
137
+ for attr_name in dir(self):
138
+ attr = getattr(self, attr_name)
139
+ if isinstance(attr, StateManager):
140
+ attr.set_context(name)
141
+ return module
142
+
102
143
 
103
144
  class HookFunctionReference:
104
145
  def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
211
252
  hook.reset_state(self._module_ref)
212
253
 
213
254
  if recurse:
214
- for module_name, module in self._module_ref.named_modules():
255
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
215
256
  if module_name == "":
216
257
  continue
258
+ module = unwrap_module(module)
217
259
  if hasattr(module, "_diffusers_hook"):
218
260
  module._diffusers_hook.reset_stateful_hooks(recurse=False)
219
261
 
@@ -223,6 +265,19 @@ class HookRegistry:
223
265
  module._diffusers_hook = cls(module)
224
266
  return module._diffusers_hook
225
267
 
268
+ def _set_context(self, name: Optional[str] = None) -> None:
269
+ for hook_name in reversed(self._hook_order):
270
+ hook = self.hooks[hook_name]
271
+ if hook._is_stateful:
272
+ hook._set_context(self._module_ref, name)
273
+
274
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
275
+ if module_name == "":
276
+ continue
277
+ module = unwrap_module(module)
278
+ if hasattr(module, "_diffusers_hook"):
279
+ module._diffusers_hook._set_context(name)
280
+
226
281
  def __repr__(self) -> str:
227
282
  registry_repr = ""
228
283
  for i, hook_name in enumerate(self._hook_order):
@@ -0,0 +1,263 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import asdict, dataclass
17
+ from typing import Callable, List, Optional
18
+
19
+ import torch
20
+
21
+ from ..utils import get_logger
22
+ from ..utils.torch_utils import unwrap_module
23
+ from ._common import (
24
+ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
25
+ _ATTENTION_CLASSES,
26
+ _FEEDFORWARD_CLASSES,
27
+ _get_submodule_from_fqn,
28
+ )
29
+ from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
30
+ from .hooks import HookRegistry, ModelHook
31
+
32
+
33
+ logger = get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+ _LAYER_SKIP_HOOK = "layer_skip_hook"
36
+
37
+
38
+ # Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
39
+ # either remove or make it serializable
40
+ @dataclass
41
+ class LayerSkipConfig:
42
+ r"""
43
+ Configuration for skipping internal transformer blocks when executing a transformer model.
44
+
45
+ Args:
46
+ indices (`List[int]`):
47
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
48
+ fqn (`str`, defaults to `"auto"`):
49
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
50
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
51
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
52
+ provide the correct fqn.
53
+ skip_attention (`bool`, defaults to `True`):
54
+ Whether to skip attention blocks.
55
+ skip_ff (`bool`, defaults to `True`):
56
+ Whether to skip feed-forward blocks.
57
+ skip_attention_scores (`bool`, defaults to `False`):
58
+ Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
59
+ projections as the output of scaled dot product attention.
60
+ dropout (`float`, defaults to `1.0`):
61
+ The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
62
+ meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
63
+ skipped layers are fully retained, which is equivalent to not skipping any layers.
64
+ """
65
+
66
+ indices: List[int]
67
+ fqn: str = "auto"
68
+ skip_attention: bool = True
69
+ skip_attention_scores: bool = False
70
+ skip_ff: bool = True
71
+ dropout: float = 1.0
72
+
73
+ def __post_init__(self):
74
+ if not (0 <= self.dropout <= 1):
75
+ raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
76
+ if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
77
+ raise ValueError(
78
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
79
+ )
80
+
81
+ def to_dict(self):
82
+ return asdict(self)
83
+
84
+ @staticmethod
85
+ def from_dict(data: dict) -> "LayerSkipConfig":
86
+ return LayerSkipConfig(**data)
87
+
88
+
89
+ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
90
+ def __torch_function__(self, func, types, args=(), kwargs=None):
91
+ if kwargs is None:
92
+ kwargs = {}
93
+ if func is torch.nn.functional.scaled_dot_product_attention:
94
+ query = kwargs.get("query", None)
95
+ key = kwargs.get("key", None)
96
+ value = kwargs.get("value", None)
97
+ query = query if query is not None else args[0]
98
+ key = key if key is not None else args[1]
99
+ value = value if value is not None else args[2]
100
+ # If the Q sequence length does not match KV sequence length, methods like
101
+ # Perturbed Attention Guidance cannot be used (because the caller expects
102
+ # the same sequence length as Q, but if we return V here, it will not match).
103
+ # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
104
+ # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
105
+ if query.shape[2] == value.shape[2]:
106
+ return value
107
+ return func(*args, **kwargs)
108
+
109
+
110
+ class AttentionProcessorSkipHook(ModelHook):
111
+ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
112
+ self.skip_processor_output_fn = skip_processor_output_fn
113
+ self.skip_attention_scores = skip_attention_scores
114
+ self.dropout = dropout
115
+
116
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
117
+ if self.skip_attention_scores:
118
+ if not math.isclose(self.dropout, 1.0):
119
+ raise ValueError(
120
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
121
+ )
122
+ with AttentionScoreSkipFunctionMode():
123
+ output = self.fn_ref.original_forward(*args, **kwargs)
124
+ else:
125
+ if math.isclose(self.dropout, 1.0):
126
+ output = self.skip_processor_output_fn(module, *args, **kwargs)
127
+ else:
128
+ output = self.fn_ref.original_forward(*args, **kwargs)
129
+ output = torch.nn.functional.dropout(output, p=self.dropout)
130
+ return output
131
+
132
+
133
+ class FeedForwardSkipHook(ModelHook):
134
+ def __init__(self, dropout: float):
135
+ super().__init__()
136
+ self.dropout = dropout
137
+
138
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
139
+ if math.isclose(self.dropout, 1.0):
140
+ output = kwargs.get("hidden_states", None)
141
+ if output is None:
142
+ output = kwargs.get("x", None)
143
+ if output is None and len(args) > 0:
144
+ output = args[0]
145
+ else:
146
+ output = self.fn_ref.original_forward(*args, **kwargs)
147
+ output = torch.nn.functional.dropout(output, p=self.dropout)
148
+ return output
149
+
150
+
151
+ class TransformerBlockSkipHook(ModelHook):
152
+ def __init__(self, dropout: float):
153
+ super().__init__()
154
+ self.dropout = dropout
155
+
156
+ def initialize_hook(self, module):
157
+ self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
158
+ return module
159
+
160
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
161
+ if math.isclose(self.dropout, 1.0):
162
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
163
+ if self._metadata.return_encoder_hidden_states_index is None:
164
+ output = original_hidden_states
165
+ else:
166
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
167
+ "encoder_hidden_states", args, kwargs
168
+ )
169
+ output = (original_hidden_states, original_encoder_hidden_states)
170
+ else:
171
+ output = self.fn_ref.original_forward(*args, **kwargs)
172
+ output = torch.nn.functional.dropout(output, p=self.dropout)
173
+ return output
174
+
175
+
176
+ def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
177
+ r"""
178
+ Apply layer skipping to internal layers of a transformer.
179
+
180
+ Args:
181
+ module (`torch.nn.Module`):
182
+ The transformer model to which the layer skip hook should be applied.
183
+ config (`LayerSkipConfig`):
184
+ The configuration for the layer skip hook.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
190
+
191
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
192
+ >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
193
+ >>> apply_layer_skip_hook(transformer, config)
194
+ ```
195
+ """
196
+ _apply_layer_skip_hook(module, config)
197
+
198
+
199
+ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
200
+ name = name or _LAYER_SKIP_HOOK
201
+
202
+ if config.skip_attention and config.skip_attention_scores:
203
+ raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
204
+ if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
205
+ raise ValueError(
206
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
207
+ )
208
+
209
+ if config.fqn == "auto":
210
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
211
+ if hasattr(module, identifier):
212
+ config.fqn = identifier
213
+ break
214
+ else:
215
+ raise ValueError(
216
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
217
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
218
+ )
219
+
220
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
221
+ if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
222
+ raise ValueError(
223
+ f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
224
+ f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
225
+ )
226
+ if len(config.indices) == 0:
227
+ raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
228
+
229
+ blocks_found = False
230
+ for i, block in enumerate(transformer_blocks):
231
+ if i not in config.indices:
232
+ continue
233
+
234
+ blocks_found = True
235
+
236
+ if config.skip_attention and config.skip_ff:
237
+ logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
238
+ registry = HookRegistry.check_if_exists_or_initialize(block)
239
+ hook = TransformerBlockSkipHook(config.dropout)
240
+ registry.register_hook(hook, name)
241
+
242
+ elif config.skip_attention or config.skip_attention_scores:
243
+ for submodule_name, submodule in block.named_modules():
244
+ if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
245
+ logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
246
+ output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
247
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
248
+ hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
249
+ registry.register_hook(hook, name)
250
+
251
+ if config.skip_ff:
252
+ for submodule_name, submodule in block.named_modules():
253
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
254
+ logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
255
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
256
+ hook = FeedForwardSkipHook(config.dropout)
257
+ registry.register_hook(hook, name)
258
+
259
+ if not blocks_found:
260
+ raise ValueError(
261
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
262
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
263
+ )
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Type, Union
18
18
  import torch
19
19
 
20
20
  from ..utils import get_logger, is_peft_available, is_peft_version
21
+ from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
21
22
  from .hooks import HookRegistry, ModelHook
22
23
 
23
24
 
@@ -27,12 +28,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
27
28
  # fmt: off
28
29
  _LAYERWISE_CASTING_HOOK = "layerwise_casting"
29
30
  _PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
30
- SUPPORTED_PYTORCH_LAYERS = (
31
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33
- torch.nn.Linear,
34
- )
35
-
36
31
  DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
37
32
  # fmt: on
38
33
 
@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
186
181
  logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
187
182
  return
188
183
 
189
- if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
184
+ if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
190
185
  logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
191
186
  apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
192
187
  return
@@ -18,8 +18,15 @@ from typing import Any, Callable, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
 
21
+ from ..models.attention import AttentionModuleMixin
21
22
  from ..models.attention_processor import Attention, MochiAttention
22
23
  from ..utils import logging
24
+ from ._common import (
25
+ _ATTENTION_CLASSES,
26
+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
27
+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
28
+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
29
+ )
23
30
  from .hooks import HookRegistry, ModelHook
24
31
 
25
32
 
@@ -27,10 +34,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
34
 
28
35
 
29
36
  _PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
30
- _ATTENTION_CLASSES = (Attention, MochiAttention)
31
- _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
32
- _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
33
- _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
34
37
 
35
38
 
36
39
  @dataclass
@@ -60,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
60
63
  cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
61
64
  The range of timesteps to skip in the cross-attention layer. The attention computations will be
62
65
  conditionally skipped if the current timestep is within the specified range.
63
- spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
66
+ spatial_attention_block_identifiers (`Tuple[str, ...]`):
64
67
  The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
65
- temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
68
+ temporal_attention_block_identifiers (`Tuple[str, ...]`):
66
69
  The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
67
- cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
70
+ cross_attention_block_identifiers (`Tuple[str, ...]`):
68
71
  The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
69
72
  """
70
73
 
@@ -76,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
76
79
  temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
77
80
  cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
78
81
 
79
- spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
80
- temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
81
- cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
82
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
83
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
84
+ cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
82
85
 
83
86
  current_timestep_callback: Callable[[], int] = None
84
87
 
@@ -227,7 +230,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
227
230
  config.spatial_attention_block_skip_range = 2
228
231
 
229
232
  for name, submodule in module.named_modules():
230
- if not isinstance(submodule, _ATTENTION_CLASSES):
233
+ if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
231
234
  # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
232
235
  # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
233
236
  # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
@@ -0,0 +1,167 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import asdict, dataclass
17
+ from typing import List, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ from ..utils import get_logger
23
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
24
+ from .hooks import HookRegistry, ModelHook
25
+
26
+
27
+ logger = get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ _SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
30
+
31
+
32
+ @dataclass
33
+ class SmoothedEnergyGuidanceConfig:
34
+ r"""
35
+ Configuration for skipping internal transformer blocks when executing a transformer model.
36
+
37
+ Args:
38
+ indices (`List[int]`):
39
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
40
+ fqn (`str`, defaults to `"auto"`):
41
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
42
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
43
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
44
+ provide the correct fqn.
45
+ _query_proj_identifiers (`List[str]`, defaults to `None`):
46
+ The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
47
+ `None`, `to_q` is used by default.
48
+ """
49
+
50
+ indices: List[int]
51
+ fqn: str = "auto"
52
+ _query_proj_identifiers: List[str] = None
53
+
54
+ def to_dict(self):
55
+ return asdict(self)
56
+
57
+ @staticmethod
58
+ def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
59
+ return SmoothedEnergyGuidanceConfig(**data)
60
+
61
+
62
+ class SmoothedEnergyGuidanceHook(ModelHook):
63
+ def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
64
+ super().__init__()
65
+ self.blur_sigma = blur_sigma
66
+ self.blur_threshold_inf = blur_threshold_inf
67
+
68
+ def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
69
+ # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
70
+ kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
71
+ smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
72
+ return smoothed_output
73
+
74
+
75
+ def _apply_smoothed_energy_guidance_hook(
76
+ module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
77
+ ) -> None:
78
+ name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
79
+
80
+ if config.fqn == "auto":
81
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
82
+ if hasattr(module, identifier):
83
+ config.fqn = identifier
84
+ break
85
+ else:
86
+ raise ValueError(
87
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
88
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
89
+ )
90
+
91
+ if config._query_proj_identifiers is None:
92
+ config._query_proj_identifiers = ["to_q"]
93
+
94
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
95
+ blocks_found = False
96
+ for i, block in enumerate(transformer_blocks):
97
+ if i not in config.indices:
98
+ continue
99
+
100
+ blocks_found = True
101
+
102
+ for submodule_name, submodule in block.named_modules():
103
+ if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
104
+ continue
105
+ for identifier in config._query_proj_identifiers:
106
+ query_proj = getattr(submodule, identifier, None)
107
+ if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
108
+ continue
109
+ logger.debug(
110
+ f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
111
+ )
112
+ registry = HookRegistry.check_if_exists_or_initialize(query_proj)
113
+ hook = SmoothedEnergyGuidanceHook(blur_sigma)
114
+ registry.register_hook(hook, name)
115
+
116
+ if not blocks_found:
117
+ raise ValueError(
118
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
119
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
120
+ )
121
+
122
+
123
+ # Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
124
+ def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
125
+ """
126
+ This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
127
+ However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
128
+ implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
129
+ assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
130
+ Smoothed Energy Guidance.
131
+
132
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
133
+ future without warning or guarantee of reproducibility.
134
+ """
135
+ assert query.ndim == 3
136
+
137
+ is_inf = sigma > sigma_threshold_inf
138
+ batch_size, seq_len, embed_dim = query.shape
139
+
140
+ seq_len_sqrt = int(math.sqrt(seq_len))
141
+ num_square_tokens = seq_len_sqrt * seq_len_sqrt
142
+ query_slice = query[:, :num_square_tokens, :]
143
+ query_slice = query_slice.permute(0, 2, 1)
144
+ query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
145
+
146
+ if is_inf:
147
+ kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
148
+ kernel_size_half = (kernel_size - 1) / 2
149
+
150
+ x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
151
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
152
+ kernel1d = pdf / pdf.sum()
153
+ kernel1d = kernel1d.to(query)
154
+ kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
155
+ kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
156
+
157
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
158
+ query_slice = F.pad(query_slice, padding, mode="reflect")
159
+ query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
160
+ else:
161
+ query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
162
+
163
+ query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
164
+ query_slice = query_slice.permute(0, 2, 1)
165
+ query[:, :num_square_tokens, :] = query_slice.clone()
166
+
167
+ return query
@@ -0,0 +1,43 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
18
+
19
+
20
+ def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
21
+ module_list_with_transformer_blocks = []
22
+ for name, submodule in module.named_modules():
23
+ name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
24
+ is_modulelist = isinstance(submodule, torch.nn.ModuleList)
25
+ if name_endswith_identifier and is_modulelist:
26
+ module_list_with_transformer_blocks.append((name, submodule))
27
+ return module_list_with_transformer_blocks
28
+
29
+
30
+ def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
31
+ attention_layers = []
32
+ for name, submodule in module.named_modules():
33
+ if isinstance(submodule, _ATTENTION_CLASSES):
34
+ attention_layers.append((name, submodule))
35
+ return attention_layers
36
+
37
+
38
+ def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
39
+ feedforward_layers = []
40
+ for name, submodule in module.named_modules():
41
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
42
+ feedforward_layers.append((name, submodule))
43
+ return feedforward_layers