diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -94,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
94
94
|
|
95
95
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
96
96
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
97
|
+
# TODO(aryan): configure calculation based on stride and dilation in the future.
|
98
|
+
# Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
|
99
|
+
time_pad = time_kernel_size - 1
|
100
|
+
height_pad = (height_kernel_size - 1) // 2
|
101
|
+
width_pad = (width_kernel_size - 1) // 2
|
101
102
|
|
103
|
+
self.pad_mode = pad_mode
|
102
104
|
self.height_pad = height_pad
|
103
105
|
self.width_pad = width_pad
|
104
106
|
self.time_pad = time_pad
|
@@ -107,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
107
109
|
self.temporal_dim = 2
|
108
110
|
self.time_kernel_size = time_kernel_size
|
109
111
|
|
110
|
-
stride = (stride, 1, 1)
|
112
|
+
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
111
113
|
dilation = (dilation, 1, 1)
|
112
114
|
self.conv = CogVideoXSafeConv3d(
|
113
115
|
in_channels=in_channels,
|
@@ -120,18 +122,24 @@ class CogVideoXCausalConv3d(nn.Module):
|
|
120
122
|
def fake_context_parallel_forward(
|
121
123
|
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
122
124
|
) -> torch.Tensor:
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
125
|
+
if self.pad_mode == "replicate":
|
126
|
+
inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
|
127
|
+
else:
|
128
|
+
kernel_size = self.time_kernel_size
|
129
|
+
if kernel_size > 1:
|
130
|
+
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
131
|
+
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
127
132
|
return inputs
|
128
133
|
|
129
134
|
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
130
135
|
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
131
|
-
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
132
136
|
|
133
|
-
|
134
|
-
|
137
|
+
if self.pad_mode == "replicate":
|
138
|
+
conv_cache = None
|
139
|
+
else:
|
140
|
+
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
141
|
+
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
142
|
+
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
135
143
|
|
136
144
|
output = self.conv(inputs)
|
137
145
|
return output, conv_cache
|
@@ -412,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
412
420
|
for i, resnet in enumerate(self.resnets):
|
413
421
|
conv_cache_key = f"resnet_{i}"
|
414
422
|
|
415
|
-
if
|
423
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
416
424
|
|
417
425
|
def create_custom_forward(module):
|
418
426
|
def create_forward(*inputs):
|
@@ -425,7 +433,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
|
425
433
|
hidden_states,
|
426
434
|
temb,
|
427
435
|
zq,
|
428
|
-
conv_cache
|
436
|
+
conv_cache.get(conv_cache_key),
|
429
437
|
)
|
430
438
|
else:
|
431
439
|
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
@@ -514,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
514
522
|
for i, resnet in enumerate(self.resnets):
|
515
523
|
conv_cache_key = f"resnet_{i}"
|
516
524
|
|
517
|
-
if
|
525
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
518
526
|
|
519
527
|
def create_custom_forward(module):
|
520
528
|
def create_forward(*inputs):
|
@@ -523,7 +531,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
|
523
531
|
return create_forward
|
524
532
|
|
525
533
|
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
526
|
-
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache
|
534
|
+
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
527
535
|
)
|
528
536
|
else:
|
529
537
|
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
@@ -628,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
628
636
|
for i, resnet in enumerate(self.resnets):
|
629
637
|
conv_cache_key = f"resnet_{i}"
|
630
638
|
|
631
|
-
if
|
639
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
632
640
|
|
633
641
|
def create_custom_forward(module):
|
634
642
|
def create_forward(*inputs):
|
@@ -641,7 +649,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
|
641
649
|
hidden_states,
|
642
650
|
temb,
|
643
651
|
zq,
|
644
|
-
conv_cache
|
652
|
+
conv_cache.get(conv_cache_key),
|
645
653
|
)
|
646
654
|
else:
|
647
655
|
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
@@ -765,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
|
|
765
773
|
|
766
774
|
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
767
775
|
|
768
|
-
if
|
776
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
769
777
|
|
770
778
|
def create_custom_forward(module):
|
771
779
|
def custom_forward(*inputs):
|
@@ -781,7 +789,7 @@ class CogVideoXEncoder3D(nn.Module):
|
|
781
789
|
hidden_states,
|
782
790
|
temb,
|
783
791
|
None,
|
784
|
-
conv_cache
|
792
|
+
conv_cache.get(conv_cache_key),
|
785
793
|
)
|
786
794
|
|
787
795
|
# 2. Mid
|
@@ -790,14 +798,14 @@ class CogVideoXEncoder3D(nn.Module):
|
|
790
798
|
hidden_states,
|
791
799
|
temb,
|
792
800
|
None,
|
793
|
-
conv_cache
|
801
|
+
conv_cache.get("mid_block"),
|
794
802
|
)
|
795
803
|
else:
|
796
804
|
# 1. Down
|
797
805
|
for i, down_block in enumerate(self.down_blocks):
|
798
806
|
conv_cache_key = f"down_block_{i}"
|
799
807
|
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
800
|
-
hidden_states, temb, None, conv_cache
|
808
|
+
hidden_states, temb, None, conv_cache.get(conv_cache_key)
|
801
809
|
)
|
802
810
|
|
803
811
|
# 2. Mid
|
@@ -931,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
|
|
931
939
|
|
932
940
|
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
933
941
|
|
934
|
-
if
|
942
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
935
943
|
|
936
944
|
def create_custom_forward(module):
|
937
945
|
def custom_forward(*inputs):
|
@@ -945,7 +953,7 @@ class CogVideoXDecoder3D(nn.Module):
|
|
945
953
|
hidden_states,
|
946
954
|
temb,
|
947
955
|
sample,
|
948
|
-
conv_cache
|
956
|
+
conv_cache.get("mid_block"),
|
949
957
|
)
|
950
958
|
|
951
959
|
# 2. Up
|
@@ -956,7 +964,7 @@ class CogVideoXDecoder3D(nn.Module):
|
|
956
964
|
hidden_states,
|
957
965
|
temb,
|
958
966
|
sample,
|
959
|
-
conv_cache
|
967
|
+
conv_cache.get(conv_cache_key),
|
960
968
|
)
|
961
969
|
else:
|
962
970
|
# 1. Mid
|
@@ -1049,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1049
1057
|
force_upcast: float = True,
|
1050
1058
|
use_quant_conv: bool = False,
|
1051
1059
|
use_post_quant_conv: bool = False,
|
1060
|
+
invert_scale_latents: bool = False,
|
1052
1061
|
):
|
1053
1062
|
super().__init__()
|
1054
1063
|
|
@@ -1467,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
1467
1476
|
z = posterior.sample(generator=generator)
|
1468
1477
|
else:
|
1469
1478
|
z = posterior.mode()
|
1470
|
-
dec = self.decode(z)
|
1479
|
+
dec = self.decode(z).sample
|
1471
1480
|
if not return_dict:
|
1472
1481
|
return (dec,)
|
1473
|
-
return dec
|
1482
|
+
return DecoderOutput(sample=dec)
|