diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@
17
17
  import importlib
18
18
  import inspect
19
19
  import os
20
+ from array import array
20
21
  from collections import OrderedDict
21
22
  from pathlib import Path
22
23
  from typing import List, Optional, Union
@@ -25,8 +26,8 @@ import safetensors
25
26
  import torch
26
27
  from huggingface_hub.utils import EntryNotFoundError
27
28
 
28
- from ..quantizers.quantization_config import QuantizationMethod
29
29
  from ..utils import (
30
+ GGUF_FILE_EXTENSION,
30
31
  SAFE_WEIGHTS_INDEX_NAME,
31
32
  SAFETENSORS_FILE_EXTENSION,
32
33
  WEIGHTS_INDEX_NAME,
@@ -34,6 +35,8 @@ from ..utils import (
34
35
  _get_model_file,
35
36
  deprecate,
36
37
  is_accelerate_available,
38
+ is_gguf_available,
39
+ is_torch_available,
37
40
  is_torch_version,
38
41
  logging,
39
42
  )
@@ -140,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
140
143
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
141
144
  if file_extension == SAFETENSORS_FILE_EXTENSION:
142
145
  return safetensors.torch.load_file(checkpoint_file, device="cpu")
146
+ elif file_extension == GGUF_FILE_EXTENSION:
147
+ return load_gguf_checkpoint(checkpoint_file)
143
148
  else:
144
149
  weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
145
150
  return torch.load(
@@ -176,11 +181,12 @@ def load_model_dict_into_meta(
176
181
  hf_quantizer=None,
177
182
  keep_in_fp32_modules=None,
178
183
  ) -> List[str]:
184
+ if device is not None and not isinstance(device, (str, torch.device)):
185
+ raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
179
186
  if hf_quantizer is None:
180
187
  device = device or torch.device("cpu")
181
188
  dtype = dtype or torch.float32
182
189
  is_quantized = hf_quantizer is not None
183
- is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
184
190
 
185
191
  accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
186
192
  empty_state_dict = model.state_dict()
@@ -211,17 +217,18 @@ def load_model_dict_into_meta(
211
217
  set_module_kwargs["dtype"] = dtype
212
218
 
213
219
  # bnb params are flattened.
220
+ # gguf quants have a different shape based on the type of quantization applied
214
221
  if empty_state_dict[param_name].shape != param.shape:
215
222
  if (
216
- is_quant_method_bnb
223
+ is_quantized
217
224
  and hf_quantizer.pre_quantized
218
225
  and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
219
226
  ):
220
- hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
221
- elif not is_quant_method_bnb:
227
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
228
+ else:
222
229
  model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
223
230
  raise ValueError(
224
- f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
231
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
225
232
  )
226
233
 
227
234
  if is_quantized and (
@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
396
403
  index_file = None
397
404
 
398
405
  return index_file
406
+
407
+
408
+ def _gguf_parse_value(_value, data_type):
409
+ if not isinstance(data_type, list):
410
+ data_type = [data_type]
411
+ if len(data_type) == 1:
412
+ data_type = data_type[0]
413
+ array_data_type = None
414
+ else:
415
+ if data_type[0] != 9:
416
+ raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
417
+ data_type, array_data_type = data_type
418
+
419
+ if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
420
+ _value = int(_value[0])
421
+ elif data_type in [6, 12]:
422
+ _value = float(_value[0])
423
+ elif data_type in [7]:
424
+ _value = bool(_value[0])
425
+ elif data_type in [8]:
426
+ _value = array("B", list(_value)).tobytes().decode()
427
+ elif data_type in [9]:
428
+ _value = _gguf_parse_value(_value, array_data_type)
429
+ return _value
430
+
431
+
432
+ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
433
+ """
434
+ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
435
+ attributes.
436
+
437
+ Args:
438
+ gguf_checkpoint_path (`str`):
439
+ The path the to GGUF file to load
440
+ return_tensors (`bool`, defaults to `True`):
441
+ Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
442
+ metadata in memory.
443
+ """
444
+
445
+ if is_gguf_available() and is_torch_available():
446
+ import gguf
447
+ from gguf import GGUFReader
448
+
449
+ from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
450
+ else:
451
+ logger.error(
452
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
453
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
454
+ )
455
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
456
+
457
+ reader = GGUFReader(gguf_checkpoint_path)
458
+
459
+ parsed_parameters = {}
460
+ for tensor in reader.tensors:
461
+ name = tensor.name
462
+ quant_type = tensor.tensor_type
463
+
464
+ # if the tensor is a torch supported dtype do not use GGUFParameter
465
+ is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
466
+ if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
467
+ _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
468
+ raise ValueError(
469
+ (
470
+ f"{name} has a quantization type: {str(quant_type)} which is unsupported."
471
+ "\n\nCurrently the following quantization types are supported: \n\n"
472
+ f"{_supported_quants_str}"
473
+ "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
474
+ )
475
+ )
476
+
477
+ weights = torch.from_numpy(tensor.data.copy())
478
+ parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
479
+
480
+ return parsed_parameters
@@ -530,7 +530,7 @@ class FlaxModelMixin(PushToHubMixin):
530
530
 
531
531
  if push_to_hub:
532
532
  commit_message = kwargs.pop("commit_message", None)
533
- private = kwargs.pop("private", False)
533
+ private = kwargs.pop("private", None)
534
534
  create_pr = kwargs.pop("create_pr", False)
535
535
  token = kwargs.pop("token", None)
536
536
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
99
99
 
100
100
 
101
101
  def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
102
- try:
103
- return next(parameter.parameters()).dtype
104
- except StopIteration:
105
- try:
106
- return next(parameter.buffers()).dtype
107
- except StopIteration:
108
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
109
-
110
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
111
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
112
- return tuples
113
-
114
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
115
- first_tuple = next(gen)
116
- return first_tuple[1].dtype
102
+ """
103
+ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
104
+ """
105
+ last_dtype = None
106
+ for param in parameter.parameters():
107
+ last_dtype = param.dtype
108
+ if param.is_floating_point():
109
+ return param.dtype
110
+
111
+ for buffer in parameter.buffers():
112
+ last_dtype = buffer.dtype
113
+ if buffer.is_floating_point():
114
+ return buffer.dtype
115
+
116
+ if last_dtype is not None:
117
+ # if no floating dtype was found return whatever the first dtype is
118
+ return last_dtype
119
+
120
+ # For nn.DataParallel compatibility in PyTorch > 1.5
121
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
122
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
123
+ return tuples
124
+
125
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
126
+ last_tuple = None
127
+ for tuple in gen:
128
+ last_tuple = tuple
129
+ if tuple[1].is_floating_point():
130
+ return tuple[1].dtype
131
+
132
+ if last_tuple is not None:
133
+ # fallback to the last dtype
134
+ return last_tuple[1].dtype
117
135
 
118
136
 
119
137
  class ModelMixin(torch.nn.Module, PushToHubMixin):
@@ -208,6 +226,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
208
226
  """
209
227
  self.set_use_npu_flash_attention(False)
210
228
 
229
+ def set_use_xla_flash_attention(
230
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
231
+ ) -> None:
232
+ # Recursively walk through all the children.
233
+ # Any children which exposes the set_use_xla_flash_attention method
234
+ # gets the message
235
+ def fn_recursive_set_flash_attention(module: torch.nn.Module):
236
+ if hasattr(module, "set_use_xla_flash_attention"):
237
+ module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
238
+
239
+ for child in module.children():
240
+ fn_recursive_set_flash_attention(child)
241
+
242
+ for module in self.children():
243
+ if isinstance(module, torch.nn.Module):
244
+ fn_recursive_set_flash_attention(module)
245
+
246
+ def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
247
+ r"""
248
+ Enable the flash attention pallals kernel for torch_xla.
249
+ """
250
+ self.set_use_xla_flash_attention(True, partition_spec)
251
+
252
+ def disable_xla_flash_attention(self):
253
+ r"""
254
+ Disable the flash attention pallals kernel for torch_xla.
255
+ """
256
+ self.set_use_xla_flash_attention(False)
257
+
211
258
  def set_use_memory_efficient_attention_xformers(
212
259
  self, valid: bool, attention_op: Optional[Callable] = None
213
260
  ) -> None:
@@ -338,7 +385,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
338
385
 
339
386
  if push_to_hub:
340
387
  commit_message = kwargs.pop("commit_message", None)
341
- private = kwargs.pop("private", False)
388
+ private = kwargs.pop("private", None)
342
389
  create_pr = kwargs.pop("create_pr", False)
343
390
  token = kwargs.pop("token", None)
344
391
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -673,8 +720,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
673
720
  if hf_quantizer is not None:
674
721
  if device_map is not None:
675
722
  raise NotImplementedError(
676
- "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
723
+ "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
677
724
  )
725
+
678
726
  hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
679
727
  torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
680
728
 
@@ -771,6 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
771
819
  revision=revision,
772
820
  subfolder=subfolder or "",
773
821
  )
822
+ # TODO: https://github.com/huggingface/diffusers/issues/10013
774
823
  if hf_quantizer is not None:
775
824
  model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
776
825
  logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
@@ -829,14 +878,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
829
878
  if device_map is None and not is_sharded:
830
879
  # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
831
880
  # It would error out during the `validate_environment()` call above in the absence of cuda.
832
- is_quant_method_bnb = (
833
- getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
834
- )
835
881
  if hf_quantizer is None:
836
882
  param_device = "cpu"
837
883
  # TODO (sayakpaul, SunMarc): remove this after model loading refactor
838
- elif is_quant_method_bnb:
839
- param_device = torch.cuda.current_device()
884
+ else:
885
+ param_device = torch.device(torch.cuda.current_device())
840
886
  state_dict = load_state_dict(model_file, variant=variant)
841
887
  model._convert_deprecated_attention_blocks(state_dict)
842
888
 
@@ -1010,14 +1056,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1010
1056
  dtype_present_in_args = True
1011
1057
  break
1012
1058
 
1013
- # Checks if the model has been loaded in 4-bit or 8-bit with BNB
1014
- if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1059
+ if getattr(self, "is_quantized", False):
1015
1060
  if dtype_present_in_args:
1016
1061
  raise ValueError(
1017
- "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
1018
- " desired `dtype` by passing the correct `torch_dtype` argument."
1062
+ "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
1063
+ "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
1019
1064
  )
1020
1065
 
1066
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
1021
1067
  if getattr(self, "is_loaded_in_8bit", False):
1022
1068
  raise ValueError(
1023
1069
  "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
@@ -22,10 +22,7 @@ import torch.nn.functional as F
22
22
 
23
23
  from ..utils import is_torch_version
24
24
  from .activations import get_activation
25
- from .embeddings import (
26
- CombinedTimestepLabelEmbeddings,
27
- PixArtAlphaCombinedTimestepSizeEmbeddings,
28
- )
25
+ from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
29
26
 
30
27
 
31
28
  class AdaLayerNorm(nn.Module):
@@ -266,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
266
263
  hidden_dtype: Optional[torch.dtype] = None,
267
264
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
268
265
  # No modulation happening here.
266
+ added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
269
267
  embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
270
268
  return self.linear(self.silu(embedded_timestep)), embedded_timestep
271
269
 
@@ -358,20 +356,21 @@ class LuminaLayerNormContinuous(nn.Module):
358
356
  out_dim: Optional[int] = None,
359
357
  ):
360
358
  super().__init__()
359
+
361
360
  # AdaLN
362
361
  self.silu = nn.SiLU()
363
362
  self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
363
+
364
364
  if norm_type == "layer_norm":
365
365
  self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
366
+ elif norm_type == "rms_norm":
367
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
366
368
  else:
367
369
  raise ValueError(f"unknown norm_type {norm_type}")
368
- # linear_2
370
+
371
+ self.linear_2 = None
369
372
  if out_dim is not None:
370
- self.linear_2 = nn.Linear(
371
- embedding_dim,
372
- out_dim,
373
- bias=bias,
374
- )
373
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
375
374
 
376
375
  def forward(
377
376
  self,
@@ -486,20 +485,24 @@ else:
486
485
 
487
486
 
488
487
  class RMSNorm(nn.Module):
489
- def __init__(self, dim, eps: float, elementwise_affine: bool = True):
488
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
490
489
  super().__init__()
491
490
 
492
491
  self.eps = eps
492
+ self.elementwise_affine = elementwise_affine
493
493
 
494
494
  if isinstance(dim, numbers.Integral):
495
495
  dim = (dim,)
496
496
 
497
497
  self.dim = torch.Size(dim)
498
498
 
499
+ self.weight = None
500
+ self.bias = None
501
+
499
502
  if elementwise_affine:
500
503
  self.weight = nn.Parameter(torch.ones(dim))
501
- else:
502
- self.weight = None
504
+ if bias:
505
+ self.bias = nn.Parameter(torch.zeros(dim))
503
506
 
504
507
  def forward(self, hidden_states):
505
508
  input_dtype = hidden_states.dtype
@@ -511,12 +514,44 @@ class RMSNorm(nn.Module):
511
514
  if self.weight.dtype in [torch.float16, torch.bfloat16]:
512
515
  hidden_states = hidden_states.to(self.weight.dtype)
513
516
  hidden_states = hidden_states * self.weight
517
+ if self.bias is not None:
518
+ hidden_states = hidden_states + self.bias
514
519
  else:
515
520
  hidden_states = hidden_states.to(input_dtype)
516
521
 
517
522
  return hidden_states
518
523
 
519
524
 
525
+ # TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
526
+ # for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
527
+ class MochiRMSNorm(nn.Module):
528
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
529
+ super().__init__()
530
+
531
+ self.eps = eps
532
+
533
+ if isinstance(dim, numbers.Integral):
534
+ dim = (dim,)
535
+
536
+ self.dim = torch.Size(dim)
537
+
538
+ if elementwise_affine:
539
+ self.weight = nn.Parameter(torch.ones(dim))
540
+ else:
541
+ self.weight = None
542
+
543
+ def forward(self, hidden_states):
544
+ input_dtype = hidden_states.dtype
545
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
546
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
547
+
548
+ if self.weight is not None:
549
+ hidden_states = hidden_states * self.weight
550
+ hidden_states = hidden_states.to(input_dtype)
551
+
552
+ return hidden_states
553
+
554
+
520
555
  class GlobalResponseNorm(nn.Module):
521
556
  # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
522
557
  def __init__(self, dim):
@@ -528,3 +563,33 @@ class GlobalResponseNorm(nn.Module):
528
563
  gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
529
564
  nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
530
565
  return self.gamma * (x * nx) + self.beta + x
566
+
567
+
568
+ class LpNorm(nn.Module):
569
+ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
570
+ super().__init__()
571
+
572
+ self.p = p
573
+ self.dim = dim
574
+ self.eps = eps
575
+
576
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
+ return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
578
+
579
+
580
+ def get_normalization(
581
+ norm_type: str = "batch_norm",
582
+ num_features: Optional[int] = None,
583
+ eps: float = 1e-5,
584
+ elementwise_affine: bool = True,
585
+ bias: bool = True,
586
+ ) -> nn.Module:
587
+ if norm_type == "rms_norm":
588
+ norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
589
+ elif norm_type == "layer_norm":
590
+ norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
591
+ elif norm_type == "batch_norm":
592
+ norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
593
+ else:
594
+ raise ValueError(f"{norm_type=} is not supported.")
595
+ return norm
@@ -11,10 +11,15 @@ if is_torch_available():
11
11
  from .lumina_nextdit2d import LuminaNextDiT2DModel
12
12
  from .pixart_transformer_2d import PixArtTransformer2DModel
13
13
  from .prior_transformer import PriorTransformer
14
+ from .sana_transformer import SanaTransformer2DModel
14
15
  from .stable_audio_transformer import StableAudioDiTModel
15
16
  from .t5_film_transformer import T5FilmDecoder
16
17
  from .transformer_2d import Transformer2DModel
18
+ from .transformer_allegro import AllegroTransformer3DModel
17
19
  from .transformer_cogview3plus import CogView3PlusTransformer2DModel
18
20
  from .transformer_flux import FluxTransformer2DModel
21
+ from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
22
+ from .transformer_ltx import LTXVideoTransformer3DModel
23
+ from .transformer_mochi import MochiTransformer3DModel
19
24
  from .transformer_sd3 import SD3Transformer2DModel
20
25
  from .transformer_temporal import TransformerTemporalModel
@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
466
466
 
467
467
  # MMDiT blocks.
468
468
  for index_block, block in enumerate(self.joint_transformer_blocks):
469
- if self.training and self.gradient_checkpointing:
469
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
470
470
 
471
471
  def create_custom_forward(module, return_dict=None):
472
472
  def custom_forward(*inputs):
@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
497
497
  combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
498
498
 
499
499
  for index_block, block in enumerate(self.single_transformer_blocks):
500
- if self.training and self.gradient_checkpointing:
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
501
 
502
502
  def create_custom_forward(module, return_dict=None):
503
503
  def custom_forward(*inputs):
@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
170
170
  Whether to flip the sin to cos in the time embedding.
171
171
  time_embed_dim (`int`, defaults to `512`):
172
172
  Output dimension of timestep embeddings.
173
+ ofs_embed_dim (`int`, defaults to `512`):
174
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
173
175
  text_embed_dim (`int`, defaults to `4096`):
174
176
  Input dimension of text embeddings from the text encoder.
175
177
  num_layers (`int`, defaults to `30`):
@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
177
179
  dropout (`float`, defaults to `0.0`):
178
180
  The dropout probability to use.
179
181
  attention_bias (`bool`, defaults to `True`):
180
- Whether or not to use bias in the attention projection layers.
182
+ Whether to use bias in the attention projection layers.
181
183
  sample_width (`int`, defaults to `90`):
182
184
  The width of the input latents.
183
185
  sample_height (`int`, defaults to `60`):
@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
198
200
  timestep_activation_fn (`str`, defaults to `"silu"`):
199
201
  Activation function to use when generating the timestep embeddings.
200
202
  norm_elementwise_affine (`bool`, defaults to `True`):
201
- Whether or not to use elementwise affine in normalization layers.
203
+ Whether to use elementwise affine in normalization layers.
202
204
  norm_eps (`float`, defaults to `1e-5`):
203
205
  The epsilon value to use in normalization layers.
204
206
  spatial_interpolation_scale (`float`, defaults to `1.875`):
@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
219
221
  flip_sin_to_cos: bool = True,
220
222
  freq_shift: int = 0,
221
223
  time_embed_dim: int = 512,
224
+ ofs_embed_dim: Optional[int] = None,
222
225
  text_embed_dim: int = 4096,
223
226
  num_layers: int = 30,
224
227
  dropout: float = 0.0,
@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
227
230
  sample_height: int = 60,
228
231
  sample_frames: int = 49,
229
232
  patch_size: int = 2,
233
+ patch_size_t: Optional[int] = None,
230
234
  temporal_compression_ratio: int = 4,
231
235
  max_text_seq_length: int = 226,
232
236
  activation_fn: str = "gelu-approximate",
@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
237
241
  temporal_interpolation_scale: float = 1.0,
238
242
  use_rotary_positional_embeddings: bool = False,
239
243
  use_learned_positional_embeddings: bool = False,
244
+ patch_bias: bool = True,
240
245
  ):
241
246
  super().__init__()
242
247
  inner_dim = num_attention_heads * attention_head_dim
@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
251
256
  # 1. Patch embedding
252
257
  self.patch_embed = CogVideoXPatchEmbed(
253
258
  patch_size=patch_size,
259
+ patch_size_t=patch_size_t,
254
260
  in_channels=in_channels,
255
261
  embed_dim=inner_dim,
256
262
  text_embed_dim=text_embed_dim,
257
- bias=True,
263
+ bias=patch_bias,
258
264
  sample_width=sample_width,
259
265
  sample_height=sample_height,
260
266
  sample_frames=sample_frames,
@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
267
273
  )
268
274
  self.embedding_dropout = nn.Dropout(dropout)
269
275
 
270
- # 2. Time embeddings
276
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
277
+
271
278
  self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
272
279
  self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
273
280
 
281
+ self.ofs_proj = None
282
+ self.ofs_embedding = None
283
+ if ofs_embed_dim:
284
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
285
+ self.ofs_embedding = TimestepEmbedding(
286
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
287
+ ) # same as time embeddings, for ofs
288
+
274
289
  # 3. Define spatio-temporal transformers blocks
275
290
  self.transformer_blocks = nn.ModuleList(
276
291
  [
@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
298
313
  norm_eps=norm_eps,
299
314
  chunk_dim=1,
300
315
  )
301
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
316
+
317
+ if patch_size_t is None:
318
+ # For CogVideox 1.0
319
+ output_dim = patch_size * patch_size * out_channels
320
+ else:
321
+ # For CogVideoX 1.5
322
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
323
+
324
+ self.proj_out = nn.Linear(inner_dim, output_dim)
302
325
 
303
326
  self.gradient_checkpointing = False
304
327
 
@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
411
434
  encoder_hidden_states: torch.Tensor,
412
435
  timestep: Union[int, float, torch.LongTensor],
413
436
  timestep_cond: Optional[torch.Tensor] = None,
437
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
414
438
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
415
439
  attention_kwargs: Optional[Dict[str, Any]] = None,
416
440
  return_dict: bool = True,
@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
442
466
  t_emb = t_emb.to(dtype=hidden_states.dtype)
443
467
  emb = self.time_embedding(t_emb, timestep_cond)
444
468
 
469
+ if self.ofs_embedding is not None:
470
+ ofs_emb = self.ofs_proj(ofs)
471
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
472
+ ofs_emb = self.ofs_embedding(ofs_emb)
473
+ emb = emb + ofs_emb
474
+
445
475
  # 2. Patch embedding
446
476
  hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
447
477
  hidden_states = self.embedding_dropout(hidden_states)
@@ -452,7 +482,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
452
482
 
453
483
  # 3. Transformer blocks
454
484
  for i, block in enumerate(self.transformer_blocks):
455
- if self.training and self.gradient_checkpointing:
485
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
456
486
 
457
487
  def create_custom_forward(module):
458
488
  def custom_forward(*inputs):
@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
491
521
  hidden_states = self.proj_out(hidden_states)
492
522
 
493
523
  # 5. Unpatchify
494
- # Note: we use `-1` instead of `channels`:
495
- # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
496
- # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
497
524
  p = self.config.patch_size
498
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
499
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
525
+ p_t = self.config.patch_size_t
526
+
527
+ if p_t is None:
528
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
529
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
530
+ else:
531
+ output = hidden_states.reshape(
532
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
533
+ )
534
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
500
535
 
501
536
  if USE_PEFT_BACKEND:
502
537
  # remove `lora_scale` from each PEFT layer
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
184
184
 
185
185
  # 2. Blocks
186
186
  for block in self.transformer_blocks:
187
- if self.training and self.gradient_checkpointing:
187
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
188
188
 
189
189
  def create_custom_forward(module, return_dict=None):
190
190
  def custom_forward(*inputs):