diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +1 -1
- diffusers/loaders/single_file_model.py +5 -0
- diffusers/loaders/single_file_utils.py +242 -2
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +5 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +128 -17
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +344 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
diffusers/models/__init__.py
CHANGED
@@ -31,17 +31,19 @@ if is_torch_available():
|
|
31
31
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
32
32
|
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
33
33
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
34
|
+
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
34
35
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
35
36
|
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
36
|
-
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
37
37
|
_import_structure["embeddings"] = ["ImageProjection"]
|
38
38
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
39
39
|
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
40
|
+
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
40
41
|
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
41
42
|
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
42
43
|
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
43
44
|
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
44
45
|
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
46
|
+
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
45
47
|
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
46
48
|
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
47
49
|
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
@@ -53,7 +55,6 @@ if is_torch_available():
|
|
53
55
|
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
54
56
|
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
55
57
|
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
56
|
-
_import_structure["vq_model"] = ["VQModel"]
|
57
58
|
|
58
59
|
if is_flax_available():
|
59
60
|
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
@@ -70,6 +71,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
70
71
|
AutoencoderKLTemporalDecoder,
|
71
72
|
AutoencoderTiny,
|
72
73
|
ConsistencyDecoderVAE,
|
74
|
+
VQModel,
|
73
75
|
)
|
74
76
|
from .controlnet import ControlNetModel
|
75
77
|
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
@@ -81,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
81
83
|
HunyuanDiT2DModel,
|
82
84
|
PixArtTransformer2DModel,
|
83
85
|
PriorTransformer,
|
86
|
+
SD3Transformer2DModel,
|
84
87
|
T5FilmDecoder,
|
85
88
|
Transformer2DModel,
|
86
89
|
TransformerTemporalModel,
|
@@ -98,7 +101,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
98
101
|
UNetSpatioTemporalConditionModel,
|
99
102
|
UVit2DModel,
|
100
103
|
)
|
101
|
-
from .vq_model import VQModel
|
102
104
|
|
103
105
|
if is_flax_available():
|
104
106
|
from .controlnet_flax import FlaxControlNetModel
|
diffusers/models/attention.py
CHANGED
@@ -20,7 +20,7 @@ from torch import nn
|
|
20
20
|
from ..utils import deprecate, logging
|
21
21
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
22
|
from .activations import GEGLU, GELU, ApproximateGELU
|
23
|
-
from .attention_processor import Attention
|
23
|
+
from .attention_processor import Attention, JointAttnProcessor2_0
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
25
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
26
26
|
|
@@ -85,6 +85,130 @@ class GatedSelfAttentionDense(nn.Module):
|
|
85
85
|
return x
|
86
86
|
|
87
87
|
|
88
|
+
@maybe_allow_in_graph
|
89
|
+
class JointTransformerBlock(nn.Module):
|
90
|
+
r"""
|
91
|
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
92
|
+
|
93
|
+
Reference: https://arxiv.org/abs/2403.03206
|
94
|
+
|
95
|
+
Parameters:
|
96
|
+
dim (`int`): The number of channels in the input and output.
|
97
|
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
98
|
+
attention_head_dim (`int`): The number of channels in each head.
|
99
|
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
100
|
+
processing of `context` conditions.
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
self.context_pre_only = context_pre_only
|
107
|
+
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
108
|
+
|
109
|
+
self.norm1 = AdaLayerNormZero(dim)
|
110
|
+
|
111
|
+
if context_norm_type == "ada_norm_continous":
|
112
|
+
self.norm1_context = AdaLayerNormContinuous(
|
113
|
+
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
114
|
+
)
|
115
|
+
elif context_norm_type == "ada_norm_zero":
|
116
|
+
self.norm1_context = AdaLayerNormZero(dim)
|
117
|
+
else:
|
118
|
+
raise ValueError(
|
119
|
+
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
120
|
+
)
|
121
|
+
if hasattr(F, "scaled_dot_product_attention"):
|
122
|
+
processor = JointAttnProcessor2_0()
|
123
|
+
else:
|
124
|
+
raise ValueError(
|
125
|
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
126
|
+
)
|
127
|
+
self.attn = Attention(
|
128
|
+
query_dim=dim,
|
129
|
+
cross_attention_dim=None,
|
130
|
+
added_kv_proj_dim=dim,
|
131
|
+
dim_head=attention_head_dim // num_attention_heads,
|
132
|
+
heads=num_attention_heads,
|
133
|
+
out_dim=attention_head_dim,
|
134
|
+
context_pre_only=context_pre_only,
|
135
|
+
bias=True,
|
136
|
+
processor=processor,
|
137
|
+
)
|
138
|
+
|
139
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
140
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
141
|
+
|
142
|
+
if not context_pre_only:
|
143
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
144
|
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
145
|
+
else:
|
146
|
+
self.norm2_context = None
|
147
|
+
self.ff_context = None
|
148
|
+
|
149
|
+
# let chunk size default to None
|
150
|
+
self._chunk_size = None
|
151
|
+
self._chunk_dim = 0
|
152
|
+
|
153
|
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
154
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
155
|
+
# Sets chunk feed-forward
|
156
|
+
self._chunk_size = chunk_size
|
157
|
+
self._chunk_dim = dim
|
158
|
+
|
159
|
+
def forward(
|
160
|
+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
161
|
+
):
|
162
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
163
|
+
|
164
|
+
if self.context_pre_only:
|
165
|
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
166
|
+
else:
|
167
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
168
|
+
encoder_hidden_states, emb=temb
|
169
|
+
)
|
170
|
+
|
171
|
+
# Attention.
|
172
|
+
attn_output, context_attn_output = self.attn(
|
173
|
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
174
|
+
)
|
175
|
+
|
176
|
+
# Process attention outputs for the `hidden_states`.
|
177
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
178
|
+
hidden_states = hidden_states + attn_output
|
179
|
+
|
180
|
+
norm_hidden_states = self.norm2(hidden_states)
|
181
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
182
|
+
if self._chunk_size is not None:
|
183
|
+
# "feed_forward_chunk_size" can be used to save memory
|
184
|
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
185
|
+
else:
|
186
|
+
ff_output = self.ff(norm_hidden_states)
|
187
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
188
|
+
|
189
|
+
hidden_states = hidden_states + ff_output
|
190
|
+
|
191
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
192
|
+
if self.context_pre_only:
|
193
|
+
encoder_hidden_states = None
|
194
|
+
else:
|
195
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
196
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
197
|
+
|
198
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
199
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
200
|
+
if self._chunk_size is not None:
|
201
|
+
# "feed_forward_chunk_size" can be used to save memory
|
202
|
+
context_ff_output = _chunked_feed_forward(
|
203
|
+
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
204
|
+
)
|
205
|
+
else:
|
206
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
207
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
208
|
+
|
209
|
+
return encoder_hidden_states, hidden_states
|
210
|
+
|
211
|
+
|
88
212
|
@maybe_allow_in_graph
|
89
213
|
class BasicTransformerBlock(nn.Module):
|
90
214
|
r"""
|
@@ -116,6 +116,7 @@ class Attention(nn.Module):
|
|
116
116
|
_from_deprecated_attn_block: bool = False,
|
117
117
|
processor: Optional["AttnProcessor"] = None,
|
118
118
|
out_dim: int = None,
|
119
|
+
context_pre_only=None,
|
119
120
|
):
|
120
121
|
super().__init__()
|
121
122
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
@@ -130,6 +131,7 @@ class Attention(nn.Module):
|
|
130
131
|
self.dropout = dropout
|
131
132
|
self.fused_projections = False
|
132
133
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
134
|
+
self.context_pre_only = context_pre_only
|
133
135
|
|
134
136
|
# we make use of this private variable to know whether this class is loaded
|
135
137
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -207,11 +209,16 @@ class Attention(nn.Module):
|
|
207
209
|
if self.added_kv_proj_dim is not None:
|
208
210
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
209
211
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
212
|
+
if self.context_pre_only is not None:
|
213
|
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
210
214
|
|
211
215
|
self.to_out = nn.ModuleList([])
|
212
216
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
213
217
|
self.to_out.append(nn.Dropout(dropout))
|
214
218
|
|
219
|
+
if self.context_pre_only is not None and not self.context_pre_only:
|
220
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
221
|
+
|
215
222
|
# set attention processor
|
216
223
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
217
224
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
@@ -539,7 +546,10 @@ class Attention(nn.Module):
|
|
539
546
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
540
547
|
|
541
548
|
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
542
|
-
|
549
|
+
quiet_attn_parameters = {"ip_adapter_masks"}
|
550
|
+
unused_kwargs = [
|
551
|
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
552
|
+
]
|
543
553
|
if len(unused_kwargs) > 0:
|
544
554
|
logger.warning(
|
545
555
|
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
@@ -1072,6 +1082,164 @@ class AttnAddedKVProcessor2_0:
|
|
1072
1082
|
return hidden_states
|
1073
1083
|
|
1074
1084
|
|
1085
|
+
class JointAttnProcessor2_0:
|
1086
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1087
|
+
|
1088
|
+
def __init__(self):
|
1089
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1090
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1091
|
+
|
1092
|
+
def __call__(
|
1093
|
+
self,
|
1094
|
+
attn: Attention,
|
1095
|
+
hidden_states: torch.FloatTensor,
|
1096
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1097
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1098
|
+
*args,
|
1099
|
+
**kwargs,
|
1100
|
+
) -> torch.FloatTensor:
|
1101
|
+
residual = hidden_states
|
1102
|
+
|
1103
|
+
input_ndim = hidden_states.ndim
|
1104
|
+
if input_ndim == 4:
|
1105
|
+
batch_size, channel, height, width = hidden_states.shape
|
1106
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1107
|
+
context_input_ndim = encoder_hidden_states.ndim
|
1108
|
+
if context_input_ndim == 4:
|
1109
|
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
1110
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1111
|
+
|
1112
|
+
batch_size = encoder_hidden_states.shape[0]
|
1113
|
+
|
1114
|
+
# `sample` projections.
|
1115
|
+
query = attn.to_q(hidden_states)
|
1116
|
+
key = attn.to_k(hidden_states)
|
1117
|
+
value = attn.to_v(hidden_states)
|
1118
|
+
|
1119
|
+
# `context` projections.
|
1120
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1121
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1122
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1123
|
+
|
1124
|
+
# attention
|
1125
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1126
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1127
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1128
|
+
|
1129
|
+
inner_dim = key.shape[-1]
|
1130
|
+
head_dim = inner_dim // attn.heads
|
1131
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1132
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1133
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1134
|
+
|
1135
|
+
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
1136
|
+
query, key, value, dropout_p=0.0, is_causal=False
|
1137
|
+
)
|
1138
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1139
|
+
hidden_states = hidden_states.to(query.dtype)
|
1140
|
+
|
1141
|
+
# Split the attention outputs.
|
1142
|
+
hidden_states, encoder_hidden_states = (
|
1143
|
+
hidden_states[:, : residual.shape[1]],
|
1144
|
+
hidden_states[:, residual.shape[1] :],
|
1145
|
+
)
|
1146
|
+
|
1147
|
+
# linear proj
|
1148
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1149
|
+
# dropout
|
1150
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1151
|
+
if not attn.context_pre_only:
|
1152
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1153
|
+
|
1154
|
+
if input_ndim == 4:
|
1155
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1156
|
+
if context_input_ndim == 4:
|
1157
|
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1158
|
+
|
1159
|
+
return hidden_states, encoder_hidden_states
|
1160
|
+
|
1161
|
+
|
1162
|
+
class FusedJointAttnProcessor2_0:
|
1163
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1164
|
+
|
1165
|
+
def __init__(self):
|
1166
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1167
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1168
|
+
|
1169
|
+
def __call__(
|
1170
|
+
self,
|
1171
|
+
attn: Attention,
|
1172
|
+
hidden_states: torch.FloatTensor,
|
1173
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1174
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1175
|
+
*args,
|
1176
|
+
**kwargs,
|
1177
|
+
) -> torch.FloatTensor:
|
1178
|
+
residual = hidden_states
|
1179
|
+
|
1180
|
+
input_ndim = hidden_states.ndim
|
1181
|
+
if input_ndim == 4:
|
1182
|
+
batch_size, channel, height, width = hidden_states.shape
|
1183
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1184
|
+
context_input_ndim = encoder_hidden_states.ndim
|
1185
|
+
if context_input_ndim == 4:
|
1186
|
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
1187
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1188
|
+
|
1189
|
+
batch_size = encoder_hidden_states.shape[0]
|
1190
|
+
|
1191
|
+
# `sample` projections.
|
1192
|
+
qkv = attn.to_qkv(hidden_states)
|
1193
|
+
split_size = qkv.shape[-1] // 3
|
1194
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1195
|
+
|
1196
|
+
# `context` projections.
|
1197
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
1198
|
+
split_size = encoder_qkv.shape[-1] // 3
|
1199
|
+
(
|
1200
|
+
encoder_hidden_states_query_proj,
|
1201
|
+
encoder_hidden_states_key_proj,
|
1202
|
+
encoder_hidden_states_value_proj,
|
1203
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
1204
|
+
|
1205
|
+
# attention
|
1206
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1207
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1208
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1209
|
+
|
1210
|
+
inner_dim = key.shape[-1]
|
1211
|
+
head_dim = inner_dim // attn.heads
|
1212
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1213
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1214
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1215
|
+
|
1216
|
+
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
1217
|
+
query, key, value, dropout_p=0.0, is_causal=False
|
1218
|
+
)
|
1219
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1220
|
+
hidden_states = hidden_states.to(query.dtype)
|
1221
|
+
|
1222
|
+
# Split the attention outputs.
|
1223
|
+
hidden_states, encoder_hidden_states = (
|
1224
|
+
hidden_states[:, : residual.shape[1]],
|
1225
|
+
hidden_states[:, residual.shape[1] :],
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
# linear proj
|
1229
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1230
|
+
# dropout
|
1231
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1232
|
+
if not attn.context_pre_only:
|
1233
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1234
|
+
|
1235
|
+
if input_ndim == 4:
|
1236
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1237
|
+
if context_input_ndim == 4:
|
1238
|
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1239
|
+
|
1240
|
+
return hidden_states, encoder_hidden_states
|
1241
|
+
|
1242
|
+
|
1075
1243
|
class XFormersAttnAddedKVProcessor:
|
1076
1244
|
r"""
|
1077
1245
|
Processor for implementing memory efficient attention using xFormers.
|
@@ -176,7 +176,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
|
176
176
|
z = posterior.sample(generator=generator)
|
177
177
|
else:
|
178
178
|
z = posterior.mode()
|
179
|
-
dec = self.decode(z, sample, mask).sample
|
179
|
+
dec = self.decode(z, generator, sample, mask).sample
|
180
180
|
|
181
181
|
if not return_dict:
|
182
182
|
return (dec,)
|
@@ -81,9 +81,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
81
81
|
norm_num_groups: int = 32,
|
82
82
|
sample_size: int = 32,
|
83
83
|
scaling_factor: float = 0.18215,
|
84
|
+
shift_factor: Optional[float] = None,
|
84
85
|
latents_mean: Optional[Tuple[float]] = None,
|
85
86
|
latents_std: Optional[Tuple[float]] = None,
|
86
87
|
force_upcast: float = True,
|
88
|
+
use_quant_conv: bool = True,
|
89
|
+
use_post_quant_conv: bool = True,
|
87
90
|
):
|
88
91
|
super().__init__()
|
89
92
|
|
@@ -110,8 +113,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
110
113
|
act_fn=act_fn,
|
111
114
|
)
|
112
115
|
|
113
|
-
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
114
|
-
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
116
|
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
117
|
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
115
118
|
|
116
119
|
self.use_slicing = False
|
117
120
|
self.use_tiling = False
|
@@ -260,7 +263,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
260
263
|
else:
|
261
264
|
h = self.encoder(x)
|
262
265
|
|
263
|
-
|
266
|
+
if self.quant_conv is not None:
|
267
|
+
moments = self.quant_conv(h)
|
268
|
+
else:
|
269
|
+
moments = h
|
270
|
+
|
264
271
|
posterior = DiagonalGaussianDistribution(moments)
|
265
272
|
|
266
273
|
if not return_dict:
|
@@ -272,7 +279,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
272
279
|
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
273
280
|
return self.tiled_decode(z, return_dict=return_dict)
|
274
281
|
|
275
|
-
|
282
|
+
if self.post_quant_conv is not None:
|
283
|
+
z = self.post_quant_conv(z)
|
284
|
+
|
276
285
|
dec = self.decoder(z)
|
277
286
|
|
278
287
|
if not return_dict:
|
@@ -281,7 +290,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
281
290
|
return DecoderOutput(sample=dec)
|
282
291
|
|
283
292
|
@apply_forward_hook
|
284
|
-
def decode(
|
293
|
+
def decode(
|
294
|
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
295
|
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
285
296
|
"""
|
286
297
|
Decode a batch of images.
|
287
298
|
|
@@ -300,7 +311,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
300
311
|
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
301
312
|
decoded = torch.cat(decoded_slices)
|
302
313
|
else:
|
303
|
-
decoded = self._decode(z
|
314
|
+
decoded = self._decode(z).sample
|
304
315
|
|
305
316
|
if not return_dict:
|
306
317
|
return (decoded,)
|
@@ -323,11 +323,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
|
323
323
|
Args:
|
324
324
|
x (`torch.Tensor`): Input batch of images.
|
325
325
|
return_dict (`bool`, *optional*, defaults to `True`):
|
326
|
-
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
|
326
|
+
Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
|
327
|
+
tuple.
|
327
328
|
|
328
329
|
Returns:
|
329
330
|
The latent representations of the encoded images. If `return_dict` is True, a
|
330
|
-
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
|
331
|
+
[`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
|
332
|
+
returned.
|
331
333
|
"""
|
332
334
|
h = self.encoder(x)
|
333
335
|
moments = self.quant_conv(h)
|
@@ -284,13 +284,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
284
284
|
Args:
|
285
285
|
x (`torch.Tensor`): Input batch of images.
|
286
286
|
return_dict (`bool`, *optional*, defaults to `True`):
|
287
|
-
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
288
|
-
tuple.
|
287
|
+
Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
288
|
+
instead of a plain tuple.
|
289
289
|
|
290
290
|
Returns:
|
291
291
|
The latent representations of the encoded images. If `return_dict` is True, a
|
292
|
-
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
|
293
|
-
is returned.
|
292
|
+
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
|
293
|
+
plain `tuple` is returned.
|
294
294
|
"""
|
295
295
|
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
296
296
|
return self.tiled_encode(x, return_dict=return_dict)
|
@@ -382,13 +382,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
|
382
382
|
Args:
|
383
383
|
x (`torch.Tensor`): Input batch of images.
|
384
384
|
return_dict (`bool`, *optional*, defaults to `True`):
|
385
|
-
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
386
|
-
plain tuple.
|
385
|
+
Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
386
|
+
instead of a plain tuple.
|
387
387
|
|
388
388
|
Returns:
|
389
|
-
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
|
390
|
-
If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
391
|
-
otherwise a plain `tuple` is returned.
|
389
|
+
[`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
|
390
|
+
If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
|
391
|
+
is returned, otherwise a plain `tuple` is returned.
|
392
392
|
"""
|
393
393
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
394
394
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|