diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import Optional, Union
|
2
|
+
from typing import Dict, Optional, Union
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.nn.functional as F
|
@@ -8,6 +8,7 @@ from torch import nn
|
|
8
8
|
from ..configuration_utils import ConfigMixin, register_to_config
|
9
9
|
from ..utils import BaseOutput
|
10
10
|
from .attention import BasicTransformerBlock
|
11
|
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
11
12
|
from .embeddings import TimestepEmbedding, Timesteps
|
12
13
|
from .modeling_utils import ModelMixin
|
13
14
|
|
@@ -15,6 +16,8 @@ from .modeling_utils import ModelMixin
|
|
15
16
|
@dataclass
|
16
17
|
class PriorTransformerOutput(BaseOutput):
|
17
18
|
"""
|
19
|
+
The output of [`PriorTransformer`].
|
20
|
+
|
18
21
|
Args:
|
19
22
|
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
20
23
|
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
@@ -25,27 +28,39 @@ class PriorTransformerOutput(BaseOutput):
|
|
25
28
|
|
26
29
|
class PriorTransformer(ModelMixin, ConfigMixin):
|
27
30
|
"""
|
28
|
-
|
29
|
-
transformer predicts the image embeddings through a denoising diffusion process.
|
30
|
-
|
31
|
-
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
32
|
-
implements for all the models (such as downloading or saving, etc.)
|
33
|
-
|
34
|
-
For more details, see the original paper: https://arxiv.org/abs/2204.06125
|
31
|
+
A Prior Transformer model.
|
35
32
|
|
36
33
|
Parameters:
|
37
34
|
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
38
35
|
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
39
36
|
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
40
|
-
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the
|
41
|
-
|
42
|
-
|
43
|
-
length of the prompt after it has been tokenized.
|
37
|
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
38
|
+
num_embeddings (`int`, *optional*, defaults to 77):
|
39
|
+
The number of embeddings of the model input `hidden_states`
|
44
40
|
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
45
|
-
projected hidden_states
|
41
|
+
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
46
42
|
additional_embeddings`.
|
47
43
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
48
|
-
|
44
|
+
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
45
|
+
The activation function to use to create timestep embeddings.
|
46
|
+
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
47
|
+
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
48
|
+
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
49
|
+
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
50
|
+
needed.
|
51
|
+
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
52
|
+
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
53
|
+
`encoder_hidden_states` is `None`.
|
54
|
+
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
55
|
+
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
56
|
+
product between the text embedding and image embedding as proposed in the unclip paper
|
57
|
+
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
58
|
+
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
59
|
+
If None, will be set to `num_attention_heads * attention_head_dim`
|
60
|
+
embedding_proj_dim (`int`, *optional*, default to None):
|
61
|
+
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
62
|
+
clip_embed_dim (`int`, *optional*, default to None):
|
63
|
+
The dimension of the output. If None, will be set to `embedding_dim`.
|
49
64
|
"""
|
50
65
|
|
51
66
|
@register_to_config
|
@@ -58,6 +73,14 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
58
73
|
num_embeddings=77,
|
59
74
|
additional_embeddings=4,
|
60
75
|
dropout: float = 0.0,
|
76
|
+
time_embed_act_fn: str = "silu",
|
77
|
+
norm_in_type: Optional[str] = None, # layer
|
78
|
+
embedding_proj_norm_type: Optional[str] = None, # layer
|
79
|
+
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
80
|
+
added_emb_type: Optional[str] = "prd", # prd
|
81
|
+
time_embed_dim: Optional[int] = None,
|
82
|
+
embedding_proj_dim: Optional[int] = None,
|
83
|
+
clip_embed_dim: Optional[int] = None,
|
61
84
|
):
|
62
85
|
super().__init__()
|
63
86
|
self.num_attention_heads = num_attention_heads
|
@@ -65,17 +88,41 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
65
88
|
inner_dim = num_attention_heads * attention_head_dim
|
66
89
|
self.additional_embeddings = additional_embeddings
|
67
90
|
|
91
|
+
time_embed_dim = time_embed_dim or inner_dim
|
92
|
+
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
93
|
+
clip_embed_dim = clip_embed_dim or embedding_dim
|
94
|
+
|
68
95
|
self.time_proj = Timesteps(inner_dim, True, 0)
|
69
|
-
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
96
|
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
70
97
|
|
71
98
|
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
72
99
|
|
73
|
-
|
74
|
-
|
100
|
+
if embedding_proj_norm_type is None:
|
101
|
+
self.embedding_proj_norm = None
|
102
|
+
elif embedding_proj_norm_type == "layer":
|
103
|
+
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
104
|
+
else:
|
105
|
+
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
106
|
+
|
107
|
+
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
108
|
+
|
109
|
+
if encoder_hid_proj_type is None:
|
110
|
+
self.encoder_hidden_states_proj = None
|
111
|
+
elif encoder_hid_proj_type == "linear":
|
112
|
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
113
|
+
else:
|
114
|
+
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
75
115
|
|
76
116
|
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
77
117
|
|
78
|
-
|
118
|
+
if added_emb_type == "prd":
|
119
|
+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
120
|
+
elif added_emb_type is None:
|
121
|
+
self.prd_embedding = None
|
122
|
+
else:
|
123
|
+
raise ValueError(
|
124
|
+
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
125
|
+
)
|
79
126
|
|
80
127
|
self.transformer_blocks = nn.ModuleList(
|
81
128
|
[
|
@@ -91,8 +138,16 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
91
138
|
]
|
92
139
|
)
|
93
140
|
|
141
|
+
if norm_in_type == "layer":
|
142
|
+
self.norm_in = nn.LayerNorm(inner_dim)
|
143
|
+
elif norm_in_type is None:
|
144
|
+
self.norm_in = None
|
145
|
+
else:
|
146
|
+
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
147
|
+
|
94
148
|
self.norm_out = nn.LayerNorm(inner_dim)
|
95
|
-
|
149
|
+
|
150
|
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
96
151
|
|
97
152
|
causal_attention_mask = torch.full(
|
98
153
|
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
@@ -101,23 +156,92 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
101
156
|
causal_attention_mask = causal_attention_mask[None, ...]
|
102
157
|
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
103
158
|
|
104
|
-
self.clip_mean = nn.Parameter(torch.zeros(1,
|
105
|
-
self.clip_std = nn.Parameter(torch.zeros(1,
|
159
|
+
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
160
|
+
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
161
|
+
|
162
|
+
@property
|
163
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
164
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
165
|
+
r"""
|
166
|
+
Returns:
|
167
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
168
|
+
indexed by its weight name.
|
169
|
+
"""
|
170
|
+
# set recursively
|
171
|
+
processors = {}
|
172
|
+
|
173
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
174
|
+
if hasattr(module, "set_processor"):
|
175
|
+
processors[f"{name}.processor"] = module.processor
|
176
|
+
|
177
|
+
for sub_name, child in module.named_children():
|
178
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
179
|
+
|
180
|
+
return processors
|
181
|
+
|
182
|
+
for name, module in self.named_children():
|
183
|
+
fn_recursive_add_processors(name, module, processors)
|
184
|
+
|
185
|
+
return processors
|
186
|
+
|
187
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
188
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
189
|
+
r"""
|
190
|
+
Sets the attention processor to use to compute attention.
|
191
|
+
|
192
|
+
Parameters:
|
193
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
194
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
195
|
+
for **all** `Attention` layers.
|
196
|
+
|
197
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
198
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
199
|
+
|
200
|
+
"""
|
201
|
+
count = len(self.attn_processors.keys())
|
202
|
+
|
203
|
+
if isinstance(processor, dict) and len(processor) != count:
|
204
|
+
raise ValueError(
|
205
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
206
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
207
|
+
)
|
208
|
+
|
209
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
210
|
+
if hasattr(module, "set_processor"):
|
211
|
+
if not isinstance(processor, dict):
|
212
|
+
module.set_processor(processor)
|
213
|
+
else:
|
214
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
215
|
+
|
216
|
+
for sub_name, child in module.named_children():
|
217
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
218
|
+
|
219
|
+
for name, module in self.named_children():
|
220
|
+
fn_recursive_attn_processor(name, module, processor)
|
221
|
+
|
222
|
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
223
|
+
def set_default_attn_processor(self):
|
224
|
+
"""
|
225
|
+
Disables custom attention processors and sets the default attention implementation.
|
226
|
+
"""
|
227
|
+
self.set_attn_processor(AttnProcessor())
|
106
228
|
|
107
229
|
def forward(
|
108
230
|
self,
|
109
231
|
hidden_states,
|
110
232
|
timestep: Union[torch.Tensor, float, int],
|
111
233
|
proj_embedding: torch.FloatTensor,
|
112
|
-
encoder_hidden_states: torch.FloatTensor,
|
234
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
113
235
|
attention_mask: Optional[torch.BoolTensor] = None,
|
114
236
|
return_dict: bool = True,
|
115
237
|
):
|
116
238
|
"""
|
239
|
+
The [`PriorTransformer`] forward method.
|
240
|
+
|
117
241
|
Args:
|
118
242
|
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
119
|
-
|
120
|
-
timestep (`torch.
|
243
|
+
The currently predicted image embeddings.
|
244
|
+
timestep (`torch.LongTensor`):
|
121
245
|
Current denoising step.
|
122
246
|
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
123
247
|
Projected embedding vector the denoising process is conditioned on.
|
@@ -126,13 +250,13 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
126
250
|
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
127
251
|
Text mask for the text embeddings.
|
128
252
|
return_dict (`bool`, *optional*, defaults to `True`):
|
129
|
-
Whether or not to return a [
|
253
|
+
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
130
254
|
tuple.
|
131
255
|
|
132
256
|
Returns:
|
133
257
|
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
134
|
-
|
135
|
-
|
258
|
+
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
259
|
+
tuple is returned where the first element is the sample tensor.
|
136
260
|
"""
|
137
261
|
batch_size = hidden_states.shape[0]
|
138
262
|
|
@@ -152,23 +276,61 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
152
276
|
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
153
277
|
time_embeddings = self.time_embedding(timesteps_projected)
|
154
278
|
|
279
|
+
if self.embedding_proj_norm is not None:
|
280
|
+
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
281
|
+
|
155
282
|
proj_embeddings = self.embedding_proj(proj_embedding)
|
156
|
-
|
283
|
+
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
284
|
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
285
|
+
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
286
|
+
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
287
|
+
|
157
288
|
hidden_states = self.proj_in(hidden_states)
|
158
|
-
|
289
|
+
|
159
290
|
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
160
291
|
|
292
|
+
additional_embeds = []
|
293
|
+
additional_embeddings_len = 0
|
294
|
+
|
295
|
+
if encoder_hidden_states is not None:
|
296
|
+
additional_embeds.append(encoder_hidden_states)
|
297
|
+
additional_embeddings_len += encoder_hidden_states.shape[1]
|
298
|
+
|
299
|
+
if len(proj_embeddings.shape) == 2:
|
300
|
+
proj_embeddings = proj_embeddings[:, None, :]
|
301
|
+
|
302
|
+
if len(hidden_states.shape) == 2:
|
303
|
+
hidden_states = hidden_states[:, None, :]
|
304
|
+
|
305
|
+
additional_embeds = additional_embeds + [
|
306
|
+
proj_embeddings,
|
307
|
+
time_embeddings[:, None, :],
|
308
|
+
hidden_states,
|
309
|
+
]
|
310
|
+
|
311
|
+
if self.prd_embedding is not None:
|
312
|
+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
313
|
+
additional_embeds.append(prd_embedding)
|
314
|
+
|
161
315
|
hidden_states = torch.cat(
|
162
|
-
|
163
|
-
encoder_hidden_states,
|
164
|
-
proj_embeddings[:, None, :],
|
165
|
-
time_embeddings[:, None, :],
|
166
|
-
hidden_states[:, None, :],
|
167
|
-
prd_embedding,
|
168
|
-
],
|
316
|
+
additional_embeds,
|
169
317
|
dim=1,
|
170
318
|
)
|
171
319
|
|
320
|
+
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
321
|
+
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
322
|
+
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
323
|
+
positional_embeddings = F.pad(
|
324
|
+
positional_embeddings,
|
325
|
+
(
|
326
|
+
0,
|
327
|
+
0,
|
328
|
+
additional_embeddings_len,
|
329
|
+
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
330
|
+
),
|
331
|
+
value=0.0,
|
332
|
+
)
|
333
|
+
|
172
334
|
hidden_states = hidden_states + positional_embeddings
|
173
335
|
|
174
336
|
if attention_mask is not None:
|
@@ -177,11 +339,19 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
|
177
339
|
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
178
340
|
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
179
341
|
|
342
|
+
if self.norm_in is not None:
|
343
|
+
hidden_states = self.norm_in(hidden_states)
|
344
|
+
|
180
345
|
for block in self.transformer_blocks:
|
181
346
|
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
182
347
|
|
183
348
|
hidden_states = self.norm_out(hidden_states)
|
184
|
-
|
349
|
+
|
350
|
+
if self.prd_embedding is not None:
|
351
|
+
hidden_states = hidden_states[:, -1]
|
352
|
+
else:
|
353
|
+
hidden_states = hidden_states[:, additional_embeddings_len:]
|
354
|
+
|
185
355
|
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
186
356
|
|
187
357
|
if not return_dict:
|
diffusers/models/resnet.py
CHANGED
@@ -95,9 +95,9 @@ class Downsample1D(nn.Module):
|
|
95
95
|
assert self.channels == self.out_channels
|
96
96
|
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
97
97
|
|
98
|
-
def forward(self,
|
99
|
-
assert
|
100
|
-
return self.conv(
|
98
|
+
def forward(self, inputs):
|
99
|
+
assert inputs.shape[1] == self.channels
|
100
|
+
return self.conv(inputs)
|
101
101
|
|
102
102
|
|
103
103
|
class Upsample2D(nn.Module):
|
@@ -431,13 +431,13 @@ class KDownsample2D(nn.Module):
|
|
431
431
|
self.pad = kernel_1d.shape[1] // 2 - 1
|
432
432
|
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
433
433
|
|
434
|
-
def forward(self,
|
435
|
-
|
436
|
-
weight =
|
437
|
-
indices = torch.arange(
|
438
|
-
kernel = self.kernel.to(weight)[None, :].expand(
|
434
|
+
def forward(self, inputs):
|
435
|
+
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
436
|
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
437
|
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
438
|
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
439
439
|
weight[indices, indices] = kernel
|
440
|
-
return F.conv2d(
|
440
|
+
return F.conv2d(inputs, weight, stride=2)
|
441
441
|
|
442
442
|
|
443
443
|
class KUpsample2D(nn.Module):
|
@@ -448,13 +448,13 @@ class KUpsample2D(nn.Module):
|
|
448
448
|
self.pad = kernel_1d.shape[1] // 2 - 1
|
449
449
|
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
450
450
|
|
451
|
-
def forward(self,
|
452
|
-
|
453
|
-
weight =
|
454
|
-
indices = torch.arange(
|
455
|
-
kernel = self.kernel.to(weight)[None, :].expand(
|
451
|
+
def forward(self, inputs):
|
452
|
+
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
453
|
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
454
|
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
455
|
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
456
456
|
weight[indices, indices] = kernel
|
457
|
-
return F.conv_transpose2d(
|
457
|
+
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
458
458
|
|
459
459
|
|
460
460
|
class ResnetBlock2D(nn.Module):
|
@@ -664,13 +664,13 @@ class Conv1dBlock(nn.Module):
|
|
664
664
|
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
665
665
|
self.mish = nn.Mish()
|
666
666
|
|
667
|
-
def forward(self,
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
return
|
667
|
+
def forward(self, inputs):
|
668
|
+
intermediate_repr = self.conv1d(inputs)
|
669
|
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
670
|
+
intermediate_repr = self.group_norm(intermediate_repr)
|
671
|
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
672
|
+
output = self.mish(intermediate_repr)
|
673
|
+
return output
|
674
674
|
|
675
675
|
|
676
676
|
# unet_rl.py
|
@@ -687,10 +687,10 @@ class ResidualTemporalBlock1D(nn.Module):
|
|
687
687
|
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
688
688
|
)
|
689
689
|
|
690
|
-
def forward(self,
|
690
|
+
def forward(self, inputs, t):
|
691
691
|
"""
|
692
692
|
Args:
|
693
|
-
|
693
|
+
inputs : [ batch_size x inp_channels x horizon ]
|
694
694
|
t : [ batch_size x embed_dim ]
|
695
695
|
|
696
696
|
returns:
|
@@ -698,9 +698,9 @@ class ResidualTemporalBlock1D(nn.Module):
|
|
698
698
|
"""
|
699
699
|
t = self.time_emb_act(t)
|
700
700
|
t = self.time_emb(t)
|
701
|
-
out = self.conv_in(
|
701
|
+
out = self.conv_in(inputs) + rearrange_dims(t)
|
702
702
|
out = self.conv_out(out)
|
703
|
-
return out + self.residual_conv(
|
703
|
+
return out + self.residual_conv(inputs)
|
704
704
|
|
705
705
|
|
706
706
|
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
@@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin
|
|
29
29
|
@dataclass
|
30
30
|
class Transformer2DModelOutput(BaseOutput):
|
31
31
|
"""
|
32
|
+
The output of [`Transformer2DModel`].
|
33
|
+
|
32
34
|
Args:
|
33
35
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
34
|
-
|
35
|
-
for the unnoised latent pixels.
|
36
|
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
37
|
+
distributions for the unnoised latent pixels.
|
36
38
|
"""
|
37
39
|
|
38
40
|
sample: torch.FloatTensor
|
@@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput):
|
|
40
42
|
|
41
43
|
class Transformer2DModel(ModelMixin, ConfigMixin):
|
42
44
|
"""
|
43
|
-
Transformer model for image-like data.
|
44
|
-
embeddings) inputs.
|
45
|
-
|
46
|
-
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
47
|
-
transformer action. Finally, reshape to image.
|
48
|
-
|
49
|
-
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
50
|
-
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
51
|
-
classes of unnoised image.
|
52
|
-
|
53
|
-
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
54
|
-
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
45
|
+
A 2D Transformer model for image-like data.
|
55
46
|
|
56
47
|
Parameters:
|
57
48
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
58
49
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
59
50
|
in_channels (`int`, *optional*):
|
60
|
-
|
51
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
61
52
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
62
53
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
63
|
-
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
64
|
-
sample_size (`int`, *optional*):
|
65
|
-
|
66
|
-
`ImagePositionalEmbeddings`.
|
54
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
55
|
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
56
|
+
This is fixed during training since it is used to learn a number of position embeddings.
|
67
57
|
num_vector_embeds (`int`, *optional*):
|
68
|
-
|
58
|
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
69
59
|
Includes the class for the masked latent pixel.
|
70
|
-
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to
|
71
|
-
num_embeds_ada_norm ( `int`, *optional*):
|
72
|
-
The number of diffusion steps used during training.
|
73
|
-
|
74
|
-
|
60
|
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
61
|
+
num_embeds_ada_norm ( `int`, *optional*):
|
62
|
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
63
|
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
64
|
+
added to the hidden states.
|
65
|
+
|
66
|
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
75
67
|
attention_bias (`bool`, *optional*):
|
76
|
-
Configure if the TransformerBlocks
|
68
|
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
77
69
|
"""
|
78
70
|
|
79
71
|
@register_to_config
|
@@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
223
215
|
return_dict: bool = True,
|
224
216
|
):
|
225
217
|
"""
|
218
|
+
The [`Transformer2DModel`] forward method.
|
219
|
+
|
226
220
|
Args:
|
227
|
-
hidden_states (
|
228
|
-
|
229
|
-
hidden_states
|
221
|
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
222
|
+
Input `hidden_states`.
|
230
223
|
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
231
224
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
232
225
|
self-attention.
|
233
226
|
timestep ( `torch.LongTensor`, *optional*):
|
234
|
-
Optional timestep to be applied as an embedding in AdaLayerNorm
|
227
|
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
235
228
|
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
236
|
-
Optional class labels to be applied as an embedding in
|
237
|
-
|
238
|
-
encoder_attention_mask ( `torch.Tensor`, *optional*
|
239
|
-
Cross-attention mask
|
240
|
-
|
241
|
-
= keep,
|
242
|
-
|
229
|
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
230
|
+
`AdaLayerZeroNorm`.
|
231
|
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
232
|
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
233
|
+
|
234
|
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
235
|
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
236
|
+
|
237
|
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
243
238
|
above. This bias will be added to the cross-attention scores.
|
244
239
|
return_dict (`bool`, *optional*, defaults to `True`):
|
245
|
-
Whether or not to return a [
|
240
|
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
241
|
+
tuple.
|
246
242
|
|
247
243
|
Returns:
|
248
|
-
[`~models.transformer_2d.Transformer2DModelOutput`]
|
249
|
-
|
250
|
-
returning a tuple, the first element is the sample tensor.
|
244
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
245
|
+
`tuple` where the first element is the sample tensor.
|
251
246
|
"""
|
252
247
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
253
248
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|