diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,9 @@ if is_accelerate_available():
51
51
 
52
52
  logger = logging.get_logger(__name__)
53
53
 
54
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
55
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
56
+
54
57
 
55
58
  def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
56
59
  """
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
181
184
  text_encoder._hf_peft_config_loaded = None
182
185
 
183
186
 
187
+ def _fetch_state_dict(
188
+ pretrained_model_name_or_path_or_dict,
189
+ weight_name,
190
+ use_safetensors,
191
+ local_files_only,
192
+ cache_dir,
193
+ force_download,
194
+ proxies,
195
+ token,
196
+ revision,
197
+ subfolder,
198
+ user_agent,
199
+ allow_pickle,
200
+ ):
201
+ model_file = None
202
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
203
+ # Let's first try to load .safetensors weights
204
+ if (use_safetensors and weight_name is None) or (
205
+ weight_name is not None and weight_name.endswith(".safetensors")
206
+ ):
207
+ try:
208
+ # Here we're relaxing the loading check to enable more Inference API
209
+ # friendliness where sometimes, it's not at all possible to automatically
210
+ # determine `weight_name`.
211
+ if weight_name is None:
212
+ weight_name = _best_guess_weight_name(
213
+ pretrained_model_name_or_path_or_dict,
214
+ file_extension=".safetensors",
215
+ local_files_only=local_files_only,
216
+ )
217
+ model_file = _get_model_file(
218
+ pretrained_model_name_or_path_or_dict,
219
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
220
+ cache_dir=cache_dir,
221
+ force_download=force_download,
222
+ proxies=proxies,
223
+ local_files_only=local_files_only,
224
+ token=token,
225
+ revision=revision,
226
+ subfolder=subfolder,
227
+ user_agent=user_agent,
228
+ )
229
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
230
+ except (IOError, safetensors.SafetensorError) as e:
231
+ if not allow_pickle:
232
+ raise e
233
+ # try loading non-safetensors weights
234
+ model_file = None
235
+ pass
236
+
237
+ if model_file is None:
238
+ if weight_name is None:
239
+ weight_name = _best_guess_weight_name(
240
+ pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
241
+ )
242
+ model_file = _get_model_file(
243
+ pretrained_model_name_or_path_or_dict,
244
+ weights_name=weight_name or LORA_WEIGHT_NAME,
245
+ cache_dir=cache_dir,
246
+ force_download=force_download,
247
+ proxies=proxies,
248
+ local_files_only=local_files_only,
249
+ token=token,
250
+ revision=revision,
251
+ subfolder=subfolder,
252
+ user_agent=user_agent,
253
+ )
254
+ state_dict = load_state_dict(model_file)
255
+ else:
256
+ state_dict = pretrained_model_name_or_path_or_dict
257
+
258
+ return state_dict
259
+
260
+
261
+ def _best_guess_weight_name(
262
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
263
+ ):
264
+ if local_files_only or HF_HUB_OFFLINE:
265
+ raise ValueError("When using the offline mode, you must specify a `weight_name`.")
266
+
267
+ targeted_files = []
268
+
269
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
270
+ return
271
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
272
+ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
273
+ else:
274
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
275
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
276
+ if len(targeted_files) == 0:
277
+ return
278
+
279
+ # "scheduler" does not correspond to a LoRA checkpoint.
280
+ # "optimizer" does not correspond to a LoRA checkpoint
281
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
282
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
283
+ targeted_files = list(
284
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
285
+ )
286
+
287
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
288
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
289
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
290
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
291
+
292
+ if len(targeted_files) > 1:
293
+ raise ValueError(
294
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
295
+ )
296
+ weight_name = targeted_files[0]
297
+ return weight_name
298
+
299
+
184
300
  class LoraBaseMixin:
185
301
  """Utility class for handling LoRAs."""
186
302
 
@@ -234,124 +350,16 @@ class LoraBaseMixin:
234
350
  return (is_model_cpu_offload, is_sequential_cpu_offload)
235
351
 
236
352
  @classmethod
237
- def _fetch_state_dict(
238
- cls,
239
- pretrained_model_name_or_path_or_dict,
240
- weight_name,
241
- use_safetensors,
242
- local_files_only,
243
- cache_dir,
244
- force_download,
245
- proxies,
246
- token,
247
- revision,
248
- subfolder,
249
- user_agent,
250
- allow_pickle,
251
- ):
252
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
253
-
254
- model_file = None
255
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
256
- # Let's first try to load .safetensors weights
257
- if (use_safetensors and weight_name is None) or (
258
- weight_name is not None and weight_name.endswith(".safetensors")
259
- ):
260
- try:
261
- # Here we're relaxing the loading check to enable more Inference API
262
- # friendliness where sometimes, it's not at all possible to automatically
263
- # determine `weight_name`.
264
- if weight_name is None:
265
- weight_name = cls._best_guess_weight_name(
266
- pretrained_model_name_or_path_or_dict,
267
- file_extension=".safetensors",
268
- local_files_only=local_files_only,
269
- )
270
- model_file = _get_model_file(
271
- pretrained_model_name_or_path_or_dict,
272
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
273
- cache_dir=cache_dir,
274
- force_download=force_download,
275
- proxies=proxies,
276
- local_files_only=local_files_only,
277
- token=token,
278
- revision=revision,
279
- subfolder=subfolder,
280
- user_agent=user_agent,
281
- )
282
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
283
- except (IOError, safetensors.SafetensorError) as e:
284
- if not allow_pickle:
285
- raise e
286
- # try loading non-safetensors weights
287
- model_file = None
288
- pass
289
-
290
- if model_file is None:
291
- if weight_name is None:
292
- weight_name = cls._best_guess_weight_name(
293
- pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
294
- )
295
- model_file = _get_model_file(
296
- pretrained_model_name_or_path_or_dict,
297
- weights_name=weight_name or LORA_WEIGHT_NAME,
298
- cache_dir=cache_dir,
299
- force_download=force_download,
300
- proxies=proxies,
301
- local_files_only=local_files_only,
302
- token=token,
303
- revision=revision,
304
- subfolder=subfolder,
305
- user_agent=user_agent,
306
- )
307
- state_dict = load_state_dict(model_file)
308
- else:
309
- state_dict = pretrained_model_name_or_path_or_dict
310
-
311
- return state_dict
353
+ def _fetch_state_dict(cls, *args, **kwargs):
354
+ deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
355
+ deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
356
+ return _fetch_state_dict(*args, **kwargs)
312
357
 
313
358
  @classmethod
314
- def _best_guess_weight_name(
315
- cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
316
- ):
317
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
318
-
319
- if local_files_only or HF_HUB_OFFLINE:
320
- raise ValueError("When using the offline mode, you must specify a `weight_name`.")
321
-
322
- targeted_files = []
323
-
324
- if os.path.isfile(pretrained_model_name_or_path_or_dict):
325
- return
326
- elif os.path.isdir(pretrained_model_name_or_path_or_dict):
327
- targeted_files = [
328
- f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
329
- ]
330
- else:
331
- files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
332
- targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
333
- if len(targeted_files) == 0:
334
- return
335
-
336
- # "scheduler" does not correspond to a LoRA checkpoint.
337
- # "optimizer" does not correspond to a LoRA checkpoint
338
- # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
339
- unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
340
- targeted_files = list(
341
- filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
342
- )
343
-
344
- if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
345
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
346
- elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
347
- targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
348
-
349
- if len(targeted_files) > 1:
350
- raise ValueError(
351
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
352
- )
353
- weight_name = targeted_files[0]
354
- return weight_name
359
+ def _best_guess_weight_name(cls, *args, **kwargs):
360
+ deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
361
+ deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
362
+ return _best_guess_weight_name(*args, **kwargs)
355
363
 
356
364
  def unload_lora_weights(self):
357
365
  """
@@ -725,8 +733,6 @@ class LoraBaseMixin:
725
733
  save_function: Callable,
726
734
  safe_serialization: bool,
727
735
  ):
728
- from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
729
-
730
736
  if os.path.isfile(save_directory):
731
737
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
732
738
  return
@@ -636,10 +636,19 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
636
636
  block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637
637
  new_key = f"transformer.single_transformer_blocks.{block_num}"
638
638
 
639
- if "proj_lora1" in old_key or "proj_lora2" in old_key:
639
+ if "proj_lora" in old_key:
640
640
  new_key += ".proj_out"
641
- elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
642
- new_key += ".norm.linear"
641
+ elif "qkv_lora" in old_key and "up" not in old_key:
642
+ handle_qkv(
643
+ old_state_dict,
644
+ new_state_dict,
645
+ old_key,
646
+ [
647
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
648
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
649
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
650
+ ],
651
+ )
643
652
 
644
653
  if "down" in old_key:
645
654
  new_key += ".lora_A.weight"
@@ -658,3 +667,309 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
658
667
  raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
659
668
 
660
669
  return new_state_dict
670
+
671
+
672
+ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
673
+ converted_state_dict = {}
674
+ original_state_dict_keys = list(original_state_dict.keys())
675
+ num_layers = 19
676
+ num_single_layers = 38
677
+ inner_dim = 3072
678
+ mlp_ratio = 4.0
679
+
680
+ def swap_scale_shift(weight):
681
+ shift, scale = weight.chunk(2, dim=0)
682
+ new_weight = torch.cat([scale, shift], dim=0)
683
+ return new_weight
684
+
685
+ for lora_key in ["lora_A", "lora_B"]:
686
+ ## time_text_embed.timestep_embedder <- time_in
687
+ converted_state_dict[
688
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
689
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
690
+ if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
691
+ converted_state_dict[
692
+ f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
693
+ ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
694
+
695
+ converted_state_dict[
696
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
697
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
698
+ if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
699
+ converted_state_dict[
700
+ f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
701
+ ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
702
+
703
+ ## time_text_embed.text_embedder <- vector_in
704
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
705
+ f"vector_in.in_layer.{lora_key}.weight"
706
+ )
707
+ if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
708
+ converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
709
+ f"vector_in.in_layer.{lora_key}.bias"
710
+ )
711
+
712
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
713
+ f"vector_in.out_layer.{lora_key}.weight"
714
+ )
715
+ if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
716
+ converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
717
+ f"vector_in.out_layer.{lora_key}.bias"
718
+ )
719
+
720
+ # guidance
721
+ has_guidance = any("guidance" in k for k in original_state_dict)
722
+ if has_guidance:
723
+ converted_state_dict[
724
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
725
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
726
+ if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
727
+ converted_state_dict[
728
+ f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
729
+ ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
730
+
731
+ converted_state_dict[
732
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
733
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
734
+ if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
735
+ converted_state_dict[
736
+ f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
737
+ ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
738
+
739
+ # context_embedder
740
+ converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
741
+ f"txt_in.{lora_key}.weight"
742
+ )
743
+ if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
744
+ converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
745
+ f"txt_in.{lora_key}.bias"
746
+ )
747
+
748
+ # x_embedder
749
+ converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
750
+ if f"img_in.{lora_key}.bias" in original_state_dict_keys:
751
+ converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
752
+
753
+ # double transformer blocks
754
+ for i in range(num_layers):
755
+ block_prefix = f"transformer_blocks.{i}."
756
+
757
+ for lora_key in ["lora_A", "lora_B"]:
758
+ # norms
759
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
760
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
761
+ )
762
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
763
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
764
+ f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
765
+ )
766
+
767
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
768
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
769
+ )
770
+ if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
771
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
772
+ f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
773
+ )
774
+
775
+ # Q, K, V
776
+ if lora_key == "lora_A":
777
+ sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
778
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
779
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
780
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
781
+
782
+ context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
783
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
784
+ [context_lora_weight]
785
+ )
786
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
787
+ [context_lora_weight]
788
+ )
789
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
790
+ [context_lora_weight]
791
+ )
792
+ else:
793
+ sample_q, sample_k, sample_v = torch.chunk(
794
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
795
+ )
796
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
797
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
798
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
799
+
800
+ context_q, context_k, context_v = torch.chunk(
801
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
802
+ )
803
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
804
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
805
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
806
+
807
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
808
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
809
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
810
+ )
811
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
812
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
813
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
814
+
815
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
816
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
817
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
818
+ )
819
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
820
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
821
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
822
+
823
+ # ff img_mlp
824
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
825
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
826
+ )
827
+ if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
828
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
829
+ f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
830
+ )
831
+
832
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
833
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
834
+ )
835
+ if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
836
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
837
+ f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
838
+ )
839
+
840
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
841
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
842
+ )
843
+ if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
844
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
845
+ f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
846
+ )
847
+
848
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
849
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
850
+ )
851
+ if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
852
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
853
+ f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
854
+ )
855
+
856
+ # output projections.
857
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
858
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
859
+ )
860
+ if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
861
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
862
+ f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
863
+ )
864
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
865
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
866
+ )
867
+ if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
868
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
869
+ f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
870
+ )
871
+
872
+ # qk_norm
873
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
874
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
875
+ )
876
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
877
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
878
+ )
879
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
880
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
881
+ )
882
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
883
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
884
+ )
885
+
886
+ # single transfomer blocks
887
+ for i in range(num_single_layers):
888
+ block_prefix = f"single_transformer_blocks.{i}."
889
+
890
+ for lora_key in ["lora_A", "lora_B"]:
891
+ # norm.linear <- single_blocks.0.modulation.lin
892
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
893
+ f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
894
+ )
895
+ if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
896
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
897
+ f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
898
+ )
899
+
900
+ # Q, K, V, mlp
901
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
902
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
903
+
904
+ if lora_key == "lora_A":
905
+ lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
906
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
907
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
908
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
909
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
910
+
911
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
912
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
913
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
914
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
915
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
916
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
917
+ else:
918
+ q, k, v, mlp = torch.split(
919
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
920
+ )
921
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
922
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
923
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
924
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
925
+
926
+ if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
927
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
928
+ original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
929
+ )
930
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
931
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
932
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
933
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
934
+
935
+ # output projections.
936
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
937
+ f"single_blocks.{i}.linear2.{lora_key}.weight"
938
+ )
939
+ if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
940
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
941
+ f"single_blocks.{i}.linear2.{lora_key}.bias"
942
+ )
943
+
944
+ # qk norm
945
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
946
+ f"single_blocks.{i}.norm.query_norm.scale"
947
+ )
948
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
949
+ f"single_blocks.{i}.norm.key_norm.scale"
950
+ )
951
+
952
+ for lora_key in ["lora_A", "lora_B"]:
953
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
954
+ f"final_layer.linear.{lora_key}.weight"
955
+ )
956
+ if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
957
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
958
+ f"final_layer.linear.{lora_key}.bias"
959
+ )
960
+
961
+ converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
962
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
963
+ )
964
+ if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
965
+ converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
966
+ original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
967
+ )
968
+
969
+ if len(original_state_dict) > 0:
970
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
971
+
972
+ for key in list(converted_state_dict.keys()):
973
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974
+
975
+ return converted_state_dict