diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +4 -5
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +8 -7
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
diffusers/models/resnet.py
CHANGED
@@ -164,7 +164,12 @@ class Upsample2D(nn.Module):
|
|
164
164
|
else:
|
165
165
|
self.Conv2d_0 = conv
|
166
166
|
|
167
|
-
def forward(
|
167
|
+
def forward(
|
168
|
+
self,
|
169
|
+
hidden_states: torch.FloatTensor,
|
170
|
+
output_size: Optional[int] = None,
|
171
|
+
scale: float = 1.0,
|
172
|
+
) -> torch.FloatTensor:
|
168
173
|
assert hidden_states.shape[1] == self.channels
|
169
174
|
|
170
175
|
if self.use_conv_transpose:
|
@@ -256,7 +261,7 @@ class Downsample2D(nn.Module):
|
|
256
261
|
else:
|
257
262
|
self.conv = conv
|
258
263
|
|
259
|
-
def forward(self, hidden_states, scale: float = 1.0):
|
264
|
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
260
265
|
assert hidden_states.shape[1] == self.channels
|
261
266
|
|
262
267
|
if self.use_conv and self.padding == 0:
|
@@ -280,7 +285,7 @@ class FirUpsample2D(nn.Module):
|
|
280
285
|
"""A 2D FIR upsampling layer with an optional convolution.
|
281
286
|
|
282
287
|
Parameters:
|
283
|
-
channels (`int
|
288
|
+
channels (`int`, optional):
|
284
289
|
number of channels in the inputs and outputs.
|
285
290
|
use_conv (`bool`, default `False`):
|
286
291
|
option to use a convolution.
|
@@ -292,7 +297,7 @@ class FirUpsample2D(nn.Module):
|
|
292
297
|
|
293
298
|
def __init__(
|
294
299
|
self,
|
295
|
-
channels: int = None,
|
300
|
+
channels: Optional[int] = None,
|
296
301
|
out_channels: Optional[int] = None,
|
297
302
|
use_conv: bool = False,
|
298
303
|
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
@@ -307,12 +312,12 @@ class FirUpsample2D(nn.Module):
|
|
307
312
|
|
308
313
|
def _upsample_2d(
|
309
314
|
self,
|
310
|
-
hidden_states: torch.
|
311
|
-
weight: Optional[torch.
|
315
|
+
hidden_states: torch.FloatTensor,
|
316
|
+
weight: Optional[torch.FloatTensor] = None,
|
312
317
|
kernel: Optional[torch.FloatTensor] = None,
|
313
318
|
factor: int = 2,
|
314
319
|
gain: float = 1,
|
315
|
-
) -> torch.
|
320
|
+
) -> torch.FloatTensor:
|
316
321
|
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
317
322
|
|
318
323
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
@@ -320,17 +325,21 @@ class FirUpsample2D(nn.Module):
|
|
320
325
|
arbitrary order.
|
321
326
|
|
322
327
|
Args:
|
323
|
-
hidden_states
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
328
|
+
hidden_states (`torch.FloatTensor`):
|
329
|
+
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
330
|
+
weight (`torch.FloatTensor`, *optional*):
|
331
|
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
332
|
+
performed by `inChannels = x.shape[0] // numGroups`.
|
333
|
+
kernel (`torch.FloatTensor`, *optional*):
|
334
|
+
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
335
|
+
corresponds to nearest-neighbor upsampling.
|
336
|
+
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
337
|
+
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
330
338
|
|
331
339
|
Returns:
|
332
|
-
output
|
333
|
-
|
340
|
+
output (`torch.FloatTensor`):
|
341
|
+
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
342
|
+
datatype as `hidden_states`.
|
334
343
|
"""
|
335
344
|
|
336
345
|
assert isinstance(factor, int) and factor >= 1
|
@@ -373,7 +382,11 @@ class FirUpsample2D(nn.Module):
|
|
373
382
|
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
374
383
|
|
375
384
|
inverse_conv = F.conv_transpose2d(
|
376
|
-
hidden_states,
|
385
|
+
hidden_states,
|
386
|
+
weight,
|
387
|
+
stride=stride,
|
388
|
+
output_padding=output_padding,
|
389
|
+
padding=0,
|
377
390
|
)
|
378
391
|
|
379
392
|
output = upfirdn2d_native(
|
@@ -392,7 +405,7 @@ class FirUpsample2D(nn.Module):
|
|
392
405
|
|
393
406
|
return output
|
394
407
|
|
395
|
-
def forward(self, hidden_states: torch.
|
408
|
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
396
409
|
if self.use_conv:
|
397
410
|
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
398
411
|
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
@@ -418,7 +431,7 @@ class FirDownsample2D(nn.Module):
|
|
418
431
|
|
419
432
|
def __init__(
|
420
433
|
self,
|
421
|
-
channels: int = None,
|
434
|
+
channels: Optional[int] = None,
|
422
435
|
out_channels: Optional[int] = None,
|
423
436
|
use_conv: bool = False,
|
424
437
|
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
@@ -433,30 +446,35 @@ class FirDownsample2D(nn.Module):
|
|
433
446
|
|
434
447
|
def _downsample_2d(
|
435
448
|
self,
|
436
|
-
hidden_states: torch.
|
437
|
-
weight: Optional[torch.
|
449
|
+
hidden_states: torch.FloatTensor,
|
450
|
+
weight: Optional[torch.FloatTensor] = None,
|
438
451
|
kernel: Optional[torch.FloatTensor] = None,
|
439
452
|
factor: int = 2,
|
440
453
|
gain: float = 1,
|
441
|
-
) -> torch.
|
454
|
+
) -> torch.FloatTensor:
|
442
455
|
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
443
456
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
444
457
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
445
458
|
arbitrary order.
|
446
459
|
|
447
460
|
Args:
|
448
|
-
hidden_states
|
449
|
-
|
461
|
+
hidden_states (`torch.FloatTensor`):
|
462
|
+
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
463
|
+
weight (`torch.FloatTensor`, *optional*):
|
450
464
|
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
451
465
|
performed by `inChannels = x.shape[0] // numGroups`.
|
452
|
-
kernel
|
453
|
-
|
454
|
-
|
455
|
-
|
466
|
+
kernel (`torch.FloatTensor`, *optional*):
|
467
|
+
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
468
|
+
corresponds to average pooling.
|
469
|
+
factor (`int`, *optional*, default to `2`):
|
470
|
+
Integer downsampling factor.
|
471
|
+
gain (`float`, *optional*, default to `1.0`):
|
472
|
+
Scaling factor for signal magnitude.
|
456
473
|
|
457
474
|
Returns:
|
458
|
-
output
|
459
|
-
|
475
|
+
output (`torch.FloatTensor`):
|
476
|
+
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
477
|
+
datatype as `x`.
|
460
478
|
"""
|
461
479
|
|
462
480
|
assert isinstance(factor, int) and factor >= 1
|
@@ -492,7 +510,7 @@ class FirDownsample2D(nn.Module):
|
|
492
510
|
|
493
511
|
return output
|
494
512
|
|
495
|
-
def forward(self, hidden_states: torch.
|
513
|
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
496
514
|
if self.use_conv:
|
497
515
|
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
498
516
|
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
@@ -519,7 +537,14 @@ class KDownsample2D(nn.Module):
|
|
519
537
|
|
520
538
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
521
539
|
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
522
|
-
weight = inputs.new_zeros(
|
540
|
+
weight = inputs.new_zeros(
|
541
|
+
[
|
542
|
+
inputs.shape[1],
|
543
|
+
inputs.shape[1],
|
544
|
+
self.kernel.shape[0],
|
545
|
+
self.kernel.shape[1],
|
546
|
+
]
|
547
|
+
)
|
523
548
|
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
524
549
|
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
525
550
|
weight[indices, indices] = kernel
|
@@ -542,7 +567,14 @@ class KUpsample2D(nn.Module):
|
|
542
567
|
|
543
568
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
544
569
|
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
545
|
-
weight = inputs.new_zeros(
|
570
|
+
weight = inputs.new_zeros(
|
571
|
+
[
|
572
|
+
inputs.shape[1],
|
573
|
+
inputs.shape[1],
|
574
|
+
self.kernel.shape[0],
|
575
|
+
self.kernel.shape[1],
|
576
|
+
]
|
577
|
+
)
|
546
578
|
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
547
579
|
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
548
580
|
weight[indices, indices] = kernel
|
@@ -679,10 +711,20 @@ class ResnetBlock2D(nn.Module):
|
|
679
711
|
self.conv_shortcut = None
|
680
712
|
if self.use_in_shortcut:
|
681
713
|
self.conv_shortcut = conv_cls(
|
682
|
-
in_channels,
|
714
|
+
in_channels,
|
715
|
+
conv_2d_out_channels,
|
716
|
+
kernel_size=1,
|
717
|
+
stride=1,
|
718
|
+
padding=0,
|
719
|
+
bias=conv_shortcut_bias,
|
683
720
|
)
|
684
721
|
|
685
|
-
def forward(
|
722
|
+
def forward(
|
723
|
+
self,
|
724
|
+
input_tensor: torch.FloatTensor,
|
725
|
+
temb: torch.FloatTensor,
|
726
|
+
scale: float = 1.0,
|
727
|
+
) -> torch.FloatTensor:
|
686
728
|
hidden_states = input_tensor
|
687
729
|
|
688
730
|
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
@@ -778,7 +820,7 @@ class Conv1dBlock(nn.Module):
|
|
778
820
|
out_channels (`int`): Number of output channels.
|
779
821
|
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
780
822
|
n_groups (`int`, default `8`): Number of groups to separate the channels into.
|
781
|
-
activation (`str`, defaults `mish`): Name of the activation function.
|
823
|
+
activation (`str`, defaults to `mish`): Name of the activation function.
|
782
824
|
"""
|
783
825
|
|
784
826
|
def __init__(
|
@@ -853,8 +895,11 @@ class ResidualTemporalBlock1D(nn.Module):
|
|
853
895
|
|
854
896
|
|
855
897
|
def upsample_2d(
|
856
|
-
hidden_states: torch.
|
857
|
-
|
898
|
+
hidden_states: torch.FloatTensor,
|
899
|
+
kernel: Optional[torch.FloatTensor] = None,
|
900
|
+
factor: int = 2,
|
901
|
+
gain: float = 1,
|
902
|
+
) -> torch.FloatTensor:
|
858
903
|
r"""Upsample2D a batch of 2D images with the given filter.
|
859
904
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
860
905
|
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
@@ -862,14 +907,19 @@ def upsample_2d(
|
|
862
907
|
a: multiple of the upsampling factor.
|
863
908
|
|
864
909
|
Args:
|
865
|
-
hidden_states
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
910
|
+
hidden_states (`torch.FloatTensor`):
|
911
|
+
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
912
|
+
kernel (`torch.FloatTensor`, *optional*):
|
913
|
+
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
914
|
+
corresponds to nearest-neighbor upsampling.
|
915
|
+
factor (`int`, *optional*, default to `2`):
|
916
|
+
Integer upsampling factor.
|
917
|
+
gain (`float`, *optional*, default to `1.0`):
|
918
|
+
Scaling factor for signal magnitude (default: 1.0).
|
870
919
|
|
871
920
|
Returns:
|
872
|
-
output
|
921
|
+
output (`torch.FloatTensor`):
|
922
|
+
Tensor of the shape `[N, C, H * factor, W * factor]`
|
873
923
|
"""
|
874
924
|
assert isinstance(factor, int) and factor >= 1
|
875
925
|
if kernel is None:
|
@@ -892,8 +942,11 @@ def upsample_2d(
|
|
892
942
|
|
893
943
|
|
894
944
|
def downsample_2d(
|
895
|
-
hidden_states: torch.
|
896
|
-
|
945
|
+
hidden_states: torch.FloatTensor,
|
946
|
+
kernel: Optional[torch.FloatTensor] = None,
|
947
|
+
factor: int = 2,
|
948
|
+
gain: float = 1,
|
949
|
+
) -> torch.FloatTensor:
|
897
950
|
r"""Downsample2D a batch of 2D images with the given filter.
|
898
951
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
899
952
|
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
@@ -901,14 +954,19 @@ def downsample_2d(
|
|
901
954
|
shape is a multiple of the downsampling factor.
|
902
955
|
|
903
956
|
Args:
|
904
|
-
hidden_states
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
957
|
+
hidden_states (`torch.FloatTensor`)
|
958
|
+
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
959
|
+
kernel (`torch.FloatTensor`, *optional*):
|
960
|
+
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
961
|
+
corresponds to average pooling.
|
962
|
+
factor (`int`, *optional*, default to `2`):
|
963
|
+
Integer downsampling factor.
|
964
|
+
gain (`float`, *optional*, default to `1.0`):
|
965
|
+
Scaling factor for signal magnitude.
|
909
966
|
|
910
967
|
Returns:
|
911
|
-
output
|
968
|
+
output (`torch.FloatTensor`):
|
969
|
+
Tensor of the shape `[N, C, H // factor, W // factor]`
|
912
970
|
"""
|
913
971
|
|
914
972
|
assert isinstance(factor, int) and factor >= 1
|
@@ -923,13 +981,20 @@ def downsample_2d(
|
|
923
981
|
kernel = kernel * gain
|
924
982
|
pad_value = kernel.shape[0] - factor
|
925
983
|
output = upfirdn2d_native(
|
926
|
-
hidden_states,
|
984
|
+
hidden_states,
|
985
|
+
kernel.to(device=hidden_states.device),
|
986
|
+
down=factor,
|
987
|
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
927
988
|
)
|
928
989
|
return output
|
929
990
|
|
930
991
|
|
931
992
|
def upfirdn2d_native(
|
932
|
-
tensor: torch.Tensor,
|
993
|
+
tensor: torch.Tensor,
|
994
|
+
kernel: torch.Tensor,
|
995
|
+
up: int = 1,
|
996
|
+
down: int = 1,
|
997
|
+
pad: Tuple[int, int] = (0, 0),
|
933
998
|
) -> torch.Tensor:
|
934
999
|
up_x = up_y = up
|
935
1000
|
down_x = down_y = down
|
@@ -985,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
|
|
985
1050
|
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
986
1051
|
"""
|
987
1052
|
|
988
|
-
def __init__(
|
1053
|
+
def __init__(
|
1054
|
+
self,
|
1055
|
+
in_dim: int,
|
1056
|
+
out_dim: Optional[int] = None,
|
1057
|
+
dropout: float = 0.0,
|
1058
|
+
norm_num_groups: int = 32,
|
1059
|
+
):
|
989
1060
|
super().__init__()
|
990
1061
|
out_dim = out_dim or in_dim
|
991
1062
|
self.in_dim = in_dim
|
@@ -993,22 +1064,24 @@ class TemporalConvLayer(nn.Module):
|
|
993
1064
|
|
994
1065
|
# conv layers
|
995
1066
|
self.conv1 = nn.Sequential(
|
996
|
-
nn.GroupNorm(
|
1067
|
+
nn.GroupNorm(norm_num_groups, in_dim),
|
1068
|
+
nn.SiLU(),
|
1069
|
+
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
|
997
1070
|
)
|
998
1071
|
self.conv2 = nn.Sequential(
|
999
|
-
nn.GroupNorm(
|
1072
|
+
nn.GroupNorm(norm_num_groups, out_dim),
|
1000
1073
|
nn.SiLU(),
|
1001
1074
|
nn.Dropout(dropout),
|
1002
1075
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
1003
1076
|
)
|
1004
1077
|
self.conv3 = nn.Sequential(
|
1005
|
-
nn.GroupNorm(
|
1078
|
+
nn.GroupNorm(norm_num_groups, out_dim),
|
1006
1079
|
nn.SiLU(),
|
1007
1080
|
nn.Dropout(dropout),
|
1008
1081
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
1009
1082
|
)
|
1010
1083
|
self.conv4 = nn.Sequential(
|
1011
|
-
nn.GroupNorm(
|
1084
|
+
nn.GroupNorm(norm_num_groups, out_dim),
|
1012
1085
|
nn.SiLU(),
|
1013
1086
|
nn.Dropout(dropout),
|
1014
1087
|
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
@@ -1035,3 +1108,261 @@ class TemporalConvLayer(nn.Module):
|
|
1035
1108
|
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
1036
1109
|
)
|
1037
1110
|
return hidden_states
|
1111
|
+
|
1112
|
+
|
1113
|
+
class TemporalResnetBlock(nn.Module):
|
1114
|
+
r"""
|
1115
|
+
A Resnet block.
|
1116
|
+
|
1117
|
+
Parameters:
|
1118
|
+
in_channels (`int`): The number of channels in the input.
|
1119
|
+
out_channels (`int`, *optional*, default to be `None`):
|
1120
|
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
1121
|
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
1122
|
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
1123
|
+
"""
|
1124
|
+
|
1125
|
+
def __init__(
|
1126
|
+
self,
|
1127
|
+
in_channels: int,
|
1128
|
+
out_channels: Optional[int] = None,
|
1129
|
+
temb_channels: int = 512,
|
1130
|
+
eps: float = 1e-6,
|
1131
|
+
):
|
1132
|
+
super().__init__()
|
1133
|
+
self.in_channels = in_channels
|
1134
|
+
out_channels = in_channels if out_channels is None else out_channels
|
1135
|
+
self.out_channels = out_channels
|
1136
|
+
|
1137
|
+
kernel_size = (3, 1, 1)
|
1138
|
+
padding = [k // 2 for k in kernel_size]
|
1139
|
+
|
1140
|
+
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
|
1141
|
+
self.conv1 = nn.Conv3d(
|
1142
|
+
in_channels,
|
1143
|
+
out_channels,
|
1144
|
+
kernel_size=kernel_size,
|
1145
|
+
stride=1,
|
1146
|
+
padding=padding,
|
1147
|
+
)
|
1148
|
+
|
1149
|
+
if temb_channels is not None:
|
1150
|
+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
1151
|
+
else:
|
1152
|
+
self.time_emb_proj = None
|
1153
|
+
|
1154
|
+
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
|
1155
|
+
|
1156
|
+
self.dropout = torch.nn.Dropout(0.0)
|
1157
|
+
self.conv2 = nn.Conv3d(
|
1158
|
+
out_channels,
|
1159
|
+
out_channels,
|
1160
|
+
kernel_size=kernel_size,
|
1161
|
+
stride=1,
|
1162
|
+
padding=padding,
|
1163
|
+
)
|
1164
|
+
|
1165
|
+
self.nonlinearity = get_activation("silu")
|
1166
|
+
|
1167
|
+
self.use_in_shortcut = self.in_channels != out_channels
|
1168
|
+
|
1169
|
+
self.conv_shortcut = None
|
1170
|
+
if self.use_in_shortcut:
|
1171
|
+
self.conv_shortcut = nn.Conv3d(
|
1172
|
+
in_channels,
|
1173
|
+
out_channels,
|
1174
|
+
kernel_size=1,
|
1175
|
+
stride=1,
|
1176
|
+
padding=0,
|
1177
|
+
)
|
1178
|
+
|
1179
|
+
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
1180
|
+
hidden_states = input_tensor
|
1181
|
+
|
1182
|
+
hidden_states = self.norm1(hidden_states)
|
1183
|
+
hidden_states = self.nonlinearity(hidden_states)
|
1184
|
+
hidden_states = self.conv1(hidden_states)
|
1185
|
+
|
1186
|
+
if self.time_emb_proj is not None:
|
1187
|
+
temb = self.nonlinearity(temb)
|
1188
|
+
temb = self.time_emb_proj(temb)[:, :, :, None, None]
|
1189
|
+
temb = temb.permute(0, 2, 1, 3, 4)
|
1190
|
+
hidden_states = hidden_states + temb
|
1191
|
+
|
1192
|
+
hidden_states = self.norm2(hidden_states)
|
1193
|
+
hidden_states = self.nonlinearity(hidden_states)
|
1194
|
+
hidden_states = self.dropout(hidden_states)
|
1195
|
+
hidden_states = self.conv2(hidden_states)
|
1196
|
+
|
1197
|
+
if self.conv_shortcut is not None:
|
1198
|
+
input_tensor = self.conv_shortcut(input_tensor)
|
1199
|
+
|
1200
|
+
output_tensor = input_tensor + hidden_states
|
1201
|
+
|
1202
|
+
return output_tensor
|
1203
|
+
|
1204
|
+
|
1205
|
+
# VideoResBlock
|
1206
|
+
class SpatioTemporalResBlock(nn.Module):
|
1207
|
+
r"""
|
1208
|
+
A SpatioTemporal Resnet block.
|
1209
|
+
|
1210
|
+
Parameters:
|
1211
|
+
in_channels (`int`): The number of channels in the input.
|
1212
|
+
out_channels (`int`, *optional*, default to be `None`):
|
1213
|
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
1214
|
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
1215
|
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
|
1216
|
+
temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
|
1217
|
+
merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
|
1218
|
+
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
|
1219
|
+
The merge strategy to use for the temporal mixing.
|
1220
|
+
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
|
1221
|
+
If `True`, switch the spatial and temporal mixing.
|
1222
|
+
"""
|
1223
|
+
|
1224
|
+
def __init__(
|
1225
|
+
self,
|
1226
|
+
in_channels: int,
|
1227
|
+
out_channels: Optional[int] = None,
|
1228
|
+
temb_channels: int = 512,
|
1229
|
+
eps: float = 1e-6,
|
1230
|
+
temporal_eps: Optional[float] = None,
|
1231
|
+
merge_factor: float = 0.5,
|
1232
|
+
merge_strategy="learned_with_images",
|
1233
|
+
switch_spatial_to_temporal_mix: bool = False,
|
1234
|
+
):
|
1235
|
+
super().__init__()
|
1236
|
+
|
1237
|
+
self.spatial_res_block = ResnetBlock2D(
|
1238
|
+
in_channels=in_channels,
|
1239
|
+
out_channels=out_channels,
|
1240
|
+
temb_channels=temb_channels,
|
1241
|
+
eps=eps,
|
1242
|
+
)
|
1243
|
+
|
1244
|
+
self.temporal_res_block = TemporalResnetBlock(
|
1245
|
+
in_channels=out_channels if out_channels is not None else in_channels,
|
1246
|
+
out_channels=out_channels if out_channels is not None else in_channels,
|
1247
|
+
temb_channels=temb_channels,
|
1248
|
+
eps=temporal_eps if temporal_eps is not None else eps,
|
1249
|
+
)
|
1250
|
+
|
1251
|
+
self.time_mixer = AlphaBlender(
|
1252
|
+
alpha=merge_factor,
|
1253
|
+
merge_strategy=merge_strategy,
|
1254
|
+
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
|
1255
|
+
)
|
1256
|
+
|
1257
|
+
def forward(
|
1258
|
+
self,
|
1259
|
+
hidden_states: torch.FloatTensor,
|
1260
|
+
temb: Optional[torch.FloatTensor] = None,
|
1261
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
1262
|
+
):
|
1263
|
+
num_frames = image_only_indicator.shape[-1]
|
1264
|
+
hidden_states = self.spatial_res_block(hidden_states, temb)
|
1265
|
+
|
1266
|
+
batch_frames, channels, height, width = hidden_states.shape
|
1267
|
+
batch_size = batch_frames // num_frames
|
1268
|
+
|
1269
|
+
hidden_states_mix = (
|
1270
|
+
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
1271
|
+
)
|
1272
|
+
hidden_states = (
|
1273
|
+
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
1274
|
+
)
|
1275
|
+
|
1276
|
+
if temb is not None:
|
1277
|
+
temb = temb.reshape(batch_size, num_frames, -1)
|
1278
|
+
|
1279
|
+
hidden_states = self.temporal_res_block(hidden_states, temb)
|
1280
|
+
hidden_states = self.time_mixer(
|
1281
|
+
x_spatial=hidden_states_mix,
|
1282
|
+
x_temporal=hidden_states,
|
1283
|
+
image_only_indicator=image_only_indicator,
|
1284
|
+
)
|
1285
|
+
|
1286
|
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
|
1287
|
+
return hidden_states
|
1288
|
+
|
1289
|
+
|
1290
|
+
class AlphaBlender(nn.Module):
|
1291
|
+
r"""
|
1292
|
+
A module to blend spatial and temporal features.
|
1293
|
+
|
1294
|
+
Parameters:
|
1295
|
+
alpha (`float`): The initial value of the blending factor.
|
1296
|
+
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
|
1297
|
+
The merge strategy to use for the temporal mixing.
|
1298
|
+
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
|
1299
|
+
If `True`, switch the spatial and temporal mixing.
|
1300
|
+
"""
|
1301
|
+
|
1302
|
+
strategies = ["learned", "fixed", "learned_with_images"]
|
1303
|
+
|
1304
|
+
def __init__(
|
1305
|
+
self,
|
1306
|
+
alpha: float,
|
1307
|
+
merge_strategy: str = "learned_with_images",
|
1308
|
+
switch_spatial_to_temporal_mix: bool = False,
|
1309
|
+
):
|
1310
|
+
super().__init__()
|
1311
|
+
self.merge_strategy = merge_strategy
|
1312
|
+
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
|
1313
|
+
|
1314
|
+
if merge_strategy not in self.strategies:
|
1315
|
+
raise ValueError(f"merge_strategy needs to be in {self.strategies}")
|
1316
|
+
|
1317
|
+
if self.merge_strategy == "fixed":
|
1318
|
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
1319
|
+
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
|
1320
|
+
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
|
1321
|
+
else:
|
1322
|
+
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
|
1323
|
+
|
1324
|
+
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
|
1325
|
+
if self.merge_strategy == "fixed":
|
1326
|
+
alpha = self.mix_factor
|
1327
|
+
|
1328
|
+
elif self.merge_strategy == "learned":
|
1329
|
+
alpha = torch.sigmoid(self.mix_factor)
|
1330
|
+
|
1331
|
+
elif self.merge_strategy == "learned_with_images":
|
1332
|
+
if image_only_indicator is None:
|
1333
|
+
raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
|
1334
|
+
|
1335
|
+
alpha = torch.where(
|
1336
|
+
image_only_indicator.bool(),
|
1337
|
+
torch.ones(1, 1, device=image_only_indicator.device),
|
1338
|
+
torch.sigmoid(self.mix_factor)[..., None],
|
1339
|
+
)
|
1340
|
+
|
1341
|
+
# (batch, channel, frames, height, width)
|
1342
|
+
if ndims == 5:
|
1343
|
+
alpha = alpha[:, None, :, None, None]
|
1344
|
+
# (batch*frames, height*width, channels)
|
1345
|
+
elif ndims == 3:
|
1346
|
+
alpha = alpha.reshape(-1)[:, None, None]
|
1347
|
+
else:
|
1348
|
+
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
|
1349
|
+
|
1350
|
+
else:
|
1351
|
+
raise NotImplementedError
|
1352
|
+
|
1353
|
+
return alpha
|
1354
|
+
|
1355
|
+
def forward(
|
1356
|
+
self,
|
1357
|
+
x_spatial: torch.Tensor,
|
1358
|
+
x_temporal: torch.Tensor,
|
1359
|
+
image_only_indicator: Optional[torch.Tensor] = None,
|
1360
|
+
) -> torch.Tensor:
|
1361
|
+
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
|
1362
|
+
alpha = alpha.to(x_spatial.dtype)
|
1363
|
+
|
1364
|
+
if self.switch_spatial_to_temporal_mix:
|
1365
|
+
alpha = 1.0 - alpha
|
1366
|
+
|
1367
|
+
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
|
1368
|
+
return x
|