diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,499 @@
1
+ # Copyright 2024 The Genmo team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...loaders.single_file_model import FromOriginalModelMixin
24
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import FeedForward
27
+ from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
28
+ from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import AdaLayerNormContinuous, RMSNorm
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class MochiModulatedRMSNorm(nn.Module):
38
+ def __init__(self, eps: float):
39
+ super().__init__()
40
+
41
+ self.eps = eps
42
+ self.norm = RMSNorm(0, eps, False)
43
+
44
+ def forward(self, hidden_states, scale=None):
45
+ hidden_states_dtype = hidden_states.dtype
46
+ hidden_states = hidden_states.to(torch.float32)
47
+
48
+ hidden_states = self.norm(hidden_states)
49
+
50
+ if scale is not None:
51
+ hidden_states = hidden_states * scale
52
+
53
+ hidden_states = hidden_states.to(hidden_states_dtype)
54
+
55
+ return hidden_states
56
+
57
+
58
+ class MochiLayerNormContinuous(nn.Module):
59
+ def __init__(
60
+ self,
61
+ embedding_dim: int,
62
+ conditioning_embedding_dim: int,
63
+ eps=1e-5,
64
+ bias=True,
65
+ ):
66
+ super().__init__()
67
+
68
+ # AdaLN
69
+ self.silu = nn.SiLU()
70
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
71
+ self.norm = MochiModulatedRMSNorm(eps=eps)
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ conditioning_embedding: torch.Tensor,
77
+ ) -> torch.Tensor:
78
+ input_dtype = x.dtype
79
+
80
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
81
+ scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
82
+ x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
83
+
84
+ return x.to(input_dtype)
85
+
86
+
87
+ class MochiRMSNormZero(nn.Module):
88
+ r"""
89
+ Adaptive RMS Norm used in Mochi.
90
+
91
+ Parameters:
92
+ embedding_dim (`int`): The size of each embedding vector.
93
+ """
94
+
95
+ def __init__(
96
+ self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ self.silu = nn.SiLU()
101
+ self.linear = nn.Linear(embedding_dim, hidden_dim)
102
+ self.norm = RMSNorm(0, eps, False)
103
+
104
+ def forward(
105
+ self, hidden_states: torch.Tensor, emb: torch.Tensor
106
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
107
+ hidden_states_dtype = hidden_states.dtype
108
+
109
+ emb = self.linear(self.silu(emb))
110
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
111
+ hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
112
+ hidden_states = hidden_states.to(hidden_states_dtype)
113
+
114
+ return hidden_states, gate_msa, scale_mlp, gate_mlp
115
+
116
+
117
+ @maybe_allow_in_graph
118
+ class MochiTransformerBlock(nn.Module):
119
+ r"""
120
+ Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
121
+
122
+ Args:
123
+ dim (`int`):
124
+ The number of channels in the input and output.
125
+ num_attention_heads (`int`):
126
+ The number of heads to use for multi-head attention.
127
+ attention_head_dim (`int`):
128
+ The number of channels in each head.
129
+ qk_norm (`str`, defaults to `"rms_norm"`):
130
+ The normalization layer to use.
131
+ activation_fn (`str`, defaults to `"swiglu"`):
132
+ Activation function to use in feed-forward.
133
+ context_pre_only (`bool`, defaults to `False`):
134
+ Whether or not to process context-related conditions with additional layers.
135
+ eps (`float`, defaults to `1e-6`):
136
+ Epsilon value for normalization layers.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_attention_heads: int,
143
+ attention_head_dim: int,
144
+ pooled_projection_dim: int,
145
+ qk_norm: str = "rms_norm",
146
+ activation_fn: str = "swiglu",
147
+ context_pre_only: bool = False,
148
+ eps: float = 1e-6,
149
+ ) -> None:
150
+ super().__init__()
151
+
152
+ self.context_pre_only = context_pre_only
153
+ self.ff_inner_dim = (4 * dim * 2) // 3
154
+ self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
155
+
156
+ self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
157
+
158
+ if not context_pre_only:
159
+ self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
160
+ else:
161
+ self.norm1_context = MochiLayerNormContinuous(
162
+ embedding_dim=pooled_projection_dim,
163
+ conditioning_embedding_dim=dim,
164
+ eps=eps,
165
+ )
166
+
167
+ self.attn1 = MochiAttention(
168
+ query_dim=dim,
169
+ heads=num_attention_heads,
170
+ dim_head=attention_head_dim,
171
+ bias=False,
172
+ added_kv_proj_dim=pooled_projection_dim,
173
+ added_proj_bias=False,
174
+ out_dim=dim,
175
+ out_context_dim=pooled_projection_dim,
176
+ context_pre_only=context_pre_only,
177
+ processor=MochiAttnProcessor2_0(),
178
+ eps=1e-5,
179
+ )
180
+
181
+ # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
182
+ self.norm2 = MochiModulatedRMSNorm(eps=eps)
183
+ self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
184
+
185
+ self.norm3 = MochiModulatedRMSNorm(eps)
186
+ self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
187
+
188
+ self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
189
+ self.ff_context = None
190
+ if not context_pre_only:
191
+ self.ff_context = FeedForward(
192
+ pooled_projection_dim,
193
+ inner_dim=self.ff_context_inner_dim,
194
+ activation_fn=activation_fn,
195
+ bias=False,
196
+ )
197
+
198
+ self.norm4 = MochiModulatedRMSNorm(eps=eps)
199
+ self.norm4_context = MochiModulatedRMSNorm(eps=eps)
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ encoder_hidden_states: torch.Tensor,
205
+ temb: torch.Tensor,
206
+ encoder_attention_mask: torch.Tensor,
207
+ image_rotary_emb: Optional[torch.Tensor] = None,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
210
+
211
+ if not self.context_pre_only:
212
+ norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
213
+ encoder_hidden_states, temb
214
+ )
215
+ else:
216
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
217
+
218
+ attn_hidden_states, context_attn_hidden_states = self.attn1(
219
+ hidden_states=norm_hidden_states,
220
+ encoder_hidden_states=norm_encoder_hidden_states,
221
+ image_rotary_emb=image_rotary_emb,
222
+ attention_mask=encoder_attention_mask,
223
+ )
224
+
225
+ hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
226
+ norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
227
+ ff_output = self.ff(norm_hidden_states)
228
+ hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
229
+
230
+ if not self.context_pre_only:
231
+ encoder_hidden_states = encoder_hidden_states + self.norm2_context(
232
+ context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
233
+ )
234
+ norm_encoder_hidden_states = self.norm3_context(
235
+ encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
236
+ )
237
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
238
+ encoder_hidden_states = encoder_hidden_states + self.norm4_context(
239
+ context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
240
+ )
241
+
242
+ return hidden_states, encoder_hidden_states
243
+
244
+
245
+ class MochiRoPE(nn.Module):
246
+ r"""
247
+ RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
248
+
249
+ Args:
250
+ base_height (`int`, defaults to `192`):
251
+ Base height used to compute interpolation scale for rotary positional embeddings.
252
+ base_width (`int`, defaults to `192`):
253
+ Base width used to compute interpolation scale for rotary positional embeddings.
254
+ """
255
+
256
+ def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
257
+ super().__init__()
258
+
259
+ self.target_area = base_height * base_width
260
+
261
+ def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
262
+ edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
263
+ return (edges[:-1] + edges[1:]) / 2
264
+
265
+ def _get_positions(
266
+ self,
267
+ num_frames: int,
268
+ height: int,
269
+ width: int,
270
+ device: Optional[torch.device] = None,
271
+ dtype: Optional[torch.dtype] = None,
272
+ ) -> torch.Tensor:
273
+ scale = (self.target_area / (height * width)) ** 0.5
274
+
275
+ t = torch.arange(num_frames, device=device, dtype=dtype)
276
+ h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
277
+ w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
278
+
279
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
280
+
281
+ positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
282
+ return positions
283
+
284
+ def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
285
+ with torch.autocast(freqs.device.type, torch.float32):
286
+ # Always run ROPE freqs computation in FP32
287
+ freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
288
+
289
+ freqs_cos = torch.cos(freqs)
290
+ freqs_sin = torch.sin(freqs)
291
+ return freqs_cos, freqs_sin
292
+
293
+ def forward(
294
+ self,
295
+ pos_frequencies: torch.Tensor,
296
+ num_frames: int,
297
+ height: int,
298
+ width: int,
299
+ device: Optional[torch.device] = None,
300
+ dtype: Optional[torch.dtype] = None,
301
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
302
+ pos = self._get_positions(num_frames, height, width, device, dtype)
303
+ rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
304
+ return rope_cos, rope_sin
305
+
306
+
307
+ @maybe_allow_in_graph
308
+ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
309
+ r"""
310
+ A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
311
+
312
+ Args:
313
+ patch_size (`int`, defaults to `2`):
314
+ The size of the patches to use in the patch embedding layer.
315
+ num_attention_heads (`int`, defaults to `24`):
316
+ The number of heads to use for multi-head attention.
317
+ attention_head_dim (`int`, defaults to `128`):
318
+ The number of channels in each head.
319
+ num_layers (`int`, defaults to `48`):
320
+ The number of layers of Transformer blocks to use.
321
+ in_channels (`int`, defaults to `12`):
322
+ The number of channels in the input.
323
+ out_channels (`int`, *optional*, defaults to `None`):
324
+ The number of channels in the output.
325
+ qk_norm (`str`, defaults to `"rms_norm"`):
326
+ The normalization layer to use.
327
+ text_embed_dim (`int`, defaults to `4096`):
328
+ Input dimension of text embeddings from the text encoder.
329
+ time_embed_dim (`int`, defaults to `256`):
330
+ Output dimension of timestep embeddings.
331
+ activation_fn (`str`, defaults to `"swiglu"`):
332
+ Activation function to use in feed-forward.
333
+ max_sequence_length (`int`, defaults to `256`):
334
+ The maximum sequence length of text embeddings supported.
335
+ """
336
+
337
+ _supports_gradient_checkpointing = True
338
+ _no_split_modules = ["MochiTransformerBlock"]
339
+
340
+ @register_to_config
341
+ def __init__(
342
+ self,
343
+ patch_size: int = 2,
344
+ num_attention_heads: int = 24,
345
+ attention_head_dim: int = 128,
346
+ num_layers: int = 48,
347
+ pooled_projection_dim: int = 1536,
348
+ in_channels: int = 12,
349
+ out_channels: Optional[int] = None,
350
+ qk_norm: str = "rms_norm",
351
+ text_embed_dim: int = 4096,
352
+ time_embed_dim: int = 256,
353
+ activation_fn: str = "swiglu",
354
+ max_sequence_length: int = 256,
355
+ ) -> None:
356
+ super().__init__()
357
+
358
+ inner_dim = num_attention_heads * attention_head_dim
359
+ out_channels = out_channels or in_channels
360
+
361
+ self.patch_embed = PatchEmbed(
362
+ patch_size=patch_size,
363
+ in_channels=in_channels,
364
+ embed_dim=inner_dim,
365
+ pos_embed_type=None,
366
+ )
367
+
368
+ self.time_embed = MochiCombinedTimestepCaptionEmbedding(
369
+ embedding_dim=inner_dim,
370
+ pooled_projection_dim=pooled_projection_dim,
371
+ text_embed_dim=text_embed_dim,
372
+ time_embed_dim=time_embed_dim,
373
+ num_attention_heads=8,
374
+ )
375
+
376
+ self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
377
+ self.rope = MochiRoPE()
378
+
379
+ self.transformer_blocks = nn.ModuleList(
380
+ [
381
+ MochiTransformerBlock(
382
+ dim=inner_dim,
383
+ num_attention_heads=num_attention_heads,
384
+ attention_head_dim=attention_head_dim,
385
+ pooled_projection_dim=pooled_projection_dim,
386
+ qk_norm=qk_norm,
387
+ activation_fn=activation_fn,
388
+ context_pre_only=i == num_layers - 1,
389
+ )
390
+ for i in range(num_layers)
391
+ ]
392
+ )
393
+
394
+ self.norm_out = AdaLayerNormContinuous(
395
+ inner_dim,
396
+ inner_dim,
397
+ elementwise_affine=False,
398
+ eps=1e-6,
399
+ norm_type="layer_norm",
400
+ )
401
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
402
+
403
+ self.gradient_checkpointing = False
404
+
405
+ def _set_gradient_checkpointing(self, module, value=False):
406
+ if hasattr(module, "gradient_checkpointing"):
407
+ module.gradient_checkpointing = value
408
+
409
+ def forward(
410
+ self,
411
+ hidden_states: torch.Tensor,
412
+ encoder_hidden_states: torch.Tensor,
413
+ timestep: torch.LongTensor,
414
+ encoder_attention_mask: torch.Tensor,
415
+ attention_kwargs: Optional[Dict[str, Any]] = None,
416
+ return_dict: bool = True,
417
+ ) -> torch.Tensor:
418
+ if attention_kwargs is not None:
419
+ attention_kwargs = attention_kwargs.copy()
420
+ lora_scale = attention_kwargs.pop("scale", 1.0)
421
+ else:
422
+ lora_scale = 1.0
423
+
424
+ if USE_PEFT_BACKEND:
425
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
426
+ scale_lora_layers(self, lora_scale)
427
+ else:
428
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
429
+ logger.warning(
430
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
431
+ )
432
+
433
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
434
+ p = self.config.patch_size
435
+
436
+ post_patch_height = height // p
437
+ post_patch_width = width // p
438
+
439
+ temb, encoder_hidden_states = self.time_embed(
440
+ timestep,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ hidden_dtype=hidden_states.dtype,
444
+ )
445
+
446
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
447
+ hidden_states = self.patch_embed(hidden_states)
448
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
449
+
450
+ image_rotary_emb = self.rope(
451
+ self.pos_frequencies,
452
+ num_frames,
453
+ post_patch_height,
454
+ post_patch_width,
455
+ device=hidden_states.device,
456
+ dtype=torch.float32,
457
+ )
458
+
459
+ for i, block in enumerate(self.transformer_blocks):
460
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
461
+
462
+ def create_custom_forward(module):
463
+ def custom_forward(*inputs):
464
+ return module(*inputs)
465
+
466
+ return custom_forward
467
+
468
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
469
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
470
+ create_custom_forward(block),
471
+ hidden_states,
472
+ encoder_hidden_states,
473
+ temb,
474
+ encoder_attention_mask,
475
+ image_rotary_emb,
476
+ **ckpt_kwargs,
477
+ )
478
+ else:
479
+ hidden_states, encoder_hidden_states = block(
480
+ hidden_states=hidden_states,
481
+ encoder_hidden_states=encoder_hidden_states,
482
+ temb=temb,
483
+ encoder_attention_mask=encoder_attention_mask,
484
+ image_rotary_emb=image_rotary_emb,
485
+ )
486
+ hidden_states = self.norm_out(hidden_states, temb)
487
+ hidden_states = self.proj_out(hidden_states)
488
+
489
+ hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
490
+ hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
491
+ output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
492
+
493
+ if USE_PEFT_BACKEND:
494
+ # remove `lora_scale` from each PEFT layer
495
+ unscale_lora_layers(self, lora_scale)
496
+
497
+ if not return_dict:
498
+ return (output,)
499
+ return Transformer2DModelOutput(sample=output)
@@ -11,20 +11,25 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
-
16
14
  from typing import Any, Dict, List, Optional, Tuple, Union
17
15
 
18
16
  import torch
19
17
  import torch.nn as nn
18
+ import torch.nn.functional as F
20
19
 
21
20
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
- from ...models.attention import JointTransformerBlock
24
- from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
21
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
22
+ from ...models.attention import FeedForward, JointTransformerBlock
23
+ from ...models.attention_processor import (
24
+ Attention,
25
+ AttentionProcessor,
26
+ FusedJointAttnProcessor2_0,
27
+ JointAttnProcessor2_0,
28
+ )
25
29
  from ...models.modeling_utils import ModelMixin
26
- from ...models.normalization import AdaLayerNormContinuous
30
+ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
27
31
  from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
32
+ from ...utils.torch_utils import maybe_allow_in_graph
28
33
  from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
34
  from ..modeling_outputs import Transformer2DModelOutput
30
35
 
@@ -32,7 +37,75 @@ from ..modeling_outputs import Transformer2DModelOutput
32
37
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
38
 
34
39
 
35
- class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
40
+ @maybe_allow_in_graph
41
+ class SD3SingleTransformerBlock(nn.Module):
42
+ r"""
43
+ A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44
+
45
+ Reference: https://arxiv.org/abs/2403.03206
46
+
47
+ Parameters:
48
+ dim (`int`): The number of channels in the input and output.
49
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
50
+ attention_head_dim (`int`): The number of channels in each head.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ dim: int,
56
+ num_attention_heads: int,
57
+ attention_head_dim: int,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.norm1 = AdaLayerNormZero(dim)
62
+
63
+ if hasattr(F, "scaled_dot_product_attention"):
64
+ processor = JointAttnProcessor2_0()
65
+ else:
66
+ raise ValueError(
67
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
68
+ )
69
+
70
+ self.attn = Attention(
71
+ query_dim=dim,
72
+ dim_head=attention_head_dim,
73
+ heads=num_attention_heads,
74
+ out_dim=dim,
75
+ bias=True,
76
+ processor=processor,
77
+ eps=1e-6,
78
+ )
79
+
80
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
81
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
82
+
83
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
84
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
85
+ # Attention.
86
+ attn_output = self.attn(
87
+ hidden_states=norm_hidden_states,
88
+ encoder_hidden_states=None,
89
+ )
90
+
91
+ # Process attention outputs for the `hidden_states`.
92
+ attn_output = gate_msa.unsqueeze(1) * attn_output
93
+ hidden_states = hidden_states + attn_output
94
+
95
+ norm_hidden_states = self.norm2(hidden_states)
96
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
97
+
98
+ ff_output = self.ff(norm_hidden_states)
99
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
100
+
101
+ hidden_states = hidden_states + ff_output
102
+
103
+ return hidden_states
104
+
105
+
106
+ class SD3Transformer2DModel(
107
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
108
+ ):
36
109
  """
37
110
  The Transformer model introduced in Stable Diffusion 3.
38
111
 
@@ -268,6 +341,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
268
341
  block_controlnet_hidden_states: List = None,
269
342
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
343
  return_dict: bool = True,
344
+ skip_layers: Optional[List[int]] = None,
271
345
  ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
346
  """
273
347
  The [`SD3Transformer2DModel`] forward method.
@@ -277,11 +351,11 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
277
351
  Input `hidden_states`.
278
352
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
353
  Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
- from the embeddings of input conditions.
282
- timestep ( `torch.LongTensor`):
354
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
355
+ Embeddings projected from the embeddings of input conditions.
356
+ timestep (`torch.LongTensor`):
283
357
  Used to indicate denoising step.
284
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
358
+ block_controlnet_hidden_states (`list` of `torch.Tensor`):
285
359
  A list of tensors that if specified are added to the residuals of transformer blocks.
286
360
  joint_attention_kwargs (`dict`, *optional*):
287
361
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -290,6 +364,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
290
364
  return_dict (`bool`, *optional*, defaults to `True`):
291
365
  Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
366
  tuple.
367
+ skip_layers (`list` of `int`, *optional*):
368
+ A list of layer indices to skip during the forward pass.
293
369
 
294
370
  Returns:
295
371
  If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
@@ -316,8 +392,17 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
316
392
  temb = self.time_text_embed(timestep, pooled_projections)
317
393
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
394
 
395
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
396
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
397
+ ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
398
+
399
+ joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
400
+
319
401
  for index_block, block in enumerate(self.transformer_blocks):
320
- if self.training and self.gradient_checkpointing:
402
+ # Skip specified layers
403
+ is_skip = True if skip_layers is not None and index_block in skip_layers else False
404
+
405
+ if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
321
406
 
322
407
  def create_custom_forward(module, return_dict=None):
323
408
  def custom_forward(*inputs):
@@ -334,18 +419,21 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
334
419
  hidden_states,
335
420
  encoder_hidden_states,
336
421
  temb,
422
+ joint_attention_kwargs,
337
423
  **ckpt_kwargs,
338
424
  )
339
-
340
- else:
425
+ elif not is_skip:
341
426
  encoder_hidden_states, hidden_states = block(
342
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
427
+ hidden_states=hidden_states,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ temb=temb,
430
+ joint_attention_kwargs=joint_attention_kwargs,
343
431
  )
344
432
 
345
433
  # controlnet residual
346
434
  if block_controlnet_hidden_states is not None and block.context_pre_only is False:
347
- interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
348
- hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
435
+ interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
436
+ hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
349
437
 
350
438
  hidden_states = self.norm_out(hidden_states, temb)
351
439
  hidden_states = self.proj_out(hidden_states)
@@ -340,7 +340,7 @@ class TransformerSpatioTemporalModel(nn.Module):
340
340
 
341
341
  # 2. Blocks
342
342
  for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
343
- if self.training and self.gradient_checkpointing:
343
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
344
344
  hidden_states = torch.utils.checkpoint.checkpoint(
345
345
  block,
346
346
  hidden_states,