diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from contextlib import nullcontext
15
+
16
+ from ..models.embeddings import (
17
+ ImageProjection,
18
+ MultiIPAdapterImageProjection,
19
+ )
20
+ from ..models.modeling_utils import load_model_dict_into_meta
21
+ from ..utils import (
22
+ is_accelerate_available,
23
+ is_torch_version,
24
+ logging,
25
+ )
26
+
27
+
28
+ if is_accelerate_available():
29
+ pass
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class FluxTransformer2DLoadersMixin:
35
+ """
36
+ Load layers into a [`FluxTransformer2DModel`].
37
+ """
38
+
39
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
40
+ if low_cpu_mem_usage:
41
+ if is_accelerate_available():
42
+ from accelerate import init_empty_weights
43
+
44
+ else:
45
+ low_cpu_mem_usage = False
46
+ logger.warning(
47
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
48
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
49
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
50
+ " install accelerate\n```\n."
51
+ )
52
+
53
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
54
+ raise NotImplementedError(
55
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
56
+ " `low_cpu_mem_usage=False`."
57
+ )
58
+
59
+ updated_state_dict = {}
60
+ image_projection = None
61
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
62
+
63
+ if "proj.weight" in state_dict:
64
+ # IP-Adapter
65
+ num_image_text_embeds = 4
66
+ if state_dict["proj.weight"].shape[0] == 65536:
67
+ num_image_text_embeds = 16
68
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
69
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
70
+
71
+ with init_context():
72
+ image_projection = ImageProjection(
73
+ cross_attention_dim=cross_attention_dim,
74
+ image_embed_dim=clip_embeddings_dim,
75
+ num_image_text_embeds=num_image_text_embeds,
76
+ )
77
+
78
+ for key, value in state_dict.items():
79
+ diffusers_name = key.replace("proj", "image_embeds")
80
+ updated_state_dict[diffusers_name] = value
81
+
82
+ if not low_cpu_mem_usage:
83
+ image_projection.load_state_dict(updated_state_dict, strict=True)
84
+ else:
85
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
86
+
87
+ return image_projection
88
+
89
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
90
+ from ..models.attention_processor import (
91
+ FluxIPAdapterJointAttnProcessor2_0,
92
+ )
93
+
94
+ if low_cpu_mem_usage:
95
+ if is_accelerate_available():
96
+ from accelerate import init_empty_weights
97
+
98
+ else:
99
+ low_cpu_mem_usage = False
100
+ logger.warning(
101
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
102
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
103
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
104
+ " install accelerate\n```\n."
105
+ )
106
+
107
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
108
+ raise NotImplementedError(
109
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
110
+ " `low_cpu_mem_usage=False`."
111
+ )
112
+
113
+ # set ip-adapter cross-attention processors & load state_dict
114
+ attn_procs = {}
115
+ key_id = 0
116
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
117
+ for name in self.attn_processors.keys():
118
+ if name.startswith("single_transformer_blocks"):
119
+ attn_processor_class = self.attn_processors[name].__class__
120
+ attn_procs[name] = attn_processor_class()
121
+ else:
122
+ cross_attention_dim = self.config.joint_attention_dim
123
+ hidden_size = self.inner_dim
124
+ attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
125
+ num_image_text_embeds = []
126
+ for state_dict in state_dicts:
127
+ if "proj.weight" in state_dict["image_proj"]:
128
+ num_image_text_embed = 4
129
+ if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
130
+ num_image_text_embed = 16
131
+ # IP-Adapter
132
+ num_image_text_embeds += [num_image_text_embed]
133
+
134
+ with init_context():
135
+ attn_procs[name] = attn_processor_class(
136
+ hidden_size=hidden_size,
137
+ cross_attention_dim=cross_attention_dim,
138
+ scale=1.0,
139
+ num_tokens=num_image_text_embeds,
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ )
143
+
144
+ value_dict = {}
145
+ for i, state_dict in enumerate(state_dicts):
146
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
147
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
148
+ value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
149
+ value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
150
+
151
+ if not low_cpu_mem_usage:
152
+ attn_procs[name].load_state_dict(value_dict)
153
+ else:
154
+ device = self.device
155
+ dtype = self.dtype
156
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
157
+
158
+ key_id += 1
159
+
160
+ return attn_procs
161
+
162
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
163
+ if not isinstance(state_dicts, list):
164
+ state_dicts = [state_dicts]
165
+
166
+ self.encoder_hid_proj = None
167
+
168
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
169
+ self.set_attn_processor(attn_procs)
170
+
171
+ image_projection_layers = []
172
+ for state_dict in state_dicts:
173
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
174
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
175
+ )
176
+ image_projection_layers.append(image_projection_layer)
177
+
178
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
179
+ self.config.encoder_hid_dim_type = "ip_image_proj"
180
+
181
+ self.to(dtype=self.dtype, device=self.device)
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict
15
+
16
+ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
17
+ from ..models.embeddings import IPAdapterTimeImageProjection
18
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19
+
20
+
21
+ class SD3Transformer2DLoadersMixin:
22
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
23
+
24
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
25
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
26
+
27
+ Args:
28
+ state_dict (`Dict`):
29
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
30
+ "image_proj", which contains parameters for image projection net.
31
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35
+ argument to `True` will raise an error.
36
+ """
37
+ # IP-Adapter cross attention parameters
38
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
39
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
40
+ timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
41
+
42
+ # Dict where key is transformer layer index, value is attention processor's state dict
43
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
44
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
45
+ for key, weights in state_dict["ip_adapter"].items():
46
+ idx, name = key.split(".", maxsplit=1)
47
+ layer_state_dict[int(idx)][name] = weights
48
+
49
+ # Create IP-Adapter attention processor
50
+ attn_procs = {}
51
+ for idx, name in enumerate(self.attn_processors.keys()):
52
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
53
+ hidden_size=hidden_size,
54
+ ip_hidden_states_dim=ip_hidden_states_dim,
55
+ head_dim=self.config.attention_head_dim,
56
+ timesteps_emb_dim=timesteps_emb_dim,
57
+ ).to(self.device, dtype=self.dtype)
58
+
59
+ if not low_cpu_mem_usage:
60
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
61
+ else:
62
+ load_model_dict_into_meta(
63
+ attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
64
+ )
65
+
66
+ self.set_attn_processor(attn_procs)
67
+
68
+ # Image projetion parameters
69
+ embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
70
+ output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
71
+ hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
72
+ heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
73
+ num_queries = state_dict["image_proj"]["latents"].shape[1]
74
+ timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
75
+
76
+ # Image projection
77
+ self.image_proj = IPAdapterTimeImageProjection(
78
+ embed_dim=embed_dim,
79
+ output_dim=output_dim,
80
+ hidden_dim=hidden_dim,
81
+ heads=heads,
82
+ num_queries=num_queries,
83
+ timestep_in_dim=timestep_in_dim,
84
+ ).to(device=self.device, dtype=self.dtype)
85
+
86
+ if not low_cpu_mem_usage:
87
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
88
+ else:
89
+ load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
diffusers/loaders/unet.py CHANGED
@@ -36,6 +36,7 @@ from ..utils import (
36
36
  USE_PEFT_BACKEND,
37
37
  _get_model_file,
38
38
  convert_unet_state_dict_to_peft,
39
+ deprecate,
39
40
  get_adapter_name,
40
41
  get_peft_kwargs,
41
42
  is_accelerate_available,
@@ -209,6 +210,10 @@ class UNet2DConditionLoadersMixin:
209
210
  is_model_cpu_offload = False
210
211
  is_sequential_cpu_offload = False
211
212
 
213
+ if is_lora:
214
+ deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
215
+ deprecate("load_attn_procs", "0.40.0", deprecation_message)
216
+
212
217
  if is_custom_diffusion:
213
218
  attn_processors = self._process_custom_diffusion(state_dict=state_dict)
214
219
  elif is_lora:
@@ -487,6 +492,9 @@ class UNet2DConditionLoadersMixin:
487
492
  )
488
493
  state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
489
494
  else:
495
+ deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
496
+ deprecate("save_attn_procs", "0.40.0", deprecation_message)
497
+
490
498
  if not USE_PEFT_BACKEND:
491
499
  raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
492
500
 
@@ -765,6 +773,7 @@ class UNet2DConditionLoadersMixin:
765
773
  from ..models.attention_processor import (
766
774
  IPAdapterAttnProcessor,
767
775
  IPAdapterAttnProcessor2_0,
776
+ IPAdapterXFormersAttnProcessor,
768
777
  )
769
778
 
770
779
  if low_cpu_mem_usage:
@@ -804,11 +813,15 @@ class UNet2DConditionLoadersMixin:
804
813
  if cross_attention_dim is None or "motion_modules" in name:
805
814
  attn_processor_class = self.attn_processors[name].__class__
806
815
  attn_procs[name] = attn_processor_class()
807
-
808
816
  else:
809
- attn_processor_class = (
810
- IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
811
- )
817
+ if "XFormers" in str(self.attn_processors[name].__class__):
818
+ attn_processor_class = IPAdapterXFormersAttnProcessor
819
+ else:
820
+ attn_processor_class = (
821
+ IPAdapterAttnProcessor2_0
822
+ if hasattr(F, "scaled_dot_product_attention")
823
+ else IPAdapterAttnProcessor
824
+ )
812
825
  num_image_text_embeds = []
813
826
  for state_dict in state_dicts:
814
827
  if "proj.weight" in state_dict["image_proj"]:
@@ -27,19 +27,29 @@ _import_structure = {}
27
27
  if is_torch_available():
28
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
29
  _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
30
31
  _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
32
+ _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
31
33
  _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
34
+ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
35
+ _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
36
+ _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
32
37
  _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
38
  _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
34
39
  _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
35
40
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
41
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
- _import_structure["controlnet"] = ["ControlNetModel"]
38
- _import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
39
- _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
40
- _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
41
- _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
42
- _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
42
+ _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
43
+ _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
44
+ _import_structure["controlnets.controlnet_hunyuan"] = [
45
+ "HunyuanDiT2DControlNetModel",
46
+ "HunyuanDiT2DMultiControlNetModel",
47
+ ]
48
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
49
+ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
50
+ _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
51
+ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
52
+ _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
43
53
  _import_structure["embeddings"] = ["ImageProjection"]
44
54
  _import_structure["modeling_utils"] = ["ModelMixin"]
45
55
  _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -51,11 +61,16 @@ if is_torch_available():
51
61
  _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
52
62
  _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
53
63
  _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
64
+ _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
54
65
  _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
55
66
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
56
67
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
68
+ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
57
69
  _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
58
70
  _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
71
+ _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
72
+ _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
73
+ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
59
74
  _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
60
75
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
61
76
  _import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -70,7 +85,7 @@ if is_torch_available():
70
85
  _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
71
86
 
72
87
  if is_flax_available():
73
- _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
88
+ _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
74
89
  _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
75
90
  _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
76
91
 
@@ -80,23 +95,37 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
80
95
  from .adapter import MultiAdapter, T2IAdapter
81
96
  from .autoencoders import (
82
97
  AsymmetricAutoencoderKL,
98
+ AutoencoderDC,
83
99
  AutoencoderKL,
100
+ AutoencoderKLAllegro,
84
101
  AutoencoderKLCogVideoX,
102
+ AutoencoderKLHunyuanVideo,
103
+ AutoencoderKLLTXVideo,
104
+ AutoencoderKLMochi,
85
105
  AutoencoderKLTemporalDecoder,
86
106
  AutoencoderOobleck,
87
107
  AutoencoderTiny,
88
108
  ConsistencyDecoderVAE,
89
109
  VQModel,
90
110
  )
91
- from .controlnet import ControlNetModel
92
- from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
93
- from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
94
- from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
95
- from .controlnet_sparsectrl import SparseControlNetModel
96
- from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
111
+ from .controlnets import (
112
+ ControlNetModel,
113
+ ControlNetUnionModel,
114
+ ControlNetXSAdapter,
115
+ FluxControlNetModel,
116
+ FluxMultiControlNetModel,
117
+ HunyuanDiT2DControlNetModel,
118
+ HunyuanDiT2DMultiControlNetModel,
119
+ MultiControlNetModel,
120
+ SD3ControlNetModel,
121
+ SD3MultiControlNetModel,
122
+ SparseControlNetModel,
123
+ UNetControlNetXSModel,
124
+ )
97
125
  from .embeddings import ImageProjection
98
126
  from .modeling_utils import ModelMixin
99
127
  from .transformers import (
128
+ AllegroTransformer3DModel,
100
129
  AuraFlowTransformer2DModel,
101
130
  CogVideoXTransformer3DModel,
102
131
  CogView3PlusTransformer2DModel,
@@ -104,10 +133,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
104
133
  DualTransformer2DModel,
105
134
  FluxTransformer2DModel,
106
135
  HunyuanDiT2DModel,
136
+ HunyuanVideoTransformer3DModel,
107
137
  LatteTransformer3DModel,
138
+ LTXVideoTransformer3DModel,
108
139
  LuminaNextDiT2DModel,
140
+ MochiTransformer3DModel,
109
141
  PixArtTransformer2DModel,
110
142
  PriorTransformer,
143
+ SanaTransformer2DModel,
111
144
  SD3Transformer2DModel,
112
145
  StableAudioDiTModel,
113
146
  T5FilmDecoder,
@@ -129,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
129
162
  )
130
163
 
131
164
  if is_flax_available():
132
- from .controlnet_flax import FlaxControlNetModel
165
+ from .controlnets import FlaxControlNetModel
133
166
  from .unets import FlaxUNet2DConditionModel
134
167
  from .vae_flax import FlaxAutoencoderKL
135
168
 
@@ -18,7 +18,7 @@ import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
20
  from ..utils import deprecate
21
- from ..utils.import_utils import is_torch_npu_available
21
+ from ..utils.import_utils import is_torch_npu_available, is_torch_version
22
22
 
23
23
 
24
24
  if is_torch_npu_available():
@@ -79,10 +79,10 @@ class GELU(nn.Module):
79
79
  self.approximate = approximate
80
80
 
81
81
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82
- if gate.device.type != "mps":
83
- return F.gelu(gate, approximate=self.approximate)
84
- # mps: gelu is not implemented for float16
85
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
82
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
83
+ # fp16 gelu not supported on mps before torch 2.0
84
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
85
+ return F.gelu(gate, approximate=self.approximate)
86
86
 
87
87
  def forward(self, hidden_states):
88
88
  hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
105
105
  self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106
106
 
107
107
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108
- if gate.device.type != "mps":
109
- return F.gelu(gate)
110
- # mps: gelu is not implemented for float16
111
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
108
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109
+ # fp16 gelu not supported on mps before torch 2.0
110
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111
+ return F.gelu(gate)
112
112
 
113
113
  def forward(self, hidden_states, *args, **kwargs):
114
114
  if len(args) > 0 or kwargs.get("scale", None) is not None:
@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
136
136
 
137
137
  def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
138
  super().__init__()
139
+
139
140
  self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140
141
  self.activation = nn.SiLU()
141
142
 
@@ -163,3 +164,15 @@ class ApproximateGELU(nn.Module):
163
164
  def forward(self, x: torch.Tensor) -> torch.Tensor:
164
165
  x = self.proj(x)
165
166
  return x * torch.sigmoid(1.702 * x)
167
+
168
+
169
+ class LinearActivation(nn.Module):
170
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
171
+ super().__init__()
172
+
173
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
174
+ self.activation = get_activation(activation)
175
+
176
+ def forward(self, hidden_states):
177
+ hidden_states = self.proj(hidden_states)
178
+ return self.activation(hidden_states)
@@ -19,7 +19,7 @@ from torch import nn
19
19
 
20
20
  from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
- from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
23
23
  from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
25
  from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
188
188
  self._chunk_dim = dim
189
189
 
190
190
  def forward(
191
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
191
+ self,
192
+ hidden_states: torch.FloatTensor,
193
+ encoder_hidden_states: torch.FloatTensor,
194
+ temb: torch.FloatTensor,
195
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
192
196
  ):
197
+ joint_attention_kwargs = joint_attention_kwargs or {}
193
198
  if self.use_dual_attention:
194
199
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195
200
  hidden_states, emb=temb
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
206
211
 
207
212
  # Attention.
208
213
  attn_output, context_attn_output = self.attn(
209
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ **joint_attention_kwargs,
210
217
  )
211
218
 
212
219
  # Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
214
221
  hidden_states = hidden_states + attn_output
215
222
 
216
223
  if self.use_dual_attention:
217
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
224
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
218
225
  attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219
226
  hidden_states = hidden_states + attn_output2
220
227
 
@@ -1222,6 +1229,8 @@ class FeedForward(nn.Module):
1222
1229
  act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1223
1230
  elif activation_fn == "swiglu":
1224
1231
  act_fn = SwiGLU(dim, inner_dim, bias=bias)
1232
+ elif activation_fn == "linear-silu":
1233
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
1225
1234
 
1226
1235
  self.net = nn.ModuleList([])
1227
1236
  # project in
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
216
216
  hidden_states = jax_memory_efficient_attention(
217
217
  query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
218
218
  )
219
-
220
219
  hidden_states = hidden_states.transpose(1, 0, 2)
220
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
221
221
  else:
222
222
  # compute attentions
223
223
  if self.split_head_dim: