diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
22
22
  from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
23
  from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
- from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
25
+ from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
26
 
27
27
 
28
28
  logger = logging.get_logger(__name__)
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
100
100
  processing of `context` conditions.
101
101
  """
102
102
 
103
- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
104
112
  super().__init__()
105
113
 
114
+ self.use_dual_attention = use_dual_attention
106
115
  self.context_pre_only = context_pre_only
107
116
  context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
117
 
109
- self.norm1 = AdaLayerNormZero(dim)
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
110
122
 
111
123
  if context_norm_type == "ada_norm_continous":
112
124
  self.norm1_context = AdaLayerNormContinuous(
@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
118
130
  raise ValueError(
119
131
  f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
132
  )
133
+
121
134
  if hasattr(F, "scaled_dot_product_attention"):
122
135
  processor = JointAttnProcessor2_0()
123
136
  else:
124
137
  raise ValueError(
125
138
  "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
139
  )
140
+
127
141
  self.attn = Attention(
128
142
  query_dim=dim,
129
143
  cross_attention_dim=None,
@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
134
148
  context_pre_only=context_pre_only,
135
149
  bias=True,
136
150
  processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
137
153
  )
138
154
 
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
139
170
  self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
171
  self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
172
 
@@ -159,7 +190,12 @@ class JointTransformerBlock(nn.Module):
159
190
  def forward(
160
191
  self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161
192
  ):
162
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
193
+ if self.use_dual_attention:
194
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195
+ hidden_states, emb=temb
196
+ )
197
+ else:
198
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
199
 
164
200
  if self.context_pre_only:
165
201
  norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
@@ -177,6 +213,11 @@ class JointTransformerBlock(nn.Module):
177
213
  attn_output = gate_msa.unsqueeze(1) * attn_output
178
214
  hidden_states = hidden_states + attn_output
179
215
 
216
+ if self.use_dual_attention:
217
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
218
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219
+ hidden_states = hidden_states + attn_output2
220
+
180
221
  norm_hidden_states = self.norm2(hidden_states)
181
222
  norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
223
  if self._chunk_size is not None:
@@ -972,15 +1013,32 @@ class FreeNoiseTransformerBlock(nn.Module):
972
1013
  return frame_indices
973
1014
 
974
1015
  def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975
- if weighting_scheme == "pyramid":
1016
+ if weighting_scheme == "flat":
1017
+ weights = [1.0] * num_frames
1018
+
1019
+ elif weighting_scheme == "pyramid":
976
1020
  if num_frames % 2 == 0:
977
1021
  # num_frames = 4 => [1, 2, 2, 1]
978
- weights = list(range(1, num_frames // 2 + 1))
1022
+ mid = num_frames // 2
1023
+ weights = list(range(1, mid + 1))
979
1024
  weights = weights + weights[::-1]
980
1025
  else:
981
1026
  # num_frames = 5 => [1, 2, 3, 2, 1]
982
- weights = list(range(1, num_frames // 2 + 1))
983
- weights = weights + [num_frames // 2 + 1] + weights[::-1]
1027
+ mid = (num_frames + 1) // 2
1028
+ weights = list(range(1, mid))
1029
+ weights = weights + [mid] + weights[::-1]
1030
+
1031
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1032
+ if num_frames % 2 == 0:
1033
+ # num_frames = 4 => [0.01, 2, 2, 1]
1034
+ mid = num_frames // 2
1035
+ weights = [0.01] * (mid - 1) + [mid]
1036
+ weights = weights + list(range(mid, 0, -1))
1037
+ else:
1038
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1039
+ mid = (num_frames + 1) // 2
1040
+ weights = [0.01] * mid
1041
+ weights = weights + list(range(mid, 0, -1))
984
1042
  else:
985
1043
  raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
986
1044
 
@@ -1087,8 +1145,26 @@ class FreeNoiseTransformerBlock(nn.Module):
1087
1145
  accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1088
1146
  num_times_accumulated[:, frame_start:frame_end] += weights
1089
1147
 
1090
- hidden_states = torch.where(
1091
- num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1148
+ # TODO(aryan): Maybe this could be done in a better way.
1149
+ #
1150
+ # Previously, this was:
1151
+ # hidden_states = torch.where(
1152
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1153
+ # )
1154
+ #
1155
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1156
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1157
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1158
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1159
+ hidden_states = torch.cat(
1160
+ [
1161
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1162
+ for accumulated_split, num_times_split in zip(
1163
+ accumulated_values.split(self.context_length, dim=1),
1164
+ num_times_accumulated.split(self.context_length, dim=1),
1165
+ )
1166
+ ],
1167
+ dim=1,
1092
1168
  ).to(dtype)
1093
1169
 
1094
1170
  # 3. Feed-forward
@@ -122,6 +122,7 @@ class Attention(nn.Module):
122
122
  out_dim: int = None,
123
123
  context_pre_only=None,
124
124
  pre_only=False,
125
+ elementwise_affine: bool = True,
125
126
  ):
126
127
  super().__init__()
127
128
 
@@ -179,8 +180,8 @@ class Attention(nn.Module):
179
180
  self.norm_q = None
180
181
  self.norm_k = None
181
182
  elif qk_norm == "layer_norm":
182
- self.norm_q = nn.LayerNorm(dim_head, eps=eps)
183
- self.norm_k = nn.LayerNorm(dim_head, eps=eps)
183
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
184
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
184
185
  elif qk_norm == "fp32_layer_norm":
185
186
  self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186
187
  self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
@@ -192,7 +193,7 @@ class Attention(nn.Module):
192
193
  self.norm_q = RMSNorm(dim_head, eps=eps)
193
194
  self.norm_k = RMSNorm(dim_head, eps=eps)
194
195
  else:
195
- raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
196
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
196
197
 
197
198
  if cross_attention_norm is None:
198
199
  self.norm_cross = None
@@ -249,6 +250,10 @@ class Attention(nn.Module):
249
250
  elif qk_norm == "rms_norm":
250
251
  self.norm_added_q = RMSNorm(dim_head, eps=eps)
251
252
  self.norm_added_k = RMSNorm(dim_head, eps=eps)
253
+ else:
254
+ raise ValueError(
255
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
256
+ )
252
257
  else:
253
258
  self.norm_added_q = None
254
259
  self.norm_added_k = None
@@ -1049,61 +1054,72 @@ class JointAttnProcessor2_0:
1049
1054
  ) -> torch.FloatTensor:
1050
1055
  residual = hidden_states
1051
1056
 
1052
- input_ndim = hidden_states.ndim
1053
- if input_ndim == 4:
1054
- batch_size, channel, height, width = hidden_states.shape
1055
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1056
- context_input_ndim = encoder_hidden_states.ndim
1057
- if context_input_ndim == 4:
1058
- batch_size, channel, height, width = encoder_hidden_states.shape
1059
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1060
-
1061
- batch_size = encoder_hidden_states.shape[0]
1057
+ batch_size = hidden_states.shape[0]
1062
1058
 
1063
1059
  # `sample` projections.
1064
1060
  query = attn.to_q(hidden_states)
1065
1061
  key = attn.to_k(hidden_states)
1066
1062
  value = attn.to_v(hidden_states)
1067
1063
 
1068
- # `context` projections.
1069
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1070
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1071
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1072
-
1073
- # attention
1074
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1075
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1076
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1077
-
1078
1064
  inner_dim = key.shape[-1]
1079
1065
  head_dim = inner_dim // attn.heads
1066
+
1080
1067
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1081
1068
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1082
1069
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1083
1070
 
1071
+ if attn.norm_q is not None:
1072
+ query = attn.norm_q(query)
1073
+ if attn.norm_k is not None:
1074
+ key = attn.norm_k(key)
1075
+
1076
+ # `context` projections.
1077
+ if encoder_hidden_states is not None:
1078
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1079
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1080
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1081
+
1082
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1083
+ batch_size, -1, attn.heads, head_dim
1084
+ ).transpose(1, 2)
1085
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1086
+ batch_size, -1, attn.heads, head_dim
1087
+ ).transpose(1, 2)
1088
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1089
+ batch_size, -1, attn.heads, head_dim
1090
+ ).transpose(1, 2)
1091
+
1092
+ if attn.norm_added_q is not None:
1093
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1094
+ if attn.norm_added_k is not None:
1095
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1096
+
1097
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
1098
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
1099
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
1100
+
1084
1101
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1085
1102
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1086
1103
  hidden_states = hidden_states.to(query.dtype)
1087
1104
 
1088
- # Split the attention outputs.
1089
- hidden_states, encoder_hidden_states = (
1090
- hidden_states[:, : residual.shape[1]],
1091
- hidden_states[:, residual.shape[1] :],
1092
- )
1105
+ if encoder_hidden_states is not None:
1106
+ # Split the attention outputs.
1107
+ hidden_states, encoder_hidden_states = (
1108
+ hidden_states[:, : residual.shape[1]],
1109
+ hidden_states[:, residual.shape[1] :],
1110
+ )
1111
+ if not attn.context_pre_only:
1112
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1093
1113
 
1094
1114
  # linear proj
1095
1115
  hidden_states = attn.to_out[0](hidden_states)
1096
1116
  # dropout
1097
1117
  hidden_states = attn.to_out[1](hidden_states)
1098
- if not attn.context_pre_only:
1099
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1100
-
1101
- if input_ndim == 4:
1102
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1103
- if context_input_ndim == 4:
1104
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1105
1118
 
1106
- return hidden_states, encoder_hidden_states
1119
+ if encoder_hidden_states is not None:
1120
+ return hidden_states, encoder_hidden_states
1121
+ else:
1122
+ return hidden_states
1107
1123
 
1108
1124
 
1109
1125
  class PAGJointAttnProcessor2_0:
@@ -1695,52 +1711,32 @@ class FusedAuraFlowAttnProcessor2_0:
1695
1711
  return hidden_states
1696
1712
 
1697
1713
 
1698
- # YiYi to-do: refactor rope related functions/classes
1699
- def apply_rope(xq, xk, freqs_cis):
1700
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
1701
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
1702
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
1703
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
1704
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
1705
-
1706
-
1707
- class FluxSingleAttnProcessor2_0:
1708
- r"""
1709
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710
- """
1714
+ class FluxAttnProcessor2_0:
1715
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1711
1716
 
1712
1717
  def __init__(self):
1713
1718
  if not hasattr(F, "scaled_dot_product_attention"):
1714
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1719
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1715
1720
 
1716
1721
  def __call__(
1717
1722
  self,
1718
1723
  attn: Attention,
1719
- hidden_states: torch.Tensor,
1720
- encoder_hidden_states: Optional[torch.Tensor] = None,
1724
+ hidden_states: torch.FloatTensor,
1725
+ encoder_hidden_states: torch.FloatTensor = None,
1721
1726
  attention_mask: Optional[torch.FloatTensor] = None,
1722
1727
  image_rotary_emb: Optional[torch.Tensor] = None,
1723
- ) -> torch.Tensor:
1724
- input_ndim = hidden_states.ndim
1725
-
1726
- if input_ndim == 4:
1727
- batch_size, channel, height, width = hidden_states.shape
1728
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1729
-
1728
+ ) -> torch.FloatTensor:
1730
1729
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1731
1730
 
1731
+ # `sample` projections.
1732
1732
  query = attn.to_q(hidden_states)
1733
- if encoder_hidden_states is None:
1734
- encoder_hidden_states = hidden_states
1735
-
1736
- key = attn.to_k(encoder_hidden_states)
1737
- value = attn.to_v(encoder_hidden_states)
1733
+ key = attn.to_k(hidden_states)
1734
+ value = attn.to_v(hidden_states)
1738
1735
 
1739
1736
  inner_dim = key.shape[-1]
1740
1737
  head_dim = inner_dim // attn.heads
1741
1738
 
1742
1739
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1743
-
1744
1740
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1745
1741
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1746
1742
 
@@ -1749,33 +1745,68 @@ class FluxSingleAttnProcessor2_0:
1749
1745
  if attn.norm_k is not None:
1750
1746
  key = attn.norm_k(key)
1751
1747
 
1752
- # Apply RoPE if needed
1748
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1749
+ if encoder_hidden_states is not None:
1750
+ # `context` projections.
1751
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1752
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1753
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1754
+
1755
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1756
+ batch_size, -1, attn.heads, head_dim
1757
+ ).transpose(1, 2)
1758
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1759
+ batch_size, -1, attn.heads, head_dim
1760
+ ).transpose(1, 2)
1761
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1762
+ batch_size, -1, attn.heads, head_dim
1763
+ ).transpose(1, 2)
1764
+
1765
+ if attn.norm_added_q is not None:
1766
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1767
+ if attn.norm_added_k is not None:
1768
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1769
+
1770
+ # attention
1771
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1772
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1773
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1774
+
1753
1775
  if image_rotary_emb is not None:
1754
- # YiYi to-do: update uising apply_rotary_emb
1755
- # from ..embeddings import apply_rotary_emb
1756
- # query = apply_rotary_emb(query, image_rotary_emb)
1757
- # key = apply_rotary_emb(key, image_rotary_emb)
1758
- query, key = apply_rope(query, key, image_rotary_emb)
1776
+ from .embeddings import apply_rotary_emb
1759
1777
 
1760
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1761
- # TODO: add support for attn.scale when we move to Torch 2.1
1762
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1778
+ query = apply_rotary_emb(query, image_rotary_emb)
1779
+ key = apply_rotary_emb(key, image_rotary_emb)
1763
1780
 
1781
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1764
1782
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1765
1783
  hidden_states = hidden_states.to(query.dtype)
1766
1784
 
1767
- if input_ndim == 4:
1768
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1785
+ if encoder_hidden_states is not None:
1786
+ encoder_hidden_states, hidden_states = (
1787
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1788
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1789
+ )
1769
1790
 
1770
- return hidden_states
1791
+ # linear proj
1792
+ hidden_states = attn.to_out[0](hidden_states)
1793
+ # dropout
1794
+ hidden_states = attn.to_out[1](hidden_states)
1795
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1796
+
1797
+ return hidden_states, encoder_hidden_states
1798
+ else:
1799
+ return hidden_states
1771
1800
 
1772
1801
 
1773
- class FluxAttnProcessor2_0:
1802
+ class FusedFluxAttnProcessor2_0:
1774
1803
  """Attention processor used typically in processing the SD3-like self-attention projections."""
1775
1804
 
1776
1805
  def __init__(self):
1777
1806
  if not hasattr(F, "scaled_dot_product_attention"):
1778
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1807
+ raise ImportError(
1808
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1809
+ )
1779
1810
 
1780
1811
  def __call__(
1781
1812
  self,
@@ -1785,21 +1816,12 @@ class FluxAttnProcessor2_0:
1785
1816
  attention_mask: Optional[torch.FloatTensor] = None,
1786
1817
  image_rotary_emb: Optional[torch.Tensor] = None,
1787
1818
  ) -> torch.FloatTensor:
1788
- input_ndim = hidden_states.ndim
1789
- if input_ndim == 4:
1790
- batch_size, channel, height, width = hidden_states.shape
1791
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1792
- context_input_ndim = encoder_hidden_states.ndim
1793
- if context_input_ndim == 4:
1794
- batch_size, channel, height, width = encoder_hidden_states.shape
1795
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1796
-
1797
- batch_size = encoder_hidden_states.shape[0]
1819
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1798
1820
 
1799
1821
  # `sample` projections.
1800
- query = attn.to_q(hidden_states)
1801
- key = attn.to_k(hidden_states)
1802
- value = attn.to_v(hidden_states)
1822
+ qkv = attn.to_qkv(hidden_states)
1823
+ split_size = qkv.shape[-1] // 3
1824
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1803
1825
 
1804
1826
  inner_dim = key.shape[-1]
1805
1827
  head_dim = inner_dim // attn.heads
@@ -1813,59 +1835,62 @@ class FluxAttnProcessor2_0:
1813
1835
  if attn.norm_k is not None:
1814
1836
  key = attn.norm_k(key)
1815
1837
 
1838
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1816
1839
  # `context` projections.
1817
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1818
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1819
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1840
+ if encoder_hidden_states is not None:
1841
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1842
+ split_size = encoder_qkv.shape[-1] // 3
1843
+ (
1844
+ encoder_hidden_states_query_proj,
1845
+ encoder_hidden_states_key_proj,
1846
+ encoder_hidden_states_value_proj,
1847
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1820
1848
 
1821
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1822
- batch_size, -1, attn.heads, head_dim
1823
- ).transpose(1, 2)
1824
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1825
- batch_size, -1, attn.heads, head_dim
1826
- ).transpose(1, 2)
1827
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1828
- batch_size, -1, attn.heads, head_dim
1829
- ).transpose(1, 2)
1830
-
1831
- if attn.norm_added_q is not None:
1832
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1833
- if attn.norm_added_k is not None:
1834
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1849
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1850
+ batch_size, -1, attn.heads, head_dim
1851
+ ).transpose(1, 2)
1852
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1853
+ batch_size, -1, attn.heads, head_dim
1854
+ ).transpose(1, 2)
1855
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1856
+ batch_size, -1, attn.heads, head_dim
1857
+ ).transpose(1, 2)
1835
1858
 
1836
- # attention
1837
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1838
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1839
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1859
+ if attn.norm_added_q is not None:
1860
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1861
+ if attn.norm_added_k is not None:
1862
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1863
+
1864
+ # attention
1865
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1866
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1867
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1840
1868
 
1841
1869
  if image_rotary_emb is not None:
1842
- # YiYi to-do: update uising apply_rotary_emb
1843
- # from ..embeddings import apply_rotary_emb
1844
- # query = apply_rotary_emb(query, image_rotary_emb)
1845
- # key = apply_rotary_emb(key, image_rotary_emb)
1846
- query, key = apply_rope(query, key, image_rotary_emb)
1870
+ from .embeddings import apply_rotary_emb
1871
+
1872
+ query = apply_rotary_emb(query, image_rotary_emb)
1873
+ key = apply_rotary_emb(key, image_rotary_emb)
1847
1874
 
1848
1875
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1849
1876
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1850
1877
  hidden_states = hidden_states.to(query.dtype)
1851
1878
 
1852
- encoder_hidden_states, hidden_states = (
1853
- hidden_states[:, : encoder_hidden_states.shape[1]],
1854
- hidden_states[:, encoder_hidden_states.shape[1] :],
1855
- )
1856
-
1857
- # linear proj
1858
- hidden_states = attn.to_out[0](hidden_states)
1859
- # dropout
1860
- hidden_states = attn.to_out[1](hidden_states)
1861
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1879
+ if encoder_hidden_states is not None:
1880
+ encoder_hidden_states, hidden_states = (
1881
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1882
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1883
+ )
1862
1884
 
1863
- if input_ndim == 4:
1864
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1865
- if context_input_ndim == 4:
1866
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1885
+ # linear proj
1886
+ hidden_states = attn.to_out[0](hidden_states)
1887
+ # dropout
1888
+ hidden_states = attn.to_out[1](hidden_states)
1889
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1867
1890
 
1868
- return hidden_states, encoder_hidden_states
1891
+ return hidden_states, encoder_hidden_states
1892
+ else:
1893
+ return hidden_states
1869
1894
 
1870
1895
 
1871
1896
  class CogVideoXAttnProcessor2_0:
@@ -4247,6 +4272,17 @@ class LoRAAttnAddedKVProcessor:
4247
4272
  pass
4248
4273
 
4249
4274
 
4275
+ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
4276
+ r"""
4277
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
4278
+ """
4279
+
4280
+ def __init__(self):
4281
+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
4282
+ deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
4283
+ super().__init__()
4284
+
4285
+
4250
4286
  ADDED_KV_ATTENTION_PROCESSORS = (
4251
4287
  AttnAddedKVProcessor,
4252
4288
  SlicedAttnAddedKVProcessor,