diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
diffusers/models/attention.py
CHANGED
@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
|
|
22
22
|
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23
23
|
from .attention_processor import Attention, JointAttnProcessor2_0
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
|
-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
25
|
+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
26
26
|
|
27
27
|
|
28
28
|
logger = logging.get_logger(__name__)
|
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
|
|
100
100
|
processing of `context` conditions.
|
101
101
|
"""
|
102
102
|
|
103
|
-
def __init__(
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
dim: int,
|
106
|
+
num_attention_heads: int,
|
107
|
+
attention_head_dim: int,
|
108
|
+
context_pre_only: bool = False,
|
109
|
+
qk_norm: Optional[str] = None,
|
110
|
+
use_dual_attention: bool = False,
|
111
|
+
):
|
104
112
|
super().__init__()
|
105
113
|
|
114
|
+
self.use_dual_attention = use_dual_attention
|
106
115
|
self.context_pre_only = context_pre_only
|
107
116
|
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
108
117
|
|
109
|
-
|
118
|
+
if use_dual_attention:
|
119
|
+
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
120
|
+
else:
|
121
|
+
self.norm1 = AdaLayerNormZero(dim)
|
110
122
|
|
111
123
|
if context_norm_type == "ada_norm_continous":
|
112
124
|
self.norm1_context = AdaLayerNormContinuous(
|
@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
|
|
118
130
|
raise ValueError(
|
119
131
|
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
120
132
|
)
|
133
|
+
|
121
134
|
if hasattr(F, "scaled_dot_product_attention"):
|
122
135
|
processor = JointAttnProcessor2_0()
|
123
136
|
else:
|
124
137
|
raise ValueError(
|
125
138
|
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
126
139
|
)
|
140
|
+
|
127
141
|
self.attn = Attention(
|
128
142
|
query_dim=dim,
|
129
143
|
cross_attention_dim=None,
|
@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
|
|
134
148
|
context_pre_only=context_pre_only,
|
135
149
|
bias=True,
|
136
150
|
processor=processor,
|
151
|
+
qk_norm=qk_norm,
|
152
|
+
eps=1e-6,
|
137
153
|
)
|
138
154
|
|
155
|
+
if use_dual_attention:
|
156
|
+
self.attn2 = Attention(
|
157
|
+
query_dim=dim,
|
158
|
+
cross_attention_dim=None,
|
159
|
+
dim_head=attention_head_dim,
|
160
|
+
heads=num_attention_heads,
|
161
|
+
out_dim=dim,
|
162
|
+
bias=True,
|
163
|
+
processor=processor,
|
164
|
+
qk_norm=qk_norm,
|
165
|
+
eps=1e-6,
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
self.attn2 = None
|
169
|
+
|
139
170
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
140
171
|
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
141
172
|
|
@@ -159,7 +190,12 @@ class JointTransformerBlock(nn.Module):
|
|
159
190
|
def forward(
|
160
191
|
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
161
192
|
):
|
162
|
-
|
193
|
+
if self.use_dual_attention:
|
194
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
195
|
+
hidden_states, emb=temb
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
163
199
|
|
164
200
|
if self.context_pre_only:
|
165
201
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
@@ -177,6 +213,11 @@ class JointTransformerBlock(nn.Module):
|
|
177
213
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
178
214
|
hidden_states = hidden_states + attn_output
|
179
215
|
|
216
|
+
if self.use_dual_attention:
|
217
|
+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
218
|
+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
219
|
+
hidden_states = hidden_states + attn_output2
|
220
|
+
|
180
221
|
norm_hidden_states = self.norm2(hidden_states)
|
181
222
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
182
223
|
if self._chunk_size is not None:
|
@@ -972,15 +1013,32 @@ class FreeNoiseTransformerBlock(nn.Module):
|
|
972
1013
|
return frame_indices
|
973
1014
|
|
974
1015
|
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
975
|
-
if weighting_scheme == "
|
1016
|
+
if weighting_scheme == "flat":
|
1017
|
+
weights = [1.0] * num_frames
|
1018
|
+
|
1019
|
+
elif weighting_scheme == "pyramid":
|
976
1020
|
if num_frames % 2 == 0:
|
977
1021
|
# num_frames = 4 => [1, 2, 2, 1]
|
978
|
-
|
1022
|
+
mid = num_frames // 2
|
1023
|
+
weights = list(range(1, mid + 1))
|
979
1024
|
weights = weights + weights[::-1]
|
980
1025
|
else:
|
981
1026
|
# num_frames = 5 => [1, 2, 3, 2, 1]
|
982
|
-
|
983
|
-
weights =
|
1027
|
+
mid = (num_frames + 1) // 2
|
1028
|
+
weights = list(range(1, mid))
|
1029
|
+
weights = weights + [mid] + weights[::-1]
|
1030
|
+
|
1031
|
+
elif weighting_scheme == "delayed_reverse_sawtooth":
|
1032
|
+
if num_frames % 2 == 0:
|
1033
|
+
# num_frames = 4 => [0.01, 2, 2, 1]
|
1034
|
+
mid = num_frames // 2
|
1035
|
+
weights = [0.01] * (mid - 1) + [mid]
|
1036
|
+
weights = weights + list(range(mid, 0, -1))
|
1037
|
+
else:
|
1038
|
+
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
1039
|
+
mid = (num_frames + 1) // 2
|
1040
|
+
weights = [0.01] * mid
|
1041
|
+
weights = weights + list(range(mid, 0, -1))
|
984
1042
|
else:
|
985
1043
|
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
986
1044
|
|
@@ -1087,8 +1145,26 @@ class FreeNoiseTransformerBlock(nn.Module):
|
|
1087
1145
|
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
1088
1146
|
num_times_accumulated[:, frame_start:frame_end] += weights
|
1089
1147
|
|
1090
|
-
|
1091
|
-
|
1148
|
+
# TODO(aryan): Maybe this could be done in a better way.
|
1149
|
+
#
|
1150
|
+
# Previously, this was:
|
1151
|
+
# hidden_states = torch.where(
|
1152
|
+
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
1153
|
+
# )
|
1154
|
+
#
|
1155
|
+
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
|
1156
|
+
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
|
1157
|
+
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
|
1158
|
+
# looked into this deeply because other memory optimizations led to more pronounced reductions.
|
1159
|
+
hidden_states = torch.cat(
|
1160
|
+
[
|
1161
|
+
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
1162
|
+
for accumulated_split, num_times_split in zip(
|
1163
|
+
accumulated_values.split(self.context_length, dim=1),
|
1164
|
+
num_times_accumulated.split(self.context_length, dim=1),
|
1165
|
+
)
|
1166
|
+
],
|
1167
|
+
dim=1,
|
1092
1168
|
).to(dtype)
|
1093
1169
|
|
1094
1170
|
# 3. Feed-forward
|
@@ -122,6 +122,7 @@ class Attention(nn.Module):
|
|
122
122
|
out_dim: int = None,
|
123
123
|
context_pre_only=None,
|
124
124
|
pre_only=False,
|
125
|
+
elementwise_affine: bool = True,
|
125
126
|
):
|
126
127
|
super().__init__()
|
127
128
|
|
@@ -179,8 +180,8 @@ class Attention(nn.Module):
|
|
179
180
|
self.norm_q = None
|
180
181
|
self.norm_k = None
|
181
182
|
elif qk_norm == "layer_norm":
|
182
|
-
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
183
|
-
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
183
|
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
184
|
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
184
185
|
elif qk_norm == "fp32_layer_norm":
|
185
186
|
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
186
187
|
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
@@ -192,7 +193,7 @@ class Attention(nn.Module):
|
|
192
193
|
self.norm_q = RMSNorm(dim_head, eps=eps)
|
193
194
|
self.norm_k = RMSNorm(dim_head, eps=eps)
|
194
195
|
else:
|
195
|
-
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None
|
196
|
+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
196
197
|
|
197
198
|
if cross_attention_norm is None:
|
198
199
|
self.norm_cross = None
|
@@ -249,6 +250,10 @@ class Attention(nn.Module):
|
|
249
250
|
elif qk_norm == "rms_norm":
|
250
251
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
251
252
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
253
|
+
else:
|
254
|
+
raise ValueError(
|
255
|
+
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
256
|
+
)
|
252
257
|
else:
|
253
258
|
self.norm_added_q = None
|
254
259
|
self.norm_added_k = None
|
@@ -1049,61 +1054,72 @@ class JointAttnProcessor2_0:
|
|
1049
1054
|
) -> torch.FloatTensor:
|
1050
1055
|
residual = hidden_states
|
1051
1056
|
|
1052
|
-
|
1053
|
-
if input_ndim == 4:
|
1054
|
-
batch_size, channel, height, width = hidden_states.shape
|
1055
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1056
|
-
context_input_ndim = encoder_hidden_states.ndim
|
1057
|
-
if context_input_ndim == 4:
|
1058
|
-
batch_size, channel, height, width = encoder_hidden_states.shape
|
1059
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1060
|
-
|
1061
|
-
batch_size = encoder_hidden_states.shape[0]
|
1057
|
+
batch_size = hidden_states.shape[0]
|
1062
1058
|
|
1063
1059
|
# `sample` projections.
|
1064
1060
|
query = attn.to_q(hidden_states)
|
1065
1061
|
key = attn.to_k(hidden_states)
|
1066
1062
|
value = attn.to_v(hidden_states)
|
1067
1063
|
|
1068
|
-
# `context` projections.
|
1069
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1070
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1071
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1072
|
-
|
1073
|
-
# attention
|
1074
|
-
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1075
|
-
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1076
|
-
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1077
|
-
|
1078
1064
|
inner_dim = key.shape[-1]
|
1079
1065
|
head_dim = inner_dim // attn.heads
|
1066
|
+
|
1080
1067
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1081
1068
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1082
1069
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1083
1070
|
|
1071
|
+
if attn.norm_q is not None:
|
1072
|
+
query = attn.norm_q(query)
|
1073
|
+
if attn.norm_k is not None:
|
1074
|
+
key = attn.norm_k(key)
|
1075
|
+
|
1076
|
+
# `context` projections.
|
1077
|
+
if encoder_hidden_states is not None:
|
1078
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1079
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1080
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1081
|
+
|
1082
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1083
|
+
batch_size, -1, attn.heads, head_dim
|
1084
|
+
).transpose(1, 2)
|
1085
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
1086
|
+
batch_size, -1, attn.heads, head_dim
|
1087
|
+
).transpose(1, 2)
|
1088
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1089
|
+
batch_size, -1, attn.heads, head_dim
|
1090
|
+
).transpose(1, 2)
|
1091
|
+
|
1092
|
+
if attn.norm_added_q is not None:
|
1093
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1094
|
+
if attn.norm_added_k is not None:
|
1095
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1096
|
+
|
1097
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
1098
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
1099
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
1100
|
+
|
1084
1101
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1085
1102
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1086
1103
|
hidden_states = hidden_states.to(query.dtype)
|
1087
1104
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
hidden_states
|
1091
|
-
|
1092
|
-
|
1105
|
+
if encoder_hidden_states is not None:
|
1106
|
+
# Split the attention outputs.
|
1107
|
+
hidden_states, encoder_hidden_states = (
|
1108
|
+
hidden_states[:, : residual.shape[1]],
|
1109
|
+
hidden_states[:, residual.shape[1] :],
|
1110
|
+
)
|
1111
|
+
if not attn.context_pre_only:
|
1112
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1093
1113
|
|
1094
1114
|
# linear proj
|
1095
1115
|
hidden_states = attn.to_out[0](hidden_states)
|
1096
1116
|
# dropout
|
1097
1117
|
hidden_states = attn.to_out[1](hidden_states)
|
1098
|
-
if not attn.context_pre_only:
|
1099
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1100
|
-
|
1101
|
-
if input_ndim == 4:
|
1102
|
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1103
|
-
if context_input_ndim == 4:
|
1104
|
-
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1105
1118
|
|
1106
|
-
|
1119
|
+
if encoder_hidden_states is not None:
|
1120
|
+
return hidden_states, encoder_hidden_states
|
1121
|
+
else:
|
1122
|
+
return hidden_states
|
1107
1123
|
|
1108
1124
|
|
1109
1125
|
class PAGJointAttnProcessor2_0:
|
@@ -1695,52 +1711,32 @@ class FusedAuraFlowAttnProcessor2_0:
|
|
1695
1711
|
return hidden_states
|
1696
1712
|
|
1697
1713
|
|
1698
|
-
|
1699
|
-
|
1700
|
-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
1701
|
-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
1702
|
-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
1703
|
-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
1704
|
-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
1705
|
-
|
1706
|
-
|
1707
|
-
class FluxSingleAttnProcessor2_0:
|
1708
|
-
r"""
|
1709
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1710
|
-
"""
|
1714
|
+
class FluxAttnProcessor2_0:
|
1715
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1711
1716
|
|
1712
1717
|
def __init__(self):
|
1713
1718
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1714
|
-
raise ImportError("
|
1719
|
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1715
1720
|
|
1716
1721
|
def __call__(
|
1717
1722
|
self,
|
1718
1723
|
attn: Attention,
|
1719
|
-
hidden_states: torch.
|
1720
|
-
encoder_hidden_states:
|
1724
|
+
hidden_states: torch.FloatTensor,
|
1725
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1721
1726
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1722
1727
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
1723
|
-
) -> torch.
|
1724
|
-
input_ndim = hidden_states.ndim
|
1725
|
-
|
1726
|
-
if input_ndim == 4:
|
1727
|
-
batch_size, channel, height, width = hidden_states.shape
|
1728
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1729
|
-
|
1728
|
+
) -> torch.FloatTensor:
|
1730
1729
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1731
1730
|
|
1731
|
+
# `sample` projections.
|
1732
1732
|
query = attn.to_q(hidden_states)
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
key = attn.to_k(encoder_hidden_states)
|
1737
|
-
value = attn.to_v(encoder_hidden_states)
|
1733
|
+
key = attn.to_k(hidden_states)
|
1734
|
+
value = attn.to_v(hidden_states)
|
1738
1735
|
|
1739
1736
|
inner_dim = key.shape[-1]
|
1740
1737
|
head_dim = inner_dim // attn.heads
|
1741
1738
|
|
1742
1739
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1743
|
-
|
1744
1740
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1745
1741
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1746
1742
|
|
@@ -1749,33 +1745,68 @@ class FluxSingleAttnProcessor2_0:
|
|
1749
1745
|
if attn.norm_k is not None:
|
1750
1746
|
key = attn.norm_k(key)
|
1751
1747
|
|
1752
|
-
#
|
1748
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
1749
|
+
if encoder_hidden_states is not None:
|
1750
|
+
# `context` projections.
|
1751
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1752
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1753
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1754
|
+
|
1755
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1756
|
+
batch_size, -1, attn.heads, head_dim
|
1757
|
+
).transpose(1, 2)
|
1758
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
1759
|
+
batch_size, -1, attn.heads, head_dim
|
1760
|
+
).transpose(1, 2)
|
1761
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1762
|
+
batch_size, -1, attn.heads, head_dim
|
1763
|
+
).transpose(1, 2)
|
1764
|
+
|
1765
|
+
if attn.norm_added_q is not None:
|
1766
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1767
|
+
if attn.norm_added_k is not None:
|
1768
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1769
|
+
|
1770
|
+
# attention
|
1771
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
1772
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
1773
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1774
|
+
|
1753
1775
|
if image_rotary_emb is not None:
|
1754
|
-
|
1755
|
-
# from ..embeddings import apply_rotary_emb
|
1756
|
-
# query = apply_rotary_emb(query, image_rotary_emb)
|
1757
|
-
# key = apply_rotary_emb(key, image_rotary_emb)
|
1758
|
-
query, key = apply_rope(query, key, image_rotary_emb)
|
1776
|
+
from .embeddings import apply_rotary_emb
|
1759
1777
|
|
1760
|
-
|
1761
|
-
|
1762
|
-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1778
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
1779
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
1763
1780
|
|
1781
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1764
1782
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1765
1783
|
hidden_states = hidden_states.to(query.dtype)
|
1766
1784
|
|
1767
|
-
if
|
1768
|
-
hidden_states =
|
1785
|
+
if encoder_hidden_states is not None:
|
1786
|
+
encoder_hidden_states, hidden_states = (
|
1787
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
1788
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
1789
|
+
)
|
1769
1790
|
|
1770
|
-
|
1791
|
+
# linear proj
|
1792
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1793
|
+
# dropout
|
1794
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1795
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1796
|
+
|
1797
|
+
return hidden_states, encoder_hidden_states
|
1798
|
+
else:
|
1799
|
+
return hidden_states
|
1771
1800
|
|
1772
1801
|
|
1773
|
-
class
|
1802
|
+
class FusedFluxAttnProcessor2_0:
|
1774
1803
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1775
1804
|
|
1776
1805
|
def __init__(self):
|
1777
1806
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1778
|
-
raise ImportError(
|
1807
|
+
raise ImportError(
|
1808
|
+
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1809
|
+
)
|
1779
1810
|
|
1780
1811
|
def __call__(
|
1781
1812
|
self,
|
@@ -1785,21 +1816,12 @@ class FluxAttnProcessor2_0:
|
|
1785
1816
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1786
1817
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
1787
1818
|
) -> torch.FloatTensor:
|
1788
|
-
|
1789
|
-
if input_ndim == 4:
|
1790
|
-
batch_size, channel, height, width = hidden_states.shape
|
1791
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1792
|
-
context_input_ndim = encoder_hidden_states.ndim
|
1793
|
-
if context_input_ndim == 4:
|
1794
|
-
batch_size, channel, height, width = encoder_hidden_states.shape
|
1795
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1796
|
-
|
1797
|
-
batch_size = encoder_hidden_states.shape[0]
|
1819
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1798
1820
|
|
1799
1821
|
# `sample` projections.
|
1800
|
-
|
1801
|
-
|
1802
|
-
value =
|
1822
|
+
qkv = attn.to_qkv(hidden_states)
|
1823
|
+
split_size = qkv.shape[-1] // 3
|
1824
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1803
1825
|
|
1804
1826
|
inner_dim = key.shape[-1]
|
1805
1827
|
head_dim = inner_dim // attn.heads
|
@@ -1813,59 +1835,62 @@ class FluxAttnProcessor2_0:
|
|
1813
1835
|
if attn.norm_k is not None:
|
1814
1836
|
key = attn.norm_k(key)
|
1815
1837
|
|
1838
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
1816
1839
|
# `context` projections.
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1840
|
+
if encoder_hidden_states is not None:
|
1841
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
1842
|
+
split_size = encoder_qkv.shape[-1] // 3
|
1843
|
+
(
|
1844
|
+
encoder_hidden_states_query_proj,
|
1845
|
+
encoder_hidden_states_key_proj,
|
1846
|
+
encoder_hidden_states_value_proj,
|
1847
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
1820
1848
|
|
1821
|
-
|
1822
|
-
|
1823
|
-
|
1824
|
-
|
1825
|
-
|
1826
|
-
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1831
|
-
if attn.norm_added_q is not None:
|
1832
|
-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1833
|
-
if attn.norm_added_k is not None:
|
1834
|
-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1849
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1850
|
+
batch_size, -1, attn.heads, head_dim
|
1851
|
+
).transpose(1, 2)
|
1852
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
1853
|
+
batch_size, -1, attn.heads, head_dim
|
1854
|
+
).transpose(1, 2)
|
1855
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1856
|
+
batch_size, -1, attn.heads, head_dim
|
1857
|
+
).transpose(1, 2)
|
1835
1858
|
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
1859
|
+
if attn.norm_added_q is not None:
|
1860
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1861
|
+
if attn.norm_added_k is not None:
|
1862
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1863
|
+
|
1864
|
+
# attention
|
1865
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
1866
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
1867
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1840
1868
|
|
1841
1869
|
if image_rotary_emb is not None:
|
1842
|
-
|
1843
|
-
|
1844
|
-
|
1845
|
-
|
1846
|
-
query, key = apply_rope(query, key, image_rotary_emb)
|
1870
|
+
from .embeddings import apply_rotary_emb
|
1871
|
+
|
1872
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
1873
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
1847
1874
|
|
1848
1875
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1849
1876
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1850
1877
|
hidden_states = hidden_states.to(query.dtype)
|
1851
1878
|
|
1852
|
-
encoder_hidden_states
|
1853
|
-
hidden_states
|
1854
|
-
|
1855
|
-
|
1856
|
-
|
1857
|
-
# linear proj
|
1858
|
-
hidden_states = attn.to_out[0](hidden_states)
|
1859
|
-
# dropout
|
1860
|
-
hidden_states = attn.to_out[1](hidden_states)
|
1861
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1879
|
+
if encoder_hidden_states is not None:
|
1880
|
+
encoder_hidden_states, hidden_states = (
|
1881
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
1882
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
1883
|
+
)
|
1862
1884
|
|
1863
|
-
|
1864
|
-
hidden_states =
|
1865
|
-
|
1866
|
-
|
1885
|
+
# linear proj
|
1886
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1887
|
+
# dropout
|
1888
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1889
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1867
1890
|
|
1868
|
-
|
1891
|
+
return hidden_states, encoder_hidden_states
|
1892
|
+
else:
|
1893
|
+
return hidden_states
|
1869
1894
|
|
1870
1895
|
|
1871
1896
|
class CogVideoXAttnProcessor2_0:
|
@@ -4247,6 +4272,17 @@ class LoRAAttnAddedKVProcessor:
|
|
4247
4272
|
pass
|
4248
4273
|
|
4249
4274
|
|
4275
|
+
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
4276
|
+
r"""
|
4277
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
4278
|
+
"""
|
4279
|
+
|
4280
|
+
def __init__(self):
|
4281
|
+
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
4282
|
+
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
4283
|
+
super().__init__()
|
4284
|
+
|
4285
|
+
|
4250
4286
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
4251
4287
|
AttnAddedKVProcessor,
|
4252
4288
|
SlicedAttnAddedKVProcessor,
|