diffusers 0.29.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (222) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_conversion_utils.py +145 -110
  10. diffusers/loaders/lora_pipeline.py +2222 -0
  11. diffusers/loaders/peft.py +213 -5
  12. diffusers/loaders/single_file.py +1 -12
  13. diffusers/loaders/single_file_model.py +31 -10
  14. diffusers/loaders/single_file_utils.py +262 -2
  15. diffusers/loaders/textual_inversion.py +1 -6
  16. diffusers/loaders/unet.py +23 -208
  17. diffusers/models/__init__.py +20 -0
  18. diffusers/models/activations.py +22 -0
  19. diffusers/models/attention.py +386 -7
  20. diffusers/models/attention_processor.py +1795 -629
  21. diffusers/models/autoencoders/__init__.py +2 -0
  22. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  23. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  25. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  26. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  27. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  28. diffusers/models/autoencoders/vq_model.py +4 -4
  29. diffusers/models/controlnet.py +2 -3
  30. diffusers/models/controlnet_hunyuan.py +401 -0
  31. diffusers/models/controlnet_sd3.py +12 -12
  32. diffusers/models/controlnet_sparsectrl.py +789 -0
  33. diffusers/models/controlnet_xs.py +40 -10
  34. diffusers/models/downsampling.py +68 -0
  35. diffusers/models/embeddings.py +319 -36
  36. diffusers/models/model_loading_utils.py +1 -3
  37. diffusers/models/modeling_flax_utils.py +1 -6
  38. diffusers/models/modeling_utils.py +4 -16
  39. diffusers/models/normalization.py +203 -12
  40. diffusers/models/transformers/__init__.py +6 -0
  41. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  42. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  43. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  44. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  45. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  46. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  47. diffusers/models/transformers/prior_transformer.py +1 -1
  48. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  49. diffusers/models/transformers/transformer_2d.py +4 -2
  50. diffusers/models/transformers/transformer_flux.py +455 -0
  51. diffusers/models/transformers/transformer_sd3.py +19 -5
  52. diffusers/models/unets/unet_1d_blocks.py +1 -1
  53. diffusers/models/unets/unet_2d_condition.py +8 -1
  54. diffusers/models/unets/unet_3d_blocks.py +51 -920
  55. diffusers/models/unets/unet_3d_condition.py +4 -1
  56. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  57. diffusers/models/unets/unet_kandinsky3.py +1 -1
  58. diffusers/models/unets/unet_motion_model.py +1330 -84
  59. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  60. diffusers/models/unets/unet_stable_cascade.py +1 -3
  61. diffusers/models/unets/uvit_2d.py +1 -1
  62. diffusers/models/upsampling.py +64 -0
  63. diffusers/models/vq_model.py +8 -4
  64. diffusers/optimization.py +1 -1
  65. diffusers/pipelines/__init__.py +100 -3
  66. diffusers/pipelines/animatediff/__init__.py +4 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  68. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  70. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  71. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  74. diffusers/pipelines/aura_flow/__init__.py +48 -0
  75. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  76. diffusers/pipelines/auto_pipeline.py +97 -19
  77. diffusers/pipelines/cogvideo/__init__.py +48 -0
  78. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  80. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  81. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  82. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  83. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  84. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  85. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  86. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  87. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  88. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  89. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  91. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  92. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  97. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  105. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  106. diffusers/pipelines/flux/__init__.py +47 -0
  107. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  108. diffusers/pipelines/flux/pipeline_output.py +21 -0
  109. diffusers/pipelines/free_init_utils.py +2 -0
  110. diffusers/pipelines/free_noise_utils.py +236 -0
  111. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  112. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  113. diffusers/pipelines/kolors/__init__.py +54 -0
  114. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  115. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  116. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  117. diffusers/pipelines/kolors/text_encoder.py +889 -0
  118. diffusers/pipelines/kolors/tokenizer.py +334 -0
  119. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  120. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  121. diffusers/pipelines/latte/__init__.py +48 -0
  122. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  123. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  124. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  125. diffusers/pipelines/lumina/__init__.py +48 -0
  126. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  127. diffusers/pipelines/pag/__init__.py +67 -0
  128. diffusers/pipelines/pag/pag_utils.py +237 -0
  129. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  130. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  131. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  132. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  133. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  138. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  139. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  140. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  141. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  142. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  143. diffusers/pipelines/pipeline_utils.py +2 -14
  144. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  145. diffusers/pipelines/stable_audio/__init__.py +50 -0
  146. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  147. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  148. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  156. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  157. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  158. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  160. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  161. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  162. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  163. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  164. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  165. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  166. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  167. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  168. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  169. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  170. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  173. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  174. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  175. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  176. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  180. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  181. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  182. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  183. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  184. diffusers/schedulers/__init__.py +8 -0
  185. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  186. diffusers/schedulers/scheduling_ddim.py +1 -1
  187. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  188. diffusers/schedulers/scheduling_ddpm.py +1 -1
  189. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  190. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  191. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  192. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  193. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  194. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  195. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  196. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  197. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  198. diffusers/schedulers/scheduling_ipndm.py +1 -1
  199. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  200. diffusers/schedulers/scheduling_utils.py +1 -3
  201. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  202. diffusers/training_utils.py +99 -14
  203. diffusers/utils/__init__.py +2 -2
  204. diffusers/utils/dummy_pt_objects.py +210 -0
  205. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  206. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  207. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  208. diffusers/utils/dynamic_modules_utils.py +1 -11
  209. diffusers/utils/export_utils.py +1 -4
  210. diffusers/utils/hub_utils.py +45 -42
  211. diffusers/utils/import_utils.py +19 -16
  212. diffusers/utils/loading_utils.py +76 -3
  213. diffusers/utils/testing_utils.py +11 -8
  214. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  215. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/RECORD +219 -166
  216. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  217. diffusers/loaders/autoencoder.py +0 -146
  218. diffusers/loaders/controlnet.py +0 -136
  219. diffusers/loaders/lora.py +0 -1729
  220. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  221. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  222. {diffusers-0.29.1.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123
123
  return new_state_dict
124
124
 
125
125
 
126
- def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
126
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
127
+ """
128
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129
+
130
+ Args:
131
+ state_dict (`dict`): The state dict to convert.
132
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134
+ "text_encoder".
135
+
136
+ Returns:
137
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138
+ """
127
139
  unet_state_dict = {}
128
140
  te_state_dict = {}
129
141
  te2_state_dict = {}
130
142
  network_alphas = {}
131
- is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
132
- is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
133
- is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
134
143
 
135
- if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
144
+ # Check for DoRA-enabled LoRAs.
145
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
146
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
147
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
148
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
136
149
  if is_peft_version("<", "0.9.0"):
137
150
  raise ValueError(
138
151
  "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139
152
  )
140
153
 
141
- # every down weight has a corresponding up weight and potentially an alpha weight
142
- lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
143
- for key in lora_keys:
154
+ # Iterate over all LoRA weights.
155
+ all_lora_keys = list(state_dict.keys())
156
+ for key in all_lora_keys:
157
+ if not key.endswith("lora_down.weight"):
158
+ continue
159
+
160
+ # Extract LoRA name.
144
161
  lora_name = key.split(".")[0]
162
+
163
+ # Find corresponding up weight and alpha.
145
164
  lora_name_up = lora_name + ".lora_up.weight"
146
165
  lora_name_alpha = lora_name + ".alpha"
147
166
 
167
+ # Handle U-Net LoRAs.
148
168
  if lora_name.startswith("lora_unet_"):
149
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
150
-
151
- if "input.blocks" in diffusers_name:
152
- diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
153
- else:
154
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
169
+ diffusers_name = _convert_unet_lora_key(key)
155
170
 
156
- if "middle.block" in diffusers_name:
157
- diffusers_name = diffusers_name.replace("middle.block", "mid_block")
158
- else:
159
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
160
- if "output.blocks" in diffusers_name:
161
- diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
162
- else:
163
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
164
-
165
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
166
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
167
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
168
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
169
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
170
- diffusers_name = diffusers_name.replace("proj.in", "proj_in")
171
- diffusers_name = diffusers_name.replace("proj.out", "proj_out")
172
- diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
173
-
174
- # SDXL specificity.
175
- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
176
- pattern = r"\.\d+(?=\D*$)"
177
- diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
178
- if ".in." in diffusers_name:
179
- diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
180
- if ".out." in diffusers_name:
181
- diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
182
- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
183
- diffusers_name = diffusers_name.replace("op", "conv")
184
- if "skip" in diffusers_name:
185
- diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
186
-
187
- # LyCORIS specificity.
188
- if "time.emb.proj" in diffusers_name:
189
- diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
190
- if "conv.shortcut" in diffusers_name:
191
- diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
192
-
193
- # General coverage.
194
- if "transformer_blocks" in diffusers_name:
195
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
196
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
197
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
198
- unet_state_dict[diffusers_name] = state_dict.pop(key)
199
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200
- elif "ff" in diffusers_name:
201
- unet_state_dict[diffusers_name] = state_dict.pop(key)
202
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
203
- elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
204
- unet_state_dict[diffusers_name] = state_dict.pop(key)
205
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
206
- else:
207
- unet_state_dict[diffusers_name] = state_dict.pop(key)
208
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
171
+ # Store down and up weights.
172
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
173
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
209
174
 
210
- if is_unet_dora_lora:
175
+ # Store DoRA scale if present.
176
+ if dora_present_in_unet:
211
177
  dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212
178
  unet_state_dict[
213
179
  diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
214
180
  ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215
181
 
182
+ # Handle text encoder LoRAs.
216
183
  elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
184
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
185
+
186
+ # Store down and up weights for te or te2.
217
187
  if lora_name.startswith(("lora_te_", "lora_te1_")):
218
- key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
188
+ te_state_dict[diffusers_name] = state_dict.pop(key)
189
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
219
190
  else:
220
- key_to_replace = "lora_te2_"
221
-
222
- diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
223
- diffusers_name = diffusers_name.replace("text.model", "text_model")
224
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
225
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
226
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229
- diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230
-
231
- if "self_attn" in diffusers_name:
232
- if lora_name.startswith(("lora_te_", "lora_te1_")):
233
- te_state_dict[diffusers_name] = state_dict.pop(key)
234
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
235
- else:
236
- te2_state_dict[diffusers_name] = state_dict.pop(key)
237
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
238
- elif "mlp" in diffusers_name:
239
- # Be aware that this is the new diffusers convention and the rest of the code might
240
- # not utilize it yet.
241
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
242
- if lora_name.startswith(("lora_te_", "lora_te1_")):
243
- te_state_dict[diffusers_name] = state_dict.pop(key)
244
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
245
- else:
246
- te2_state_dict[diffusers_name] = state_dict.pop(key)
247
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248
- # OneTrainer specificity
249
- elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250
191
  te2_state_dict[diffusers_name] = state_dict.pop(key)
251
192
  te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
252
193
 
253
- if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
194
+ # Store DoRA scale if present.
195
+ if dora_present_in_te or dora_present_in_te2:
254
196
  dora_scale_key_to_replace_te = (
255
197
  "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
256
198
  )
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
263
205
  diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
264
206
  ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
265
207
 
266
- # Rename the alphas so that they can be mapped appropriately.
208
+ # Store alpha if present.
267
209
  if lora_name_alpha in state_dict:
268
210
  alpha = state_dict.pop(lora_name_alpha).item()
269
- if lora_name_alpha.startswith("lora_unet_"):
270
- prefix = "unet."
271
- elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
272
- prefix = "text_encoder."
273
- else:
274
- prefix = "text_encoder_2."
275
- new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
276
- network_alphas.update({new_name: alpha})
211
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
277
212
 
213
+ # Check if any keys remain.
278
214
  if len(state_dict) > 0:
279
215
  raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
280
216
 
281
- logger.info("Kohya-style checkpoint detected.")
217
+ logger.info("Non-diffusers checkpoint detected.")
218
+
219
+ # Construct final state dict.
282
220
  unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
283
221
  te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
284
222
  te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
291
229
 
292
230
  new_state_dict = {**unet_state_dict, **te_state_dict}
293
231
  return new_state_dict, network_alphas
232
+
233
+
234
+ def _convert_unet_lora_key(key):
235
+ """
236
+ Converts a U-Net LoRA key to a Diffusers compatible key.
237
+ """
238
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
239
+
240
+ # Replace common U-Net naming patterns.
241
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
242
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
243
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
244
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
245
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
246
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
247
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
248
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
249
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
250
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
251
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
252
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
253
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
254
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
255
+
256
+ # SDXL specific conversions.
257
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
258
+ pattern = r"\.\d+(?=\D*$)"
259
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
260
+ if ".in." in diffusers_name:
261
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
262
+ if ".out." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
264
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("op", "conv")
266
+ if "skip" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
268
+
269
+ # LyCORIS specific conversions.
270
+ if "time.emb.proj" in diffusers_name:
271
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
272
+ if "conv.shortcut" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
274
+
275
+ # General conversions.
276
+ if "transformer_blocks" in diffusers_name:
277
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
278
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
279
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
280
+ elif "ff" in diffusers_name:
281
+ pass
282
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
283
+ pass
284
+ else:
285
+ pass
286
+
287
+ return diffusers_name
288
+
289
+
290
+ def _convert_text_encoder_lora_key(key, lora_name):
291
+ """
292
+ Converts a text encoder LoRA key to a Diffusers compatible key.
293
+ """
294
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
295
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
296
+ else:
297
+ key_to_replace = "lora_te2_"
298
+
299
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
300
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
301
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
302
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
303
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
304
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
305
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
306
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
307
+
308
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
309
+ pass
310
+ elif "mlp" in diffusers_name:
311
+ # Be aware that this is the new diffusers convention and the rest of the code might
312
+ # not utilize it yet.
313
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
314
+ return diffusers_name
315
+
316
+
317
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
318
+ """
319
+ Gets the correct alpha name for the Diffusers model.
320
+ """
321
+ if lora_name_alpha.startswith("lora_unet_"):
322
+ prefix = "unet."
323
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
324
+ prefix = "text_encoder."
325
+ else:
326
+ prefix = "text_encoder_2."
327
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328
+ return {new_name: alpha}