diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from typing import Tuple, Union
|
14
|
+
from typing import Optional, Tuple, Union
|
15
15
|
|
16
16
|
import flax
|
17
17
|
import flax.linen as nn
|
@@ -35,9 +35,11 @@ from .unet_2d_blocks_flax import (
|
|
35
35
|
@flax.struct.dataclass
|
36
36
|
class FlaxUNet2DConditionOutput(BaseOutput):
|
37
37
|
"""
|
38
|
+
The output of [`FlaxUNet2DConditionModel`].
|
39
|
+
|
38
40
|
Args:
|
39
41
|
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
40
|
-
|
42
|
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
41
43
|
"""
|
42
44
|
|
43
45
|
sample: jnp.ndarray
|
@@ -46,17 +48,17 @@ class FlaxUNet2DConditionOutput(BaseOutput):
|
|
46
48
|
@flax_register_to_config
|
47
49
|
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
48
50
|
r"""
|
49
|
-
|
50
|
-
|
51
|
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
52
|
+
shaped output.
|
51
53
|
|
52
|
-
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for
|
53
|
-
|
54
|
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
|
55
|
+
implemented for all models (such as downloading or saving).
|
54
56
|
|
55
|
-
|
56
|
-
subclass. Use it as a regular Flax
|
57
|
+
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
58
|
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
57
59
|
general usage and behavior.
|
58
60
|
|
59
|
-
|
61
|
+
Inherent JAX features such as the following are supported:
|
60
62
|
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
61
63
|
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
62
64
|
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
@@ -69,18 +71,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
69
71
|
The number of channels in the input sample.
|
70
72
|
out_channels (`int`, *optional*, defaults to 4):
|
71
73
|
The number of channels in the output.
|
72
|
-
down_block_types (`Tuple[str]`, *optional*, defaults to `("
|
73
|
-
The tuple of downsample blocks to use.
|
74
|
-
|
75
|
-
|
76
|
-
The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
|
77
|
-
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
|
74
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
75
|
+
The tuple of downsample blocks to use.
|
76
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
77
|
+
The tuple of upsample blocks to use.
|
78
78
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
79
79
|
The tuple of output channels for each block.
|
80
80
|
layers_per_block (`int`, *optional*, defaults to 2):
|
81
81
|
The number of layers per block.
|
82
82
|
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
83
83
|
The dimension of the attention heads.
|
84
|
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
85
|
+
The number of attention heads.
|
84
86
|
cross_attention_dim (`int`, *optional*, defaults to 768):
|
85
87
|
The dimension of the cross attention features.
|
86
88
|
dropout (`float`, *optional*, defaults to 0):
|
@@ -89,8 +91,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
89
91
|
Whether to flip the sin to cos in the time embedding.
|
90
92
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
91
93
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
92
|
-
|
93
|
-
|
94
|
+
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
|
94
95
|
"""
|
95
96
|
|
96
97
|
sample_size: int = 32
|
@@ -107,6 +108,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
107
108
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
108
109
|
layers_per_block: int = 2
|
109
110
|
attention_head_dim: Union[int, Tuple[int]] = 8
|
111
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
110
112
|
cross_attention_dim: int = 1280
|
111
113
|
dropout: float = 0.0
|
112
114
|
use_linear_projection: bool = False
|
@@ -131,6 +133,19 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
131
133
|
block_out_channels = self.block_out_channels
|
132
134
|
time_embed_dim = block_out_channels[0] * 4
|
133
135
|
|
136
|
+
if self.num_attention_heads is not None:
|
137
|
+
raise ValueError(
|
138
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
139
|
+
)
|
140
|
+
|
141
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
142
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
143
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
144
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
145
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
146
|
+
# which is why we correct for the naming here.
|
147
|
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
148
|
+
|
134
149
|
# input
|
135
150
|
self.conv_in = nn.Conv(
|
136
151
|
block_out_channels[0],
|
@@ -150,9 +165,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
150
165
|
if isinstance(only_cross_attention, bool):
|
151
166
|
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
152
167
|
|
153
|
-
|
154
|
-
|
155
|
-
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
|
168
|
+
if isinstance(num_attention_heads, int):
|
169
|
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
156
170
|
|
157
171
|
# down
|
158
172
|
down_blocks = []
|
@@ -168,7 +182,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
168
182
|
out_channels=output_channel,
|
169
183
|
dropout=self.dropout,
|
170
184
|
num_layers=self.layers_per_block,
|
171
|
-
|
185
|
+
num_attention_heads=num_attention_heads[i],
|
172
186
|
add_downsample=not is_final_block,
|
173
187
|
use_linear_projection=self.use_linear_projection,
|
174
188
|
only_cross_attention=only_cross_attention[i],
|
@@ -192,7 +206,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
192
206
|
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
193
207
|
in_channels=block_out_channels[-1],
|
194
208
|
dropout=self.dropout,
|
195
|
-
|
209
|
+
num_attention_heads=num_attention_heads[-1],
|
196
210
|
use_linear_projection=self.use_linear_projection,
|
197
211
|
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
198
212
|
dtype=self.dtype,
|
@@ -201,7 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
201
215
|
# up
|
202
216
|
up_blocks = []
|
203
217
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
204
|
-
|
218
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
205
219
|
only_cross_attention = list(reversed(only_cross_attention))
|
206
220
|
output_channel = reversed_block_out_channels[0]
|
207
221
|
for i, up_block_type in enumerate(self.up_block_types):
|
@@ -217,7 +231,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
217
231
|
out_channels=output_channel,
|
218
232
|
prev_output_channel=prev_output_channel,
|
219
233
|
num_layers=self.layers_per_block + 1,
|
220
|
-
|
234
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
221
235
|
add_upsample=not is_final_block,
|
222
236
|
dropout=self.dropout,
|
223
237
|
use_linear_projection=self.use_linear_projection,
|
@@ -29,7 +29,7 @@ def get_down_block(
|
|
29
29
|
add_downsample,
|
30
30
|
resnet_eps,
|
31
31
|
resnet_act_fn,
|
32
|
-
|
32
|
+
num_attention_heads,
|
33
33
|
resnet_groups=None,
|
34
34
|
cross_attention_dim=None,
|
35
35
|
downsample_padding=None,
|
@@ -66,7 +66,7 @@ def get_down_block(
|
|
66
66
|
resnet_groups=resnet_groups,
|
67
67
|
downsample_padding=downsample_padding,
|
68
68
|
cross_attention_dim=cross_attention_dim,
|
69
|
-
|
69
|
+
num_attention_heads=num_attention_heads,
|
70
70
|
dual_cross_attention=dual_cross_attention,
|
71
71
|
use_linear_projection=use_linear_projection,
|
72
72
|
only_cross_attention=only_cross_attention,
|
@@ -86,7 +86,7 @@ def get_up_block(
|
|
86
86
|
add_upsample,
|
87
87
|
resnet_eps,
|
88
88
|
resnet_act_fn,
|
89
|
-
|
89
|
+
num_attention_heads,
|
90
90
|
resnet_groups=None,
|
91
91
|
cross_attention_dim=None,
|
92
92
|
dual_cross_attention=False,
|
@@ -122,7 +122,7 @@ def get_up_block(
|
|
122
122
|
resnet_act_fn=resnet_act_fn,
|
123
123
|
resnet_groups=resnet_groups,
|
124
124
|
cross_attention_dim=cross_attention_dim,
|
125
|
-
|
125
|
+
num_attention_heads=num_attention_heads,
|
126
126
|
dual_cross_attention=dual_cross_attention,
|
127
127
|
use_linear_projection=use_linear_projection,
|
128
128
|
only_cross_attention=only_cross_attention,
|
@@ -144,7 +144,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
144
144
|
resnet_act_fn: str = "swish",
|
145
145
|
resnet_groups: int = 32,
|
146
146
|
resnet_pre_norm: bool = True,
|
147
|
-
|
147
|
+
num_attention_heads=1,
|
148
148
|
output_scale_factor=1.0,
|
149
149
|
cross_attention_dim=1280,
|
150
150
|
dual_cross_attention=False,
|
@@ -154,7 +154,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
154
154
|
super().__init__()
|
155
155
|
|
156
156
|
self.has_cross_attention = True
|
157
|
-
self.
|
157
|
+
self.num_attention_heads = num_attention_heads
|
158
158
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
159
159
|
|
160
160
|
# there is always at least one resnet
|
@@ -185,8 +185,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
185
185
|
for _ in range(num_layers):
|
186
186
|
attentions.append(
|
187
187
|
Transformer2DModel(
|
188
|
-
in_channels //
|
189
|
-
|
188
|
+
in_channels // num_attention_heads,
|
189
|
+
num_attention_heads,
|
190
190
|
in_channels=in_channels,
|
191
191
|
num_layers=1,
|
192
192
|
cross_attention_dim=cross_attention_dim,
|
@@ -197,8 +197,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
197
197
|
)
|
198
198
|
temp_attentions.append(
|
199
199
|
TransformerTemporalModel(
|
200
|
-
in_channels //
|
201
|
-
|
200
|
+
in_channels // num_attention_heads,
|
201
|
+
num_attention_heads,
|
202
202
|
in_channels=in_channels,
|
203
203
|
num_layers=1,
|
204
204
|
cross_attention_dim=cross_attention_dim,
|
@@ -250,10 +250,11 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|
250
250
|
hidden_states,
|
251
251
|
encoder_hidden_states=encoder_hidden_states,
|
252
252
|
cross_attention_kwargs=cross_attention_kwargs,
|
253
|
-
|
253
|
+
return_dict=False,
|
254
|
+
)[0]
|
254
255
|
hidden_states = temp_attn(
|
255
|
-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
256
|
-
)
|
256
|
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
257
|
+
)[0]
|
257
258
|
hidden_states = resnet(hidden_states, temb)
|
258
259
|
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
259
260
|
|
@@ -273,7 +274,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
273
274
|
resnet_act_fn: str = "swish",
|
274
275
|
resnet_groups: int = 32,
|
275
276
|
resnet_pre_norm: bool = True,
|
276
|
-
|
277
|
+
num_attention_heads=1,
|
277
278
|
cross_attention_dim=1280,
|
278
279
|
output_scale_factor=1.0,
|
279
280
|
downsample_padding=1,
|
@@ -290,7 +291,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
290
291
|
temp_convs = []
|
291
292
|
|
292
293
|
self.has_cross_attention = True
|
293
|
-
self.
|
294
|
+
self.num_attention_heads = num_attention_heads
|
294
295
|
|
295
296
|
for i in range(num_layers):
|
296
297
|
in_channels = in_channels if i == 0 else out_channels
|
@@ -317,8 +318,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
317
318
|
)
|
318
319
|
attentions.append(
|
319
320
|
Transformer2DModel(
|
320
|
-
out_channels //
|
321
|
-
|
321
|
+
out_channels // num_attention_heads,
|
322
|
+
num_attention_heads,
|
322
323
|
in_channels=out_channels,
|
323
324
|
num_layers=1,
|
324
325
|
cross_attention_dim=cross_attention_dim,
|
@@ -330,8 +331,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
330
331
|
)
|
331
332
|
temp_attentions.append(
|
332
333
|
TransformerTemporalModel(
|
333
|
-
out_channels //
|
334
|
-
|
334
|
+
out_channels // num_attention_heads,
|
335
|
+
num_attention_heads,
|
335
336
|
in_channels=out_channels,
|
336
337
|
num_layers=1,
|
337
338
|
cross_attention_dim=cross_attention_dim,
|
@@ -377,10 +378,11 @@ class CrossAttnDownBlock3D(nn.Module):
|
|
377
378
|
hidden_states,
|
378
379
|
encoder_hidden_states=encoder_hidden_states,
|
379
380
|
cross_attention_kwargs=cross_attention_kwargs,
|
380
|
-
|
381
|
+
return_dict=False,
|
382
|
+
)[0]
|
381
383
|
hidden_states = temp_attn(
|
382
|
-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
383
|
-
)
|
384
|
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
385
|
+
)[0]
|
384
386
|
|
385
387
|
output_states += (hidden_states,)
|
386
388
|
|
@@ -486,7 +488,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
486
488
|
resnet_act_fn: str = "swish",
|
487
489
|
resnet_groups: int = 32,
|
488
490
|
resnet_pre_norm: bool = True,
|
489
|
-
|
491
|
+
num_attention_heads=1,
|
490
492
|
cross_attention_dim=1280,
|
491
493
|
output_scale_factor=1.0,
|
492
494
|
add_upsample=True,
|
@@ -502,7 +504,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
502
504
|
temp_attentions = []
|
503
505
|
|
504
506
|
self.has_cross_attention = True
|
505
|
-
self.
|
507
|
+
self.num_attention_heads = num_attention_heads
|
506
508
|
|
507
509
|
for i in range(num_layers):
|
508
510
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
@@ -531,8 +533,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
531
533
|
)
|
532
534
|
attentions.append(
|
533
535
|
Transformer2DModel(
|
534
|
-
out_channels //
|
535
|
-
|
536
|
+
out_channels // num_attention_heads,
|
537
|
+
num_attention_heads,
|
536
538
|
in_channels=out_channels,
|
537
539
|
num_layers=1,
|
538
540
|
cross_attention_dim=cross_attention_dim,
|
@@ -544,8 +546,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
544
546
|
)
|
545
547
|
temp_attentions.append(
|
546
548
|
TransformerTemporalModel(
|
547
|
-
out_channels //
|
548
|
-
|
549
|
+
out_channels // num_attention_heads,
|
550
|
+
num_attention_heads,
|
549
551
|
in_channels=out_channels,
|
550
552
|
num_layers=1,
|
551
553
|
cross_attention_dim=cross_attention_dim,
|
@@ -590,10 +592,11 @@ class CrossAttnUpBlock3D(nn.Module):
|
|
590
592
|
hidden_states,
|
591
593
|
encoder_hidden_states=encoder_hidden_states,
|
592
594
|
cross_attention_kwargs=cross_attention_kwargs,
|
593
|
-
|
595
|
+
return_dict=False,
|
596
|
+
)[0]
|
594
597
|
hidden_states = temp_attn(
|
595
|
-
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
|
596
|
-
)
|
598
|
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
599
|
+
)[0]
|
597
600
|
|
598
601
|
if self.upsamplers is not None:
|
599
602
|
for upsampler in self.upsamplers:
|
@@ -43,9 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
43
43
|
@dataclass
|
44
44
|
class UNet3DConditionOutput(BaseOutput):
|
45
45
|
"""
|
46
|
+
The output of [`UNet3DConditionModel`].
|
47
|
+
|
46
48
|
Args:
|
47
49
|
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
48
|
-
|
50
|
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
49
51
|
"""
|
50
52
|
|
51
53
|
sample: torch.FloatTensor
|
@@ -53,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput):
|
|
53
55
|
|
54
56
|
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
55
57
|
r"""
|
56
|
-
|
57
|
-
|
58
|
+
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
59
|
+
shaped output.
|
58
60
|
|
59
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for
|
60
|
-
|
61
|
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
62
|
+
for all models (such as downloading or saving).
|
61
63
|
|
62
64
|
Parameters:
|
63
65
|
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
@@ -66,7 +68,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
66
68
|
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
67
69
|
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
68
70
|
The tuple of downsample blocks to use.
|
69
|
-
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
71
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
70
72
|
The tuple of upsample blocks to use.
|
71
73
|
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
72
74
|
The tuple of output channels for each block.
|
@@ -75,10 +77,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
75
77
|
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
76
78
|
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
77
79
|
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
78
|
-
If `None`,
|
80
|
+
If `None`, normalization and activation layers is skipped in post-processing.
|
79
81
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
80
82
|
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
81
83
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
84
|
+
num_attention_heads (`int`, *optional*): The number of attention heads.
|
82
85
|
"""
|
83
86
|
|
84
87
|
_supports_gradient_checkpointing = False
|
@@ -105,11 +108,25 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
105
108
|
norm_eps: float = 1e-5,
|
106
109
|
cross_attention_dim: int = 1024,
|
107
110
|
attention_head_dim: Union[int, Tuple[int]] = 64,
|
111
|
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
108
112
|
):
|
109
113
|
super().__init__()
|
110
114
|
|
111
115
|
self.sample_size = sample_size
|
112
116
|
|
117
|
+
if num_attention_heads is not None:
|
118
|
+
raise NotImplementedError(
|
119
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
120
|
+
)
|
121
|
+
|
122
|
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
123
|
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
124
|
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
125
|
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
126
|
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
127
|
+
# which is why we correct for the naming here.
|
128
|
+
num_attention_heads = num_attention_heads or attention_head_dim
|
129
|
+
|
113
130
|
# Check inputs
|
114
131
|
if len(down_block_types) != len(up_block_types):
|
115
132
|
raise ValueError(
|
@@ -121,9 +138,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
121
138
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
122
139
|
)
|
123
140
|
|
124
|
-
if not isinstance(
|
141
|
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
125
142
|
raise ValueError(
|
126
|
-
f"Must provide the same number of `
|
143
|
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
127
144
|
)
|
128
145
|
|
129
146
|
# input
|
@@ -156,8 +173,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
156
173
|
self.down_blocks = nn.ModuleList([])
|
157
174
|
self.up_blocks = nn.ModuleList([])
|
158
175
|
|
159
|
-
if isinstance(
|
160
|
-
|
176
|
+
if isinstance(num_attention_heads, int):
|
177
|
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
161
178
|
|
162
179
|
# down
|
163
180
|
output_channel = block_out_channels[0]
|
@@ -177,7 +194,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
177
194
|
resnet_act_fn=act_fn,
|
178
195
|
resnet_groups=norm_num_groups,
|
179
196
|
cross_attention_dim=cross_attention_dim,
|
180
|
-
|
197
|
+
num_attention_heads=num_attention_heads[i],
|
181
198
|
downsample_padding=downsample_padding,
|
182
199
|
dual_cross_attention=False,
|
183
200
|
)
|
@@ -191,7 +208,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
191
208
|
resnet_act_fn=act_fn,
|
192
209
|
output_scale_factor=mid_block_scale_factor,
|
193
210
|
cross_attention_dim=cross_attention_dim,
|
194
|
-
|
211
|
+
num_attention_heads=num_attention_heads[-1],
|
195
212
|
resnet_groups=norm_num_groups,
|
196
213
|
dual_cross_attention=False,
|
197
214
|
)
|
@@ -201,7 +218,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
201
218
|
|
202
219
|
# up
|
203
220
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
204
|
-
|
221
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
205
222
|
|
206
223
|
output_channel = reversed_block_out_channels[0]
|
207
224
|
for i, up_block_type in enumerate(up_block_types):
|
@@ -230,7 +247,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
230
247
|
resnet_act_fn=act_fn,
|
231
248
|
resnet_groups=norm_num_groups,
|
232
249
|
cross_attention_dim=cross_attention_dim,
|
233
|
-
|
250
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
234
251
|
dual_cross_attention=False,
|
235
252
|
)
|
236
253
|
self.up_blocks.append(up_block)
|
@@ -281,13 +298,13 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
281
298
|
r"""
|
282
299
|
Enable sliced attention computation.
|
283
300
|
|
284
|
-
When this option is enabled, the attention module
|
285
|
-
|
301
|
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
302
|
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
286
303
|
|
287
304
|
Args:
|
288
305
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
289
|
-
When `"auto"`,
|
290
|
-
`"max"`, maximum amount of memory
|
306
|
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
307
|
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
291
308
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
292
309
|
must be a multiple of `slice_size`.
|
293
310
|
"""
|
@@ -345,11 +362,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
345
362
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
346
363
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
347
364
|
r"""
|
365
|
+
Sets the attention processor to use to compute attention.
|
366
|
+
|
348
367
|
Parameters:
|
349
|
-
|
368
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
350
369
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
351
|
-
|
352
|
-
|
370
|
+
for **all** `Attention` layers.
|
371
|
+
|
372
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
373
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
353
374
|
|
354
375
|
"""
|
355
376
|
count = len(self.attn_processors.keys())
|
@@ -373,6 +394,46 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
373
394
|
for name, module in self.named_children():
|
374
395
|
fn_recursive_attn_processor(name, module, processor)
|
375
396
|
|
397
|
+
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
398
|
+
"""
|
399
|
+
Sets the attention processor to use [feed forward
|
400
|
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
401
|
+
|
402
|
+
Parameters:
|
403
|
+
chunk_size (`int`, *optional*):
|
404
|
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
405
|
+
over each tensor of dim=`dim`.
|
406
|
+
dim (`int`, *optional*, defaults to `0`):
|
407
|
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
408
|
+
or dim=1 (sequence length).
|
409
|
+
"""
|
410
|
+
if dim not in [0, 1]:
|
411
|
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
412
|
+
|
413
|
+
# By default chunk size is 1
|
414
|
+
chunk_size = chunk_size or 1
|
415
|
+
|
416
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
417
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
418
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
419
|
+
|
420
|
+
for child in module.children():
|
421
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
422
|
+
|
423
|
+
for module in self.children():
|
424
|
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
425
|
+
|
426
|
+
def disable_forward_chunking(self):
|
427
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
428
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
429
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
430
|
+
|
431
|
+
for child in module.children():
|
432
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
433
|
+
|
434
|
+
for module in self.children():
|
435
|
+
fn_recursive_feed_forward(module, None, 0)
|
436
|
+
|
376
437
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
377
438
|
def set_default_attn_processor(self):
|
378
439
|
"""
|
@@ -398,21 +459,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
398
459
|
return_dict: bool = True,
|
399
460
|
) -> Union[UNet3DConditionOutput, Tuple]:
|
400
461
|
r"""
|
462
|
+
The [`UNet3DConditionModel`] forward method.
|
463
|
+
|
401
464
|
Args:
|
402
|
-
sample (`torch.FloatTensor`):
|
403
|
-
|
404
|
-
|
465
|
+
sample (`torch.FloatTensor`):
|
466
|
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
467
|
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
468
|
+
encoder_hidden_states (`torch.FloatTensor`):
|
469
|
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
405
470
|
return_dict (`bool`, *optional*, defaults to `True`):
|
406
|
-
Whether or not to return a [
|
471
|
+
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
472
|
+
tuple.
|
407
473
|
cross_attention_kwargs (`dict`, *optional*):
|
408
|
-
A kwargs dictionary that if specified is passed along to the `
|
409
|
-
`self.processor` in
|
410
|
-
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
474
|
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
411
475
|
|
412
476
|
Returns:
|
413
|
-
[`~models.
|
414
|
-
|
415
|
-
|
477
|
+
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
|
478
|
+
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
|
479
|
+
a `tuple` is returned where the first element is the sample tensor.
|
416
480
|
"""
|
417
481
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
418
482
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
@@ -467,8 +531,11 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
467
531
|
sample = self.conv_in(sample)
|
468
532
|
|
469
533
|
sample = self.transformer_in(
|
470
|
-
sample,
|
471
|
-
|
534
|
+
sample,
|
535
|
+
num_frames=num_frames,
|
536
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
537
|
+
return_dict=False,
|
538
|
+
)[0]
|
472
539
|
|
473
540
|
# 3. down
|
474
541
|
down_block_res_samples = (sample,)
|