diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
diffusers/models/attention.py
CHANGED
@@ -119,6 +119,15 @@ class BasicTransformerBlock(nn.Module):
|
|
119
119
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
120
120
|
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
121
121
|
|
122
|
+
# let chunk size default to None
|
123
|
+
self._chunk_size = None
|
124
|
+
self._chunk_dim = 0
|
125
|
+
|
126
|
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
127
|
+
# Sets chunk feed-forward
|
128
|
+
self._chunk_size = chunk_size
|
129
|
+
self._chunk_dim = dim
|
130
|
+
|
122
131
|
def forward(
|
123
132
|
self,
|
124
133
|
hidden_states: torch.FloatTensor,
|
@@ -141,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
|
|
141
150
|
norm_hidden_states = self.norm1(hidden_states)
|
142
151
|
|
143
152
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
153
|
+
|
144
154
|
attn_output = self.attn1(
|
145
155
|
norm_hidden_states,
|
146
156
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
@@ -171,7 +181,20 @@ class BasicTransformerBlock(nn.Module):
|
|
171
181
|
if self.use_ada_layer_norm_zero:
|
172
182
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
173
183
|
|
174
|
-
|
184
|
+
if self._chunk_size is not None:
|
185
|
+
# "feed_forward_chunk_size" can be used to save memory
|
186
|
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
187
|
+
raise ValueError(
|
188
|
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
189
|
+
)
|
190
|
+
|
191
|
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
192
|
+
ff_output = torch.cat(
|
193
|
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
194
|
+
dim=self._chunk_dim,
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
ff_output = self.ff(norm_hidden_states)
|
175
198
|
|
176
199
|
if self.use_ada_layer_norm_zero:
|
177
200
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
@@ -152,6 +152,7 @@ class FlaxAttention(nn.Module):
|
|
152
152
|
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
153
153
|
|
154
154
|
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
155
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
155
156
|
|
156
157
|
def reshape_heads_to_batch_dim(self, tensor):
|
157
158
|
batch_size, seq_len, dim = tensor.shape
|
@@ -214,7 +215,7 @@ class FlaxAttention(nn.Module):
|
|
214
215
|
|
215
216
|
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
216
217
|
hidden_states = self.proj_attn(hidden_states)
|
217
|
-
return hidden_states
|
218
|
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
218
219
|
|
219
220
|
|
220
221
|
class FlaxBasicTransformerBlock(nn.Module):
|
@@ -260,6 +261,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
|
260
261
|
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
261
262
|
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
262
263
|
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
264
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
263
265
|
|
264
266
|
def __call__(self, hidden_states, context, deterministic=True):
|
265
267
|
# self attention
|
@@ -280,7 +282,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
|
280
282
|
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
281
283
|
hidden_states = hidden_states + residual
|
282
284
|
|
283
|
-
return hidden_states
|
285
|
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
284
286
|
|
285
287
|
|
286
288
|
class FlaxTransformer2DModel(nn.Module):
|
@@ -356,6 +358,8 @@ class FlaxTransformer2DModel(nn.Module):
|
|
356
358
|
dtype=self.dtype,
|
357
359
|
)
|
358
360
|
|
361
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
362
|
+
|
359
363
|
def __call__(self, hidden_states, context, deterministic=True):
|
360
364
|
batch, height, width, channels = hidden_states.shape
|
361
365
|
residual = hidden_states
|
@@ -378,7 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
|
|
378
382
|
hidden_states = self.proj_out(hidden_states)
|
379
383
|
|
380
384
|
hidden_states = hidden_states + residual
|
381
|
-
return hidden_states
|
385
|
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
382
386
|
|
383
387
|
|
384
388
|
class FlaxFeedForward(nn.Module):
|
@@ -409,7 +413,7 @@ class FlaxFeedForward(nn.Module):
|
|
409
413
|
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
410
414
|
|
411
415
|
def __call__(self, hidden_states, deterministic=True):
|
412
|
-
hidden_states = self.net_0(hidden_states)
|
416
|
+
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
413
417
|
hidden_states = self.net_2(hidden_states)
|
414
418
|
return hidden_states
|
415
419
|
|
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
|
|
434
438
|
def setup(self):
|
435
439
|
inner_dim = self.dim * 4
|
436
440
|
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
441
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
437
442
|
|
438
443
|
def __call__(self, hidden_states, deterministic=True):
|
439
444
|
hidden_states = self.proj(hidden_states)
|
440
445
|
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
441
|
-
return hidden_linear * nn.gelu(hidden_gelu)
|
446
|
+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
@@ -78,6 +78,7 @@ class Attention(nn.Module):
|
|
78
78
|
self.upcast_softmax = upcast_softmax
|
79
79
|
self.rescale_output_factor = rescale_output_factor
|
80
80
|
self.residual_connection = residual_connection
|
81
|
+
self.dropout = dropout
|
81
82
|
|
82
83
|
# we make use of this private variable to know whether this class is loaded
|
83
84
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -1117,7 +1118,9 @@ class AttnProcessor2_0:
|
|
1117
1118
|
value = attn.to_v(encoder_hidden_states)
|
1118
1119
|
|
1119
1120
|
head_dim = inner_dim // attn.heads
|
1121
|
+
|
1120
1122
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1123
|
+
|
1121
1124
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1122
1125
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1123
1126
|
|
@@ -12,13 +12,14 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
from dataclasses import dataclass
|
15
|
-
from typing import Optional, Tuple, Union
|
15
|
+
from typing import Dict, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
import torch.nn as nn
|
19
19
|
|
20
20
|
from ..configuration_utils import ConfigMixin, register_to_config
|
21
21
|
from ..utils import BaseOutput, apply_forward_hook
|
22
|
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
22
23
|
from .modeling_utils import ModelMixin
|
23
24
|
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
24
25
|
|
@@ -38,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput):
|
|
38
39
|
|
39
40
|
|
40
41
|
class AutoencoderKL(ModelMixin, ConfigMixin):
|
41
|
-
r"""
|
42
|
-
and
|
42
|
+
r"""
|
43
|
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
43
44
|
|
44
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
45
|
-
|
45
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46
|
+
for all models (such as downloading or saving).
|
46
47
|
|
47
48
|
Parameters:
|
48
49
|
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
49
50
|
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
50
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
51
|
-
|
52
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
53
|
-
|
54
|
-
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
55
|
-
|
51
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
52
|
+
Tuple of downsample block types.
|
53
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
54
|
+
Tuple of upsample block types.
|
55
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
56
|
+
Tuple of block output channels.
|
56
57
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
57
58
|
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
58
|
-
sample_size (`int`, *optional*, defaults to `32`):
|
59
|
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
59
60
|
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
60
61
|
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
61
62
|
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
@@ -130,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|
130
131
|
def enable_tiling(self, use_tiling: bool = True):
|
131
132
|
r"""
|
132
133
|
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
133
|
-
compute decoding and encoding in several steps. This is useful
|
134
|
-
|
134
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
135
|
+
processing larger images.
|
135
136
|
"""
|
136
137
|
self.use_tiling = use_tiling
|
137
138
|
|
138
139
|
def disable_tiling(self):
|
139
140
|
r"""
|
140
|
-
Disable tiled VAE decoding. If `
|
141
|
-
|
141
|
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
142
|
+
decoding in one step.
|
142
143
|
"""
|
143
144
|
self.enable_tiling(False)
|
144
145
|
|
@@ -151,17 +152,89 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|
151
152
|
|
152
153
|
def disable_slicing(self):
|
153
154
|
r"""
|
154
|
-
Disable sliced VAE decoding. If `enable_slicing` was previously
|
155
|
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
155
156
|
decoding in one step.
|
156
157
|
"""
|
157
158
|
self.use_slicing = False
|
158
159
|
|
160
|
+
@property
|
161
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
162
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
163
|
+
r"""
|
164
|
+
Returns:
|
165
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
166
|
+
indexed by its weight name.
|
167
|
+
"""
|
168
|
+
# set recursively
|
169
|
+
processors = {}
|
170
|
+
|
171
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
172
|
+
if hasattr(module, "set_processor"):
|
173
|
+
processors[f"{name}.processor"] = module.processor
|
174
|
+
|
175
|
+
for sub_name, child in module.named_children():
|
176
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
177
|
+
|
178
|
+
return processors
|
179
|
+
|
180
|
+
for name, module in self.named_children():
|
181
|
+
fn_recursive_add_processors(name, module, processors)
|
182
|
+
|
183
|
+
return processors
|
184
|
+
|
185
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
186
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
187
|
+
r"""
|
188
|
+
Sets the attention processor to use to compute attention.
|
189
|
+
|
190
|
+
Parameters:
|
191
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
192
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
193
|
+
for **all** `Attention` layers.
|
194
|
+
|
195
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
196
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
197
|
+
|
198
|
+
"""
|
199
|
+
count = len(self.attn_processors.keys())
|
200
|
+
|
201
|
+
if isinstance(processor, dict) and len(processor) != count:
|
202
|
+
raise ValueError(
|
203
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
204
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
205
|
+
)
|
206
|
+
|
207
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
208
|
+
if hasattr(module, "set_processor"):
|
209
|
+
if not isinstance(processor, dict):
|
210
|
+
module.set_processor(processor)
|
211
|
+
else:
|
212
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
213
|
+
|
214
|
+
for sub_name, child in module.named_children():
|
215
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
216
|
+
|
217
|
+
for name, module in self.named_children():
|
218
|
+
fn_recursive_attn_processor(name, module, processor)
|
219
|
+
|
220
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
221
|
+
def set_default_attn_processor(self):
|
222
|
+
"""
|
223
|
+
Disables custom attention processors and sets the default attention implementation.
|
224
|
+
"""
|
225
|
+
self.set_attn_processor(AttnProcessor())
|
226
|
+
|
159
227
|
@apply_forward_hook
|
160
228
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
161
229
|
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
162
230
|
return self.tiled_encode(x, return_dict=return_dict)
|
163
231
|
|
164
|
-
|
232
|
+
if self.use_slicing and x.shape[0] > 1:
|
233
|
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
234
|
+
h = torch.cat(encoded_slices)
|
235
|
+
else:
|
236
|
+
h = self.encoder(x)
|
237
|
+
|
165
238
|
moments = self.quant_conv(h)
|
166
239
|
posterior = DiagonalGaussianDistribution(moments)
|
167
240
|
|
@@ -210,14 +283,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|
210
283
|
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
211
284
|
r"""Encode a batch of images using a tiled encoder.
|
212
285
|
|
213
|
-
Args:
|
214
286
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
215
|
-
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
216
|
-
different from non-tiled encoding
|
287
|
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
288
|
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
217
289
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
218
|
-
|
219
|
-
|
220
|
-
|
290
|
+
output, but they should be much less noticeable.
|
291
|
+
|
292
|
+
Args:
|
293
|
+
x (`torch.FloatTensor`): Input batch of images.
|
294
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
295
|
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
299
|
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
300
|
+
`tuple` is returned.
|
221
301
|
"""
|
222
302
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
223
303
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
@@ -255,17 +335,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|
255
335
|
return AutoencoderKLOutput(latent_dist=posterior)
|
256
336
|
|
257
337
|
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
258
|
-
r"""
|
338
|
+
r"""
|
339
|
+
Decode a batch of images using a tiled decoder.
|
259
340
|
|
260
341
|
Args:
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
342
|
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
343
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
344
|
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
348
|
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
349
|
+
returned.
|
269
350
|
"""
|
270
351
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
271
352
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
diffusers/models/controlnet.py
CHANGED
@@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
37
37
|
|
38
38
|
@dataclass
|
39
39
|
class ControlNetOutput(BaseOutput):
|
40
|
+
"""
|
41
|
+
The output of [`ControlNetModel`].
|
42
|
+
|
43
|
+
Args:
|
44
|
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
45
|
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
46
|
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
47
|
+
used to condition the original UNet's downsampling activations.
|
48
|
+
mid_down_block_re_sample (`torch.Tensor`):
|
49
|
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
50
|
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
51
|
+
Output can be used to condition the original UNet's middle block activation.
|
52
|
+
"""
|
53
|
+
|
40
54
|
down_block_res_samples: Tuple[torch.Tensor]
|
41
55
|
mid_block_res_sample: torch.Tensor
|
42
56
|
|
@@ -87,12 +101,65 @@ class ControlNetConditioningEmbedding(nn.Module):
|
|
87
101
|
|
88
102
|
|
89
103
|
class ControlNetModel(ModelMixin, ConfigMixin):
|
104
|
+
"""
|
105
|
+
A ControlNet model.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
in_channels (`int`, defaults to 4):
|
109
|
+
The number of channels in the input sample.
|
110
|
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
111
|
+
Whether to flip the sin to cos in the time embedding.
|
112
|
+
freq_shift (`int`, defaults to 0):
|
113
|
+
The frequency shift to apply to the time embedding.
|
114
|
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
115
|
+
The tuple of downsample blocks to use.
|
116
|
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
117
|
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
118
|
+
The tuple of output channels for each block.
|
119
|
+
layers_per_block (`int`, defaults to 2):
|
120
|
+
The number of layers per block.
|
121
|
+
downsample_padding (`int`, defaults to 1):
|
122
|
+
The padding to use for the downsampling convolution.
|
123
|
+
mid_block_scale_factor (`float`, defaults to 1):
|
124
|
+
The scale factor to use for the mid block.
|
125
|
+
act_fn (`str`, defaults to "silu"):
|
126
|
+
The activation function to use.
|
127
|
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
128
|
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
129
|
+
in post-processing.
|
130
|
+
norm_eps (`float`, defaults to 1e-5):
|
131
|
+
The epsilon to use for the normalization.
|
132
|
+
cross_attention_dim (`int`, defaults to 1280):
|
133
|
+
The dimension of the cross attention features.
|
134
|
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
135
|
+
The dimension of the attention heads.
|
136
|
+
use_linear_projection (`bool`, defaults to `False`):
|
137
|
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
138
|
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
139
|
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
140
|
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
141
|
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
142
|
+
class conditioning with `class_embed_type` equal to `None`.
|
143
|
+
upcast_attention (`bool`, defaults to `False`):
|
144
|
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
145
|
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
146
|
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
147
|
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
148
|
+
`class_embed_type="projection"`.
|
149
|
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
150
|
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
151
|
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
152
|
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
153
|
+
global_pool_conditions (`bool`, defaults to `False`):
|
154
|
+
"""
|
155
|
+
|
90
156
|
_supports_gradient_checkpointing = True
|
91
157
|
|
92
158
|
@register_to_config
|
93
159
|
def __init__(
|
94
160
|
self,
|
95
161
|
in_channels: int = 4,
|
162
|
+
conditioning_channels: int = 3,
|
96
163
|
flip_sin_to_cos: bool = True,
|
97
164
|
freq_shift: int = 0,
|
98
165
|
down_block_types: Tuple[str] = (
|
@@ -111,6 +178,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
111
178
|
norm_eps: float = 1e-5,
|
112
179
|
cross_attention_dim: int = 1280,
|
113
180
|
attention_head_dim: Union[int, Tuple[int]] = 8,
|
181
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
114
182
|
use_linear_projection: bool = False,
|
115
183
|
class_embed_type: Optional[str] = None,
|
116
184
|
num_class_embeds: Optional[int] = None,
|
@@ -123,6 +191,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
123
191
|
):
|
124
192
|
super().__init__()
|
125
193
|
|
194
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
195
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
196
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
197
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
198
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
199
|
+
# which is why we correct for the naming here.
|
200
|
+
num_attention_heads = num_attention_heads or attention_head_dim
|
201
|
+
|
126
202
|
# Check inputs
|
127
203
|
if len(block_out_channels) != len(down_block_types):
|
128
204
|
raise ValueError(
|
@@ -134,9 +210,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
134
210
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
135
211
|
)
|
136
212
|
|
137
|
-
if not isinstance(
|
213
|
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
138
214
|
raise ValueError(
|
139
|
-
f"Must provide the same number of `
|
215
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
140
216
|
)
|
141
217
|
|
142
218
|
# input
|
@@ -185,6 +261,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
185
261
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
186
262
|
conditioning_embedding_channels=block_out_channels[0],
|
187
263
|
block_out_channels=conditioning_embedding_out_channels,
|
264
|
+
conditioning_channels=conditioning_channels,
|
188
265
|
)
|
189
266
|
|
190
267
|
self.down_blocks = nn.ModuleList([])
|
@@ -196,6 +273,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
196
273
|
if isinstance(attention_head_dim, int):
|
197
274
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
198
275
|
|
276
|
+
if isinstance(num_attention_heads, int):
|
277
|
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
278
|
+
|
199
279
|
# down
|
200
280
|
output_channel = block_out_channels[0]
|
201
281
|
|
@@ -219,7 +299,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
219
299
|
resnet_act_fn=act_fn,
|
220
300
|
resnet_groups=norm_num_groups,
|
221
301
|
cross_attention_dim=cross_attention_dim,
|
222
|
-
|
302
|
+
num_attention_heads=num_attention_heads[i],
|
303
|
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
223
304
|
downsample_padding=downsample_padding,
|
224
305
|
use_linear_projection=use_linear_projection,
|
225
306
|
only_cross_attention=only_cross_attention[i],
|
@@ -253,7 +334,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
253
334
|
output_scale_factor=mid_block_scale_factor,
|
254
335
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
255
336
|
cross_attention_dim=cross_attention_dim,
|
256
|
-
|
337
|
+
num_attention_heads=num_attention_heads[-1],
|
257
338
|
resnet_groups=norm_num_groups,
|
258
339
|
use_linear_projection=use_linear_projection,
|
259
340
|
upcast_attention=upcast_attention,
|
@@ -268,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
268
349
|
load_weights_from_unet: bool = True,
|
269
350
|
):
|
270
351
|
r"""
|
271
|
-
Instantiate
|
352
|
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
272
353
|
|
273
354
|
Parameters:
|
274
355
|
unet (`UNet2DConditionModel`):
|
275
|
-
UNet model
|
276
|
-
|
356
|
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
357
|
+
where applicable.
|
277
358
|
"""
|
278
359
|
controlnet = cls(
|
279
360
|
in_channels=unet.config.in_channels,
|
@@ -290,6 +371,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
290
371
|
norm_eps=unet.config.norm_eps,
|
291
372
|
cross_attention_dim=unet.config.cross_attention_dim,
|
292
373
|
attention_head_dim=unet.config.attention_head_dim,
|
374
|
+
num_attention_heads=unet.config.num_attention_heads,
|
293
375
|
use_linear_projection=unet.config.use_linear_projection,
|
294
376
|
class_embed_type=unet.config.class_embed_type,
|
295
377
|
num_class_embeds=unet.config.num_class_embeds,
|
@@ -341,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
341
423
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
342
424
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
343
425
|
r"""
|
426
|
+
Sets the attention processor to use to compute attention.
|
427
|
+
|
344
428
|
Parameters:
|
345
|
-
|
429
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
346
430
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
347
|
-
|
348
|
-
|
431
|
+
for **all** `Attention` layers.
|
432
|
+
|
433
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
434
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
349
435
|
|
350
436
|
"""
|
351
437
|
count = len(self.attn_processors.keys())
|
@@ -381,13 +467,13 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
381
467
|
r"""
|
382
468
|
Enable sliced attention computation.
|
383
469
|
|
384
|
-
When this option is enabled, the attention module
|
385
|
-
|
470
|
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
471
|
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
386
472
|
|
387
473
|
Args:
|
388
474
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
389
|
-
When `"auto"`,
|
390
|
-
`"max"`, maximum amount of memory
|
475
|
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
476
|
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
391
477
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
392
478
|
must be a multiple of `slice_size`.
|
393
479
|
"""
|
@@ -460,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
|
460
546
|
guess_mode: bool = False,
|
461
547
|
return_dict: bool = True,
|
462
548
|
) -> Union[ControlNetOutput, Tuple]:
|
549
|
+
"""
|
550
|
+
The [`ControlNetModel`] forward method.
|
551
|
+
|
552
|
+
Args:
|
553
|
+
sample (`torch.FloatTensor`):
|
554
|
+
The noisy input tensor.
|
555
|
+
timestep (`Union[torch.Tensor, float, int]`):
|
556
|
+
The number of timesteps to denoise an input.
|
557
|
+
encoder_hidden_states (`torch.Tensor`):
|
558
|
+
The encoder hidden states.
|
559
|
+
controlnet_cond (`torch.FloatTensor`):
|
560
|
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
561
|
+
conditioning_scale (`float`, defaults to `1.0`):
|
562
|
+
The scale factor for ControlNet outputs.
|
563
|
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
564
|
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
565
|
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
566
|
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
567
|
+
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
568
|
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
569
|
+
guess_mode (`bool`, defaults to `False`):
|
570
|
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
571
|
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
572
|
+
return_dict (`bool`, defaults to `True`):
|
573
|
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
574
|
+
|
575
|
+
Returns:
|
576
|
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
577
|
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
578
|
+
returned where the first element is the sample tensor.
|
579
|
+
"""
|
463
580
|
# check channel order
|
464
581
|
channel_order = self.config.controlnet_conditioning_channel_order
|
465
582
|
|