diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -22,10 +22,7 @@ import torch.nn.functional as F
22
22
 
23
23
  from ..utils import is_torch_version
24
24
  from .activations import get_activation
25
- from .embeddings import (
26
- CombinedTimestepLabelEmbeddings,
27
- PixArtAlphaCombinedTimestepSizeEmbeddings,
28
- )
25
+ from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
29
26
 
30
27
 
31
28
  class AdaLayerNorm(nn.Module):
@@ -97,6 +94,40 @@ class FP32LayerNorm(nn.LayerNorm):
97
94
  ).to(origin_dtype)
98
95
 
99
96
 
97
+ class SD35AdaLayerNormZeroX(nn.Module):
98
+ r"""
99
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
100
+
101
+ Parameters:
102
+ embedding_dim (`int`): The size of each embedding vector.
103
+ num_embeddings (`int`): The size of the embeddings dictionary.
104
+ """
105
+
106
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
107
+ super().__init__()
108
+
109
+ self.silu = nn.SiLU()
110
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
111
+ if norm_type == "layer_norm":
112
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
113
+ else:
114
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
115
+
116
+ def forward(
117
+ self,
118
+ hidden_states: torch.Tensor,
119
+ emb: Optional[torch.Tensor] = None,
120
+ ) -> Tuple[torch.Tensor, ...]:
121
+ emb = self.linear(self.silu(emb))
122
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
123
+ 9, dim=1
124
+ )
125
+ norm_hidden_states = self.norm(hidden_states)
126
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
127
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
128
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
129
+
130
+
100
131
  class AdaLayerNormZero(nn.Module):
101
132
  r"""
102
133
  Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -232,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
232
263
  hidden_dtype: Optional[torch.dtype] = None,
233
264
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
234
265
  # No modulation happening here.
266
+ added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
235
267
  embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
236
268
  return self.linear(self.silu(embedded_timestep)), embedded_timestep
237
269
 
@@ -324,20 +356,21 @@ class LuminaLayerNormContinuous(nn.Module):
324
356
  out_dim: Optional[int] = None,
325
357
  ):
326
358
  super().__init__()
359
+
327
360
  # AdaLN
328
361
  self.silu = nn.SiLU()
329
362
  self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
363
+
330
364
  if norm_type == "layer_norm":
331
365
  self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
366
+ elif norm_type == "rms_norm":
367
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
332
368
  else:
333
369
  raise ValueError(f"unknown norm_type {norm_type}")
334
- # linear_2
370
+
371
+ self.linear_2 = None
335
372
  if out_dim is not None:
336
- self.linear_2 = nn.Linear(
337
- embedding_dim,
338
- out_dim,
339
- bias=bias,
340
- )
373
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
341
374
 
342
375
  def forward(
343
376
  self,
@@ -355,6 +388,51 @@ class LuminaLayerNormContinuous(nn.Module):
355
388
  return x
356
389
 
357
390
 
391
+ class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
392
+ r"""
393
+ Norm layer adaptive layer norm zero (adaLN-Zero).
394
+
395
+ Parameters:
396
+ embedding_dim (`int`): The size of each embedding vector.
397
+ num_embeddings (`int`): The size of the embeddings dictionary.
398
+ """
399
+
400
+ def __init__(self, embedding_dim: int, dim: int):
401
+ super().__init__()
402
+
403
+ self.silu = nn.SiLU()
404
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
405
+ self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
406
+ self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
407
+
408
+ def forward(
409
+ self,
410
+ x: torch.Tensor,
411
+ context: torch.Tensor,
412
+ emb: Optional[torch.Tensor] = None,
413
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
414
+ emb = self.linear(self.silu(emb))
415
+ (
416
+ shift_msa,
417
+ scale_msa,
418
+ gate_msa,
419
+ shift_mlp,
420
+ scale_mlp,
421
+ gate_mlp,
422
+ c_shift_msa,
423
+ c_scale_msa,
424
+ c_gate_msa,
425
+ c_shift_mlp,
426
+ c_scale_mlp,
427
+ c_gate_mlp,
428
+ ) = emb.chunk(12, dim=1)
429
+ normed_x = self.norm_x(x)
430
+ normed_context = self.norm_c(context)
431
+ x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
432
+ context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
433
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
434
+
435
+
358
436
  class CogVideoXLayerNormZero(nn.Module):
359
437
  def __init__(
360
438
  self,
@@ -407,20 +485,24 @@ else:
407
485
 
408
486
 
409
487
  class RMSNorm(nn.Module):
410
- def __init__(self, dim, eps: float, elementwise_affine: bool = True):
488
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
411
489
  super().__init__()
412
490
 
413
491
  self.eps = eps
492
+ self.elementwise_affine = elementwise_affine
414
493
 
415
494
  if isinstance(dim, numbers.Integral):
416
495
  dim = (dim,)
417
496
 
418
497
  self.dim = torch.Size(dim)
419
498
 
499
+ self.weight = None
500
+ self.bias = None
501
+
420
502
  if elementwise_affine:
421
503
  self.weight = nn.Parameter(torch.ones(dim))
422
- else:
423
- self.weight = None
504
+ if bias:
505
+ self.bias = nn.Parameter(torch.zeros(dim))
424
506
 
425
507
  def forward(self, hidden_states):
426
508
  input_dtype = hidden_states.dtype
@@ -432,12 +514,44 @@ class RMSNorm(nn.Module):
432
514
  if self.weight.dtype in [torch.float16, torch.bfloat16]:
433
515
  hidden_states = hidden_states.to(self.weight.dtype)
434
516
  hidden_states = hidden_states * self.weight
517
+ if self.bias is not None:
518
+ hidden_states = hidden_states + self.bias
435
519
  else:
436
520
  hidden_states = hidden_states.to(input_dtype)
437
521
 
438
522
  return hidden_states
439
523
 
440
524
 
525
+ # TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
526
+ # for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
527
+ class MochiRMSNorm(nn.Module):
528
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
529
+ super().__init__()
530
+
531
+ self.eps = eps
532
+
533
+ if isinstance(dim, numbers.Integral):
534
+ dim = (dim,)
535
+
536
+ self.dim = torch.Size(dim)
537
+
538
+ if elementwise_affine:
539
+ self.weight = nn.Parameter(torch.ones(dim))
540
+ else:
541
+ self.weight = None
542
+
543
+ def forward(self, hidden_states):
544
+ input_dtype = hidden_states.dtype
545
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
546
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
547
+
548
+ if self.weight is not None:
549
+ hidden_states = hidden_states * self.weight
550
+ hidden_states = hidden_states.to(input_dtype)
551
+
552
+ return hidden_states
553
+
554
+
441
555
  class GlobalResponseNorm(nn.Module):
442
556
  # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
443
557
  def __init__(self, dim):
@@ -449,3 +563,33 @@ class GlobalResponseNorm(nn.Module):
449
563
  gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
450
564
  nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
451
565
  return self.gamma * (x * nx) + self.beta + x
566
+
567
+
568
+ class LpNorm(nn.Module):
569
+ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
570
+ super().__init__()
571
+
572
+ self.p = p
573
+ self.dim = dim
574
+ self.eps = eps
575
+
576
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
+ return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
578
+
579
+
580
+ def get_normalization(
581
+ norm_type: str = "batch_norm",
582
+ num_features: Optional[int] = None,
583
+ eps: float = 1e-5,
584
+ elementwise_affine: bool = True,
585
+ bias: bool = True,
586
+ ) -> nn.Module:
587
+ if norm_type == "rms_norm":
588
+ norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
589
+ elif norm_type == "layer_norm":
590
+ norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
591
+ elif norm_type == "batch_norm":
592
+ norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
593
+ else:
594
+ raise ValueError(f"{norm_type=} is not supported.")
595
+ return norm
@@ -11,9 +11,15 @@ if is_torch_available():
11
11
  from .lumina_nextdit2d import LuminaNextDiT2DModel
12
12
  from .pixart_transformer_2d import PixArtTransformer2DModel
13
13
  from .prior_transformer import PriorTransformer
14
+ from .sana_transformer import SanaTransformer2DModel
14
15
  from .stable_audio_transformer import StableAudioDiTModel
15
16
  from .t5_film_transformer import T5FilmDecoder
16
17
  from .transformer_2d import Transformer2DModel
18
+ from .transformer_allegro import AllegroTransformer3DModel
19
+ from .transformer_cogview3plus import CogView3PlusTransformer2DModel
17
20
  from .transformer_flux import FluxTransformer2DModel
21
+ from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
22
+ from .transformer_ltx import LTXVideoTransformer3DModel
23
+ from .transformer_mochi import MochiTransformer3DModel
18
24
  from .transformer_sd3 import SD3Transformer2DModel
19
25
  from .transformer_temporal import TransformerTemporalModel
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
274
274
  pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
275
275
  """
276
276
 
277
+ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
277
278
  _supports_gradient_checkpointing = True
278
279
 
279
280
  @register_to_config
@@ -465,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
465
466
 
466
467
  # MMDiT blocks.
467
468
  for index_block, block in enumerate(self.joint_transformer_blocks):
468
- if self.training and self.gradient_checkpointing:
469
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
469
470
 
470
471
  def create_custom_forward(module, return_dict=None):
471
472
  def custom_forward(*inputs):
@@ -496,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
496
497
  combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
497
498
 
498
499
  for index_block, block in enumerate(self.single_transformer_blocks):
499
- if self.training and self.gradient_checkpointing:
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
500
501
 
501
502
  def create_custom_forward(module, return_dict=None):
502
503
  def custom_forward(*inputs):
@@ -19,7 +19,8 @@ import torch
19
19
  from torch import nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...utils import is_torch_version, logging
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
23
24
  from ...utils.torch_utils import maybe_allow_in_graph
24
25
  from ..attention import Attention, FeedForward
25
26
  from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
152
153
  return hidden_states, encoder_hidden_states
153
154
 
154
155
 
155
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
156
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
156
157
  """
157
158
  A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
158
159
 
@@ -169,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
169
170
  Whether to flip the sin to cos in the time embedding.
170
171
  time_embed_dim (`int`, defaults to `512`):
171
172
  Output dimension of timestep embeddings.
173
+ ofs_embed_dim (`int`, defaults to `512`):
174
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
172
175
  text_embed_dim (`int`, defaults to `4096`):
173
176
  Input dimension of text embeddings from the text encoder.
174
177
  num_layers (`int`, defaults to `30`):
@@ -176,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
176
179
  dropout (`float`, defaults to `0.0`):
177
180
  The dropout probability to use.
178
181
  attention_bias (`bool`, defaults to `True`):
179
- Whether or not to use bias in the attention projection layers.
182
+ Whether to use bias in the attention projection layers.
180
183
  sample_width (`int`, defaults to `90`):
181
184
  The width of the input latents.
182
185
  sample_height (`int`, defaults to `60`):
@@ -197,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
197
200
  timestep_activation_fn (`str`, defaults to `"silu"`):
198
201
  Activation function to use when generating the timestep embeddings.
199
202
  norm_elementwise_affine (`bool`, defaults to `True`):
200
- Whether or not to use elementwise affine in normalization layers.
203
+ Whether to use elementwise affine in normalization layers.
201
204
  norm_eps (`float`, defaults to `1e-5`):
202
205
  The epsilon value to use in normalization layers.
203
206
  spatial_interpolation_scale (`float`, defaults to `1.875`):
@@ -218,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
218
221
  flip_sin_to_cos: bool = True,
219
222
  freq_shift: int = 0,
220
223
  time_embed_dim: int = 512,
224
+ ofs_embed_dim: Optional[int] = None,
221
225
  text_embed_dim: int = 4096,
222
226
  num_layers: int = 30,
223
227
  dropout: float = 0.0,
@@ -226,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
226
230
  sample_height: int = 60,
227
231
  sample_frames: int = 49,
228
232
  patch_size: int = 2,
233
+ patch_size_t: Optional[int] = None,
229
234
  temporal_compression_ratio: int = 4,
230
235
  max_text_seq_length: int = 226,
231
236
  activation_fn: str = "gelu-approximate",
@@ -236,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
236
241
  temporal_interpolation_scale: float = 1.0,
237
242
  use_rotary_positional_embeddings: bool = False,
238
243
  use_learned_positional_embeddings: bool = False,
244
+ patch_bias: bool = True,
239
245
  ):
240
246
  super().__init__()
241
247
  inner_dim = num_attention_heads * attention_head_dim
@@ -250,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
250
256
  # 1. Patch embedding
251
257
  self.patch_embed = CogVideoXPatchEmbed(
252
258
  patch_size=patch_size,
259
+ patch_size_t=patch_size_t,
253
260
  in_channels=in_channels,
254
261
  embed_dim=inner_dim,
255
262
  text_embed_dim=text_embed_dim,
256
- bias=True,
263
+ bias=patch_bias,
257
264
  sample_width=sample_width,
258
265
  sample_height=sample_height,
259
266
  sample_frames=sample_frames,
@@ -266,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
266
273
  )
267
274
  self.embedding_dropout = nn.Dropout(dropout)
268
275
 
269
- # 2. Time embeddings
276
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
277
+
270
278
  self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
271
279
  self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
272
280
 
281
+ self.ofs_proj = None
282
+ self.ofs_embedding = None
283
+ if ofs_embed_dim:
284
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
285
+ self.ofs_embedding = TimestepEmbedding(
286
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
287
+ ) # same as time embeddings, for ofs
288
+
273
289
  # 3. Define spatio-temporal transformers blocks
274
290
  self.transformer_blocks = nn.ModuleList(
275
291
  [
@@ -297,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
297
313
  norm_eps=norm_eps,
298
314
  chunk_dim=1,
299
315
  )
300
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
316
+
317
+ if patch_size_t is None:
318
+ # For CogVideox 1.0
319
+ output_dim = patch_size * patch_size * out_channels
320
+ else:
321
+ # For CogVideoX 1.5
322
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
323
+
324
+ self.proj_out = nn.Linear(inner_dim, output_dim)
301
325
 
302
326
  self.gradient_checkpointing = False
303
327
 
@@ -410,9 +434,26 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
410
434
  encoder_hidden_states: torch.Tensor,
411
435
  timestep: Union[int, float, torch.LongTensor],
412
436
  timestep_cond: Optional[torch.Tensor] = None,
437
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
413
438
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
439
+ attention_kwargs: Optional[Dict[str, Any]] = None,
414
440
  return_dict: bool = True,
415
441
  ):
442
+ if attention_kwargs is not None:
443
+ attention_kwargs = attention_kwargs.copy()
444
+ lora_scale = attention_kwargs.pop("scale", 1.0)
445
+ else:
446
+ lora_scale = 1.0
447
+
448
+ if USE_PEFT_BACKEND:
449
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
450
+ scale_lora_layers(self, lora_scale)
451
+ else:
452
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
453
+ logger.warning(
454
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
455
+ )
456
+
416
457
  batch_size, num_frames, channels, height, width = hidden_states.shape
417
458
 
418
459
  # 1. Time embedding
@@ -425,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
425
466
  t_emb = t_emb.to(dtype=hidden_states.dtype)
426
467
  emb = self.time_embedding(t_emb, timestep_cond)
427
468
 
469
+ if self.ofs_embedding is not None:
470
+ ofs_emb = self.ofs_proj(ofs)
471
+ ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
472
+ ofs_emb = self.ofs_embedding(ofs_emb)
473
+ emb = emb + ofs_emb
474
+
428
475
  # 2. Patch embedding
429
476
  hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
430
477
  hidden_states = self.embedding_dropout(hidden_states)
@@ -435,7 +482,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
435
482
 
436
483
  # 3. Transformer blocks
437
484
  for i, block in enumerate(self.transformer_blocks):
438
- if self.training and self.gradient_checkpointing:
485
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
439
486
 
440
487
  def create_custom_forward(module):
441
488
  def custom_forward(*inputs):
@@ -474,12 +521,21 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
474
521
  hidden_states = self.proj_out(hidden_states)
475
522
 
476
523
  # 5. Unpatchify
477
- # Note: we use `-1` instead of `channels`:
478
- # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
479
- # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
480
524
  p = self.config.patch_size
481
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
482
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
525
+ p_t = self.config.patch_size_t
526
+
527
+ if p_t is None:
528
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
529
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
530
+ else:
531
+ output = hidden_states.reshape(
532
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
533
+ )
534
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
535
+
536
+ if USE_PEFT_BACKEND:
537
+ # remove `lora_scale` from each PEFT layer
538
+ unscale_lora_layers(self, lora_scale)
483
539
 
484
540
  if not return_dict:
485
541
  return (output,)
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
184
184
 
185
185
  # 2. Blocks
186
186
  for block in self.transformer_blocks:
187
- if self.training and self.gradient_checkpointing:
187
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
188
188
 
189
189
  def create_custom_forward(module, return_dict=None):
190
190
  def custom_forward(*inputs):
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
156
156
 
157
157
  # define temporal positional embedding
158
158
  temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
159
- inner_dim, torch.arange(0, video_length).unsqueeze(1)
159
+ inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
160
160
  ) # 1152 hidden size
161
- self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
161
+ self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
162
162
 
163
163
  self.gradient_checkpointing = False
164
164
 
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
238
238
  for i, (spatial_block, temp_block) in enumerate(
239
239
  zip(self.transformer_blocks, self.temporal_transformer_blocks)
240
240
  ):
241
- if self.training and self.gradient_checkpointing:
241
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
242
242
  hidden_states = torch.utils.checkpoint.checkpoint(
243
243
  spatial_block,
244
244
  hidden_states,
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
271
271
  if i == 0 and num_frame > 1:
272
272
  hidden_states = hidden_states + self.temp_pos_embed
273
273
 
274
- if self.training and self.gradient_checkpointing:
274
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
275
275
  hidden_states = torch.utils.checkpoint.checkpoint(
276
276
  temp_block,
277
277
  hidden_states,
@@ -19,7 +19,7 @@ from torch import nn
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...utils import is_torch_version, logging
21
21
  from ..attention import BasicTransformerBlock
22
- from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22
+ from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
23
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
24
24
  from ..modeling_outputs import Transformer2DModelOutput
25
25
  from ..modeling_utils import ModelMixin
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
247
247
  for name, module in self.named_children():
248
248
  fn_recursive_attn_processor(name, module, processor)
249
249
 
250
+ def set_default_attn_processor(self):
251
+ """
252
+ Disables custom attention processors and sets the default attention implementation.
253
+
254
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
255
+ """
256
+ self.set_attn_processor(AttnProcessor())
257
+
250
258
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251
259
  def fuse_qkv_projections(self):
252
260
  """
@@ -378,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
378
386
 
379
387
  # 2. Blocks
380
388
  for block in self.transformer_blocks:
381
- if self.training and self.gradient_checkpointing:
389
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
382
390
 
383
391
  def create_custom_forward(module, return_dict=None):
384
392
  def custom_forward(*inputs):