diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, List, Optional, Union
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
17
 
18
+ import numpy as np
18
19
  import torch
19
20
  import torch.nn as nn
20
21
  import torch.nn.functional as F
@@ -22,52 +23,23 @@ import torch.nn.functional as F
22
23
  from ...configuration_utils import ConfigMixin, register_to_config
23
24
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
25
  from ...models.attention import FeedForward
25
- from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
26
+ from ...models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ FusedFluxAttnProcessor2_0,
31
+ )
26
32
  from ...models.modeling_utils import ModelMixin
27
33
  from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
28
34
  from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
35
  from ...utils.torch_utils import maybe_allow_in_graph
30
- from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
36
+ from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
31
37
  from ..modeling_outputs import Transformer2DModelOutput
32
38
 
33
39
 
34
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
41
 
36
42
 
37
- # YiYi to-do: refactor rope related functions/classes
38
- def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
39
- assert dim % 2 == 0, "The dimension must be even."
40
-
41
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
42
- omega = 1.0 / (theta**scale)
43
-
44
- batch_size, seq_length = pos.shape
45
- out = torch.einsum("...n,d->...nd", pos, omega)
46
- cos_out = torch.cos(out)
47
- sin_out = torch.sin(out)
48
-
49
- stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
50
- out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
51
- return out.float()
52
-
53
-
54
- # YiYi to-do: refactor rope related functions/classes
55
- class EmbedND(nn.Module):
56
- def __init__(self, dim: int, theta: int, axes_dim: List[int]):
57
- super().__init__()
58
- self.dim = dim
59
- self.theta = theta
60
- self.axes_dim = axes_dim
61
-
62
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
63
- n_axes = ids.shape[-1]
64
- emb = torch.cat(
65
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
66
- dim=-3,
67
- )
68
- return emb.unsqueeze(1)
69
-
70
-
71
43
  @maybe_allow_in_graph
72
44
  class FluxSingleTransformerBlock(nn.Module):
73
45
  r"""
@@ -92,7 +64,7 @@ class FluxSingleTransformerBlock(nn.Module):
92
64
  self.act_mlp = nn.GELU(approximate="tanh")
93
65
  self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
94
66
 
95
- processor = FluxSingleAttnProcessor2_0()
67
+ processor = FluxAttnProcessor2_0()
96
68
  self.attn = Attention(
97
69
  query_dim=dim,
98
70
  cross_attention_dim=None,
@@ -111,14 +83,16 @@ class FluxSingleTransformerBlock(nn.Module):
111
83
  hidden_states: torch.FloatTensor,
112
84
  temb: torch.FloatTensor,
113
85
  image_rotary_emb=None,
86
+ joint_attention_kwargs=None,
114
87
  ):
115
88
  residual = hidden_states
116
89
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
117
90
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
118
-
91
+ joint_attention_kwargs = joint_attention_kwargs or {}
119
92
  attn_output = self.attn(
120
93
  hidden_states=norm_hidden_states,
121
94
  image_rotary_emb=image_rotary_emb,
95
+ **joint_attention_kwargs,
122
96
  )
123
97
 
124
98
  hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -189,18 +163,20 @@ class FluxTransformerBlock(nn.Module):
189
163
  encoder_hidden_states: torch.FloatTensor,
190
164
  temb: torch.FloatTensor,
191
165
  image_rotary_emb=None,
166
+ joint_attention_kwargs=None,
192
167
  ):
193
168
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
194
169
 
195
170
  norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
196
171
  encoder_hidden_states, emb=temb
197
172
  )
198
-
173
+ joint_attention_kwargs = joint_attention_kwargs or {}
199
174
  # Attention.
200
175
  attn_output, context_attn_output = self.attn(
201
176
  hidden_states=norm_hidden_states,
202
177
  encoder_hidden_states=norm_encoder_hidden_states,
203
178
  image_rotary_emb=image_rotary_emb,
179
+ **joint_attention_kwargs,
204
180
  )
205
181
 
206
182
  # Process attention outputs for the `hidden_states`.
@@ -250,6 +226,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
250
226
  """
251
227
 
252
228
  _supports_gradient_checkpointing = True
229
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
253
230
 
254
231
  @register_to_config
255
232
  def __init__(
@@ -263,13 +240,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
263
240
  joint_attention_dim: int = 4096,
264
241
  pooled_projection_dim: int = 768,
265
242
  guidance_embeds: bool = False,
266
- axes_dims_rope: List[int] = [16, 56, 56],
243
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
267
244
  ):
268
245
  super().__init__()
269
246
  self.out_channels = in_channels
270
247
  self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
271
248
 
272
- self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
249
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
250
+
273
251
  text_time_guidance_cls = (
274
252
  CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
275
253
  )
@@ -307,6 +285,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
307
285
 
308
286
  self.gradient_checkpointing = False
309
287
 
288
+ @property
289
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
290
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
291
+ r"""
292
+ Returns:
293
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
294
+ indexed by its weight name.
295
+ """
296
+ # set recursively
297
+ processors = {}
298
+
299
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
300
+ if hasattr(module, "get_processor"):
301
+ processors[f"{name}.processor"] = module.get_processor()
302
+
303
+ for sub_name, child in module.named_children():
304
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
305
+
306
+ return processors
307
+
308
+ for name, module in self.named_children():
309
+ fn_recursive_add_processors(name, module, processors)
310
+
311
+ return processors
312
+
313
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
314
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
315
+ r"""
316
+ Sets the attention processor to use to compute attention.
317
+
318
+ Parameters:
319
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
320
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
321
+ for **all** `Attention` layers.
322
+
323
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
324
+ processor. This is strongly recommended when setting trainable attention processors.
325
+
326
+ """
327
+ count = len(self.attn_processors.keys())
328
+
329
+ if isinstance(processor, dict) and len(processor) != count:
330
+ raise ValueError(
331
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
332
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
333
+ )
334
+
335
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
336
+ if hasattr(module, "set_processor"):
337
+ if not isinstance(processor, dict):
338
+ module.set_processor(processor)
339
+ else:
340
+ module.set_processor(processor.pop(f"{name}.processor"))
341
+
342
+ for sub_name, child in module.named_children():
343
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
344
+
345
+ for name, module in self.named_children():
346
+ fn_recursive_attn_processor(name, module, processor)
347
+
348
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
349
+ def fuse_qkv_projections(self):
350
+ """
351
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
352
+ are fused. For cross-attention modules, key and value projection matrices are fused.
353
+
354
+ <Tip warning={true}>
355
+
356
+ This API is 🧪 experimental.
357
+
358
+ </Tip>
359
+ """
360
+ self.original_attn_processors = None
361
+
362
+ for _, attn_processor in self.attn_processors.items():
363
+ if "Added" in str(attn_processor.__class__.__name__):
364
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
365
+
366
+ self.original_attn_processors = self.attn_processors
367
+
368
+ for module in self.modules():
369
+ if isinstance(module, Attention):
370
+ module.fuse_projections(fuse=True)
371
+
372
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
373
+
374
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
375
+ def unfuse_qkv_projections(self):
376
+ """Disables the fused QKV projection if enabled.
377
+
378
+ <Tip warning={true}>
379
+
380
+ This API is 🧪 experimental.
381
+
382
+ </Tip>
383
+
384
+ """
385
+ if self.original_attn_processors is not None:
386
+ self.set_attn_processor(self.original_attn_processors)
387
+
310
388
  def _set_gradient_checkpointing(self, module, value=False):
311
389
  if hasattr(module, "gradient_checkpointing"):
312
390
  module.gradient_checkpointing = value
@@ -321,7 +399,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
321
399
  txt_ids: torch.Tensor = None,
322
400
  guidance: torch.Tensor = None,
323
401
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
402
+ controlnet_block_samples=None,
403
+ controlnet_single_block_samples=None,
324
404
  return_dict: bool = True,
405
+ controlnet_blocks_repeat: bool = False,
325
406
  ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
326
407
  """
327
408
  The [`FluxTransformer2DModel`] forward method.
@@ -377,7 +458,20 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
377
458
  )
378
459
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)
379
460
 
380
- ids = torch.cat((txt_ids, img_ids), dim=1)
461
+ if txt_ids.ndim == 3:
462
+ logger.warning(
463
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
464
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
465
+ )
466
+ txt_ids = txt_ids[0]
467
+ if img_ids.ndim == 3:
468
+ logger.warning(
469
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
470
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
471
+ )
472
+ img_ids = img_ids[0]
473
+
474
+ ids = torch.cat((txt_ids, img_ids), dim=0)
381
475
  image_rotary_emb = self.pos_embed(ids)
382
476
 
383
477
  for index_block, block in enumerate(self.transformer_blocks):
@@ -408,8 +502,21 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
408
502
  encoder_hidden_states=encoder_hidden_states,
409
503
  temb=temb,
410
504
  image_rotary_emb=image_rotary_emb,
505
+ joint_attention_kwargs=joint_attention_kwargs,
411
506
  )
412
507
 
508
+ # controlnet residual
509
+ if controlnet_block_samples is not None:
510
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
511
+ interval_control = int(np.ceil(interval_control))
512
+ # For Xlabs ControlNet.
513
+ if controlnet_blocks_repeat:
514
+ hidden_states = (
515
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
516
+ )
517
+ else:
518
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
519
+
413
520
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
414
521
 
415
522
  for index_block, block in enumerate(self.single_transformer_blocks):
@@ -438,6 +545,16 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
438
545
  hidden_states=hidden_states,
439
546
  temb=temb,
440
547
  image_rotary_emb=image_rotary_emb,
548
+ joint_attention_kwargs=joint_attention_kwargs,
549
+ )
550
+
551
+ # controlnet residual
552
+ if controlnet_single_block_samples is not None:
553
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
554
+ interval_control = int(np.ceil(interval_control))
555
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
556
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
557
+ + controlnet_single_block_samples[index_block // interval_control]
441
558
  )
442
559
 
443
560
  hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
@@ -69,6 +69,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
69
69
  pooled_projection_dim: int = 2048,
70
70
  out_channels: int = 16,
71
71
  pos_embed_max_size: int = 96,
72
+ dual_attention_layers: Tuple[
73
+ int, ...
74
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
+ qk_norm: Optional[str] = None,
72
76
  ):
73
77
  super().__init__()
74
78
  default_out_channels = in_channels
@@ -97,6 +101,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
97
101
  num_attention_heads=self.config.num_attention_heads,
98
102
  attention_head_dim=self.config.attention_head_dim,
99
103
  context_pre_only=i == num_layers - 1,
104
+ qk_norm=qk_norm,
105
+ use_dual_attention=True if i in dual_attention_layers else False,
100
106
  )
101
107
  for i in range(self.config.num_layers)
102
108
  ]
@@ -463,7 +463,6 @@ class UNet2DConditionModel(
463
463
  dropout=dropout,
464
464
  )
465
465
  self.up_blocks.append(up_block)
466
- prev_output_channel = output_channel
467
466
 
468
467
  # out
469
468
  if norm_num_groups is not None:
@@ -599,7 +598,7 @@ class UNet2DConditionModel(
599
598
  )
600
599
  elif encoder_hid_dim_type is not None:
601
600
  raise ValueError(
602
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
601
+ f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
603
602
  )
604
603
  else:
605
604
  self.encoder_hid_proj = None
@@ -679,7 +678,9 @@ class UNet2DConditionModel(
679
678
  # Kandinsky 2.2 ControlNet
680
679
  self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
681
680
  elif addition_embed_type is not None:
682
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
681
+ raise ValueError(
682
+ f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
683
+ )
683
684
 
684
685
  def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
685
686
  if attention_type in ["gated", "gated-text-image"]:
@@ -990,7 +991,7 @@ class UNet2DConditionModel(
990
991
  image_embs = added_cond_kwargs.get("image_embeds")
991
992
  aug_emb = self.add_embedding(image_embs)
992
993
  elif self.config.addition_embed_type == "image_hint":
993
- # Kandinsky 2.2 - style
994
+ # Kandinsky 2.2 ControlNet - style
994
995
  if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
995
996
  raise ValueError(
996
997
  f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
@@ -1009,7 +1010,7 @@ class UNet2DConditionModel(
1009
1010
  # Kandinsky 2.1 - style
1010
1011
  if "image_embeds" not in added_cond_kwargs:
1011
1012
  raise ValueError(
1012
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1013
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1013
1014
  )
1014
1015
 
1015
1016
  image_embeds = added_cond_kwargs.get("image_embeds")
@@ -1018,14 +1019,14 @@ class UNet2DConditionModel(
1018
1019
  # Kandinsky 2.2 - style
1019
1020
  if "image_embeds" not in added_cond_kwargs:
1020
1021
  raise ValueError(
1021
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1022
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1022
1023
  )
1023
1024
  image_embeds = added_cond_kwargs.get("image_embeds")
1024
1025
  encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1025
1026
  elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1026
1027
  if "image_embeds" not in added_cond_kwargs:
1027
1028
  raise ValueError(
1028
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1029
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1029
1030
  )
1030
1031
 
1031
1032
  if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
@@ -1140,7 +1141,6 @@ class UNet2DConditionModel(
1140
1141
  # 1. time
1141
1142
  t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1142
1143
  emb = self.time_embedding(t_emb, timestep_cond)
1143
- aug_emb = None
1144
1144
 
1145
1145
  class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1146
1146
  if class_emb is not None:
@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
116
116
 
117
117
  self.in_channels = in_channels
118
118
 
119
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
119
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
120
120
  self.proj_in = nn.Linear(in_channels, inner_dim)
121
121
 
122
122
  # 3. Define transformers blocks
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
187
187
  hidden_states = self.norm(hidden_states)
188
188
  hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
189
189
 
190
- hidden_states = self.proj_in(hidden_states)
190
+ hidden_states = self.proj_in(input=hidden_states)
191
191
 
192
192
  # 2. Blocks
193
193
  for block in self.transformer_blocks:
194
194
  hidden_states = block(
195
- hidden_states,
195
+ hidden_states=hidden_states,
196
196
  encoder_hidden_states=encoder_hidden_states,
197
197
  timestep=timestep,
198
198
  cross_attention_kwargs=cross_attention_kwargs,
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
200
200
  )
201
201
 
202
202
  # 3. Output
203
- hidden_states = self.proj_out(hidden_states)
203
+ hidden_states = self.proj_out(input=hidden_states)
204
204
  hidden_states = (
205
205
  hidden_states[None, None, :]
206
206
  .reshape(batch_size, height, width, num_frames, channel)
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
344
344
  )
345
345
 
346
346
  else:
347
- hidden_states = resnet(hidden_states, temb)
347
+ hidden_states = resnet(input_tensor=hidden_states, temb=temb)
348
348
 
349
349
  hidden_states = motion_module(hidden_states, num_frames=num_frames)
350
350
 
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
352
352
 
353
353
  if self.downsamplers is not None:
354
354
  for downsampler in self.downsamplers:
355
- hidden_states = downsampler(hidden_states)
355
+ hidden_states = downsampler(hidden_states=hidden_states)
356
356
 
357
357
  output_states = output_states + (hidden_states,)
358
358
 
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
531
531
  temb,
532
532
  **ckpt_kwargs,
533
533
  )
534
- hidden_states = attn(
535
- hidden_states,
536
- encoder_hidden_states=encoder_hidden_states,
537
- cross_attention_kwargs=cross_attention_kwargs,
538
- attention_mask=attention_mask,
539
- encoder_attention_mask=encoder_attention_mask,
540
- return_dict=False,
541
- )[0]
542
534
  else:
543
- hidden_states = resnet(hidden_states, temb)
535
+ hidden_states = resnet(input_tensor=hidden_states, temb=temb)
536
+
537
+ hidden_states = attn(
538
+ hidden_states=hidden_states,
539
+ encoder_hidden_states=encoder_hidden_states,
540
+ cross_attention_kwargs=cross_attention_kwargs,
541
+ attention_mask=attention_mask,
542
+ encoder_attention_mask=encoder_attention_mask,
543
+ return_dict=False,
544
+ )[0]
544
545
 
545
- hidden_states = attn(
546
- hidden_states,
547
- encoder_hidden_states=encoder_hidden_states,
548
- cross_attention_kwargs=cross_attention_kwargs,
549
- attention_mask=attention_mask,
550
- encoder_attention_mask=encoder_attention_mask,
551
- return_dict=False,
552
- )[0]
553
546
  hidden_states = motion_module(
554
547
  hidden_states,
555
548
  num_frames=num_frames,
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
563
556
 
564
557
  if self.downsamplers is not None:
565
558
  for downsampler in self.downsamplers:
566
- hidden_states = downsampler(hidden_states)
559
+ hidden_states = downsampler(hidden_states=hidden_states)
567
560
 
568
561
  output_states = output_states + (hidden_states,)
569
562
 
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
757
750
  temb,
758
751
  **ckpt_kwargs,
759
752
  )
760
- hidden_states = attn(
761
- hidden_states,
762
- encoder_hidden_states=encoder_hidden_states,
763
- cross_attention_kwargs=cross_attention_kwargs,
764
- attention_mask=attention_mask,
765
- encoder_attention_mask=encoder_attention_mask,
766
- return_dict=False,
767
- )[0]
768
753
  else:
769
- hidden_states = resnet(hidden_states, temb)
754
+ hidden_states = resnet(input_tensor=hidden_states, temb=temb)
755
+
756
+ hidden_states = attn(
757
+ hidden_states=hidden_states,
758
+ encoder_hidden_states=encoder_hidden_states,
759
+ cross_attention_kwargs=cross_attention_kwargs,
760
+ attention_mask=attention_mask,
761
+ encoder_attention_mask=encoder_attention_mask,
762
+ return_dict=False,
763
+ )[0]
770
764
 
771
- hidden_states = attn(
772
- hidden_states,
773
- encoder_hidden_states=encoder_hidden_states,
774
- cross_attention_kwargs=cross_attention_kwargs,
775
- attention_mask=attention_mask,
776
- encoder_attention_mask=encoder_attention_mask,
777
- return_dict=False,
778
- )[0]
779
765
  hidden_states = motion_module(
780
766
  hidden_states,
781
767
  num_frames=num_frames,
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
783
769
 
784
770
  if self.upsamplers is not None:
785
771
  for upsampler in self.upsamplers:
786
- hidden_states = upsampler(hidden_states, upsample_size)
772
+ hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
787
773
 
788
774
  return hidden_states
789
775
 
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
929
915
  create_custom_forward(resnet), hidden_states, temb
930
916
  )
931
917
  else:
932
- hidden_states = resnet(hidden_states, temb)
918
+ hidden_states = resnet(input_tensor=hidden_states, temb=temb)
933
919
 
934
920
  hidden_states = motion_module(hidden_states, num_frames=num_frames)
935
921
 
936
922
  if self.upsamplers is not None:
937
923
  for upsampler in self.upsamplers:
938
- hidden_states = upsampler(hidden_states, upsample_size)
924
+ hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
939
925
 
940
926
  return hidden_states
941
927
 
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1080
1066
  if cross_attention_kwargs.get("scale", None) is not None:
1081
1067
  logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1082
1068
 
1083
- hidden_states = self.resnets[0](hidden_states, temb)
1069
+ hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
1084
1070
 
1085
1071
  blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1086
1072
  for attn, resnet, motion_module in blocks:
1073
+ hidden_states = attn(
1074
+ hidden_states=hidden_states,
1075
+ encoder_hidden_states=encoder_hidden_states,
1076
+ cross_attention_kwargs=cross_attention_kwargs,
1077
+ attention_mask=attention_mask,
1078
+ encoder_attention_mask=encoder_attention_mask,
1079
+ return_dict=False,
1080
+ )[0]
1081
+
1087
1082
  if self.training and self.gradient_checkpointing:
1088
1083
 
1089
1084
  def create_custom_forward(module, return_dict=None):
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1096
1091
  return custom_forward
1097
1092
 
1098
1093
  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1099
- hidden_states = attn(
1100
- hidden_states,
1101
- encoder_hidden_states=encoder_hidden_states,
1102
- cross_attention_kwargs=cross_attention_kwargs,
1103
- attention_mask=attention_mask,
1104
- encoder_attention_mask=encoder_attention_mask,
1105
- return_dict=False,
1106
- )[0]
1107
1094
  hidden_states = torch.utils.checkpoint.checkpoint(
1108
1095
  create_custom_forward(motion_module),
1109
1096
  hidden_states,
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1117
1104
  **ckpt_kwargs,
1118
1105
  )
1119
1106
  else:
1120
- hidden_states = attn(
1121
- hidden_states,
1122
- encoder_hidden_states=encoder_hidden_states,
1123
- cross_attention_kwargs=cross_attention_kwargs,
1124
- attention_mask=attention_mask,
1125
- encoder_attention_mask=encoder_attention_mask,
1126
- return_dict=False,
1127
- )[0]
1128
1107
  hidden_states = motion_module(
1129
1108
  hidden_states,
1130
1109
  num_frames=num_frames,
1131
1110
  )
1132
- hidden_states = resnet(hidden_states, temb)
1111
+ hidden_states = resnet(input_tensor=hidden_states, temb=temb)
1133
1112
 
1134
1113
  return hidden_states
1135
1114
 
@@ -2178,7 +2157,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
2178
2157
 
2179
2158
  emb = emb if aug_emb is None else emb + aug_emb
2180
2159
  emb = emb.repeat_interleave(repeats=num_frames, dim=0)
2181
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
2182
2160
 
2183
2161
  if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
2184
2162
  if "image_embeds" not in added_cond_kwargs:
@@ -19,6 +19,7 @@ import torch.nn as nn
19
19
  import torch.nn.functional as F
20
20
 
21
21
  from ..utils import deprecate
22
+ from ..utils.import_utils import is_torch_version
22
23
  from .normalization import RMSNorm
23
24
 
24
25
 
@@ -151,11 +152,10 @@ class Upsample2D(nn.Module):
151
152
  if self.use_conv_transpose:
152
153
  return self.conv(hidden_states)
153
154
 
154
- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
155
- # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
156
- # https://github.com/pytorch/pytorch/issues/86679
155
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
156
+ # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
157
157
  dtype = hidden_states.dtype
158
- if dtype == torch.bfloat16:
158
+ if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
159
159
  hidden_states = hidden_states.to(torch.float32)
160
160
 
161
161
  # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
170
170
  else:
171
171
  hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
172
172
 
173
- # If the input is bfloat16, we cast back to bfloat16
174
- if dtype == torch.bfloat16:
173
+ # Cast back to original dtype
174
+ if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
175
175
  hidden_states = hidden_states.to(dtype)
176
176
 
177
177
  # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed