diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -94,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module):
94
94
 
95
95
  time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
96
96
 
97
- self.pad_mode = pad_mode
98
- time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
99
- height_pad = height_kernel_size // 2
100
- width_pad = width_kernel_size // 2
97
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
98
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
99
+ time_pad = time_kernel_size - 1
100
+ height_pad = (height_kernel_size - 1) // 2
101
+ width_pad = (width_kernel_size - 1) // 2
101
102
 
103
+ self.pad_mode = pad_mode
102
104
  self.height_pad = height_pad
103
105
  self.width_pad = width_pad
104
106
  self.time_pad = time_pad
@@ -107,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module):
107
109
  self.temporal_dim = 2
108
110
  self.time_kernel_size = time_kernel_size
109
111
 
110
- stride = (stride, 1, 1)
112
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
111
113
  dilation = (dilation, 1, 1)
112
114
  self.conv = CogVideoXSafeConv3d(
113
115
  in_channels=in_channels,
@@ -120,18 +122,24 @@ class CogVideoXCausalConv3d(nn.Module):
120
122
  def fake_context_parallel_forward(
121
123
  self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
122
124
  ) -> torch.Tensor:
123
- kernel_size = self.time_kernel_size
124
- if kernel_size > 1:
125
- cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
126
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
125
+ if self.pad_mode == "replicate":
126
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
127
+ else:
128
+ kernel_size = self.time_kernel_size
129
+ if kernel_size > 1:
130
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
131
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
132
  return inputs
128
133
 
129
134
  def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
130
135
  inputs = self.fake_context_parallel_forward(inputs, conv_cache)
131
- conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
132
136
 
133
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
134
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
137
+ if self.pad_mode == "replicate":
138
+ conv_cache = None
139
+ else:
140
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
141
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
142
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
135
143
 
136
144
  output = self.conv(inputs)
137
145
  return output, conv_cache
@@ -412,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
412
420
  for i, resnet in enumerate(self.resnets):
413
421
  conv_cache_key = f"resnet_{i}"
414
422
 
415
- if self.training and self.gradient_checkpointing:
423
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
416
424
 
417
425
  def create_custom_forward(module):
418
426
  def create_forward(*inputs):
@@ -425,7 +433,7 @@ class CogVideoXDownBlock3D(nn.Module):
425
433
  hidden_states,
426
434
  temb,
427
435
  zq,
428
- conv_cache=conv_cache.get(conv_cache_key),
436
+ conv_cache.get(conv_cache_key),
429
437
  )
430
438
  else:
431
439
  hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -514,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
514
522
  for i, resnet in enumerate(self.resnets):
515
523
  conv_cache_key = f"resnet_{i}"
516
524
 
517
- if self.training and self.gradient_checkpointing:
525
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
518
526
 
519
527
  def create_custom_forward(module):
520
528
  def create_forward(*inputs):
@@ -523,7 +531,7 @@ class CogVideoXMidBlock3D(nn.Module):
523
531
  return create_forward
524
532
 
525
533
  hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
526
- create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
534
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
527
535
  )
528
536
  else:
529
537
  hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -628,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
628
636
  for i, resnet in enumerate(self.resnets):
629
637
  conv_cache_key = f"resnet_{i}"
630
638
 
631
- if self.training and self.gradient_checkpointing:
639
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
632
640
 
633
641
  def create_custom_forward(module):
634
642
  def create_forward(*inputs):
@@ -641,7 +649,7 @@ class CogVideoXUpBlock3D(nn.Module):
641
649
  hidden_states,
642
650
  temb,
643
651
  zq,
644
- conv_cache=conv_cache.get(conv_cache_key),
652
+ conv_cache.get(conv_cache_key),
645
653
  )
646
654
  else:
647
655
  hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -765,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
765
773
 
766
774
  hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
767
775
 
768
- if self.training and self.gradient_checkpointing:
776
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
769
777
 
770
778
  def create_custom_forward(module):
771
779
  def custom_forward(*inputs):
@@ -781,7 +789,7 @@ class CogVideoXEncoder3D(nn.Module):
781
789
  hidden_states,
782
790
  temb,
783
791
  None,
784
- conv_cache=conv_cache.get(conv_cache_key),
792
+ conv_cache.get(conv_cache_key),
785
793
  )
786
794
 
787
795
  # 2. Mid
@@ -790,14 +798,14 @@ class CogVideoXEncoder3D(nn.Module):
790
798
  hidden_states,
791
799
  temb,
792
800
  None,
793
- conv_cache=conv_cache.get("mid_block"),
801
+ conv_cache.get("mid_block"),
794
802
  )
795
803
  else:
796
804
  # 1. Down
797
805
  for i, down_block in enumerate(self.down_blocks):
798
806
  conv_cache_key = f"down_block_{i}"
799
807
  hidden_states, new_conv_cache[conv_cache_key] = down_block(
800
- hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
808
+ hidden_states, temb, None, conv_cache.get(conv_cache_key)
801
809
  )
802
810
 
803
811
  # 2. Mid
@@ -931,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
931
939
 
932
940
  hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
933
941
 
934
- if self.training and self.gradient_checkpointing:
942
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
935
943
 
936
944
  def create_custom_forward(module):
937
945
  def custom_forward(*inputs):
@@ -945,7 +953,7 @@ class CogVideoXDecoder3D(nn.Module):
945
953
  hidden_states,
946
954
  temb,
947
955
  sample,
948
- conv_cache=conv_cache.get("mid_block"),
956
+ conv_cache.get("mid_block"),
949
957
  )
950
958
 
951
959
  # 2. Up
@@ -956,7 +964,7 @@ class CogVideoXDecoder3D(nn.Module):
956
964
  hidden_states,
957
965
  temb,
958
966
  sample,
959
- conv_cache=conv_cache.get(conv_cache_key),
967
+ conv_cache.get(conv_cache_key),
960
968
  )
961
969
  else:
962
970
  # 1. Mid
@@ -1049,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1049
1057
  force_upcast: float = True,
1050
1058
  use_quant_conv: bool = False,
1051
1059
  use_post_quant_conv: bool = False,
1060
+ invert_scale_latents: bool = False,
1052
1061
  ):
1053
1062
  super().__init__()
1054
1063
 
@@ -1467,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1467
1476
  z = posterior.sample(generator=generator)
1468
1477
  else:
1469
1478
  z = posterior.mode()
1470
- dec = self.decode(z)
1479
+ dec = self.decode(z).sample
1471
1480
  if not return_dict:
1472
1481
  return (dec,)
1473
- return dec
1482
+ return DecoderOutput(sample=dec)