diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +4 -5
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +8 -7
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from torch import nn
|
|
20
20
|
|
21
21
|
from ..configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ..models.embeddings import ImagePositionalEmbeddings
|
23
|
-
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
|
23
|
+
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
24
24
|
from .attention import BasicTransformerBlock
|
25
25
|
from .embeddings import CaptionProjection, PatchEmbed
|
26
26
|
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
70
70
|
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
71
71
|
"""
|
72
72
|
|
73
|
+
_supports_gradient_checkpointing = True
|
74
|
+
|
73
75
|
@register_to_config
|
74
76
|
def __init__(
|
75
77
|
self,
|
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
237
239
|
|
238
240
|
self.gradient_checkpointing = False
|
239
241
|
|
242
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
243
|
+
if hasattr(module, "gradient_checkpointing"):
|
244
|
+
module.gradient_checkpointing = value
|
245
|
+
|
240
246
|
def forward(
|
241
247
|
self,
|
242
248
|
hidden_states: torch.Tensor,
|
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
360
366
|
|
361
367
|
for block in self.transformer_blocks:
|
362
368
|
if self.training and self.gradient_checkpointing:
|
369
|
+
|
370
|
+
def create_custom_forward(module, return_dict=None):
|
371
|
+
def custom_forward(*inputs):
|
372
|
+
if return_dict is not None:
|
373
|
+
return module(*inputs, return_dict=return_dict)
|
374
|
+
else:
|
375
|
+
return module(*inputs)
|
376
|
+
|
377
|
+
return custom_forward
|
378
|
+
|
379
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
363
380
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
364
|
-
block,
|
381
|
+
create_custom_forward(block),
|
365
382
|
hidden_states,
|
366
383
|
attention_mask,
|
367
384
|
encoder_hidden_states,
|
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|
369
386
|
timestep,
|
370
387
|
cross_attention_kwargs,
|
371
388
|
class_labels,
|
372
|
-
|
389
|
+
**ckpt_kwargs,
|
373
390
|
)
|
374
391
|
else:
|
375
392
|
hidden_states = block(
|
@@ -19,8 +19,10 @@ from torch import nn
|
|
19
19
|
|
20
20
|
from ..configuration_utils import ConfigMixin, register_to_config
|
21
21
|
from ..utils import BaseOutput
|
22
|
-
from .attention import BasicTransformerBlock
|
22
|
+
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
23
|
+
from .embeddings import TimestepEmbedding, Timesteps
|
23
24
|
from .modeling_utils import ModelMixin
|
25
|
+
from .resnet import AlphaBlender
|
24
26
|
|
25
27
|
|
26
28
|
@dataclass
|
@@ -195,3 +197,183 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
|
195
197
|
return (output,)
|
196
198
|
|
197
199
|
return TransformerTemporalModelOutput(sample=output)
|
200
|
+
|
201
|
+
|
202
|
+
class TransformerSpatioTemporalModel(nn.Module):
|
203
|
+
"""
|
204
|
+
A Transformer model for video-like data.
|
205
|
+
|
206
|
+
Parameters:
|
207
|
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
208
|
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
209
|
+
in_channels (`int`, *optional*):
|
210
|
+
The number of channels in the input and output (specify if the input is **continuous**).
|
211
|
+
out_channels (`int`, *optional*):
|
212
|
+
The number of channels in the output (specify if the input is **continuous**).
|
213
|
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
214
|
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
215
|
+
"""
|
216
|
+
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
num_attention_heads: int = 16,
|
220
|
+
attention_head_dim: int = 88,
|
221
|
+
in_channels: int = 320,
|
222
|
+
out_channels: Optional[int] = None,
|
223
|
+
num_layers: int = 1,
|
224
|
+
cross_attention_dim: Optional[int] = None,
|
225
|
+
):
|
226
|
+
super().__init__()
|
227
|
+
self.num_attention_heads = num_attention_heads
|
228
|
+
self.attention_head_dim = attention_head_dim
|
229
|
+
|
230
|
+
inner_dim = num_attention_heads * attention_head_dim
|
231
|
+
self.inner_dim = inner_dim
|
232
|
+
|
233
|
+
# 2. Define input layers
|
234
|
+
self.in_channels = in_channels
|
235
|
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
236
|
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
237
|
+
|
238
|
+
# 3. Define transformers blocks
|
239
|
+
self.transformer_blocks = nn.ModuleList(
|
240
|
+
[
|
241
|
+
BasicTransformerBlock(
|
242
|
+
inner_dim,
|
243
|
+
num_attention_heads,
|
244
|
+
attention_head_dim,
|
245
|
+
cross_attention_dim=cross_attention_dim,
|
246
|
+
)
|
247
|
+
for d in range(num_layers)
|
248
|
+
]
|
249
|
+
)
|
250
|
+
|
251
|
+
time_mix_inner_dim = inner_dim
|
252
|
+
self.temporal_transformer_blocks = nn.ModuleList(
|
253
|
+
[
|
254
|
+
TemporalBasicTransformerBlock(
|
255
|
+
inner_dim,
|
256
|
+
time_mix_inner_dim,
|
257
|
+
num_attention_heads,
|
258
|
+
attention_head_dim,
|
259
|
+
cross_attention_dim=cross_attention_dim,
|
260
|
+
)
|
261
|
+
for _ in range(num_layers)
|
262
|
+
]
|
263
|
+
)
|
264
|
+
|
265
|
+
time_embed_dim = in_channels * 4
|
266
|
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
267
|
+
self.time_proj = Timesteps(in_channels, True, 0)
|
268
|
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
269
|
+
|
270
|
+
# 4. Define output layers
|
271
|
+
self.out_channels = in_channels if out_channels is None else out_channels
|
272
|
+
# TODO: should use out_channels for continuous projections
|
273
|
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
274
|
+
|
275
|
+
self.gradient_checkpointing = False
|
276
|
+
|
277
|
+
def forward(
|
278
|
+
self,
|
279
|
+
hidden_states: torch.Tensor,
|
280
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
281
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
282
|
+
return_dict: bool = True,
|
283
|
+
):
|
284
|
+
"""
|
285
|
+
Args:
|
286
|
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
287
|
+
Input hidden_states.
|
288
|
+
num_frames (`int`):
|
289
|
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
290
|
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
291
|
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
292
|
+
self-attention.
|
293
|
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
294
|
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
295
|
+
images, 0 indicates that the input contains video frames.
|
296
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
297
|
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
298
|
+
tuple.
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
302
|
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
303
|
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
304
|
+
"""
|
305
|
+
# 1. Input
|
306
|
+
batch_frames, _, height, width = hidden_states.shape
|
307
|
+
num_frames = image_only_indicator.shape[-1]
|
308
|
+
batch_size = batch_frames // num_frames
|
309
|
+
|
310
|
+
time_context = encoder_hidden_states
|
311
|
+
time_context_first_timestep = time_context[None, :].reshape(
|
312
|
+
batch_size, num_frames, -1, time_context.shape[-1]
|
313
|
+
)[:, 0]
|
314
|
+
time_context = time_context_first_timestep[None, :].broadcast_to(
|
315
|
+
height * width, batch_size, 1, time_context.shape[-1]
|
316
|
+
)
|
317
|
+
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
318
|
+
|
319
|
+
residual = hidden_states
|
320
|
+
|
321
|
+
hidden_states = self.norm(hidden_states)
|
322
|
+
inner_dim = hidden_states.shape[1]
|
323
|
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
324
|
+
hidden_states = self.proj_in(hidden_states)
|
325
|
+
|
326
|
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
327
|
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
328
|
+
num_frames_emb = num_frames_emb.reshape(-1)
|
329
|
+
t_emb = self.time_proj(num_frames_emb)
|
330
|
+
|
331
|
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
332
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
333
|
+
# there might be better ways to encapsulate this.
|
334
|
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
335
|
+
|
336
|
+
emb = self.time_pos_embed(t_emb)
|
337
|
+
emb = emb[:, None, :]
|
338
|
+
|
339
|
+
# 2. Blocks
|
340
|
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
341
|
+
if self.training and self.gradient_checkpointing:
|
342
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
343
|
+
block,
|
344
|
+
hidden_states,
|
345
|
+
None,
|
346
|
+
encoder_hidden_states,
|
347
|
+
None,
|
348
|
+
use_reentrant=False,
|
349
|
+
)
|
350
|
+
else:
|
351
|
+
hidden_states = block(
|
352
|
+
hidden_states,
|
353
|
+
encoder_hidden_states=encoder_hidden_states,
|
354
|
+
)
|
355
|
+
|
356
|
+
hidden_states_mix = hidden_states
|
357
|
+
hidden_states_mix = hidden_states_mix + emb
|
358
|
+
|
359
|
+
hidden_states_mix = temporal_block(
|
360
|
+
hidden_states_mix,
|
361
|
+
num_frames=num_frames,
|
362
|
+
encoder_hidden_states=time_context,
|
363
|
+
)
|
364
|
+
hidden_states = self.time_mixer(
|
365
|
+
x_spatial=hidden_states,
|
366
|
+
x_temporal=hidden_states_mix,
|
367
|
+
image_only_indicator=image_only_indicator,
|
368
|
+
)
|
369
|
+
|
370
|
+
# 3. Output
|
371
|
+
hidden_states = self.proj_out(hidden_states)
|
372
|
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
373
|
+
|
374
|
+
output = hidden_states + residual
|
375
|
+
|
376
|
+
if not return_dict:
|
377
|
+
return (output,)
|
378
|
+
|
379
|
+
return TransformerTemporalModelOutput(sample=output)
|
@@ -45,6 +45,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
45
45
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
46
46
|
Parameters `dtype`
|
47
47
|
"""
|
48
|
+
|
48
49
|
in_channels: int
|
49
50
|
out_channels: int
|
50
51
|
dropout: float = 0.0
|
@@ -125,6 +126,7 @@ class FlaxDownBlock2D(nn.Module):
|
|
125
126
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
126
127
|
Parameters `dtype`
|
127
128
|
"""
|
129
|
+
|
128
130
|
in_channels: int
|
129
131
|
out_channels: int
|
130
132
|
dropout: float = 0.0
|
@@ -190,6 +192,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
190
192
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
191
193
|
Parameters `dtype`
|
192
194
|
"""
|
195
|
+
|
193
196
|
in_channels: int
|
194
197
|
out_channels: int
|
195
198
|
prev_output_channel: int
|
@@ -275,6 +278,7 @@ class FlaxUpBlock2D(nn.Module):
|
|
275
278
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
276
279
|
Parameters `dtype`
|
277
280
|
"""
|
281
|
+
|
278
282
|
in_channels: int
|
279
283
|
out_channels: int
|
280
284
|
prev_output_channel: int
|
@@ -339,6 +343,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
339
343
|
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
340
344
|
Parameters `dtype`
|
341
345
|
"""
|
346
|
+
|
342
347
|
in_channels: int
|
343
348
|
dropout: float = 0.0
|
344
349
|
num_layers: int = 1
|
@@ -1022,6 +1022,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
1022
1022
|
)
|
1023
1023
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1024
1024
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1025
|
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1026
|
+
if "image_embeds" not in added_cond_kwargs:
|
1027
|
+
raise ValueError(
|
1028
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
1029
|
+
)
|
1030
|
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
1031
|
+
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
1032
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
1033
|
+
|
1025
1034
|
# 2. pre-process
|
1026
1035
|
sample = self.conv_in(sample)
|
1027
1036
|
|
@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
100
100
|
sample_size: int = 32
|
101
101
|
in_channels: int = 4
|
102
102
|
out_channels: int = 4
|
103
|
-
down_block_types: Tuple[str] = (
|
103
|
+
down_block_types: Tuple[str, ...] = (
|
104
104
|
"CrossAttnDownBlock2D",
|
105
105
|
"CrossAttnDownBlock2D",
|
106
106
|
"CrossAttnDownBlock2D",
|
107
107
|
"DownBlock2D",
|
108
108
|
)
|
109
|
-
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
109
|
+
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
110
110
|
only_cross_attention: Union[bool, Tuple[bool]] = False
|
111
|
-
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
111
|
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
112
112
|
layers_per_block: int = 2
|
113
|
-
attention_head_dim: Union[int, Tuple[int]] = 8
|
114
|
-
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
113
|
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
114
|
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
115
115
|
cross_attention_dim: int = 1280
|
116
116
|
dropout: float = 0.0
|
117
117
|
use_linear_projection: bool = False
|
@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
120
120
|
freq_shift: int = 0
|
121
121
|
use_memory_efficient_attention: bool = False
|
122
122
|
split_head_dim: bool = False
|
123
|
-
transformer_layers_per_block: Union[int, Tuple[int]] = 1
|
123
|
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
|
124
124
|
addition_embed_type: Optional[str] = None
|
125
125
|
addition_time_embed_dim: Optional[int] = None
|
126
126
|
addition_embed_type_num_heads: int = 64
|
@@ -158,7 +158,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
158
158
|
}
|
159
159
|
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
160
160
|
|
161
|
-
def setup(self):
|
161
|
+
def setup(self) -> None:
|
162
162
|
block_out_channels = self.block_out_channels
|
163
163
|
time_embed_dim = block_out_channels[0] * 4
|
164
164
|
|
@@ -320,15 +320,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
320
320
|
|
321
321
|
def __call__(
|
322
322
|
self,
|
323
|
-
sample,
|
324
|
-
timesteps,
|
325
|
-
encoder_hidden_states,
|
323
|
+
sample: jnp.ndarray,
|
324
|
+
timesteps: Union[jnp.ndarray, float, int],
|
325
|
+
encoder_hidden_states: jnp.ndarray,
|
326
326
|
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
|
327
|
-
down_block_additional_residuals=None,
|
328
|
-
mid_block_additional_residual=None,
|
327
|
+
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
|
328
|
+
mid_block_additional_residual: Optional[jnp.ndarray] = None,
|
329
329
|
return_dict: bool = True,
|
330
330
|
train: bool = False,
|
331
|
-
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
331
|
+
) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
|
332
332
|
r"""
|
333
333
|
Args:
|
334
334
|
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|