diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -119,6 +119,15 @@ class BasicTransformerBlock(nn.Module):
119
119
  self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
120
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
121
 
122
+ # let chunk size default to None
123
+ self._chunk_size = None
124
+ self._chunk_dim = 0
125
+
126
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
+ # Sets chunk feed-forward
128
+ self._chunk_size = chunk_size
129
+ self._chunk_dim = dim
130
+
122
131
  def forward(
123
132
  self,
124
133
  hidden_states: torch.FloatTensor,
@@ -141,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
141
150
  norm_hidden_states = self.norm1(hidden_states)
142
151
 
143
152
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
153
+
144
154
  attn_output = self.attn1(
145
155
  norm_hidden_states,
146
156
  encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
@@ -171,7 +181,20 @@ class BasicTransformerBlock(nn.Module):
171
181
  if self.use_ada_layer_norm_zero:
172
182
  norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
173
183
 
174
- ff_output = self.ff(norm_hidden_states)
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
194
+ dim=self._chunk_dim,
195
+ )
196
+ else:
197
+ ff_output = self.ff(norm_hidden_states)
175
198
 
176
199
  if self.use_ada_layer_norm_zero:
177
200
  ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -152,6 +152,7 @@ class FlaxAttention(nn.Module):
152
152
  self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153
153
 
154
154
  self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
155
156
 
156
157
  def reshape_heads_to_batch_dim(self, tensor):
157
158
  batch_size, seq_len, dim = tensor.shape
@@ -214,7 +215,7 @@ class FlaxAttention(nn.Module):
214
215
 
215
216
  hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
216
217
  hidden_states = self.proj_attn(hidden_states)
217
- return hidden_states
218
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
218
219
 
219
220
 
220
221
  class FlaxBasicTransformerBlock(nn.Module):
@@ -260,6 +261,7 @@ class FlaxBasicTransformerBlock(nn.Module):
260
261
  self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
261
262
  self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262
263
  self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
263
265
 
264
266
  def __call__(self, hidden_states, context, deterministic=True):
265
267
  # self attention
@@ -280,7 +282,7 @@ class FlaxBasicTransformerBlock(nn.Module):
280
282
  hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
281
283
  hidden_states = hidden_states + residual
282
284
 
283
- return hidden_states
285
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
284
286
 
285
287
 
286
288
  class FlaxTransformer2DModel(nn.Module):
@@ -356,6 +358,8 @@ class FlaxTransformer2DModel(nn.Module):
356
358
  dtype=self.dtype,
357
359
  )
358
360
 
361
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
362
+
359
363
  def __call__(self, hidden_states, context, deterministic=True):
360
364
  batch, height, width, channels = hidden_states.shape
361
365
  residual = hidden_states
@@ -378,7 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
378
382
  hidden_states = self.proj_out(hidden_states)
379
383
 
380
384
  hidden_states = hidden_states + residual
381
- return hidden_states
385
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
382
386
 
383
387
 
384
388
  class FlaxFeedForward(nn.Module):
@@ -409,7 +413,7 @@ class FlaxFeedForward(nn.Module):
409
413
  self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
410
414
 
411
415
  def __call__(self, hidden_states, deterministic=True):
412
- hidden_states = self.net_0(hidden_states)
416
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
413
417
  hidden_states = self.net_2(hidden_states)
414
418
  return hidden_states
415
419
 
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
434
438
  def setup(self):
435
439
  inner_dim = self.dim * 4
436
440
  self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
437
442
 
438
443
  def __call__(self, hidden_states, deterministic=True):
439
444
  hidden_states = self.proj(hidden_states)
440
445
  hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
441
- return hidden_linear * nn.gelu(hidden_gelu)
446
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
@@ -78,6 +78,7 @@ class Attention(nn.Module):
78
78
  self.upcast_softmax = upcast_softmax
79
79
  self.rescale_output_factor = rescale_output_factor
80
80
  self.residual_connection = residual_connection
81
+ self.dropout = dropout
81
82
 
82
83
  # we make use of this private variable to know whether this class is loaded
83
84
  # with an deprecated state dict so that we can convert it on the fly
@@ -1117,7 +1118,9 @@ class AttnProcessor2_0:
1117
1118
  value = attn.to_v(encoder_hidden_states)
1118
1119
 
1119
1120
  head_dim = inner_dim // attn.heads
1121
+
1120
1122
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1123
+
1121
1124
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1122
1125
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1123
1126
 
@@ -12,13 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
15
+ from typing import Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
19
19
 
20
20
  from ..configuration_utils import ConfigMixin, register_to_config
21
21
  from ..utils import BaseOutput, apply_forward_hook
22
+ from .attention_processor import AttentionProcessor, AttnProcessor
22
23
  from .modeling_utils import ModelMixin
23
24
  from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
24
25
 
@@ -38,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput):
38
39
 
39
40
 
40
41
  class AutoencoderKL(ModelMixin, ConfigMixin):
41
- r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
42
- and Max Welling.
42
+ r"""
43
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
43
44
 
44
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
45
- implements for all the model (such as downloading or saving, etc.)
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
46
47
 
47
48
  Parameters:
48
49
  in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
50
  out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
- down_block_types (`Tuple[str]`, *optional*, defaults to :
51
- obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
52
- up_block_types (`Tuple[str]`, *optional*, defaults to :
53
- obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
54
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
55
- obj:`(64,)`): Tuple of block output channels.
51
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
52
+ Tuple of downsample block types.
53
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
54
+ Tuple of upsample block types.
55
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
56
+ Tuple of block output channels.
56
57
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
57
58
  latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
58
- sample_size (`int`, *optional*, defaults to `32`): TODO
59
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
59
60
  scaling_factor (`float`, *optional*, defaults to 0.18215):
60
61
  The component-wise standard deviation of the trained latent space computed using the first batch of the
61
62
  training set. This is used to scale the latent space to have unit variance when training the diffusion
@@ -130,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
130
131
  def enable_tiling(self, use_tiling: bool = True):
131
132
  r"""
132
133
  Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
133
- compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
134
- the processing of larger images.
134
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
135
+ processing larger images.
135
136
  """
136
137
  self.use_tiling = use_tiling
137
138
 
138
139
  def disable_tiling(self):
139
140
  r"""
140
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
141
- computing decoding in one step.
141
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
142
+ decoding in one step.
142
143
  """
143
144
  self.enable_tiling(False)
144
145
 
@@ -151,17 +152,89 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
151
152
 
152
153
  def disable_slicing(self):
153
154
  r"""
154
- Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
155
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
155
156
  decoding in one step.
156
157
  """
157
158
  self.use_slicing = False
158
159
 
160
+ @property
161
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
162
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
163
+ r"""
164
+ Returns:
165
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
166
+ indexed by its weight name.
167
+ """
168
+ # set recursively
169
+ processors = {}
170
+
171
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
172
+ if hasattr(module, "set_processor"):
173
+ processors[f"{name}.processor"] = module.processor
174
+
175
+ for sub_name, child in module.named_children():
176
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
177
+
178
+ return processors
179
+
180
+ for name, module in self.named_children():
181
+ fn_recursive_add_processors(name, module, processors)
182
+
183
+ return processors
184
+
185
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
186
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
187
+ r"""
188
+ Sets the attention processor to use to compute attention.
189
+
190
+ Parameters:
191
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
192
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
193
+ for **all** `Attention` layers.
194
+
195
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
196
+ processor. This is strongly recommended when setting trainable attention processors.
197
+
198
+ """
199
+ count = len(self.attn_processors.keys())
200
+
201
+ if isinstance(processor, dict) and len(processor) != count:
202
+ raise ValueError(
203
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
204
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
205
+ )
206
+
207
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
208
+ if hasattr(module, "set_processor"):
209
+ if not isinstance(processor, dict):
210
+ module.set_processor(processor)
211
+ else:
212
+ module.set_processor(processor.pop(f"{name}.processor"))
213
+
214
+ for sub_name, child in module.named_children():
215
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
216
+
217
+ for name, module in self.named_children():
218
+ fn_recursive_attn_processor(name, module, processor)
219
+
220
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
221
+ def set_default_attn_processor(self):
222
+ """
223
+ Disables custom attention processors and sets the default attention implementation.
224
+ """
225
+ self.set_attn_processor(AttnProcessor())
226
+
159
227
  @apply_forward_hook
160
228
  def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
161
229
  if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
162
230
  return self.tiled_encode(x, return_dict=return_dict)
163
231
 
164
- h = self.encoder(x)
232
+ if self.use_slicing and x.shape[0] > 1:
233
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
234
+ h = torch.cat(encoded_slices)
235
+ else:
236
+ h = self.encoder(x)
237
+
165
238
  moments = self.quant_conv(h)
166
239
  posterior = DiagonalGaussianDistribution(moments)
167
240
 
@@ -210,14 +283,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
210
283
  def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
211
284
  r"""Encode a batch of images using a tiled encoder.
212
285
 
213
- Args:
214
286
  When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
215
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
216
- different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
287
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
288
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
217
289
  tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
218
- look of the output, but they should be much less noticeable.
219
- x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
220
- Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
290
+ output, but they should be much less noticeable.
291
+
292
+ Args:
293
+ x (`torch.FloatTensor`): Input batch of images.
294
+ return_dict (`bool`, *optional*, defaults to `True`):
295
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
296
+
297
+ Returns:
298
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
299
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
300
+ `tuple` is returned.
221
301
  """
222
302
  overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
223
303
  blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
@@ -255,17 +335,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
255
335
  return AutoencoderKLOutput(latent_dist=posterior)
256
336
 
257
337
  def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
258
- r"""Decode a batch of images using a tiled decoder.
338
+ r"""
339
+ Decode a batch of images using a tiled decoder.
259
340
 
260
341
  Args:
261
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
262
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
263
- different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
264
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
265
- look of the output, but they should be much less noticeable.
266
- z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
267
- `True`):
268
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
342
+ z (`torch.FloatTensor`): Input batch of latent vectors.
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
345
+
346
+ Returns:
347
+ [`~models.vae.DecoderOutput`] or `tuple`:
348
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
349
+ returned.
269
350
  """
270
351
  overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
271
352
  blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
@@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
37
 
38
38
  @dataclass
39
39
  class ControlNetOutput(BaseOutput):
40
+ """
41
+ The output of [`ControlNetModel`].
42
+
43
+ Args:
44
+ down_block_res_samples (`tuple[torch.Tensor]`):
45
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
46
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
47
+ used to condition the original UNet's downsampling activations.
48
+ mid_down_block_re_sample (`torch.Tensor`):
49
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
50
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
51
+ Output can be used to condition the original UNet's middle block activation.
52
+ """
53
+
40
54
  down_block_res_samples: Tuple[torch.Tensor]
41
55
  mid_block_res_sample: torch.Tensor
42
56
 
@@ -87,12 +101,65 @@ class ControlNetConditioningEmbedding(nn.Module):
87
101
 
88
102
 
89
103
  class ControlNetModel(ModelMixin, ConfigMixin):
104
+ """
105
+ A ControlNet model.
106
+
107
+ Args:
108
+ in_channels (`int`, defaults to 4):
109
+ The number of channels in the input sample.
110
+ flip_sin_to_cos (`bool`, defaults to `True`):
111
+ Whether to flip the sin to cos in the time embedding.
112
+ freq_shift (`int`, defaults to 0):
113
+ The frequency shift to apply to the time embedding.
114
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
115
+ The tuple of downsample blocks to use.
116
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
117
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
118
+ The tuple of output channels for each block.
119
+ layers_per_block (`int`, defaults to 2):
120
+ The number of layers per block.
121
+ downsample_padding (`int`, defaults to 1):
122
+ The padding to use for the downsampling convolution.
123
+ mid_block_scale_factor (`float`, defaults to 1):
124
+ The scale factor to use for the mid block.
125
+ act_fn (`str`, defaults to "silu"):
126
+ The activation function to use.
127
+ norm_num_groups (`int`, *optional*, defaults to 32):
128
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
129
+ in post-processing.
130
+ norm_eps (`float`, defaults to 1e-5):
131
+ The epsilon to use for the normalization.
132
+ cross_attention_dim (`int`, defaults to 1280):
133
+ The dimension of the cross attention features.
134
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
135
+ The dimension of the attention heads.
136
+ use_linear_projection (`bool`, defaults to `False`):
137
+ class_embed_type (`str`, *optional*, defaults to `None`):
138
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
139
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140
+ num_class_embeds (`int`, *optional*, defaults to 0):
141
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
142
+ class conditioning with `class_embed_type` equal to `None`.
143
+ upcast_attention (`bool`, defaults to `False`):
144
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
145
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
146
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
147
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
148
+ `class_embed_type="projection"`.
149
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
150
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
151
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
152
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
153
+ global_pool_conditions (`bool`, defaults to `False`):
154
+ """
155
+
90
156
  _supports_gradient_checkpointing = True
91
157
 
92
158
  @register_to_config
93
159
  def __init__(
94
160
  self,
95
161
  in_channels: int = 4,
162
+ conditioning_channels: int = 3,
96
163
  flip_sin_to_cos: bool = True,
97
164
  freq_shift: int = 0,
98
165
  down_block_types: Tuple[str] = (
@@ -111,6 +178,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
111
178
  norm_eps: float = 1e-5,
112
179
  cross_attention_dim: int = 1280,
113
180
  attention_head_dim: Union[int, Tuple[int]] = 8,
181
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
114
182
  use_linear_projection: bool = False,
115
183
  class_embed_type: Optional[str] = None,
116
184
  num_class_embeds: Optional[int] = None,
@@ -123,6 +191,14 @@ class ControlNetModel(ModelMixin, ConfigMixin):
123
191
  ):
124
192
  super().__init__()
125
193
 
194
+ # If `num_attention_heads` is not defined (which is the case for most models)
195
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
196
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
197
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
198
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
199
+ # which is why we correct for the naming here.
200
+ num_attention_heads = num_attention_heads or attention_head_dim
201
+
126
202
  # Check inputs
127
203
  if len(block_out_channels) != len(down_block_types):
128
204
  raise ValueError(
@@ -134,9 +210,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
134
210
  f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
135
211
  )
136
212
 
137
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
213
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
138
214
  raise ValueError(
139
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
215
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
140
216
  )
141
217
 
142
218
  # input
@@ -185,6 +261,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
185
261
  self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
186
262
  conditioning_embedding_channels=block_out_channels[0],
187
263
  block_out_channels=conditioning_embedding_out_channels,
264
+ conditioning_channels=conditioning_channels,
188
265
  )
189
266
 
190
267
  self.down_blocks = nn.ModuleList([])
@@ -196,6 +273,9 @@ class ControlNetModel(ModelMixin, ConfigMixin):
196
273
  if isinstance(attention_head_dim, int):
197
274
  attention_head_dim = (attention_head_dim,) * len(down_block_types)
198
275
 
276
+ if isinstance(num_attention_heads, int):
277
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
278
+
199
279
  # down
200
280
  output_channel = block_out_channels[0]
201
281
 
@@ -219,7 +299,8 @@ class ControlNetModel(ModelMixin, ConfigMixin):
219
299
  resnet_act_fn=act_fn,
220
300
  resnet_groups=norm_num_groups,
221
301
  cross_attention_dim=cross_attention_dim,
222
- attn_num_head_channels=attention_head_dim[i],
302
+ num_attention_heads=num_attention_heads[i],
303
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
223
304
  downsample_padding=downsample_padding,
224
305
  use_linear_projection=use_linear_projection,
225
306
  only_cross_attention=only_cross_attention[i],
@@ -253,7 +334,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
253
334
  output_scale_factor=mid_block_scale_factor,
254
335
  resnet_time_scale_shift=resnet_time_scale_shift,
255
336
  cross_attention_dim=cross_attention_dim,
256
- attn_num_head_channels=attention_head_dim[-1],
337
+ num_attention_heads=num_attention_heads[-1],
257
338
  resnet_groups=norm_num_groups,
258
339
  use_linear_projection=use_linear_projection,
259
340
  upcast_attention=upcast_attention,
@@ -268,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin):
268
349
  load_weights_from_unet: bool = True,
269
350
  ):
270
351
  r"""
271
- Instantiate Controlnet class from UNet2DConditionModel.
352
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
272
353
 
273
354
  Parameters:
274
355
  unet (`UNet2DConditionModel`):
275
- UNet model which weights are copied to the ControlNet. Note that all configuration options are also
276
- copied where applicable.
356
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
357
+ where applicable.
277
358
  """
278
359
  controlnet = cls(
279
360
  in_channels=unet.config.in_channels,
@@ -290,6 +371,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
290
371
  norm_eps=unet.config.norm_eps,
291
372
  cross_attention_dim=unet.config.cross_attention_dim,
292
373
  attention_head_dim=unet.config.attention_head_dim,
374
+ num_attention_heads=unet.config.num_attention_heads,
293
375
  use_linear_projection=unet.config.use_linear_projection,
294
376
  class_embed_type=unet.config.class_embed_type,
295
377
  num_class_embeds=unet.config.num_class_embeds,
@@ -341,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin):
341
423
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
342
424
  def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
343
425
  r"""
426
+ Sets the attention processor to use to compute attention.
427
+
344
428
  Parameters:
345
- `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
429
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
346
430
  The instantiated processor class or a dictionary of processor classes that will be set as the processor
347
- of **all** `Attention` layers.
348
- In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
431
+ for **all** `Attention` layers.
432
+
433
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
434
+ processor. This is strongly recommended when setting trainable attention processors.
349
435
 
350
436
  """
351
437
  count = len(self.attn_processors.keys())
@@ -381,13 +467,13 @@ class ControlNetModel(ModelMixin, ConfigMixin):
381
467
  r"""
382
468
  Enable sliced attention computation.
383
469
 
384
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
385
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
470
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
471
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
386
472
 
387
473
  Args:
388
474
  slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
389
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
390
- `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
475
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
476
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
391
477
  provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
392
478
  must be a multiple of `slice_size`.
393
479
  """
@@ -460,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin):
460
546
  guess_mode: bool = False,
461
547
  return_dict: bool = True,
462
548
  ) -> Union[ControlNetOutput, Tuple]:
549
+ """
550
+ The [`ControlNetModel`] forward method.
551
+
552
+ Args:
553
+ sample (`torch.FloatTensor`):
554
+ The noisy input tensor.
555
+ timestep (`Union[torch.Tensor, float, int]`):
556
+ The number of timesteps to denoise an input.
557
+ encoder_hidden_states (`torch.Tensor`):
558
+ The encoder hidden states.
559
+ controlnet_cond (`torch.FloatTensor`):
560
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
561
+ conditioning_scale (`float`, defaults to `1.0`):
562
+ The scale factor for ControlNet outputs.
563
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
564
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
565
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
566
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
567
+ cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
568
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
569
+ guess_mode (`bool`, defaults to `False`):
570
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
571
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
572
+ return_dict (`bool`, defaults to `True`):
573
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
574
+
575
+ Returns:
576
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
577
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
578
+ returned where the first element is the sample tensor.
579
+ """
463
580
  # check channel order
464
581
  channel_order = self.config.controlnet_conditioning_channel_order
465
582