diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch.nn as nn
|
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...loaders import FromOriginalModelMixin
|
23
24
|
from ...utils import is_torchvision_available
|
24
25
|
from ..attention import FeedForward
|
25
26
|
from ..attention_processor import Attention
|
@@ -186,9 +187,15 @@ class CosmosAttnProcessor2_0:
|
|
186
187
|
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
187
188
|
|
188
189
|
# 4. Prepare for GQA
|
189
|
-
|
190
|
-
|
191
|
-
|
190
|
+
if torch.onnx.is_in_onnx_export():
|
191
|
+
query_idx = torch.tensor(query.size(3), device=query.device)
|
192
|
+
key_idx = torch.tensor(key.size(3), device=key.device)
|
193
|
+
value_idx = torch.tensor(value.size(3), device=value.device)
|
194
|
+
|
195
|
+
else:
|
196
|
+
query_idx = query.size(3)
|
197
|
+
key_idx = key.size(3)
|
198
|
+
value_idx = value.size(3)
|
192
199
|
key = key.repeat_interleave(query_idx // key_idx, dim=3)
|
193
200
|
value = value.repeat_interleave(query_idx // value_idx, dim=3)
|
194
201
|
|
@@ -377,7 +384,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
|
|
377
384
|
return (emb / norm).type_as(hidden_states)
|
378
385
|
|
379
386
|
|
380
|
-
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
387
|
+
class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
381
388
|
r"""
|
382
389
|
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
|
383
390
|
|
@@ -12,28 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
15
|
+
import inspect
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
+
import torch.nn.functional as F
|
21
22
|
|
22
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
24
|
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
24
25
|
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
25
26
|
from ...utils.import_utils import is_torch_npu_available
|
26
27
|
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
-
from ..attention import FeedForward
|
28
|
-
from ..
|
29
|
-
Attention,
|
30
|
-
AttentionProcessor,
|
31
|
-
FluxAttnProcessor2_0,
|
32
|
-
FluxAttnProcessor2_0_NPU,
|
33
|
-
FusedFluxAttnProcessor2_0,
|
34
|
-
)
|
28
|
+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
29
|
+
from ..attention_dispatch import dispatch_attention_fn
|
35
30
|
from ..cache_utils import CacheMixin
|
36
|
-
from ..embeddings import
|
31
|
+
from ..embeddings import (
|
32
|
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
33
|
+
CombinedTimestepTextProjEmbeddings,
|
34
|
+
apply_rotary_emb,
|
35
|
+
get_1d_rotary_pos_embed,
|
36
|
+
)
|
37
37
|
from ..modeling_outputs import Transformer2DModelOutput
|
38
38
|
from ..modeling_utils import ModelMixin
|
39
39
|
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
|
42
42
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
43
|
|
44
44
|
|
45
|
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
46
|
+
query = attn.to_q(hidden_states)
|
47
|
+
key = attn.to_k(hidden_states)
|
48
|
+
value = attn.to_v(hidden_states)
|
49
|
+
|
50
|
+
encoder_query = encoder_key = encoder_value = None
|
51
|
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
52
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
53
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
54
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
55
|
+
|
56
|
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
57
|
+
|
58
|
+
|
59
|
+
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
60
|
+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
61
|
+
|
62
|
+
encoder_query = encoder_key = encoder_value = (None,)
|
63
|
+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
64
|
+
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
65
|
+
|
66
|
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
67
|
+
|
68
|
+
|
69
|
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
70
|
+
if attn.fused_projections:
|
71
|
+
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
72
|
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
73
|
+
|
74
|
+
|
75
|
+
class FluxAttnProcessor:
|
76
|
+
_attention_backend = None
|
77
|
+
|
78
|
+
def __init__(self):
|
79
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
80
|
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
81
|
+
|
82
|
+
def __call__(
|
83
|
+
self,
|
84
|
+
attn: "FluxAttention",
|
85
|
+
hidden_states: torch.Tensor,
|
86
|
+
encoder_hidden_states: torch.Tensor = None,
|
87
|
+
attention_mask: Optional[torch.Tensor] = None,
|
88
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
89
|
+
) -> torch.Tensor:
|
90
|
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
91
|
+
attn, hidden_states, encoder_hidden_states
|
92
|
+
)
|
93
|
+
|
94
|
+
query = query.unflatten(-1, (attn.heads, -1))
|
95
|
+
key = key.unflatten(-1, (attn.heads, -1))
|
96
|
+
value = value.unflatten(-1, (attn.heads, -1))
|
97
|
+
|
98
|
+
query = attn.norm_q(query)
|
99
|
+
key = attn.norm_k(key)
|
100
|
+
|
101
|
+
if attn.added_kv_proj_dim is not None:
|
102
|
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
103
|
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
104
|
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
105
|
+
|
106
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
107
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
108
|
+
|
109
|
+
query = torch.cat([encoder_query, query], dim=1)
|
110
|
+
key = torch.cat([encoder_key, key], dim=1)
|
111
|
+
value = torch.cat([encoder_value, value], dim=1)
|
112
|
+
|
113
|
+
if image_rotary_emb is not None:
|
114
|
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
115
|
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
116
|
+
|
117
|
+
hidden_states = dispatch_attention_fn(
|
118
|
+
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
119
|
+
)
|
120
|
+
hidden_states = hidden_states.flatten(2, 3)
|
121
|
+
hidden_states = hidden_states.to(query.dtype)
|
122
|
+
|
123
|
+
if encoder_hidden_states is not None:
|
124
|
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
125
|
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
126
|
+
)
|
127
|
+
hidden_states = attn.to_out[0](hidden_states)
|
128
|
+
hidden_states = attn.to_out[1](hidden_states)
|
129
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
130
|
+
|
131
|
+
return hidden_states, encoder_hidden_states
|
132
|
+
else:
|
133
|
+
return hidden_states
|
134
|
+
|
135
|
+
|
136
|
+
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
137
|
+
"""Flux Attention processor for IP-Adapter."""
|
138
|
+
|
139
|
+
_attention_backend = None
|
140
|
+
|
141
|
+
def __init__(
|
142
|
+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
143
|
+
):
|
144
|
+
super().__init__()
|
145
|
+
|
146
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
147
|
+
raise ImportError(
|
148
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
149
|
+
)
|
150
|
+
|
151
|
+
self.hidden_size = hidden_size
|
152
|
+
self.cross_attention_dim = cross_attention_dim
|
153
|
+
|
154
|
+
if not isinstance(num_tokens, (tuple, list)):
|
155
|
+
num_tokens = [num_tokens]
|
156
|
+
|
157
|
+
if not isinstance(scale, list):
|
158
|
+
scale = [scale] * len(num_tokens)
|
159
|
+
if len(scale) != len(num_tokens):
|
160
|
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
161
|
+
self.scale = scale
|
162
|
+
|
163
|
+
self.to_k_ip = nn.ModuleList(
|
164
|
+
[
|
165
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
166
|
+
for _ in range(len(num_tokens))
|
167
|
+
]
|
168
|
+
)
|
169
|
+
self.to_v_ip = nn.ModuleList(
|
170
|
+
[
|
171
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
172
|
+
for _ in range(len(num_tokens))
|
173
|
+
]
|
174
|
+
)
|
175
|
+
|
176
|
+
def __call__(
|
177
|
+
self,
|
178
|
+
attn: "FluxAttention",
|
179
|
+
hidden_states: torch.Tensor,
|
180
|
+
encoder_hidden_states: torch.Tensor = None,
|
181
|
+
attention_mask: Optional[torch.Tensor] = None,
|
182
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
183
|
+
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
184
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
185
|
+
) -> torch.Tensor:
|
186
|
+
batch_size = hidden_states.shape[0]
|
187
|
+
|
188
|
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
189
|
+
attn, hidden_states, encoder_hidden_states
|
190
|
+
)
|
191
|
+
|
192
|
+
query = query.unflatten(-1, (attn.heads, -1))
|
193
|
+
key = key.unflatten(-1, (attn.heads, -1))
|
194
|
+
value = value.unflatten(-1, (attn.heads, -1))
|
195
|
+
|
196
|
+
query = attn.norm_q(query)
|
197
|
+
key = attn.norm_k(key)
|
198
|
+
ip_query = query
|
199
|
+
|
200
|
+
if encoder_hidden_states is not None:
|
201
|
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
202
|
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
203
|
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
204
|
+
|
205
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
206
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
207
|
+
|
208
|
+
query = torch.cat([encoder_query, query], dim=1)
|
209
|
+
key = torch.cat([encoder_key, key], dim=1)
|
210
|
+
value = torch.cat([encoder_value, value], dim=1)
|
211
|
+
|
212
|
+
if image_rotary_emb is not None:
|
213
|
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
214
|
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
215
|
+
|
216
|
+
hidden_states = dispatch_attention_fn(
|
217
|
+
query,
|
218
|
+
key,
|
219
|
+
value,
|
220
|
+
attn_mask=attention_mask,
|
221
|
+
dropout_p=0.0,
|
222
|
+
is_causal=False,
|
223
|
+
backend=self._attention_backend,
|
224
|
+
)
|
225
|
+
hidden_states = hidden_states.flatten(2, 3)
|
226
|
+
hidden_states = hidden_states.to(query.dtype)
|
227
|
+
|
228
|
+
if encoder_hidden_states is not None:
|
229
|
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
230
|
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
231
|
+
)
|
232
|
+
hidden_states = attn.to_out[0](hidden_states)
|
233
|
+
hidden_states = attn.to_out[1](hidden_states)
|
234
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
235
|
+
|
236
|
+
# IP-adapter
|
237
|
+
ip_attn_output = torch.zeros_like(hidden_states)
|
238
|
+
|
239
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
240
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
241
|
+
):
|
242
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
243
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
244
|
+
|
245
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
|
246
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
|
247
|
+
|
248
|
+
current_ip_hidden_states = dispatch_attention_fn(
|
249
|
+
ip_query,
|
250
|
+
ip_key,
|
251
|
+
ip_value,
|
252
|
+
attn_mask=None,
|
253
|
+
dropout_p=0.0,
|
254
|
+
is_causal=False,
|
255
|
+
backend=self._attention_backend,
|
256
|
+
)
|
257
|
+
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
258
|
+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
259
|
+
ip_attn_output += scale * current_ip_hidden_states
|
260
|
+
|
261
|
+
return hidden_states, encoder_hidden_states, ip_attn_output
|
262
|
+
else:
|
263
|
+
return hidden_states
|
264
|
+
|
265
|
+
|
266
|
+
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
267
|
+
_default_processor_cls = FluxAttnProcessor
|
268
|
+
_available_processors = [
|
269
|
+
FluxAttnProcessor,
|
270
|
+
FluxIPAdapterAttnProcessor,
|
271
|
+
]
|
272
|
+
|
273
|
+
def __init__(
|
274
|
+
self,
|
275
|
+
query_dim: int,
|
276
|
+
heads: int = 8,
|
277
|
+
dim_head: int = 64,
|
278
|
+
dropout: float = 0.0,
|
279
|
+
bias: bool = False,
|
280
|
+
added_kv_proj_dim: Optional[int] = None,
|
281
|
+
added_proj_bias: Optional[bool] = True,
|
282
|
+
out_bias: bool = True,
|
283
|
+
eps: float = 1e-5,
|
284
|
+
out_dim: int = None,
|
285
|
+
context_pre_only: Optional[bool] = None,
|
286
|
+
pre_only: bool = False,
|
287
|
+
elementwise_affine: bool = True,
|
288
|
+
processor=None,
|
289
|
+
):
|
290
|
+
super().__init__()
|
291
|
+
|
292
|
+
self.head_dim = dim_head
|
293
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
294
|
+
self.query_dim = query_dim
|
295
|
+
self.use_bias = bias
|
296
|
+
self.dropout = dropout
|
297
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
298
|
+
self.context_pre_only = context_pre_only
|
299
|
+
self.pre_only = pre_only
|
300
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
301
|
+
self.added_kv_proj_dim = added_kv_proj_dim
|
302
|
+
self.added_proj_bias = added_proj_bias
|
303
|
+
|
304
|
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
305
|
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
306
|
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
307
|
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
308
|
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
309
|
+
|
310
|
+
if not self.pre_only:
|
311
|
+
self.to_out = torch.nn.ModuleList([])
|
312
|
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
313
|
+
self.to_out.append(torch.nn.Dropout(dropout))
|
314
|
+
|
315
|
+
if added_kv_proj_dim is not None:
|
316
|
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
317
|
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
318
|
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
319
|
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
320
|
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
321
|
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
322
|
+
|
323
|
+
if processor is None:
|
324
|
+
processor = self._default_processor_cls()
|
325
|
+
self.set_processor(processor)
|
326
|
+
|
327
|
+
def forward(
|
328
|
+
self,
|
329
|
+
hidden_states: torch.Tensor,
|
330
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
331
|
+
attention_mask: Optional[torch.Tensor] = None,
|
332
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
333
|
+
**kwargs,
|
334
|
+
) -> torch.Tensor:
|
335
|
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
336
|
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
337
|
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
338
|
+
if len(unused_kwargs) > 0:
|
339
|
+
logger.warning(
|
340
|
+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
341
|
+
)
|
342
|
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
343
|
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
344
|
+
|
345
|
+
|
45
346
|
@maybe_allow_in_graph
|
46
347
|
class FluxSingleTransformerBlock(nn.Module):
|
47
348
|
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
54
355
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
55
356
|
|
56
357
|
if is_torch_npu_available():
|
358
|
+
from ..attention_processor import FluxAttnProcessor2_0_NPU
|
359
|
+
|
57
360
|
deprecation_message = (
|
58
361
|
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
59
362
|
"should be set explicitly using the `set_attn_processor` method."
|
@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
61
364
|
deprecate("npu_processor", "0.34.0", deprecation_message)
|
62
365
|
processor = FluxAttnProcessor2_0_NPU()
|
63
366
|
else:
|
64
|
-
processor =
|
367
|
+
processor = FluxAttnProcessor()
|
65
368
|
|
66
|
-
self.attn =
|
369
|
+
self.attn = FluxAttention(
|
67
370
|
query_dim=dim,
|
68
|
-
cross_attention_dim=None,
|
69
371
|
dim_head=attention_head_dim,
|
70
372
|
heads=num_attention_heads,
|
71
373
|
out_dim=dim,
|
72
374
|
bias=True,
|
73
375
|
processor=processor,
|
74
|
-
qk_norm="rms_norm",
|
75
376
|
eps=1e-6,
|
76
377
|
pre_only=True,
|
77
378
|
)
|
@@ -79,10 +380,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
79
380
|
def forward(
|
80
381
|
self,
|
81
382
|
hidden_states: torch.Tensor,
|
383
|
+
encoder_hidden_states: torch.Tensor,
|
82
384
|
temb: torch.Tensor,
|
83
385
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
84
386
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
85
|
-
) -> torch.Tensor:
|
387
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
388
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
389
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
390
|
+
|
86
391
|
residual = hidden_states
|
87
392
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
88
393
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
@@ -100,7 +405,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
100
405
|
if hidden_states.dtype == torch.float16:
|
101
406
|
hidden_states = hidden_states.clip(-65504, 65504)
|
102
407
|
|
103
|
-
|
408
|
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
409
|
+
return encoder_hidden_states, hidden_states
|
104
410
|
|
105
411
|
|
106
412
|
@maybe_allow_in_graph
|
@@ -113,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
|
|
113
419
|
self.norm1 = AdaLayerNormZero(dim)
|
114
420
|
self.norm1_context = AdaLayerNormZero(dim)
|
115
421
|
|
116
|
-
self.attn =
|
422
|
+
self.attn = FluxAttention(
|
117
423
|
query_dim=dim,
|
118
|
-
cross_attention_dim=None,
|
119
424
|
added_kv_proj_dim=dim,
|
120
425
|
dim_head=attention_head_dim,
|
121
426
|
heads=num_attention_heads,
|
122
427
|
out_dim=dim,
|
123
428
|
context_pre_only=False,
|
124
429
|
bias=True,
|
125
|
-
processor=
|
126
|
-
qk_norm=qk_norm,
|
430
|
+
processor=FluxAttnProcessor(),
|
127
431
|
eps=eps,
|
128
432
|
)
|
129
433
|
|
@@ -147,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
|
|
147
451
|
encoder_hidden_states, emb=temb
|
148
452
|
)
|
149
453
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
454
|
+
|
150
455
|
# Attention.
|
151
456
|
attention_outputs = self.attn(
|
152
457
|
hidden_states=norm_hidden_states,
|
@@ -175,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
|
|
175
480
|
hidden_states = hidden_states + ip_attn_output
|
176
481
|
|
177
482
|
# Process attention outputs for the `encoder_hidden_states`.
|
178
|
-
|
179
483
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
180
484
|
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
181
485
|
|
@@ -190,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
|
|
190
494
|
return encoder_hidden_states, hidden_states
|
191
495
|
|
192
496
|
|
497
|
+
class FluxPosEmbed(nn.Module):
|
498
|
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
499
|
+
def __init__(self, theta: int, axes_dim: List[int]):
|
500
|
+
super().__init__()
|
501
|
+
self.theta = theta
|
502
|
+
self.axes_dim = axes_dim
|
503
|
+
|
504
|
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
505
|
+
n_axes = ids.shape[-1]
|
506
|
+
cos_out = []
|
507
|
+
sin_out = []
|
508
|
+
pos = ids.float()
|
509
|
+
is_mps = ids.device.type == "mps"
|
510
|
+
is_npu = ids.device.type == "npu"
|
511
|
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
512
|
+
for i in range(n_axes):
|
513
|
+
cos, sin = get_1d_rotary_pos_embed(
|
514
|
+
self.axes_dim[i],
|
515
|
+
pos[:, i],
|
516
|
+
theta=self.theta,
|
517
|
+
repeat_interleave_real=True,
|
518
|
+
use_real=True,
|
519
|
+
freqs_dtype=freqs_dtype,
|
520
|
+
)
|
521
|
+
cos_out.append(cos)
|
522
|
+
sin_out.append(sin)
|
523
|
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
524
|
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
525
|
+
return freqs_cos, freqs_sin
|
526
|
+
|
527
|
+
|
193
528
|
class FluxTransformer2DModel(
|
194
|
-
ModelMixin,
|
529
|
+
ModelMixin,
|
530
|
+
ConfigMixin,
|
531
|
+
PeftAdapterMixin,
|
532
|
+
FromOriginalModelMixin,
|
533
|
+
FluxTransformer2DLoadersMixin,
|
534
|
+
CacheMixin,
|
535
|
+
AttentionMixin,
|
195
536
|
):
|
196
537
|
"""
|
197
538
|
The Transformer model introduced in Flux.
|
@@ -227,6 +568,7 @@ class FluxTransformer2DModel(
|
|
227
568
|
_supports_gradient_checkpointing = True
|
228
569
|
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
229
570
|
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
571
|
+
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
230
572
|
|
231
573
|
@register_to_config
|
232
574
|
def __init__(
|
@@ -286,106 +628,6 @@ class FluxTransformer2DModel(
|
|
286
628
|
|
287
629
|
self.gradient_checkpointing = False
|
288
630
|
|
289
|
-
@property
|
290
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
291
|
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
292
|
-
r"""
|
293
|
-
Returns:
|
294
|
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
295
|
-
indexed by its weight name.
|
296
|
-
"""
|
297
|
-
# set recursively
|
298
|
-
processors = {}
|
299
|
-
|
300
|
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
301
|
-
if hasattr(module, "get_processor"):
|
302
|
-
processors[f"{name}.processor"] = module.get_processor()
|
303
|
-
|
304
|
-
for sub_name, child in module.named_children():
|
305
|
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
306
|
-
|
307
|
-
return processors
|
308
|
-
|
309
|
-
for name, module in self.named_children():
|
310
|
-
fn_recursive_add_processors(name, module, processors)
|
311
|
-
|
312
|
-
return processors
|
313
|
-
|
314
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
315
|
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
316
|
-
r"""
|
317
|
-
Sets the attention processor to use to compute attention.
|
318
|
-
|
319
|
-
Parameters:
|
320
|
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
321
|
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
322
|
-
for **all** `Attention` layers.
|
323
|
-
|
324
|
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
325
|
-
processor. This is strongly recommended when setting trainable attention processors.
|
326
|
-
|
327
|
-
"""
|
328
|
-
count = len(self.attn_processors.keys())
|
329
|
-
|
330
|
-
if isinstance(processor, dict) and len(processor) != count:
|
331
|
-
raise ValueError(
|
332
|
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
333
|
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
334
|
-
)
|
335
|
-
|
336
|
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
337
|
-
if hasattr(module, "set_processor"):
|
338
|
-
if not isinstance(processor, dict):
|
339
|
-
module.set_processor(processor)
|
340
|
-
else:
|
341
|
-
module.set_processor(processor.pop(f"{name}.processor"))
|
342
|
-
|
343
|
-
for sub_name, child in module.named_children():
|
344
|
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
345
|
-
|
346
|
-
for name, module in self.named_children():
|
347
|
-
fn_recursive_attn_processor(name, module, processor)
|
348
|
-
|
349
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
350
|
-
def fuse_qkv_projections(self):
|
351
|
-
"""
|
352
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
353
|
-
are fused. For cross-attention modules, key and value projection matrices are fused.
|
354
|
-
|
355
|
-
<Tip warning={true}>
|
356
|
-
|
357
|
-
This API is 🧪 experimental.
|
358
|
-
|
359
|
-
</Tip>
|
360
|
-
"""
|
361
|
-
self.original_attn_processors = None
|
362
|
-
|
363
|
-
for _, attn_processor in self.attn_processors.items():
|
364
|
-
if "Added" in str(attn_processor.__class__.__name__):
|
365
|
-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
366
|
-
|
367
|
-
self.original_attn_processors = self.attn_processors
|
368
|
-
|
369
|
-
for module in self.modules():
|
370
|
-
if isinstance(module, Attention):
|
371
|
-
module.fuse_projections(fuse=True)
|
372
|
-
|
373
|
-
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
374
|
-
|
375
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
376
|
-
def unfuse_qkv_projections(self):
|
377
|
-
"""Disables the fused QKV projection if enabled.
|
378
|
-
|
379
|
-
<Tip warning={true}>
|
380
|
-
|
381
|
-
This API is 🧪 experimental.
|
382
|
-
|
383
|
-
</Tip>
|
384
|
-
|
385
|
-
"""
|
386
|
-
if self.original_attn_processors is not None:
|
387
|
-
self.set_attn_processor(self.original_attn_processors)
|
388
|
-
|
389
631
|
def forward(
|
390
632
|
self,
|
391
633
|
hidden_states: torch.Tensor,
|
@@ -484,6 +726,7 @@ class FluxTransformer2DModel(
|
|
484
726
|
encoder_hidden_states,
|
485
727
|
temb,
|
486
728
|
image_rotary_emb,
|
729
|
+
joint_attention_kwargs,
|
487
730
|
)
|
488
731
|
|
489
732
|
else:
|
@@ -506,20 +749,22 @@ class FluxTransformer2DModel(
|
|
506
749
|
)
|
507
750
|
else:
|
508
751
|
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
509
|
-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
510
752
|
|
511
753
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
512
754
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
513
|
-
hidden_states = self._gradient_checkpointing_func(
|
755
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
514
756
|
block,
|
515
757
|
hidden_states,
|
758
|
+
encoder_hidden_states,
|
516
759
|
temb,
|
517
760
|
image_rotary_emb,
|
761
|
+
joint_attention_kwargs,
|
518
762
|
)
|
519
763
|
|
520
764
|
else:
|
521
|
-
hidden_states = block(
|
765
|
+
encoder_hidden_states, hidden_states = block(
|
522
766
|
hidden_states=hidden_states,
|
767
|
+
encoder_hidden_states=encoder_hidden_states,
|
523
768
|
temb=temb,
|
524
769
|
image_rotary_emb=image_rotary_emb,
|
525
770
|
joint_attention_kwargs=joint_attention_kwargs,
|
@@ -529,12 +774,7 @@ class FluxTransformer2DModel(
|
|
529
774
|
if controlnet_single_block_samples is not None:
|
530
775
|
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
531
776
|
interval_control = int(np.ceil(interval_control))
|
532
|
-
hidden_states
|
533
|
-
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
534
|
-
+ controlnet_single_block_samples[index_block // interval_control]
|
535
|
-
)
|
536
|
-
|
537
|
-
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
777
|
+
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
538
778
|
|
539
779
|
hidden_states = self.norm_out(hidden_states, temb)
|
540
780
|
output = self.proj_out(hidden_states)
|
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
|
870
870
|
"HunyuanVideoPatchEmbed",
|
871
871
|
"HunyuanVideoTokenRefiner",
|
872
872
|
]
|
873
|
+
_repeated_blocks = [
|
874
|
+
"HunyuanVideoTransformerBlock",
|
875
|
+
"HunyuanVideoSingleTransformerBlock",
|
876
|
+
"HunyuanVideoPatchEmbed",
|
877
|
+
"HunyuanVideoTokenRefiner",
|
878
|
+
]
|
873
879
|
|
874
880
|
@register_to_config
|
875
881
|
def __init__(
|