diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/hooks/hooks.py
CHANGED
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
|
20
20
|
from ..utils.logging import get_logger
|
21
|
+
from ..utils.torch_utils import unwrap_module
|
21
22
|
|
22
23
|
|
23
24
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
24
25
|
|
25
26
|
|
27
|
+
class BaseState:
|
28
|
+
def reset(self, *args, **kwargs) -> None:
|
29
|
+
raise NotImplementedError(
|
30
|
+
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class StateManager:
|
35
|
+
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
|
36
|
+
self._state_cls = state_cls
|
37
|
+
self._init_args = init_args if init_args is not None else ()
|
38
|
+
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
39
|
+
self._state_cache = {}
|
40
|
+
self._current_context = None
|
41
|
+
|
42
|
+
def get_state(self):
|
43
|
+
if self._current_context is None:
|
44
|
+
raise ValueError("No context is set. Please set a context before retrieving the state.")
|
45
|
+
if self._current_context not in self._state_cache.keys():
|
46
|
+
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
|
47
|
+
return self._state_cache[self._current_context]
|
48
|
+
|
49
|
+
def set_context(self, name: str) -> None:
|
50
|
+
self._current_context = name
|
51
|
+
|
52
|
+
def reset(self, *args, **kwargs) -> None:
|
53
|
+
for name, state in list(self._state_cache.items()):
|
54
|
+
state.reset(*args, **kwargs)
|
55
|
+
self._state_cache.pop(name)
|
56
|
+
self._current_context = None
|
57
|
+
|
58
|
+
|
26
59
|
class ModelHook:
|
27
60
|
r"""
|
28
61
|
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
@@ -99,6 +132,14 @@ class ModelHook:
|
|
99
132
|
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
100
133
|
return module
|
101
134
|
|
135
|
+
def _set_context(self, module: torch.nn.Module, name: str) -> None:
|
136
|
+
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
|
137
|
+
for attr_name in dir(self):
|
138
|
+
attr = getattr(self, attr_name)
|
139
|
+
if isinstance(attr, StateManager):
|
140
|
+
attr.set_context(name)
|
141
|
+
return module
|
142
|
+
|
102
143
|
|
103
144
|
class HookFunctionReference:
|
104
145
|
def __init__(self) -> None:
|
@@ -211,9 +252,10 @@ class HookRegistry:
|
|
211
252
|
hook.reset_state(self._module_ref)
|
212
253
|
|
213
254
|
if recurse:
|
214
|
-
for module_name, module in self._module_ref.named_modules():
|
255
|
+
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
215
256
|
if module_name == "":
|
216
257
|
continue
|
258
|
+
module = unwrap_module(module)
|
217
259
|
if hasattr(module, "_diffusers_hook"):
|
218
260
|
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
219
261
|
|
@@ -223,6 +265,19 @@ class HookRegistry:
|
|
223
265
|
module._diffusers_hook = cls(module)
|
224
266
|
return module._diffusers_hook
|
225
267
|
|
268
|
+
def _set_context(self, name: Optional[str] = None) -> None:
|
269
|
+
for hook_name in reversed(self._hook_order):
|
270
|
+
hook = self.hooks[hook_name]
|
271
|
+
if hook._is_stateful:
|
272
|
+
hook._set_context(self._module_ref, name)
|
273
|
+
|
274
|
+
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
275
|
+
if module_name == "":
|
276
|
+
continue
|
277
|
+
module = unwrap_module(module)
|
278
|
+
if hasattr(module, "_diffusers_hook"):
|
279
|
+
module._diffusers_hook._set_context(name)
|
280
|
+
|
226
281
|
def __repr__(self) -> str:
|
227
282
|
registry_repr = ""
|
228
283
|
for i, hook_name in enumerate(self._hook_order):
|
@@ -0,0 +1,263 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from dataclasses import asdict, dataclass
|
17
|
+
from typing import Callable, List, Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from ..utils import get_logger
|
22
|
+
from ..utils.torch_utils import unwrap_module
|
23
|
+
from ._common import (
|
24
|
+
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
25
|
+
_ATTENTION_CLASSES,
|
26
|
+
_FEEDFORWARD_CLASSES,
|
27
|
+
_get_submodule_from_fqn,
|
28
|
+
)
|
29
|
+
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
30
|
+
from .hooks import HookRegistry, ModelHook
|
31
|
+
|
32
|
+
|
33
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
36
|
+
|
37
|
+
|
38
|
+
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
|
39
|
+
# either remove or make it serializable
|
40
|
+
@dataclass
|
41
|
+
class LayerSkipConfig:
|
42
|
+
r"""
|
43
|
+
Configuration for skipping internal transformer blocks when executing a transformer model.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
indices (`List[int]`):
|
47
|
+
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
48
|
+
fqn (`str`, defaults to `"auto"`):
|
49
|
+
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
50
|
+
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
51
|
+
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
52
|
+
provide the correct fqn.
|
53
|
+
skip_attention (`bool`, defaults to `True`):
|
54
|
+
Whether to skip attention blocks.
|
55
|
+
skip_ff (`bool`, defaults to `True`):
|
56
|
+
Whether to skip feed-forward blocks.
|
57
|
+
skip_attention_scores (`bool`, defaults to `False`):
|
58
|
+
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
59
|
+
projections as the output of scaled dot product attention.
|
60
|
+
dropout (`float`, defaults to `1.0`):
|
61
|
+
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
62
|
+
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
63
|
+
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
64
|
+
"""
|
65
|
+
|
66
|
+
indices: List[int]
|
67
|
+
fqn: str = "auto"
|
68
|
+
skip_attention: bool = True
|
69
|
+
skip_attention_scores: bool = False
|
70
|
+
skip_ff: bool = True
|
71
|
+
dropout: float = 1.0
|
72
|
+
|
73
|
+
def __post_init__(self):
|
74
|
+
if not (0 <= self.dropout <= 1):
|
75
|
+
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
76
|
+
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
77
|
+
raise ValueError(
|
78
|
+
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
79
|
+
)
|
80
|
+
|
81
|
+
def to_dict(self):
|
82
|
+
return asdict(self)
|
83
|
+
|
84
|
+
@staticmethod
|
85
|
+
def from_dict(data: dict) -> "LayerSkipConfig":
|
86
|
+
return LayerSkipConfig(**data)
|
87
|
+
|
88
|
+
|
89
|
+
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
90
|
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
91
|
+
if kwargs is None:
|
92
|
+
kwargs = {}
|
93
|
+
if func is torch.nn.functional.scaled_dot_product_attention:
|
94
|
+
query = kwargs.get("query", None)
|
95
|
+
key = kwargs.get("key", None)
|
96
|
+
value = kwargs.get("value", None)
|
97
|
+
query = query if query is not None else args[0]
|
98
|
+
key = key if key is not None else args[1]
|
99
|
+
value = value if value is not None else args[2]
|
100
|
+
# If the Q sequence length does not match KV sequence length, methods like
|
101
|
+
# Perturbed Attention Guidance cannot be used (because the caller expects
|
102
|
+
# the same sequence length as Q, but if we return V here, it will not match).
|
103
|
+
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
|
104
|
+
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
|
105
|
+
if query.shape[2] == value.shape[2]:
|
106
|
+
return value
|
107
|
+
return func(*args, **kwargs)
|
108
|
+
|
109
|
+
|
110
|
+
class AttentionProcessorSkipHook(ModelHook):
|
111
|
+
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
112
|
+
self.skip_processor_output_fn = skip_processor_output_fn
|
113
|
+
self.skip_attention_scores = skip_attention_scores
|
114
|
+
self.dropout = dropout
|
115
|
+
|
116
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
117
|
+
if self.skip_attention_scores:
|
118
|
+
if not math.isclose(self.dropout, 1.0):
|
119
|
+
raise ValueError(
|
120
|
+
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
121
|
+
)
|
122
|
+
with AttentionScoreSkipFunctionMode():
|
123
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
124
|
+
else:
|
125
|
+
if math.isclose(self.dropout, 1.0):
|
126
|
+
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
127
|
+
else:
|
128
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
129
|
+
output = torch.nn.functional.dropout(output, p=self.dropout)
|
130
|
+
return output
|
131
|
+
|
132
|
+
|
133
|
+
class FeedForwardSkipHook(ModelHook):
|
134
|
+
def __init__(self, dropout: float):
|
135
|
+
super().__init__()
|
136
|
+
self.dropout = dropout
|
137
|
+
|
138
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
139
|
+
if math.isclose(self.dropout, 1.0):
|
140
|
+
output = kwargs.get("hidden_states", None)
|
141
|
+
if output is None:
|
142
|
+
output = kwargs.get("x", None)
|
143
|
+
if output is None and len(args) > 0:
|
144
|
+
output = args[0]
|
145
|
+
else:
|
146
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
147
|
+
output = torch.nn.functional.dropout(output, p=self.dropout)
|
148
|
+
return output
|
149
|
+
|
150
|
+
|
151
|
+
class TransformerBlockSkipHook(ModelHook):
|
152
|
+
def __init__(self, dropout: float):
|
153
|
+
super().__init__()
|
154
|
+
self.dropout = dropout
|
155
|
+
|
156
|
+
def initialize_hook(self, module):
|
157
|
+
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
158
|
+
return module
|
159
|
+
|
160
|
+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
161
|
+
if math.isclose(self.dropout, 1.0):
|
162
|
+
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
163
|
+
if self._metadata.return_encoder_hidden_states_index is None:
|
164
|
+
output = original_hidden_states
|
165
|
+
else:
|
166
|
+
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
167
|
+
"encoder_hidden_states", args, kwargs
|
168
|
+
)
|
169
|
+
output = (original_hidden_states, original_encoder_hidden_states)
|
170
|
+
else:
|
171
|
+
output = self.fn_ref.original_forward(*args, **kwargs)
|
172
|
+
output = torch.nn.functional.dropout(output, p=self.dropout)
|
173
|
+
return output
|
174
|
+
|
175
|
+
|
176
|
+
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
177
|
+
r"""
|
178
|
+
Apply layer skipping to internal layers of a transformer.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
module (`torch.nn.Module`):
|
182
|
+
The transformer model to which the layer skip hook should be applied.
|
183
|
+
config (`LayerSkipConfig`):
|
184
|
+
The configuration for the layer skip hook.
|
185
|
+
|
186
|
+
Example:
|
187
|
+
|
188
|
+
```python
|
189
|
+
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
190
|
+
|
191
|
+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
192
|
+
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
193
|
+
>>> apply_layer_skip_hook(transformer, config)
|
194
|
+
```
|
195
|
+
"""
|
196
|
+
_apply_layer_skip_hook(module, config)
|
197
|
+
|
198
|
+
|
199
|
+
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
200
|
+
name = name or _LAYER_SKIP_HOOK
|
201
|
+
|
202
|
+
if config.skip_attention and config.skip_attention_scores:
|
203
|
+
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
204
|
+
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
205
|
+
raise ValueError(
|
206
|
+
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
207
|
+
)
|
208
|
+
|
209
|
+
if config.fqn == "auto":
|
210
|
+
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
211
|
+
if hasattr(module, identifier):
|
212
|
+
config.fqn = identifier
|
213
|
+
break
|
214
|
+
else:
|
215
|
+
raise ValueError(
|
216
|
+
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
217
|
+
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
218
|
+
)
|
219
|
+
|
220
|
+
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
221
|
+
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
222
|
+
raise ValueError(
|
223
|
+
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
224
|
+
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
225
|
+
)
|
226
|
+
if len(config.indices) == 0:
|
227
|
+
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
228
|
+
|
229
|
+
blocks_found = False
|
230
|
+
for i, block in enumerate(transformer_blocks):
|
231
|
+
if i not in config.indices:
|
232
|
+
continue
|
233
|
+
|
234
|
+
blocks_found = True
|
235
|
+
|
236
|
+
if config.skip_attention and config.skip_ff:
|
237
|
+
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
238
|
+
registry = HookRegistry.check_if_exists_or_initialize(block)
|
239
|
+
hook = TransformerBlockSkipHook(config.dropout)
|
240
|
+
registry.register_hook(hook, name)
|
241
|
+
|
242
|
+
elif config.skip_attention or config.skip_attention_scores:
|
243
|
+
for submodule_name, submodule in block.named_modules():
|
244
|
+
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
245
|
+
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
246
|
+
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
247
|
+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
248
|
+
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
249
|
+
registry.register_hook(hook, name)
|
250
|
+
|
251
|
+
if config.skip_ff:
|
252
|
+
for submodule_name, submodule in block.named_modules():
|
253
|
+
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
254
|
+
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
255
|
+
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
256
|
+
hook = FeedForwardSkipHook(config.dropout)
|
257
|
+
registry.register_hook(hook, name)
|
258
|
+
|
259
|
+
if not blocks_found:
|
260
|
+
raise ValueError(
|
261
|
+
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
262
|
+
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
263
|
+
)
|
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Type, Union
|
|
18
18
|
import torch
|
19
19
|
|
20
20
|
from ..utils import get_logger, is_peft_available, is_peft_version
|
21
|
+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
21
22
|
from .hooks import HookRegistry, ModelHook
|
22
23
|
|
23
24
|
|
@@ -27,12 +28,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
27
28
|
# fmt: off
|
28
29
|
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
|
29
30
|
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
|
30
|
-
SUPPORTED_PYTORCH_LAYERS = (
|
31
|
-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
32
|
-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
33
|
-
torch.nn.Linear,
|
34
|
-
)
|
35
|
-
|
36
31
|
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
|
37
32
|
# fmt: on
|
38
33
|
|
@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
|
|
186
181
|
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
|
187
182
|
return
|
188
183
|
|
189
|
-
if isinstance(module,
|
184
|
+
if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
190
185
|
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
|
191
186
|
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
|
192
187
|
return
|
@@ -18,8 +18,15 @@ from typing import Any, Callable, Optional, Tuple, Union
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
|
21
|
+
from ..models.attention import AttentionModuleMixin
|
21
22
|
from ..models.attention_processor import Attention, MochiAttention
|
22
23
|
from ..utils import logging
|
24
|
+
from ._common import (
|
25
|
+
_ATTENTION_CLASSES,
|
26
|
+
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
27
|
+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
28
|
+
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
29
|
+
)
|
23
30
|
from .hooks import HookRegistry, ModelHook
|
24
31
|
|
25
32
|
|
@@ -27,10 +34,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
27
34
|
|
28
35
|
|
29
36
|
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
|
30
|
-
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
31
|
-
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
32
|
-
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
33
|
-
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
34
37
|
|
35
38
|
|
36
39
|
@dataclass
|
@@ -60,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
|
|
60
63
|
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
61
64
|
The range of timesteps to skip in the cross-attention layer. The attention computations will be
|
62
65
|
conditionally skipped if the current timestep is within the specified range.
|
63
|
-
spatial_attention_block_identifiers (`Tuple[str, ...]
|
66
|
+
spatial_attention_block_identifiers (`Tuple[str, ...]`):
|
64
67
|
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
|
65
|
-
temporal_attention_block_identifiers (`Tuple[str, ...]
|
68
|
+
temporal_attention_block_identifiers (`Tuple[str, ...]`):
|
66
69
|
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
|
67
|
-
cross_attention_block_identifiers (`Tuple[str, ...]
|
70
|
+
cross_attention_block_identifiers (`Tuple[str, ...]`):
|
68
71
|
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
|
69
72
|
"""
|
70
73
|
|
@@ -76,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
|
|
76
79
|
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
77
80
|
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
78
81
|
|
79
|
-
spatial_attention_block_identifiers: Tuple[str, ...] =
|
80
|
-
temporal_attention_block_identifiers: Tuple[str, ...] =
|
81
|
-
cross_attention_block_identifiers: Tuple[str, ...] =
|
82
|
+
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
83
|
+
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
84
|
+
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
82
85
|
|
83
86
|
current_timestep_callback: Callable[[], int] = None
|
84
87
|
|
@@ -227,7 +230,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
|
|
227
230
|
config.spatial_attention_block_skip_range = 2
|
228
231
|
|
229
232
|
for name, submodule in module.named_modules():
|
230
|
-
if not isinstance(submodule, _ATTENTION_CLASSES):
|
233
|
+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
231
234
|
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
|
232
235
|
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
|
233
236
|
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
|
@@ -0,0 +1,167 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from dataclasses import asdict, dataclass
|
17
|
+
from typing import List, Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ..utils import get_logger
|
23
|
+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
|
24
|
+
from .hooks import HookRegistry, ModelHook
|
25
|
+
|
26
|
+
|
27
|
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
28
|
+
|
29
|
+
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class SmoothedEnergyGuidanceConfig:
|
34
|
+
r"""
|
35
|
+
Configuration for skipping internal transformer blocks when executing a transformer model.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
indices (`List[int]`):
|
39
|
+
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
40
|
+
fqn (`str`, defaults to `"auto"`):
|
41
|
+
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
42
|
+
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
43
|
+
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
44
|
+
provide the correct fqn.
|
45
|
+
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
46
|
+
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
|
47
|
+
`None`, `to_q` is used by default.
|
48
|
+
"""
|
49
|
+
|
50
|
+
indices: List[int]
|
51
|
+
fqn: str = "auto"
|
52
|
+
_query_proj_identifiers: List[str] = None
|
53
|
+
|
54
|
+
def to_dict(self):
|
55
|
+
return asdict(self)
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
|
59
|
+
return SmoothedEnergyGuidanceConfig(**data)
|
60
|
+
|
61
|
+
|
62
|
+
class SmoothedEnergyGuidanceHook(ModelHook):
|
63
|
+
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
|
64
|
+
super().__init__()
|
65
|
+
self.blur_sigma = blur_sigma
|
66
|
+
self.blur_threshold_inf = blur_threshold_inf
|
67
|
+
|
68
|
+
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
|
69
|
+
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
|
70
|
+
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
|
71
|
+
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
|
72
|
+
return smoothed_output
|
73
|
+
|
74
|
+
|
75
|
+
def _apply_smoothed_energy_guidance_hook(
|
76
|
+
module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
|
77
|
+
) -> None:
|
78
|
+
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
79
|
+
|
80
|
+
if config.fqn == "auto":
|
81
|
+
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
82
|
+
if hasattr(module, identifier):
|
83
|
+
config.fqn = identifier
|
84
|
+
break
|
85
|
+
else:
|
86
|
+
raise ValueError(
|
87
|
+
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
88
|
+
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
89
|
+
)
|
90
|
+
|
91
|
+
if config._query_proj_identifiers is None:
|
92
|
+
config._query_proj_identifiers = ["to_q"]
|
93
|
+
|
94
|
+
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
95
|
+
blocks_found = False
|
96
|
+
for i, block in enumerate(transformer_blocks):
|
97
|
+
if i not in config.indices:
|
98
|
+
continue
|
99
|
+
|
100
|
+
blocks_found = True
|
101
|
+
|
102
|
+
for submodule_name, submodule in block.named_modules():
|
103
|
+
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
104
|
+
continue
|
105
|
+
for identifier in config._query_proj_identifiers:
|
106
|
+
query_proj = getattr(submodule, identifier, None)
|
107
|
+
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
|
108
|
+
continue
|
109
|
+
logger.debug(
|
110
|
+
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
|
111
|
+
)
|
112
|
+
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
113
|
+
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
114
|
+
registry.register_hook(hook, name)
|
115
|
+
|
116
|
+
if not blocks_found:
|
117
|
+
raise ValueError(
|
118
|
+
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
119
|
+
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
124
|
+
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
125
|
+
"""
|
126
|
+
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
|
127
|
+
However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
|
128
|
+
implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
|
129
|
+
assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
|
130
|
+
Smoothed Energy Guidance.
|
131
|
+
|
132
|
+
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
133
|
+
future without warning or guarantee of reproducibility.
|
134
|
+
"""
|
135
|
+
assert query.ndim == 3
|
136
|
+
|
137
|
+
is_inf = sigma > sigma_threshold_inf
|
138
|
+
batch_size, seq_len, embed_dim = query.shape
|
139
|
+
|
140
|
+
seq_len_sqrt = int(math.sqrt(seq_len))
|
141
|
+
num_square_tokens = seq_len_sqrt * seq_len_sqrt
|
142
|
+
query_slice = query[:, :num_square_tokens, :]
|
143
|
+
query_slice = query_slice.permute(0, 2, 1)
|
144
|
+
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
145
|
+
|
146
|
+
if is_inf:
|
147
|
+
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
148
|
+
kernel_size_half = (kernel_size - 1) / 2
|
149
|
+
|
150
|
+
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
|
151
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
152
|
+
kernel1d = pdf / pdf.sum()
|
153
|
+
kernel1d = kernel1d.to(query)
|
154
|
+
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
|
155
|
+
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
|
156
|
+
|
157
|
+
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
158
|
+
query_slice = F.pad(query_slice, padding, mode="reflect")
|
159
|
+
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
160
|
+
else:
|
161
|
+
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
162
|
+
|
163
|
+
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
164
|
+
query_slice = query_slice.permute(0, 2, 1)
|
165
|
+
query[:, :num_square_tokens, :] = query_slice.clone()
|
166
|
+
|
167
|
+
return query
|
diffusers/hooks/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
18
|
+
|
19
|
+
|
20
|
+
def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
|
21
|
+
module_list_with_transformer_blocks = []
|
22
|
+
for name, submodule in module.named_modules():
|
23
|
+
name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
|
24
|
+
is_modulelist = isinstance(submodule, torch.nn.ModuleList)
|
25
|
+
if name_endswith_identifier and is_modulelist:
|
26
|
+
module_list_with_transformer_blocks.append((name, submodule))
|
27
|
+
return module_list_with_transformer_blocks
|
28
|
+
|
29
|
+
|
30
|
+
def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
|
31
|
+
attention_layers = []
|
32
|
+
for name, submodule in module.named_modules():
|
33
|
+
if isinstance(submodule, _ATTENTION_CLASSES):
|
34
|
+
attention_layers.append((name, submodule))
|
35
|
+
return attention_layers
|
36
|
+
|
37
|
+
|
38
|
+
def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
|
39
|
+
feedforward_layers = []
|
40
|
+
for name, submodule in module.named_modules():
|
41
|
+
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
42
|
+
feedforward_layers.append((name, submodule))
|
43
|
+
return feedforward_layers
|