diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Tuple, Union
14
+ from typing import Optional, Tuple, Union
15
15
 
16
16
  import flax
17
17
  import flax.linen as nn
@@ -35,9 +35,11 @@ from .unet_2d_blocks_flax import (
35
35
  @flax.struct.dataclass
36
36
  class FlaxUNet2DConditionOutput(BaseOutput):
37
37
  """
38
+ The output of [`FlaxUNet2DConditionModel`].
39
+
38
40
  Args:
39
41
  sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
40
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
42
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
41
43
  """
42
44
 
43
45
  sample: jnp.ndarray
@@ -46,17 +48,17 @@ class FlaxUNet2DConditionOutput(BaseOutput):
46
48
  @flax_register_to_config
47
49
  class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
48
50
  r"""
49
- FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
50
- timestep and returns sample shaped output.
51
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
52
+ shaped output.
51
53
 
52
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
53
- implements for all the models (such as downloading or saving, etc.)
54
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
55
+ implemented for all models (such as downloading or saving).
54
56
 
55
- Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
56
- subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
57
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
58
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
57
59
  general usage and behavior.
58
60
 
59
- Finally, this model supports inherent JAX features such as:
61
+ Inherent JAX features such as the following are supported:
60
62
  - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
61
63
  - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
62
64
  - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
@@ -69,18 +71,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
69
71
  The number of channels in the input sample.
70
72
  out_channels (`int`, *optional*, defaults to 4):
71
73
  The number of channels in the output.
72
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
73
- The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
74
- "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
75
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
76
- The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
77
- "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
74
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
75
+ The tuple of downsample blocks to use.
76
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
+ The tuple of upsample blocks to use.
78
78
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
79
79
  The tuple of output channels for each block.
80
80
  layers_per_block (`int`, *optional*, defaults to 2):
81
81
  The number of layers per block.
82
82
  attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
83
83
  The dimension of the attention heads.
84
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
85
+ The number of attention heads.
84
86
  cross_attention_dim (`int`, *optional*, defaults to 768):
85
87
  The dimension of the cross attention features.
86
88
  dropout (`float`, *optional*, defaults to 0):
@@ -89,8 +91,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
89
91
  Whether to flip the sin to cos in the time embedding.
90
92
  freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
91
93
  use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
92
- enable memory efficient attention https://arxiv.org/abs/2112.05682
93
-
94
+ Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
94
95
  """
95
96
 
96
97
  sample_size: int = 32
@@ -107,6 +108,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
107
108
  block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
108
109
  layers_per_block: int = 2
109
110
  attention_head_dim: Union[int, Tuple[int]] = 8
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
110
112
  cross_attention_dim: int = 1280
111
113
  dropout: float = 0.0
112
114
  use_linear_projection: bool = False
@@ -131,6 +133,19 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
131
133
  block_out_channels = self.block_out_channels
132
134
  time_embed_dim = block_out_channels[0] * 4
133
135
 
136
+ if self.num_attention_heads is not None:
137
+ raise ValueError(
138
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
139
+ )
140
+
141
+ # If `num_attention_heads` is not defined (which is the case for most models)
142
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
143
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
144
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
145
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
146
+ # which is why we correct for the naming here.
147
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
148
+
134
149
  # input
135
150
  self.conv_in = nn.Conv(
136
151
  block_out_channels[0],
@@ -150,9 +165,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
150
165
  if isinstance(only_cross_attention, bool):
151
166
  only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
152
167
 
153
- attention_head_dim = self.attention_head_dim
154
- if isinstance(attention_head_dim, int):
155
- attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
168
+ if isinstance(num_attention_heads, int):
169
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
156
170
 
157
171
  # down
158
172
  down_blocks = []
@@ -168,7 +182,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
168
182
  out_channels=output_channel,
169
183
  dropout=self.dropout,
170
184
  num_layers=self.layers_per_block,
171
- attn_num_head_channels=attention_head_dim[i],
185
+ num_attention_heads=num_attention_heads[i],
172
186
  add_downsample=not is_final_block,
173
187
  use_linear_projection=self.use_linear_projection,
174
188
  only_cross_attention=only_cross_attention[i],
@@ -192,7 +206,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
192
206
  self.mid_block = FlaxUNetMidBlock2DCrossAttn(
193
207
  in_channels=block_out_channels[-1],
194
208
  dropout=self.dropout,
195
- attn_num_head_channels=attention_head_dim[-1],
209
+ num_attention_heads=num_attention_heads[-1],
196
210
  use_linear_projection=self.use_linear_projection,
197
211
  use_memory_efficient_attention=self.use_memory_efficient_attention,
198
212
  dtype=self.dtype,
@@ -201,7 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
201
215
  # up
202
216
  up_blocks = []
203
217
  reversed_block_out_channels = list(reversed(block_out_channels))
204
- reversed_attention_head_dim = list(reversed(attention_head_dim))
218
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
205
219
  only_cross_attention = list(reversed(only_cross_attention))
206
220
  output_channel = reversed_block_out_channels[0]
207
221
  for i, up_block_type in enumerate(self.up_block_types):
@@ -217,7 +231,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
217
231
  out_channels=output_channel,
218
232
  prev_output_channel=prev_output_channel,
219
233
  num_layers=self.layers_per_block + 1,
220
- attn_num_head_channels=reversed_attention_head_dim[i],
234
+ num_attention_heads=reversed_num_attention_heads[i],
221
235
  add_upsample=not is_final_block,
222
236
  dropout=self.dropout,
223
237
  use_linear_projection=self.use_linear_projection,
@@ -29,7 +29,7 @@ def get_down_block(
29
29
  add_downsample,
30
30
  resnet_eps,
31
31
  resnet_act_fn,
32
- attn_num_head_channels,
32
+ num_attention_heads,
33
33
  resnet_groups=None,
34
34
  cross_attention_dim=None,
35
35
  downsample_padding=None,
@@ -66,7 +66,7 @@ def get_down_block(
66
66
  resnet_groups=resnet_groups,
67
67
  downsample_padding=downsample_padding,
68
68
  cross_attention_dim=cross_attention_dim,
69
- attn_num_head_channels=attn_num_head_channels,
69
+ num_attention_heads=num_attention_heads,
70
70
  dual_cross_attention=dual_cross_attention,
71
71
  use_linear_projection=use_linear_projection,
72
72
  only_cross_attention=only_cross_attention,
@@ -86,7 +86,7 @@ def get_up_block(
86
86
  add_upsample,
87
87
  resnet_eps,
88
88
  resnet_act_fn,
89
- attn_num_head_channels,
89
+ num_attention_heads,
90
90
  resnet_groups=None,
91
91
  cross_attention_dim=None,
92
92
  dual_cross_attention=False,
@@ -122,7 +122,7 @@ def get_up_block(
122
122
  resnet_act_fn=resnet_act_fn,
123
123
  resnet_groups=resnet_groups,
124
124
  cross_attention_dim=cross_attention_dim,
125
- attn_num_head_channels=attn_num_head_channels,
125
+ num_attention_heads=num_attention_heads,
126
126
  dual_cross_attention=dual_cross_attention,
127
127
  use_linear_projection=use_linear_projection,
128
128
  only_cross_attention=only_cross_attention,
@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
144
144
  resnet_act_fn: str = "swish",
145
145
  resnet_groups: int = 32,
146
146
  resnet_pre_norm: bool = True,
147
- attn_num_head_channels=1,
147
+ num_attention_heads=1,
148
148
  output_scale_factor=1.0,
149
149
  cross_attention_dim=1280,
150
150
  dual_cross_attention=False,
@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
154
154
  super().__init__()
155
155
 
156
156
  self.has_cross_attention = True
157
- self.attn_num_head_channels = attn_num_head_channels
157
+ self.num_attention_heads = num_attention_heads
158
158
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
159
159
 
160
160
  # there is always at least one resnet
@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
185
185
  for _ in range(num_layers):
186
186
  attentions.append(
187
187
  Transformer2DModel(
188
- in_channels // attn_num_head_channels,
189
- attn_num_head_channels,
188
+ in_channels // num_attention_heads,
189
+ num_attention_heads,
190
190
  in_channels=in_channels,
191
191
  num_layers=1,
192
192
  cross_attention_dim=cross_attention_dim,
@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
197
197
  )
198
198
  temp_attentions.append(
199
199
  TransformerTemporalModel(
200
- in_channels // attn_num_head_channels,
201
- attn_num_head_channels,
200
+ in_channels // num_attention_heads,
201
+ num_attention_heads,
202
202
  in_channels=in_channels,
203
203
  num_layers=1,
204
204
  cross_attention_dim=cross_attention_dim,
@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module):
250
250
  hidden_states,
251
251
  encoder_hidden_states=encoder_hidden_states,
252
252
  cross_attention_kwargs=cross_attention_kwargs,
253
- ).sample
253
+ return_dict=False,
254
+ )[0]
254
255
  hidden_states = temp_attn(
255
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
256
- ).sample
256
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
257
+ )[0]
257
258
  hidden_states = resnet(hidden_states, temb)
258
259
  hidden_states = temp_conv(hidden_states, num_frames=num_frames)
259
260
 
@@ -273,7 +274,7 @@ class CrossAttnDownBlock3D(nn.Module):
273
274
  resnet_act_fn: str = "swish",
274
275
  resnet_groups: int = 32,
275
276
  resnet_pre_norm: bool = True,
276
- attn_num_head_channels=1,
277
+ num_attention_heads=1,
277
278
  cross_attention_dim=1280,
278
279
  output_scale_factor=1.0,
279
280
  downsample_padding=1,
@@ -290,7 +291,7 @@ class CrossAttnDownBlock3D(nn.Module):
290
291
  temp_convs = []
291
292
 
292
293
  self.has_cross_attention = True
293
- self.attn_num_head_channels = attn_num_head_channels
294
+ self.num_attention_heads = num_attention_heads
294
295
 
295
296
  for i in range(num_layers):
296
297
  in_channels = in_channels if i == 0 else out_channels
@@ -317,8 +318,8 @@ class CrossAttnDownBlock3D(nn.Module):
317
318
  )
318
319
  attentions.append(
319
320
  Transformer2DModel(
320
- out_channels // attn_num_head_channels,
321
- attn_num_head_channels,
321
+ out_channels // num_attention_heads,
322
+ num_attention_heads,
322
323
  in_channels=out_channels,
323
324
  num_layers=1,
324
325
  cross_attention_dim=cross_attention_dim,
@@ -330,8 +331,8 @@ class CrossAttnDownBlock3D(nn.Module):
330
331
  )
331
332
  temp_attentions.append(
332
333
  TransformerTemporalModel(
333
- out_channels // attn_num_head_channels,
334
- attn_num_head_channels,
334
+ out_channels // num_attention_heads,
335
+ num_attention_heads,
335
336
  in_channels=out_channels,
336
337
  num_layers=1,
337
338
  cross_attention_dim=cross_attention_dim,
@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module):
377
378
  hidden_states,
378
379
  encoder_hidden_states=encoder_hidden_states,
379
380
  cross_attention_kwargs=cross_attention_kwargs,
380
- ).sample
381
+ return_dict=False,
382
+ )[0]
381
383
  hidden_states = temp_attn(
382
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
383
- ).sample
384
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
385
+ )[0]
384
386
 
385
387
  output_states += (hidden_states,)
386
388
 
@@ -486,7 +488,7 @@ class CrossAttnUpBlock3D(nn.Module):
486
488
  resnet_act_fn: str = "swish",
487
489
  resnet_groups: int = 32,
488
490
  resnet_pre_norm: bool = True,
489
- attn_num_head_channels=1,
491
+ num_attention_heads=1,
490
492
  cross_attention_dim=1280,
491
493
  output_scale_factor=1.0,
492
494
  add_upsample=True,
@@ -502,7 +504,7 @@ class CrossAttnUpBlock3D(nn.Module):
502
504
  temp_attentions = []
503
505
 
504
506
  self.has_cross_attention = True
505
- self.attn_num_head_channels = attn_num_head_channels
507
+ self.num_attention_heads = num_attention_heads
506
508
 
507
509
  for i in range(num_layers):
508
510
  res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -531,8 +533,8 @@ class CrossAttnUpBlock3D(nn.Module):
531
533
  )
532
534
  attentions.append(
533
535
  Transformer2DModel(
534
- out_channels // attn_num_head_channels,
535
- attn_num_head_channels,
536
+ out_channels // num_attention_heads,
537
+ num_attention_heads,
536
538
  in_channels=out_channels,
537
539
  num_layers=1,
538
540
  cross_attention_dim=cross_attention_dim,
@@ -544,8 +546,8 @@ class CrossAttnUpBlock3D(nn.Module):
544
546
  )
545
547
  temp_attentions.append(
546
548
  TransformerTemporalModel(
547
- out_channels // attn_num_head_channels,
548
- attn_num_head_channels,
549
+ out_channels // num_attention_heads,
550
+ num_attention_heads,
549
551
  in_channels=out_channels,
550
552
  num_layers=1,
551
553
  cross_attention_dim=cross_attention_dim,
@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module):
590
592
  hidden_states,
591
593
  encoder_hidden_states=encoder_hidden_states,
592
594
  cross_attention_kwargs=cross_attention_kwargs,
593
- ).sample
595
+ return_dict=False,
596
+ )[0]
594
597
  hidden_states = temp_attn(
595
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
596
- ).sample
598
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
599
+ )[0]
597
600
 
598
601
  if self.upsamplers is not None:
599
602
  for upsampler in self.upsamplers:
@@ -43,9 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
43
  @dataclass
44
44
  class UNet3DConditionOutput(BaseOutput):
45
45
  """
46
+ The output of [`UNet3DConditionModel`].
47
+
46
48
  Args:
47
49
  sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
48
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
50
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
49
51
  """
50
52
 
51
53
  sample: torch.FloatTensor
@@ -53,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput):
53
55
 
54
56
  class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55
57
  r"""
56
- UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
57
- and returns sample shaped output.
58
+ A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
59
+ shaped output.
58
60
 
59
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
60
- implements for all the models (such as downloading or saving, etc.)
61
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
62
+ for all models (such as downloading or saving).
61
63
 
62
64
  Parameters:
63
65
  sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
@@ -66,7 +68,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
66
68
  out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
67
69
  down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
70
  The tuple of downsample blocks to use.
69
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
71
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
70
72
  The tuple of upsample blocks to use.
71
73
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
74
  The tuple of output channels for each block.
@@ -75,10 +77,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
75
77
  mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
78
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
79
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
- If `None`, it will skip the normalization and activation layers in post-processing
80
+ If `None`, normalization and activation layers is skipped in post-processing.
79
81
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
80
82
  cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
81
83
  attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
84
+ num_attention_heads (`int`, *optional*): The number of attention heads.
82
85
  """
83
86
 
84
87
  _supports_gradient_checkpointing = False
@@ -105,11 +108,25 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
105
108
  norm_eps: float = 1e-5,
106
109
  cross_attention_dim: int = 1024,
107
110
  attention_head_dim: Union[int, Tuple[int]] = 64,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
108
112
  ):
109
113
  super().__init__()
110
114
 
111
115
  self.sample_size = sample_size
112
116
 
117
+ if num_attention_heads is not None:
118
+ raise NotImplementedError(
119
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
120
+ )
121
+
122
+ # If `num_attention_heads` is not defined (which is the case for most models)
123
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
124
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
125
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
126
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
127
+ # which is why we correct for the naming here.
128
+ num_attention_heads = num_attention_heads or attention_head_dim
129
+
113
130
  # Check inputs
114
131
  if len(down_block_types) != len(up_block_types):
115
132
  raise ValueError(
@@ -121,9 +138,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
121
138
  f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
122
139
  )
123
140
 
124
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
141
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
125
142
  raise ValueError(
126
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
143
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
127
144
  )
128
145
 
129
146
  # input
@@ -156,8 +173,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
156
173
  self.down_blocks = nn.ModuleList([])
157
174
  self.up_blocks = nn.ModuleList([])
158
175
 
159
- if isinstance(attention_head_dim, int):
160
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
176
+ if isinstance(num_attention_heads, int):
177
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
161
178
 
162
179
  # down
163
180
  output_channel = block_out_channels[0]
@@ -177,7 +194,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
177
194
  resnet_act_fn=act_fn,
178
195
  resnet_groups=norm_num_groups,
179
196
  cross_attention_dim=cross_attention_dim,
180
- attn_num_head_channels=attention_head_dim[i],
197
+ num_attention_heads=num_attention_heads[i],
181
198
  downsample_padding=downsample_padding,
182
199
  dual_cross_attention=False,
183
200
  )
@@ -191,7 +208,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
191
208
  resnet_act_fn=act_fn,
192
209
  output_scale_factor=mid_block_scale_factor,
193
210
  cross_attention_dim=cross_attention_dim,
194
- attn_num_head_channels=attention_head_dim[-1],
211
+ num_attention_heads=num_attention_heads[-1],
195
212
  resnet_groups=norm_num_groups,
196
213
  dual_cross_attention=False,
197
214
  )
@@ -201,7 +218,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
201
218
 
202
219
  # up
203
220
  reversed_block_out_channels = list(reversed(block_out_channels))
204
- reversed_attention_head_dim = list(reversed(attention_head_dim))
221
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
205
222
 
206
223
  output_channel = reversed_block_out_channels[0]
207
224
  for i, up_block_type in enumerate(up_block_types):
@@ -230,7 +247,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
230
247
  resnet_act_fn=act_fn,
231
248
  resnet_groups=norm_num_groups,
232
249
  cross_attention_dim=cross_attention_dim,
233
- attn_num_head_channels=reversed_attention_head_dim[i],
250
+ num_attention_heads=reversed_num_attention_heads[i],
234
251
  dual_cross_attention=False,
235
252
  )
236
253
  self.up_blocks.append(up_block)
@@ -281,13 +298,13 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
281
298
  r"""
282
299
  Enable sliced attention computation.
283
300
 
284
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
285
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
301
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
302
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
286
303
 
287
304
  Args:
288
305
  slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
289
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
290
- `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
306
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
307
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
291
308
  provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
292
309
  must be a multiple of `slice_size`.
293
310
  """
@@ -345,11 +362,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
345
362
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
346
363
  def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
347
364
  r"""
365
+ Sets the attention processor to use to compute attention.
366
+
348
367
  Parameters:
349
- `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
368
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
350
369
  The instantiated processor class or a dictionary of processor classes that will be set as the processor
351
- of **all** `Attention` layers.
352
- In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
370
+ for **all** `Attention` layers.
371
+
372
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
373
+ processor. This is strongly recommended when setting trainable attention processors.
353
374
 
354
375
  """
355
376
  count = len(self.attn_processors.keys())
@@ -373,6 +394,46 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
373
394
  for name, module in self.named_children():
374
395
  fn_recursive_attn_processor(name, module, processor)
375
396
 
397
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
398
+ """
399
+ Sets the attention processor to use [feed forward
400
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
401
+
402
+ Parameters:
403
+ chunk_size (`int`, *optional*):
404
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
405
+ over each tensor of dim=`dim`.
406
+ dim (`int`, *optional*, defaults to `0`):
407
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
408
+ or dim=1 (sequence length).
409
+ """
410
+ if dim not in [0, 1]:
411
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
412
+
413
+ # By default chunk size is 1
414
+ chunk_size = chunk_size or 1
415
+
416
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
417
+ if hasattr(module, "set_chunk_feed_forward"):
418
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
419
+
420
+ for child in module.children():
421
+ fn_recursive_feed_forward(child, chunk_size, dim)
422
+
423
+ for module in self.children():
424
+ fn_recursive_feed_forward(module, chunk_size, dim)
425
+
426
+ def disable_forward_chunking(self):
427
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
428
+ if hasattr(module, "set_chunk_feed_forward"):
429
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
430
+
431
+ for child in module.children():
432
+ fn_recursive_feed_forward(child, chunk_size, dim)
433
+
434
+ for module in self.children():
435
+ fn_recursive_feed_forward(module, None, 0)
436
+
376
437
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
377
438
  def set_default_attn_processor(self):
378
439
  """
@@ -398,21 +459,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
398
459
  return_dict: bool = True,
399
460
  ) -> Union[UNet3DConditionOutput, Tuple]:
400
461
  r"""
462
+ The [`UNet3DConditionModel`] forward method.
463
+
401
464
  Args:
402
- sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
403
- timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
404
- encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
465
+ sample (`torch.FloatTensor`):
466
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
467
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
468
+ encoder_hidden_states (`torch.FloatTensor`):
469
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
405
470
  return_dict (`bool`, *optional*, defaults to `True`):
406
- Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
471
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
472
+ tuple.
407
473
  cross_attention_kwargs (`dict`, *optional*):
408
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
409
- `self.processor` in
410
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
474
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
411
475
 
412
476
  Returns:
413
- [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
414
- [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
415
- returning a tuple, the first element is the sample tensor.
477
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
478
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
479
+ a `tuple` is returned where the first element is the sample tensor.
416
480
  """
417
481
  # By default samples have to be AT least a multiple of the overall upsampling factor.
418
482
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
@@ -467,8 +531,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
467
531
  sample = self.conv_in(sample)
468
532
 
469
533
  sample = self.transformer_in(
470
- sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
471
- ).sample
534
+ sample,
535
+ num_frames=num_frames,
536
+ cross_attention_kwargs=cross_attention_kwargs,
537
+ return_dict=False,
538
+ )[0]
472
539
 
473
540
  # 3. down
474
541
  down_block_res_samples = (sample,)