diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -62,7 +62,14 @@ CHECKPOINT_KEY_NAMES = {
62
62
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
63
63
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
64
64
  "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
65
- "controlnet": "control_model.time_embed.0.weight",
65
+ "controlnet": [
66
+ "control_model.time_embed.0.weight",
67
+ "controlnet_cond_embedding.conv_in.weight",
68
+ ],
69
+ # TODO: find non-Diffusers keys for controlnet_xl
70
+ "controlnet_xl": "add_embedding.linear_1.weight",
71
+ "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
72
+ "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
66
73
  "playground-v2-5": "edm_mean",
67
74
  "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
68
75
  "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
@@ -74,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
74
81
  "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
75
82
  "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
76
83
  "stable_cascade_stage_c": "clip_txt_mapper.weight",
77
- "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
78
- "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
84
+ "sd3": [
85
+ "joint_blocks.0.context_block.adaLN_modulation.1.bias",
86
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
87
+ ],
88
+ "sd35_large": [
89
+ "joint_blocks.37.x_block.mlp.fc1.weight",
90
+ "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
91
+ ],
79
92
  "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
80
93
  "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
81
94
  "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -85,6 +98,17 @@ CHECKPOINT_KEY_NAMES = {
85
98
  "double_blocks.0.img_attn.norm.key_norm.scale",
86
99
  "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
87
100
  ],
101
+ "ltx-video": [
102
+ "model.diffusion_model.patchify_proj.weight",
103
+ "model.diffusion_model.transformer_blocks.27.scale_shift_table",
104
+ "patchify_proj.weight",
105
+ "transformer_blocks.27.scale_shift_table",
106
+ "vae.per_channel_statistics.mean-of-means",
107
+ ],
108
+ "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
109
+ "autoencoder-dc-sana": "encoder.project_in.conv.bias",
110
+ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
111
+ "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
88
112
  }
89
113
 
90
114
  DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -96,6 +120,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
96
120
  "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
97
121
  "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
98
122
  "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
123
+ "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
124
+ "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
125
+ "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
99
126
  "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
100
127
  "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
101
128
  "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
@@ -117,6 +144,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
117
144
  "sd35_large": {
118
145
  "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
119
146
  },
147
+ "sd35_medium": {
148
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
149
+ },
120
150
  "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
121
151
  "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
122
152
  "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -124,7 +154,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
124
154
  "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
125
155
  "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
126
156
  "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
157
+ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
158
+ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
127
159
  "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
160
+ "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
161
+ "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
162
+ "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
163
+ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
164
+ "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
165
+ "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
166
+ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
167
+ "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
128
168
  }
129
169
 
130
170
  # Use to configure model sample size when original config is provided
@@ -481,8 +521,16 @@ def infer_diffusers_model_type(checkpoint):
481
521
  elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
482
522
  model_type = "upscale"
483
523
 
484
- elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
485
- model_type = "controlnet"
524
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
525
+ if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
526
+ if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
527
+ model_type = "controlnet_xl_large"
528
+ elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
529
+ model_type = "controlnet_xl_mid"
530
+ else:
531
+ model_type = "controlnet_xl_small"
532
+ else:
533
+ model_type = "controlnet"
486
534
 
487
535
  elif (
488
536
  CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
@@ -508,10 +556,20 @@ def infer_diffusers_model_type(checkpoint):
508
556
  ):
509
557
  model_type = "stable_cascade_stage_b"
510
558
 
511
- elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
512
- model_type = "sd3"
559
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
560
+ checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
561
+ ):
562
+ if "model.diffusion_model.pos_embed" in checkpoint:
563
+ key = "model.diffusion_model.pos_embed"
564
+ else:
565
+ key = "pos_embed"
566
+
567
+ if checkpoint[key].shape[1] == 36864:
568
+ model_type = "sd3"
569
+ elif checkpoint[key].shape[1] == 147456:
570
+ model_type = "sd35_medium"
513
571
 
514
- elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
572
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
515
573
  model_type = "sd35_large"
516
574
 
517
575
  elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
@@ -537,9 +595,44 @@ def infer_diffusers_model_type(checkpoint):
537
595
  if any(
538
596
  g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
539
597
  ):
540
- model_type = "flux-dev"
598
+ if checkpoint["img_in.weight"].shape[1] == 384:
599
+ model_type = "flux-fill"
600
+
601
+ elif checkpoint["img_in.weight"].shape[1] == 128:
602
+ model_type = "flux-depth"
603
+ else:
604
+ model_type = "flux-dev"
541
605
  else:
542
606
  model_type = "flux-schnell"
607
+
608
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
609
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
610
+ model_type = "ltx-video-0.9.1"
611
+ else:
612
+ model_type = "ltx-video"
613
+
614
+ elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
615
+ encoder_key = "encoder.project_in.conv.conv.bias"
616
+ decoder_key = "decoder.project_in.main.conv.weight"
617
+
618
+ if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
619
+ model_type = "autoencoder-dc-f32c32-sana"
620
+
621
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
622
+ model_type = "autoencoder-dc-f32c32"
623
+
624
+ elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
625
+ model_type = "autoencoder-dc-f64c128"
626
+
627
+ else:
628
+ model_type = "autoencoder-dc-f128c512"
629
+
630
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
631
+ model_type = "mochi-1-preview"
632
+
633
+ elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
634
+ model_type = "hunyuan-video"
635
+
543
636
  else:
544
637
  model_type = "v1"
545
638
 
@@ -1072,6 +1165,9 @@ def convert_controlnet_checkpoint(
1072
1165
  config,
1073
1166
  **kwargs,
1074
1167
  ):
1168
+ # Return checkpoint if it's already been converted
1169
+ if "time_embedding.linear_1.weight" in checkpoint:
1170
+ return checkpoint
1075
1171
  # Some controlnet ckpt files are distributed independently from the rest of the
1076
1172
  # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
1077
1173
  if "time_embed.0.weight" in checkpoint:
@@ -1677,6 +1773,12 @@ def swap_scale_shift(weight, dim):
1677
1773
  return new_weight
1678
1774
 
1679
1775
 
1776
+ def swap_proj_gate(weight):
1777
+ proj, gate = weight.chunk(2, dim=0)
1778
+ new_weight = torch.cat([gate, proj], dim=0)
1779
+ return new_weight
1780
+
1781
+
1680
1782
  def get_attn2_layers(state_dict):
1681
1783
  attn2_layers = []
1682
1784
  for key in state_dict.keys():
@@ -2171,3 +2273,411 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2171
2273
  )
2172
2274
 
2173
2275
  return converted_state_dict
2276
+
2277
+
2278
+ def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2279
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
2280
+
2281
+ TRANSFORMER_KEYS_RENAME_DICT = {
2282
+ "model.diffusion_model.": "",
2283
+ "patchify_proj": "proj_in",
2284
+ "adaln_single": "time_embed",
2285
+ "q_norm": "norm_q",
2286
+ "k_norm": "norm_k",
2287
+ }
2288
+
2289
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {}
2290
+
2291
+ for key in list(converted_state_dict.keys()):
2292
+ new_key = key
2293
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2294
+ new_key = new_key.replace(replace_key, rename_key)
2295
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2296
+
2297
+ for key in list(converted_state_dict.keys()):
2298
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2299
+ if special_key not in key:
2300
+ continue
2301
+ handler_fn_inplace(key, converted_state_dict)
2302
+
2303
+ return converted_state_dict
2304
+
2305
+
2306
+ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
2307
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
2308
+
2309
+ def remove_keys_(key: str, state_dict):
2310
+ state_dict.pop(key)
2311
+
2312
+ VAE_KEYS_RENAME_DICT = {
2313
+ # common
2314
+ "vae.": "",
2315
+ # decoder
2316
+ "up_blocks.0": "mid_block",
2317
+ "up_blocks.1": "up_blocks.0",
2318
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
2319
+ "up_blocks.3": "up_blocks.1",
2320
+ "up_blocks.4": "up_blocks.2.conv_in",
2321
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2322
+ "up_blocks.6": "up_blocks.2",
2323
+ "up_blocks.7": "up_blocks.3.conv_in",
2324
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
2325
+ "up_blocks.9": "up_blocks.3",
2326
+ # encoder
2327
+ "down_blocks.0": "down_blocks.0",
2328
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
2329
+ "down_blocks.2": "down_blocks.0.conv_out",
2330
+ "down_blocks.3": "down_blocks.1",
2331
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
2332
+ "down_blocks.5": "down_blocks.1.conv_out",
2333
+ "down_blocks.6": "down_blocks.2",
2334
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
2335
+ "down_blocks.8": "down_blocks.3",
2336
+ "down_blocks.9": "mid_block",
2337
+ # common
2338
+ "conv_shortcut": "conv_shortcut.conv",
2339
+ "res_blocks": "resnets",
2340
+ "norm3.norm": "norm3",
2341
+ "per_channel_statistics.mean-of-means": "latents_mean",
2342
+ "per_channel_statistics.std-of-means": "latents_std",
2343
+ }
2344
+
2345
+ VAE_091_RENAME_DICT = {
2346
+ # decoder
2347
+ "up_blocks.0": "mid_block",
2348
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
2349
+ "up_blocks.2": "up_blocks.0",
2350
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
2351
+ "up_blocks.4": "up_blocks.1",
2352
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
2353
+ "up_blocks.6": "up_blocks.2",
2354
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
2355
+ "up_blocks.8": "up_blocks.3",
2356
+ # common
2357
+ "last_time_embedder": "time_embedder",
2358
+ "last_scale_shift_table": "scale_shift_table",
2359
+ }
2360
+
2361
+ VAE_SPECIAL_KEYS_REMAP = {
2362
+ "per_channel_statistics.channel": remove_keys_,
2363
+ "per_channel_statistics.mean-of-means": remove_keys_,
2364
+ "per_channel_statistics.mean-of-stds": remove_keys_,
2365
+ "timestep_scale_multiplier": remove_keys_,
2366
+ }
2367
+
2368
+ if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
2369
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
2370
+
2371
+ for key in list(converted_state_dict.keys()):
2372
+ new_key = key
2373
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
2374
+ new_key = new_key.replace(replace_key, rename_key)
2375
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2376
+
2377
+ for key in list(converted_state_dict.keys()):
2378
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
2379
+ if special_key not in key:
2380
+ continue
2381
+ handler_fn_inplace(key, converted_state_dict)
2382
+
2383
+ return converted_state_dict
2384
+
2385
+
2386
+ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
2387
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2388
+
2389
+ def remap_qkv_(key: str, state_dict):
2390
+ qkv = state_dict.pop(key)
2391
+ q, k, v = torch.chunk(qkv, 3, dim=0)
2392
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
2393
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2394
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2395
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
2396
+
2397
+ def remap_proj_conv_(key: str, state_dict):
2398
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
2399
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2400
+
2401
+ AE_KEYS_RENAME_DICT = {
2402
+ # common
2403
+ "main.": "",
2404
+ "op_list.": "",
2405
+ "context_module": "attn",
2406
+ "local_module": "conv_out",
2407
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
2408
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
2409
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
2410
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
2411
+ "depth_conv.conv": "conv_depth",
2412
+ "inverted_conv.conv": "conv_inverted",
2413
+ "point_conv.conv": "conv_point",
2414
+ "point_conv.norm": "norm",
2415
+ "conv.conv.": "conv.",
2416
+ "conv1.conv": "conv1",
2417
+ "conv2.conv": "conv2",
2418
+ "conv2.norm": "norm",
2419
+ "proj.norm": "norm_out",
2420
+ # encoder
2421
+ "encoder.project_in.conv": "encoder.conv_in",
2422
+ "encoder.project_out.0.conv": "encoder.conv_out",
2423
+ "encoder.stages": "encoder.down_blocks",
2424
+ # decoder
2425
+ "decoder.project_in.conv": "decoder.conv_in",
2426
+ "decoder.project_out.0": "decoder.norm_out",
2427
+ "decoder.project_out.2.conv": "decoder.conv_out",
2428
+ "decoder.stages": "decoder.up_blocks",
2429
+ }
2430
+
2431
+ AE_F32C32_F64C128_F128C512_KEYS = {
2432
+ "encoder.project_in.conv": "encoder.conv_in.conv",
2433
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
2434
+ }
2435
+
2436
+ AE_SPECIAL_KEYS_REMAP = {
2437
+ "qkv.conv.weight": remap_qkv_,
2438
+ "proj.conv.weight": remap_proj_conv_,
2439
+ }
2440
+ if "encoder.project_in.conv.bias" not in converted_state_dict:
2441
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
2442
+
2443
+ for key in list(converted_state_dict.keys()):
2444
+ new_key = key[:]
2445
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
2446
+ new_key = new_key.replace(replace_key, rename_key)
2447
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
2448
+
2449
+ for key in list(converted_state_dict.keys()):
2450
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
2451
+ if special_key not in key:
2452
+ continue
2453
+ handler_fn_inplace(key, converted_state_dict)
2454
+
2455
+ return converted_state_dict
2456
+
2457
+
2458
+ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
2459
+ new_state_dict = {}
2460
+
2461
+ # Comfy checkpoints add this prefix
2462
+ keys = list(checkpoint.keys())
2463
+ for k in keys:
2464
+ if "model.diffusion_model." in k:
2465
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2466
+
2467
+ # Convert patch_embed
2468
+ new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
2469
+ new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
2470
+
2471
+ # Convert time_embed
2472
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
2473
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
2474
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
2475
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
2476
+ new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
2477
+ new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
2478
+ new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
2479
+ new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
2480
+ new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
2481
+ new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
2482
+ new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
2483
+ new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
2484
+
2485
+ # Convert transformer blocks
2486
+ num_layers = 48
2487
+ for i in range(num_layers):
2488
+ block_prefix = f"transformer_blocks.{i}."
2489
+ old_prefix = f"blocks.{i}."
2490
+
2491
+ # norm1
2492
+ new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
2493
+ new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
2494
+ if i < num_layers - 1:
2495
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
2496
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2497
+ else:
2498
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
2499
+ old_prefix + "mod_y.weight"
2500
+ )
2501
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
2502
+
2503
+ # Visual attention
2504
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
2505
+ q, k, v = qkv_weight.chunk(3, dim=0)
2506
+
2507
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
2508
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
2509
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
2510
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
2511
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
2512
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
2513
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
2514
+
2515
+ # Context attention
2516
+ qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
2517
+ q, k, v = qkv_weight.chunk(3, dim=0)
2518
+
2519
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
2520
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
2521
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
2522
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
2523
+ old_prefix + "attn.q_norm_y.weight"
2524
+ )
2525
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
2526
+ old_prefix + "attn.k_norm_y.weight"
2527
+ )
2528
+ if i < num_layers - 1:
2529
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
2530
+ old_prefix + "attn.proj_y.weight"
2531
+ )
2532
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
2533
+
2534
+ # MLP
2535
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
2536
+ checkpoint.pop(old_prefix + "mlp_x.w1.weight")
2537
+ )
2538
+ new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
2539
+ if i < num_layers - 1:
2540
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
2541
+ checkpoint.pop(old_prefix + "mlp_y.w1.weight")
2542
+ )
2543
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
2544
+
2545
+ # Output layers
2546
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
2547
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
2548
+ new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
2549
+ new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
2550
+
2551
+ new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
2552
+
2553
+ return new_state_dict
2554
+
2555
+
2556
+ def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
2557
+ def remap_norm_scale_shift_(key, state_dict):
2558
+ weight = state_dict.pop(key)
2559
+ shift, scale = weight.chunk(2, dim=0)
2560
+ new_weight = torch.cat([scale, shift], dim=0)
2561
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
2562
+
2563
+ def remap_txt_in_(key, state_dict):
2564
+ def rename_key(key):
2565
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
2566
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
2567
+ new_key = new_key.replace("txt_in", "context_embedder")
2568
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
2569
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
2570
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
2571
+ new_key = new_key.replace("mlp", "ff")
2572
+ return new_key
2573
+
2574
+ if "self_attn_qkv" in key:
2575
+ weight = state_dict.pop(key)
2576
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2577
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
2578
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
2579
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
2580
+ else:
2581
+ state_dict[rename_key(key)] = state_dict.pop(key)
2582
+
2583
+ def remap_img_attn_qkv_(key, state_dict):
2584
+ weight = state_dict.pop(key)
2585
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2586
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
2587
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
2588
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
2589
+
2590
+ def remap_txt_attn_qkv_(key, state_dict):
2591
+ weight = state_dict.pop(key)
2592
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
2593
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
2594
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
2595
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
2596
+
2597
+ def remap_single_transformer_blocks_(key, state_dict):
2598
+ hidden_size = 3072
2599
+
2600
+ if "linear1.weight" in key:
2601
+ linear1_weight = state_dict.pop(key)
2602
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
2603
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
2604
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
2605
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
2606
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
2607
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
2608
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
2609
+
2610
+ elif "linear1.bias" in key:
2611
+ linear1_bias = state_dict.pop(key)
2612
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
2613
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
2614
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
2615
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
2616
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
2617
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
2618
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
2619
+
2620
+ else:
2621
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
2622
+ new_key = new_key.replace("linear2", "proj_out")
2623
+ new_key = new_key.replace("q_norm", "attn.norm_q")
2624
+ new_key = new_key.replace("k_norm", "attn.norm_k")
2625
+ state_dict[new_key] = state_dict.pop(key)
2626
+
2627
+ TRANSFORMER_KEYS_RENAME_DICT = {
2628
+ "img_in": "x_embedder",
2629
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
2630
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
2631
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
2632
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
2633
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
2634
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
2635
+ "double_blocks": "transformer_blocks",
2636
+ "img_attn_q_norm": "attn.norm_q",
2637
+ "img_attn_k_norm": "attn.norm_k",
2638
+ "img_attn_proj": "attn.to_out.0",
2639
+ "txt_attn_q_norm": "attn.norm_added_q",
2640
+ "txt_attn_k_norm": "attn.norm_added_k",
2641
+ "txt_attn_proj": "attn.to_add_out",
2642
+ "img_mod.linear": "norm1.linear",
2643
+ "img_norm1": "norm1.norm",
2644
+ "img_norm2": "norm2",
2645
+ "img_mlp": "ff",
2646
+ "txt_mod.linear": "norm1_context.linear",
2647
+ "txt_norm1": "norm1.norm",
2648
+ "txt_norm2": "norm2_context",
2649
+ "txt_mlp": "ff_context",
2650
+ "self_attn_proj": "attn.to_out.0",
2651
+ "modulation.linear": "norm.linear",
2652
+ "pre_norm": "norm.norm",
2653
+ "final_layer.norm_final": "norm_out.norm",
2654
+ "final_layer.linear": "proj_out",
2655
+ "fc1": "net.0.proj",
2656
+ "fc2": "net.2",
2657
+ "input_embedder": "proj_in",
2658
+ }
2659
+
2660
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
2661
+ "txt_in": remap_txt_in_,
2662
+ "img_attn_qkv": remap_img_attn_qkv_,
2663
+ "txt_attn_qkv": remap_txt_attn_qkv_,
2664
+ "single_blocks": remap_single_transformer_blocks_,
2665
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
2666
+ }
2667
+
2668
+ def update_state_dict_(state_dict, old_key, new_key):
2669
+ state_dict[new_key] = state_dict.pop(old_key)
2670
+
2671
+ for key in list(checkpoint.keys()):
2672
+ new_key = key[:]
2673
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
2674
+ new_key = new_key.replace(replace_key, rename_key)
2675
+ update_state_dict_(checkpoint, key, new_key)
2676
+
2677
+ for key in list(checkpoint.keys()):
2678
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
2679
+ if special_key not in key:
2680
+ continue
2681
+ handler_fn_inplace(key, checkpoint)
2682
+
2683
+ return checkpoint
@@ -497,19 +497,19 @@ class TextualInversionLoaderMixin:
497
497
  # load embeddings of text_encoder 1 (CLIP ViT-L/14)
498
498
  pipeline.load_textual_inversion(
499
499
  state_dict["clip_l"],
500
- token=["<s0>", "<s1>"],
500
+ tokens=["<s0>", "<s1>"],
501
501
  text_encoder=pipeline.text_encoder,
502
502
  tokenizer=pipeline.tokenizer,
503
503
  )
504
504
  # load embeddings of text_encoder 2 (CLIP ViT-G/14)
505
505
  pipeline.load_textual_inversion(
506
506
  state_dict["clip_g"],
507
- token=["<s0>", "<s1>"],
507
+ tokens=["<s0>", "<s1>"],
508
508
  text_encoder=pipeline.text_encoder_2,
509
509
  tokenizer=pipeline.tokenizer_2,
510
510
  )
511
511
 
512
- # Unload explicitly from both text encoders abd tokenizers
512
+ # Unload explicitly from both text encoders and tokenizers
513
513
  pipeline.unload_textual_inversion(
514
514
  tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
515
515
  )