diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ Conversion script for the Stable Diffusion checkpoints."""
15
+ """Conversion script for the Stable Diffusion checkpoints."""
16
16
 
17
17
  import os
18
18
  import re
@@ -26,7 +26,6 @@ import yaml
26
26
  from ..models.modeling_utils import load_state_dict
27
27
  from ..schedulers import (
28
28
  DDIMScheduler,
29
- DDPMScheduler,
30
29
  DPMSolverMultistepScheduler,
31
30
  EDMDPMSolverMultistepScheduler,
32
31
  EulerAncestralDiscreteScheduler,
@@ -35,133 +34,85 @@ from ..schedulers import (
35
34
  LMSDiscreteScheduler,
36
35
  PNDMScheduler,
37
36
  )
38
- from ..utils import is_accelerate_available, is_transformers_available, logging
37
+ from ..utils import (
38
+ SAFETENSORS_WEIGHTS_NAME,
39
+ WEIGHTS_NAME,
40
+ deprecate,
41
+ is_accelerate_available,
42
+ is_transformers_available,
43
+ logging,
44
+ )
39
45
  from ..utils.hub_utils import _get_model_file
40
46
 
41
47
 
42
48
  if is_transformers_available():
43
- from transformers import (
44
- CLIPTextConfig,
45
- CLIPTextModel,
46
- CLIPTextModelWithProjection,
47
- CLIPTokenizer,
48
- )
49
+ from transformers import AutoImageProcessor
49
50
 
50
51
  if is_accelerate_available():
51
52
  from accelerate import init_empty_weights
52
53
 
53
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+ from ..models.modeling_utils import load_model_dict_into_meta
54
55
 
55
- CONFIG_URLS = {
56
- "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
57
- "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml",
58
- "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml",
59
- "xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml",
60
- "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
61
- "controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml",
62
- }
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63
57
 
64
58
  CHECKPOINT_KEY_NAMES = {
65
59
  "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
66
60
  "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
67
61
  "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
62
+ "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
63
+ "controlnet": "control_model.time_embed.0.weight",
64
+ "playground-v2-5": "edm_mean",
65
+ "inpainting": "model.diffusion_model.input_blocks.0.0.weight",
66
+ "clip": "cond_stage_model.transformer.text_model.embeddings.position_ids",
67
+ "clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
68
+ "open_clip": "cond_stage_model.model.token_embedding.weight",
69
+ "open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
70
+ "open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
71
+ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
72
+ "stable_cascade_stage_c": "clip_txt_mapper.weight",
68
73
  }
69
74
 
70
- SCHEDULER_DEFAULT_CONFIG = {
71
- "beta_schedule": "scaled_linear",
72
- "beta_start": 0.00085,
73
- "beta_end": 0.012,
74
- "interpolation_type": "linear",
75
- "num_train_timesteps": 1000,
76
- "prediction_type": "epsilon",
77
- "sample_max_value": 1.0,
78
- "set_alpha_to_one": False,
79
- "skip_prk_steps": True,
80
- "steps_offset": 1,
81
- "timestep_spacing": "leading",
75
+ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
76
+ "xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
77
+ "xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
78
+ "xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
79
+ "playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
80
+ "upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
81
+ "inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"},
82
+ "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
83
+ "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
84
+ "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
85
+ "v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"},
86
+ "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
87
+ "stable_cascade_stage_b_lite": {
88
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade",
89
+ "subfolder": "decoder_lite",
90
+ },
91
+ "stable_cascade_stage_c": {
92
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
93
+ "subfolder": "prior",
94
+ },
95
+ "stable_cascade_stage_c_lite": {
96
+ "pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
97
+ "subfolder": "prior_lite",
98
+ },
82
99
  }
83
100
 
84
-
85
- STABLE_CASCADE_DEFAULT_CONFIGS = {
86
- "stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
87
- "stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
88
- "stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
89
- "stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
101
+ # Use to configure model sample size when original config is provided
102
+ DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
103
+ "xl_base": 1024,
104
+ "xl_refiner": 1024,
105
+ "xl_inpaint": 1024,
106
+ "playground-v2-5": 1024,
107
+ "upscale": 512,
108
+ "inpainting": 512,
109
+ "inpainting_v2": 512,
110
+ "controlnet": 512,
111
+ "v2": 768,
112
+ "v1": 512,
90
113
  }
91
114
 
92
115
 
93
- def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
94
- is_stage_c = "clip_txt_mapper.weight" in original_state_dict
95
-
96
- if is_stage_c:
97
- state_dict = {}
98
- for key in original_state_dict.keys():
99
- if key.endswith("in_proj_weight"):
100
- weights = original_state_dict[key].chunk(3, 0)
101
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
102
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
103
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
104
- elif key.endswith("in_proj_bias"):
105
- weights = original_state_dict[key].chunk(3, 0)
106
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
107
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
108
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
109
- elif key.endswith("out_proj.weight"):
110
- weights = original_state_dict[key]
111
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
112
- elif key.endswith("out_proj.bias"):
113
- weights = original_state_dict[key]
114
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
115
- else:
116
- state_dict[key] = original_state_dict[key]
117
- else:
118
- state_dict = {}
119
- for key in original_state_dict.keys():
120
- if key.endswith("in_proj_weight"):
121
- weights = original_state_dict[key].chunk(3, 0)
122
- state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
123
- state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
124
- state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
125
- elif key.endswith("in_proj_bias"):
126
- weights = original_state_dict[key].chunk(3, 0)
127
- state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
128
- state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
129
- state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
130
- elif key.endswith("out_proj.weight"):
131
- weights = original_state_dict[key]
132
- state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
133
- elif key.endswith("out_proj.bias"):
134
- weights = original_state_dict[key]
135
- state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
136
- # rename clip_mapper to clip_txt_pooled_mapper
137
- elif key.endswith("clip_mapper.weight"):
138
- weights = original_state_dict[key]
139
- state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
140
- elif key.endswith("clip_mapper.bias"):
141
- weights = original_state_dict[key]
142
- state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
143
- else:
144
- state_dict[key] = original_state_dict[key]
145
-
146
- return state_dict
147
-
148
-
149
- def infer_stable_cascade_single_file_config(checkpoint):
150
- is_stage_c = "clip_txt_mapper.weight" in checkpoint
151
- is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
152
-
153
- if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
154
- config_type = "stage_c_lite"
155
- elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
156
- config_type = "stage_c"
157
- elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
158
- config_type = "stage_b_lite"
159
- elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
160
- config_type = "stage_b"
161
-
162
- return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
163
-
164
-
165
116
  DIFFUSERS_TO_LDM_MAPPING = {
166
117
  "unet": {
167
118
  "layers": {
@@ -255,14 +206,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
255
206
  },
256
207
  }
257
208
 
258
- LDM_VAE_KEY = "first_stage_model."
259
- LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
260
- PLAYGROUND_VAE_SCALING_FACTOR = 0.5
261
- LDM_UNET_KEY = "model.diffusion_model."
262
- LDM_CONTROLNET_KEY = "control_model."
263
- LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
264
- LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
265
-
266
209
  SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
267
210
  "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
268
211
  "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
@@ -279,11 +222,51 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
279
222
  "cond_stage_model.model.text_projection",
280
223
  ]
281
224
 
225
+ # To support legacy scheduler_type argument
226
+ SCHEDULER_DEFAULT_CONFIG = {
227
+ "beta_schedule": "scaled_linear",
228
+ "beta_start": 0.00085,
229
+ "beta_end": 0.012,
230
+ "interpolation_type": "linear",
231
+ "num_train_timesteps": 1000,
232
+ "prediction_type": "epsilon",
233
+ "sample_max_value": 1.0,
234
+ "set_alpha_to_one": False,
235
+ "skip_prk_steps": True,
236
+ "steps_offset": 1,
237
+ "timestep_spacing": "leading",
238
+ }
239
+
240
+ LDM_VAE_KEY = "first_stage_model."
241
+ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
242
+ PLAYGROUND_VAE_SCALING_FACTOR = 0.5
243
+ LDM_UNET_KEY = "model.diffusion_model."
244
+ LDM_CONTROLNET_KEY = "control_model."
245
+ LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
246
+ OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
247
+ LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
282
248
 
283
249
  VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
284
250
 
285
251
 
252
+ class SingleFileComponentError(Exception):
253
+ def __init__(self, message=None):
254
+ self.message = message
255
+ super().__init__(self.message)
256
+
257
+
258
+ def is_valid_url(url):
259
+ result = urlparse(url)
260
+ if result.scheme and result.netloc:
261
+ return True
262
+
263
+ return False
264
+
265
+
286
266
  def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
267
+ if not is_valid_url(pretrained_model_name_or_path):
268
+ raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
269
+
287
270
  pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
288
271
  weights_name = None
289
272
  repo_id = (None,)
@@ -291,6 +274,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
291
274
  pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
292
275
  match = re.match(pattern, pretrained_model_name_or_path)
293
276
  if not match:
277
+ logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
294
278
  return repo_id, weights_name
295
279
 
296
280
  repo_id = f"{match.group(1)}/{match.group(2)}"
@@ -299,34 +283,18 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
299
283
  return repo_id, weights_name
300
284
 
301
285
 
302
- def fetch_ldm_config_and_checkpoint(
303
- pretrained_model_link_or_path,
304
- class_name,
305
- original_config_file=None,
306
- resume_download=False,
307
- force_download=False,
308
- proxies=None,
309
- token=None,
310
- cache_dir=None,
311
- local_files_only=None,
312
- revision=None,
313
- ):
314
- checkpoint = load_single_file_model_checkpoint(
315
- pretrained_model_link_or_path,
316
- resume_download=resume_download,
317
- force_download=force_download,
318
- proxies=proxies,
319
- token=token,
320
- cache_dir=cache_dir,
321
- local_files_only=local_files_only,
322
- revision=revision,
323
- )
324
- original_config = fetch_original_config(class_name, checkpoint, original_config_file)
286
+ def _is_model_weights_in_cached_folder(cached_folder, name):
287
+ pretrained_model_name_or_path = os.path.join(cached_folder, name)
288
+ weights_exist = False
325
289
 
326
- return original_config, checkpoint
290
+ for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
291
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
292
+ weights_exist = True
327
293
 
294
+ return weights_exist
328
295
 
329
- def load_single_file_model_checkpoint(
296
+
297
+ def load_single_file_checkpoint(
330
298
  pretrained_model_link_or_path,
331
299
  resume_download=False,
332
300
  force_download=False,
@@ -337,10 +305,11 @@ def load_single_file_model_checkpoint(
337
305
  revision=None,
338
306
  ):
339
307
  if os.path.isfile(pretrained_model_link_or_path):
340
- checkpoint = load_state_dict(pretrained_model_link_or_path)
308
+ pretrained_model_link_or_path = pretrained_model_link_or_path
309
+
341
310
  else:
342
311
  repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
343
- checkpoint_path = _get_model_file(
312
+ pretrained_model_link_or_path = _get_model_file(
344
313
  repo_id,
345
314
  weights_name=weights_name,
346
315
  force_download=force_download,
@@ -351,7 +320,8 @@ def load_single_file_model_checkpoint(
351
320
  token=token,
352
321
  revision=revision,
353
322
  )
354
- checkpoint = load_state_dict(checkpoint_path)
323
+
324
+ checkpoint = load_state_dict(pretrained_model_link_or_path)
355
325
 
356
326
  # some checkpoints contain the model state dict under a "state_dict" key
357
327
  while "state_dict" in checkpoint:
@@ -360,120 +330,154 @@ def load_single_file_model_checkpoint(
360
330
  return checkpoint
361
331
 
362
332
 
363
- def infer_original_config_file(class_name, checkpoint):
364
- if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
365
- config_url = CONFIG_URLS["v2"]
333
+ def fetch_original_config(original_config_file, local_files_only=False):
334
+ if os.path.isfile(original_config_file):
335
+ with open(original_config_file, "r") as fp:
336
+ original_config_file = fp.read()
366
337
 
367
- elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
368
- config_url = CONFIG_URLS["xl"]
338
+ elif is_valid_url(original_config_file):
339
+ if local_files_only:
340
+ raise ValueError(
341
+ "`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
342
+ "Please provide a valid local file path."
343
+ )
369
344
 
370
- elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
371
- config_url = CONFIG_URLS["xl_refiner"]
345
+ original_config_file = BytesIO(requests.get(original_config_file).content)
372
346
 
373
- elif class_name == "StableDiffusionUpscalePipeline":
374
- config_url = CONFIG_URLS["upscale"]
347
+ else:
348
+ raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
375
349
 
376
- elif class_name == "ControlNetModel":
377
- config_url = CONFIG_URLS["controlnet"]
350
+ original_config = yaml.safe_load(original_config_file)
378
351
 
379
- else:
380
- config_url = CONFIG_URLS["v1"]
352
+ return original_config
381
353
 
382
- original_config_file = BytesIO(requests.get(config_url).content)
383
354
 
384
- return original_config_file
355
+ def is_clip_model(checkpoint):
356
+ if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
357
+ return True
385
358
 
359
+ return False
386
360
 
387
- def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None):
388
- def is_valid_url(url):
389
- result = urlparse(url)
390
- if result.scheme and result.netloc:
391
- return True
392
361
 
393
- return False
362
+ def is_clip_sdxl_model(checkpoint):
363
+ if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
364
+ return True
394
365
 
395
- if original_config_file is None:
396
- original_config_file = infer_original_config_file(pipeline_class_name, checkpoint)
366
+ return False
397
367
 
398
- elif os.path.isfile(original_config_file):
399
- with open(original_config_file, "r") as fp:
400
- original_config_file = fp.read()
401
368
 
402
- elif is_valid_url(original_config_file):
403
- original_config_file = BytesIO(requests.get(original_config_file).content)
369
+ def is_open_clip_model(checkpoint):
370
+ if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
371
+ return True
404
372
 
405
- else:
406
- raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
373
+ return False
407
374
 
408
- original_config = yaml.safe_load(original_config_file)
409
375
 
410
- return original_config
376
+ def is_open_clip_sdxl_model(checkpoint):
377
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
378
+ return True
411
379
 
380
+ return False
412
381
 
413
- def infer_model_type(original_config, checkpoint, model_type=None):
414
- if model_type is not None:
415
- return model_type
416
382
 
417
- has_cond_stage_config = (
418
- "cond_stage_config" in original_config["model"]["params"]
419
- and original_config["model"]["params"]["cond_stage_config"] is not None
420
- )
421
- has_network_config = (
422
- "network_config" in original_config["model"]["params"]
423
- and original_config["model"]["params"]["network_config"] is not None
383
+ def is_open_clip_sdxl_refiner_model(checkpoint):
384
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
385
+ return True
386
+
387
+ return False
388
+
389
+
390
+ def is_clip_model_in_single_file(class_obj, checkpoint):
391
+ is_clip_in_checkpoint = any(
392
+ [
393
+ is_clip_model(checkpoint),
394
+ is_open_clip_model(checkpoint),
395
+ is_open_clip_sdxl_model(checkpoint),
396
+ is_open_clip_sdxl_refiner_model(checkpoint),
397
+ ]
424
398
  )
399
+ if (
400
+ class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
401
+ ) and is_clip_in_checkpoint:
402
+ return True
403
+
404
+ return False
425
405
 
426
- if has_cond_stage_config:
427
- model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
428
406
 
429
- elif has_network_config:
430
- context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
431
- if "edm_mean" in checkpoint and "edm_std" in checkpoint:
432
- model_type = "Playground"
433
- elif context_dim == 2048:
434
- model_type = "SDXL"
407
+ def infer_diffusers_model_type(checkpoint):
408
+ if (
409
+ CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
410
+ and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
411
+ ):
412
+ if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
413
+ model_type = "inpainting_v2"
435
414
  else:
436
- model_type = "SDXL-Refiner"
437
- else:
438
- raise ValueError("Unable to infer model type from config")
415
+ model_type = "inpainting"
439
416
 
440
- logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}")
417
+ elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
418
+ model_type = "v2"
441
419
 
442
- return model_type
420
+ elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
421
+ model_type = "playground-v2-5"
443
422
 
423
+ elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
424
+ model_type = "xl_base"
444
425
 
445
- def get_default_scheduler_config():
446
- return SCHEDULER_DEFAULT_CONFIG
426
+ elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
427
+ model_type = "xl_refiner"
447
428
 
429
+ elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
430
+ model_type = "upscale"
448
431
 
449
- def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None):
450
- if image_size:
451
- return image_size
432
+ elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
433
+ model_type = "controlnet"
452
434
 
453
- global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
454
- model_type = infer_model_type(original_config, checkpoint, model_type)
435
+ elif (
436
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
437
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
438
+ ):
439
+ model_type = "stable_cascade_stage_c_lite"
455
440
 
456
- if pipeline_class_name == "StableDiffusionUpscalePipeline":
457
- image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
458
- return image_size
441
+ elif (
442
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
443
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
444
+ ):
445
+ model_type = "stable_cascade_stage_c"
459
446
 
460
- elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
461
- image_size = 1024
462
- return image_size
447
+ elif (
448
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
449
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
450
+ ):
451
+ model_type = "stable_cascade_stage_b_lite"
463
452
 
464
453
  elif (
465
- "parameterization" in original_config["model"]["params"]
466
- and original_config["model"]["params"]["parameterization"] == "v"
454
+ CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
455
+ and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
467
456
  ):
468
- # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
469
- # as it relies on a brittle global step parameter here
470
- image_size = 512 if global_step == 875000 else 768
471
- return image_size
457
+ model_type = "stable_cascade_stage_b"
472
458
 
473
459
  else:
474
- image_size = 512
460
+ model_type = "v1"
461
+
462
+ return model_type
463
+
464
+
465
+ def fetch_diffusers_config(checkpoint):
466
+ model_type = infer_diffusers_model_type(checkpoint)
467
+ model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
468
+
469
+ return model_path
470
+
471
+
472
+ def set_image_size(checkpoint, image_size=None):
473
+ if image_size:
475
474
  return image_size
476
475
 
476
+ model_type = infer_diffusers_model_type(checkpoint)
477
+ image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
478
+
479
+ return image_size
480
+
477
481
 
478
482
  # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
479
483
  def conv_attn_to_linear(checkpoint):
@@ -488,10 +492,21 @@ def conv_attn_to_linear(checkpoint):
488
492
  checkpoint[key] = checkpoint[key][:, :, 0]
489
493
 
490
494
 
491
- def create_unet_diffusers_config(original_config, image_size: int):
495
+ def create_unet_diffusers_config_from_ldm(
496
+ original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
497
+ ):
492
498
  """
493
499
  Creates a config for the diffusers based on the config of the LDM model.
494
500
  """
501
+ if image_size is not None:
502
+ deprecation_message = (
503
+ "Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
504
+ "is deprecated and will be ignored in future versions."
505
+ )
506
+ deprecate("image_size", "1.0.0", deprecation_message)
507
+
508
+ image_size = set_image_size(checkpoint, image_size=image_size)
509
+
495
510
  if (
496
511
  "unet_config" in original_config["model"]["params"]
497
512
  and original_config["model"]["params"]["unet_config"] is not None
@@ -500,6 +515,16 @@ def create_unet_diffusers_config(original_config, image_size: int):
500
515
  else:
501
516
  unet_params = original_config["model"]["params"]["network_config"]["params"]
502
517
 
518
+ if num_in_channels is not None:
519
+ deprecation_message = (
520
+ "Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
521
+ "is deprecated and will be ignored in future versions."
522
+ )
523
+ deprecate("image_size", "1.0.0", deprecation_message)
524
+ in_channels = num_in_channels
525
+ else:
526
+ in_channels = unet_params["in_channels"]
527
+
503
528
  vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
504
529
  block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
505
530
 
@@ -564,7 +589,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
564
589
 
565
590
  config = {
566
591
  "sample_size": image_size // vae_scale_factor,
567
- "in_channels": unet_params["in_channels"],
592
+ "in_channels": in_channels,
568
593
  "down_block_types": down_block_types,
569
594
  "block_out_channels": block_out_channels,
570
595
  "layers_per_block": unet_params["num_res_blocks"],
@@ -578,6 +603,14 @@ def create_unet_diffusers_config(original_config, image_size: int):
578
603
  "transformer_layers_per_block": transformer_layers_per_block,
579
604
  }
580
605
 
606
+ if upcast_attention is not None:
607
+ deprecation_message = (
608
+ "Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
609
+ "is deprecated and will be ignored in future versions."
610
+ )
611
+ deprecate("image_size", "1.0.0", deprecation_message)
612
+ config["upcast_attention"] = upcast_attention
613
+
581
614
  if "disable_self_attentions" in unet_params:
582
615
  config["only_cross_attention"] = unet_params["disable_self_attentions"]
583
616
 
@@ -590,9 +623,18 @@ def create_unet_diffusers_config(original_config, image_size: int):
590
623
  return config
591
624
 
592
625
 
593
- def create_controlnet_diffusers_config(original_config, image_size: int):
626
+ def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
627
+ if image_size is not None:
628
+ deprecation_message = (
629
+ "Configuring ControlNetModel with the `image_size` argument"
630
+ "is deprecated and will be ignored in future versions."
631
+ )
632
+ deprecate("image_size", "1.0.0", deprecation_message)
633
+
634
+ image_size = set_image_size(checkpoint, image_size=image_size)
635
+
594
636
  unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
595
- diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
637
+ diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
596
638
 
597
639
  controlnet_config = {
598
640
  "conditioning_channels": unet_params["hint_channels"],
@@ -613,15 +655,33 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
613
655
  return controlnet_config
614
656
 
615
657
 
616
- def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
658
+ def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
617
659
  """
618
660
  Creates a config for the diffusers based on the config of the LDM model.
619
661
  """
662
+ if image_size is not None:
663
+ deprecation_message = (
664
+ "Configuring AutoencoderKL with the `image_size` argument"
665
+ "is deprecated and will be ignored in future versions."
666
+ )
667
+ deprecate("image_size", "1.0.0", deprecation_message)
668
+
669
+ image_size = set_image_size(checkpoint, image_size=image_size)
670
+
671
+ if "edm_mean" in checkpoint and "edm_std" in checkpoint:
672
+ latents_mean = checkpoint["edm_mean"]
673
+ latents_std = checkpoint["edm_std"]
674
+ else:
675
+ latents_mean = None
676
+ latents_std = None
677
+
620
678
  vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
621
679
  if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
622
680
  scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
681
+
623
682
  elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
624
683
  scaling_factor = original_config["model"]["params"]["scale_factor"]
684
+
625
685
  elif scaling_factor is None:
626
686
  scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
627
687
 
@@ -658,16 +718,104 @@ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, ma
658
718
  )
659
719
  if mapping:
660
720
  diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
661
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
721
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
662
722
 
663
723
 
664
724
  def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
665
725
  for ldm_key in ldm_keys:
666
726
  diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
667
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
727
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
728
+
668
729
 
730
+ def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
731
+ for ldm_key in keys:
732
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
733
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
734
+
735
+
736
+ def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
737
+ for ldm_key in keys:
738
+ diffusers_key = (
739
+ ldm_key.replace(mapping["old"], mapping["new"])
740
+ .replace("norm.weight", "group_norm.weight")
741
+ .replace("norm.bias", "group_norm.bias")
742
+ .replace("q.weight", "to_q.weight")
743
+ .replace("q.bias", "to_q.bias")
744
+ .replace("k.weight", "to_k.weight")
745
+ .replace("k.bias", "to_k.bias")
746
+ .replace("v.weight", "to_v.weight")
747
+ .replace("v.bias", "to_v.bias")
748
+ .replace("proj_out.weight", "to_out.0.weight")
749
+ .replace("proj_out.bias", "to_out.0.bias")
750
+ )
751
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
752
+
753
+ # proj_attn.weight has to be converted from conv 1D to linear
754
+ shape = new_checkpoint[diffusers_key].shape
669
755
 
670
- def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
756
+ if len(shape) == 3:
757
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
758
+ elif len(shape) == 4:
759
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
760
+
761
+
762
+ def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
763
+ is_stage_c = "clip_txt_mapper.weight" in checkpoint
764
+
765
+ if is_stage_c:
766
+ state_dict = {}
767
+ for key in checkpoint.keys():
768
+ if key.endswith("in_proj_weight"):
769
+ weights = checkpoint[key].chunk(3, 0)
770
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
771
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
772
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
773
+ elif key.endswith("in_proj_bias"):
774
+ weights = checkpoint[key].chunk(3, 0)
775
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
776
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
777
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
778
+ elif key.endswith("out_proj.weight"):
779
+ weights = checkpoint[key]
780
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
781
+ elif key.endswith("out_proj.bias"):
782
+ weights = checkpoint[key]
783
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
784
+ else:
785
+ state_dict[key] = checkpoint[key]
786
+ else:
787
+ state_dict = {}
788
+ for key in checkpoint.keys():
789
+ if key.endswith("in_proj_weight"):
790
+ weights = checkpoint[key].chunk(3, 0)
791
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
792
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
793
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
794
+ elif key.endswith("in_proj_bias"):
795
+ weights = checkpoint[key].chunk(3, 0)
796
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
797
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
798
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
799
+ elif key.endswith("out_proj.weight"):
800
+ weights = checkpoint[key]
801
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
802
+ elif key.endswith("out_proj.bias"):
803
+ weights = checkpoint[key]
804
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
805
+ # rename clip_mapper to clip_txt_pooled_mapper
806
+ elif key.endswith("clip_mapper.weight"):
807
+ weights = checkpoint[key]
808
+ state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
809
+ elif key.endswith("clip_mapper.bias"):
810
+ weights = checkpoint[key]
811
+ state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
812
+ else:
813
+ state_dict[key] = checkpoint[key]
814
+
815
+ return state_dict
816
+
817
+
818
+ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
671
819
  """
672
820
  Takes a state dict and a config, and returns a converted checkpoint.
673
821
  """
@@ -686,7 +834,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
686
834
  for key in keys:
687
835
  if key.startswith("model.diffusion_model"):
688
836
  flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
689
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
837
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
690
838
  else:
691
839
  if sum(k.startswith("model_ema") for k in keys) > 100:
692
840
  logger.warning(
@@ -695,7 +843,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
695
843
  )
696
844
  for key in keys:
697
845
  if key.startswith(unet_key):
698
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
846
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
699
847
 
700
848
  new_checkpoint = {}
701
849
  ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
@@ -756,10 +904,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
756
904
  )
757
905
 
758
906
  if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
759
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
907
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
760
908
  f"input_blocks.{i}.0.op.weight"
761
909
  )
762
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
910
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
763
911
  f"input_blocks.{i}.0.op.bias"
764
912
  )
765
913
 
@@ -773,19 +921,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
773
921
  )
774
922
 
775
923
  # Mid blocks
776
- resnet_0 = middle_blocks[0]
777
- attentions = middle_blocks[1]
778
- resnet_1 = middle_blocks[2]
779
-
780
- update_unet_resnet_ldm_to_diffusers(
781
- resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}
782
- )
783
- update_unet_resnet_ldm_to_diffusers(
784
- resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}
785
- )
786
- update_unet_attention_ldm_to_diffusers(
787
- attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}
788
- )
924
+ for key in middle_blocks.keys():
925
+ diffusers_key = max(key - 1, 0)
926
+ if key % 2 == 0:
927
+ update_unet_resnet_ldm_to_diffusers(
928
+ middle_blocks[key],
929
+ new_checkpoint,
930
+ unet_state_dict,
931
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
932
+ )
933
+ else:
934
+ update_unet_attention_ldm_to_diffusers(
935
+ middle_blocks[key],
936
+ new_checkpoint,
937
+ unet_state_dict,
938
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
939
+ )
789
940
 
790
941
  # Up Blocks
791
942
  for i in range(num_output_blocks):
@@ -834,6 +985,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
834
985
  def convert_controlnet_checkpoint(
835
986
  checkpoint,
836
987
  config,
988
+ **kwargs,
837
989
  ):
838
990
  # Some controlnet ckpt files are distributed independently from the rest of the
839
991
  # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
@@ -846,7 +998,7 @@ def convert_controlnet_checkpoint(
846
998
  controlnet_key = LDM_CONTROLNET_KEY
847
999
  for key in keys:
848
1000
  if key.startswith(controlnet_key):
849
- controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key)
1001
+ controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
850
1002
 
851
1003
  new_checkpoint = {}
852
1004
  ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
@@ -880,10 +1032,10 @@ def convert_controlnet_checkpoint(
880
1032
  )
881
1033
 
882
1034
  if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
883
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop(
1035
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
884
1036
  f"input_blocks.{i}.0.op.weight"
885
1037
  )
886
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop(
1038
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
887
1039
  f"input_blocks.{i}.0.op.bias"
888
1040
  )
889
1041
 
@@ -898,8 +1050,8 @@ def convert_controlnet_checkpoint(
898
1050
 
899
1051
  # controlnet down blocks
900
1052
  for i in range(num_input_blocks):
901
- new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight")
902
- new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias")
1053
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
1054
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
903
1055
 
904
1056
  # Retrieves the keys for the middle blocks only
905
1057
  num_middle_blocks = len(
@@ -909,33 +1061,28 @@ def convert_controlnet_checkpoint(
909
1061
  layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
910
1062
  for layer_id in range(num_middle_blocks)
911
1063
  }
912
- if middle_blocks:
913
- resnet_0 = middle_blocks[0]
914
- attentions = middle_blocks[1]
915
- resnet_1 = middle_blocks[2]
916
1064
 
917
- update_unet_resnet_ldm_to_diffusers(
918
- resnet_0,
919
- new_checkpoint,
920
- controlnet_state_dict,
921
- mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"},
922
- )
923
- update_unet_resnet_ldm_to_diffusers(
924
- resnet_1,
925
- new_checkpoint,
926
- controlnet_state_dict,
927
- mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"},
928
- )
929
- update_unet_attention_ldm_to_diffusers(
930
- attentions,
931
- new_checkpoint,
932
- controlnet_state_dict,
933
- mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"},
934
- )
1065
+ # Mid blocks
1066
+ for key in middle_blocks.keys():
1067
+ diffusers_key = max(key - 1, 0)
1068
+ if key % 2 == 0:
1069
+ update_unet_resnet_ldm_to_diffusers(
1070
+ middle_blocks[key],
1071
+ new_checkpoint,
1072
+ controlnet_state_dict,
1073
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
1074
+ )
1075
+ else:
1076
+ update_unet_attention_ldm_to_diffusers(
1077
+ middle_blocks[key],
1078
+ new_checkpoint,
1079
+ controlnet_state_dict,
1080
+ mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
1081
+ )
935
1082
 
936
1083
  # mid block
937
- new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight")
938
- new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias")
1084
+ new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
1085
+ new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
939
1086
 
940
1087
  # controlnet cond embedding blocks
941
1088
  cond_embedding_blocks = {
@@ -949,88 +1096,16 @@ def convert_controlnet_checkpoint(
949
1096
  diffusers_idx = idx - 1
950
1097
  cond_block_id = 2 * idx
951
1098
 
952
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop(
1099
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
953
1100
  f"input_hint_block.{cond_block_id}.weight"
954
1101
  )
955
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop(
1102
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
956
1103
  f"input_hint_block.{cond_block_id}.bias"
957
1104
  )
958
1105
 
959
1106
  return new_checkpoint
960
1107
 
961
1108
 
962
- def create_diffusers_controlnet_model_from_ldm(
963
- pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
964
- ):
965
- # import here to avoid circular imports
966
- from ..models import ControlNetModel
967
-
968
- image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
969
-
970
- diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
971
- diffusers_config["upcast_attention"] = upcast_attention
972
-
973
- diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
974
-
975
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
976
- with ctx():
977
- controlnet = ControlNetModel(**diffusers_config)
978
-
979
- if is_accelerate_available():
980
- from ..models.modeling_utils import load_model_dict_into_meta
981
-
982
- unexpected_keys = load_model_dict_into_meta(
983
- controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
984
- )
985
- if controlnet._keys_to_ignore_on_load_unexpected is not None:
986
- for pat in controlnet._keys_to_ignore_on_load_unexpected:
987
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
988
-
989
- if len(unexpected_keys) > 0:
990
- logger.warning(
991
- f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
992
- )
993
- else:
994
- controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
995
-
996
- if torch_dtype is not None:
997
- controlnet = controlnet.to(torch_dtype)
998
-
999
- return {"controlnet": controlnet}
1000
-
1001
-
1002
- def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
1003
- for ldm_key in keys:
1004
- diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
1005
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
1006
-
1007
-
1008
- def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
1009
- for ldm_key in keys:
1010
- diffusers_key = (
1011
- ldm_key.replace(mapping["old"], mapping["new"])
1012
- .replace("norm.weight", "group_norm.weight")
1013
- .replace("norm.bias", "group_norm.bias")
1014
- .replace("q.weight", "to_q.weight")
1015
- .replace("q.bias", "to_q.bias")
1016
- .replace("k.weight", "to_k.weight")
1017
- .replace("k.bias", "to_k.bias")
1018
- .replace("v.weight", "to_v.weight")
1019
- .replace("v.bias", "to_v.bias")
1020
- .replace("proj_out.weight", "to_out.0.weight")
1021
- .replace("proj_out.bias", "to_out.0.bias")
1022
- )
1023
- new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
1024
-
1025
- # proj_attn.weight has to be converted from conv 1D to linear
1026
- shape = new_checkpoint[diffusers_key].shape
1027
-
1028
- if len(shape) == 3:
1029
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
1030
- elif len(shape) == 4:
1031
- new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
1032
-
1033
-
1034
1109
  def convert_ldm_vae_checkpoint(checkpoint, config):
1035
1110
  # extract state dict for VAE
1036
1111
  # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
@@ -1063,10 +1138,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1063
1138
  mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
1064
1139
  )
1065
1140
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
1066
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
1141
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
1067
1142
  f"encoder.down.{i}.downsample.conv.weight"
1068
1143
  )
1069
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
1144
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
1070
1145
  f"encoder.down.{i}.downsample.conv.bias"
1071
1146
  )
1072
1147
 
@@ -1131,18 +1206,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1131
1206
  return new_checkpoint
1132
1207
 
1133
1208
 
1134
- def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, torch_dtype=None):
1135
- try:
1136
- config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
1137
- except Exception:
1138
- raise ValueError(
1139
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
1140
- )
1141
-
1142
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1143
- with ctx():
1144
- text_model = CLIPTextModel(config)
1145
-
1209
+ def convert_ldm_clip_checkpoint(checkpoint):
1146
1210
  keys = list(checkpoint.keys())
1147
1211
  text_model_dict = {}
1148
1212
 
@@ -1152,57 +1216,26 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
1152
1216
  for prefix in remove_prefixes:
1153
1217
  if key.startswith(prefix):
1154
1218
  diffusers_key = key.replace(prefix, "")
1155
- text_model_dict[diffusers_key] = checkpoint[key]
1156
-
1157
- if is_accelerate_available():
1158
- from ..models.modeling_utils import load_model_dict_into_meta
1159
-
1160
- unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
1161
- if text_model._keys_to_ignore_on_load_unexpected is not None:
1162
- for pat in text_model._keys_to_ignore_on_load_unexpected:
1163
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1219
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1164
1220
 
1165
- if len(unexpected_keys) > 0:
1166
- logger.warning(
1167
- f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
1168
- )
1169
- else:
1170
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1171
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1221
+ return text_model_dict
1172
1222
 
1173
- text_model.load_state_dict(text_model_dict)
1174
1223
 
1175
- if torch_dtype is not None:
1176
- text_model = text_model.to(torch_dtype)
1177
-
1178
- return text_model
1179
-
1180
-
1181
- def create_text_encoder_from_open_clip_checkpoint(
1182
- config_name,
1224
+ def convert_open_clip_checkpoint(
1225
+ text_model,
1183
1226
  checkpoint,
1184
1227
  prefix="cond_stage_model.model.",
1185
- has_projection=False,
1186
- local_files_only=False,
1187
- torch_dtype=None,
1188
- **config_kwargs,
1189
1228
  ):
1190
- try:
1191
- config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
1192
- except Exception:
1193
- raise ValueError(
1194
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
1195
- )
1196
-
1197
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1198
- with ctx():
1199
- text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
1200
-
1201
1229
  text_model_dict = {}
1202
1230
  text_proj_key = prefix + "text_projection"
1203
- text_proj_dim = (
1204
- int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1205
- )
1231
+
1232
+ if text_proj_key in checkpoint:
1233
+ text_proj_dim = int(checkpoint[text_proj_key].shape[0])
1234
+ elif hasattr(text_model.config, "projection_dim"):
1235
+ text_proj_dim = text_model.config.projection_dim
1236
+ else:
1237
+ text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1238
+
1206
1239
  text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1207
1240
 
1208
1241
  keys = list(checkpoint.keys())
@@ -1235,309 +1268,165 @@ def create_text_encoder_from_open_clip_checkpoint(
1235
1268
  )
1236
1269
 
1237
1270
  if key.endswith(".in_proj_weight"):
1238
- weight_value = checkpoint[key]
1271
+ weight_value = checkpoint.get(key)
1239
1272
 
1240
- text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :]
1241
- text_model_dict[diffusers_key + ".k_proj.weight"] = weight_value[text_proj_dim : text_proj_dim * 2, :]
1242
- text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :]
1273
+ text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
1274
+ text_model_dict[diffusers_key + ".k_proj.weight"] = (
1275
+ weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
1276
+ )
1277
+ text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
1243
1278
 
1244
1279
  elif key.endswith(".in_proj_bias"):
1245
- weight_value = checkpoint[key]
1246
- text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
1247
- text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
1248
- text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
1249
- else:
1250
- text_model_dict[diffusers_key] = checkpoint[key]
1251
-
1252
- if is_accelerate_available():
1253
- from ..models.modeling_utils import load_model_dict_into_meta
1254
-
1255
- unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
1256
- if text_model._keys_to_ignore_on_load_unexpected is not None:
1257
- for pat in text_model._keys_to_ignore_on_load_unexpected:
1258
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1259
-
1260
- if len(unexpected_keys) > 0:
1261
- logger.warning(
1262
- f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
1280
+ weight_value = checkpoint.get(key)
1281
+ text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
1282
+ text_model_dict[diffusers_key + ".k_proj.bias"] = (
1283
+ weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
1263
1284
  )
1285
+ text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
1286
+ else:
1287
+ text_model_dict[diffusers_key] = checkpoint.get(key)
1264
1288
 
1265
- else:
1266
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1267
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1268
-
1269
- text_model.load_state_dict(text_model_dict)
1270
-
1271
- if torch_dtype is not None:
1272
- text_model = text_model.to(torch_dtype)
1289
+ if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1290
+ text_model_dict.pop("text_model.embeddings.position_ids", None)
1273
1291
 
1274
- return text_model
1292
+ return text_model_dict
1275
1293
 
1276
1294
 
1277
- def create_diffusers_unet_model_from_ldm(
1278
- pipeline_class_name,
1279
- original_config,
1295
+ def create_diffusers_clip_model_from_ldm(
1296
+ cls,
1280
1297
  checkpoint,
1281
- num_in_channels=None,
1282
- upcast_attention=None,
1283
- extract_ema=False,
1284
- image_size=None,
1298
+ subfolder="",
1299
+ config=None,
1285
1300
  torch_dtype=None,
1286
- model_type=None,
1301
+ local_files_only=None,
1302
+ is_legacy_loading=False,
1287
1303
  ):
1288
- from ..models import UNet2DConditionModel
1304
+ if config:
1305
+ config = {"pretrained_model_name_or_path": config}
1306
+ else:
1307
+ config = fetch_diffusers_config(checkpoint)
1289
1308
 
1290
- if num_in_channels is None:
1291
- if pipeline_class_name in [
1292
- "StableDiffusionInpaintPipeline",
1293
- "StableDiffusionControlNetInpaintPipeline",
1294
- "StableDiffusionXLInpaintPipeline",
1295
- "StableDiffusionXLControlNetInpaintPipeline",
1296
- ]:
1297
- num_in_channels = 9
1309
+ # For backwards compatibility
1310
+ # Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
1311
+ # in the cache_dir, rather than in a subfolder of the Diffusers model
1312
+ if is_legacy_loading:
1313
+ logger.warning(
1314
+ (
1315
+ "Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
1316
+ "the local cache directory with the necessary CLIP model config files. "
1317
+ "Attempting to load CLIP model from legacy cache directory."
1318
+ )
1319
+ )
1298
1320
 
1299
- elif pipeline_class_name == "StableDiffusionUpscalePipeline":
1300
- num_in_channels = 7
1321
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1322
+ clip_config = "openai/clip-vit-large-patch14"
1323
+ config["pretrained_model_name_or_path"] = clip_config
1324
+ subfolder = ""
1301
1325
 
1302
- else:
1303
- num_in_channels = 4
1326
+ elif is_open_clip_model(checkpoint):
1327
+ clip_config = "stabilityai/stable-diffusion-2"
1328
+ config["pretrained_model_name_or_path"] = clip_config
1329
+ subfolder = "text_encoder"
1304
1330
 
1305
- image_size = set_image_size(
1306
- pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
1307
- )
1308
- unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
1309
- unet_config["in_channels"] = num_in_channels
1310
- if upcast_attention is not None:
1311
- unet_config["upcast_attention"] = upcast_attention
1331
+ else:
1332
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1333
+ config["pretrained_model_name_or_path"] = clip_config
1334
+ subfolder = ""
1312
1335
 
1313
- diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
1336
+ model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1314
1337
  ctx = init_empty_weights if is_accelerate_available() else nullcontext
1315
-
1316
1338
  with ctx():
1317
- unet = UNet2DConditionModel(**unet_config)
1339
+ model = cls(model_config)
1318
1340
 
1319
- if is_accelerate_available():
1320
- from ..models.modeling_utils import load_model_dict_into_meta
1341
+ position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
1321
1342
 
1322
- unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
1323
- if unet._keys_to_ignore_on_load_unexpected is not None:
1324
- for pat in unet._keys_to_ignore_on_load_unexpected:
1325
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1326
-
1327
- if len(unexpected_keys) > 0:
1328
- logger.warning(
1329
- f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
1330
- )
1331
- else:
1332
- unet.load_state_dict(diffusers_format_unet_checkpoint)
1333
-
1334
- if torch_dtype is not None:
1335
- unet = unet.to(torch_dtype)
1343
+ if is_clip_model(checkpoint):
1344
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1336
1345
 
1337
- return {"unet": unet}
1346
+ elif (
1347
+ is_clip_sdxl_model(checkpoint)
1348
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
1349
+ ):
1350
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1338
1351
 
1352
+ elif is_open_clip_model(checkpoint):
1353
+ prefix = "cond_stage_model.model."
1354
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1339
1355
 
1340
- def create_diffusers_vae_model_from_ldm(
1341
- pipeline_class_name,
1342
- original_config,
1343
- checkpoint,
1344
- image_size=None,
1345
- scaling_factor=None,
1346
- torch_dtype=None,
1347
- model_type=None,
1348
- ):
1349
- # import here to avoid circular imports
1350
- from ..models import AutoencoderKL
1356
+ elif (
1357
+ is_open_clip_sdxl_model(checkpoint)
1358
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
1359
+ ):
1360
+ prefix = "conditioner.embedders.1.model."
1361
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1351
1362
 
1352
- image_size = set_image_size(
1353
- pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
1354
- )
1355
- model_type = infer_model_type(original_config, checkpoint, model_type)
1363
+ elif is_open_clip_sdxl_refiner_model(checkpoint):
1364
+ prefix = "conditioner.embedders.0.model."
1365
+ diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1356
1366
 
1357
- if model_type == "Playground":
1358
- edm_mean = (
1359
- checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
1360
- )
1361
- edm_std = (
1362
- checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
1363
- )
1364
1367
  else:
1365
- edm_mean = None
1366
- edm_std = None
1367
-
1368
- vae_config = create_vae_diffusers_config(
1369
- original_config,
1370
- image_size=image_size,
1371
- scaling_factor=scaling_factor,
1372
- latents_mean=edm_mean,
1373
- latents_std=edm_std,
1374
- )
1375
- diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
1376
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
1377
-
1378
- with ctx():
1379
- vae = AutoencoderKL(**vae_config)
1368
+ raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1380
1369
 
1381
1370
  if is_accelerate_available():
1382
- from ..models.modeling_utils import load_model_dict_into_meta
1383
-
1384
- unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
1385
- if vae._keys_to_ignore_on_load_unexpected is not None:
1386
- for pat in vae._keys_to_ignore_on_load_unexpected:
1371
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1372
+ if model._keys_to_ignore_on_load_unexpected is not None:
1373
+ for pat in model._keys_to_ignore_on_load_unexpected:
1387
1374
  unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1388
1375
 
1389
1376
  if len(unexpected_keys) > 0:
1390
1377
  logger.warning(
1391
- f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
1378
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1392
1379
  )
1380
+
1393
1381
  else:
1394
- vae.load_state_dict(diffusers_format_vae_checkpoint)
1382
+ model.load_state_dict(diffusers_format_checkpoint)
1395
1383
 
1396
1384
  if torch_dtype is not None:
1397
- vae = vae.to(torch_dtype)
1385
+ model.to(torch_dtype)
1398
1386
 
1399
- return {"vae": vae}
1387
+ model.eval()
1400
1388
 
1389
+ return model
1401
1390
 
1402
- def create_text_encoders_and_tokenizers_from_ldm(
1403
- original_config,
1391
+
1392
+ def _legacy_load_scheduler(
1393
+ cls,
1404
1394
  checkpoint,
1405
- model_type=None,
1406
- local_files_only=False,
1407
- torch_dtype=None,
1395
+ component_name,
1396
+ original_config=None,
1397
+ **kwargs,
1408
1398
  ):
1409
- model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
1399
+ scheduler_type = kwargs.get("scheduler_type", None)
1400
+ prediction_type = kwargs.get("prediction_type", None)
1410
1401
 
1411
- if model_type == "FrozenOpenCLIPEmbedder":
1412
- config_name = "stabilityai/stable-diffusion-2"
1413
- config_kwargs = {"subfolder": "text_encoder"}
1414
-
1415
- try:
1416
- text_encoder = create_text_encoder_from_open_clip_checkpoint(
1417
- config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
1418
- )
1419
- tokenizer = CLIPTokenizer.from_pretrained(
1420
- config_name, subfolder="tokenizer", local_files_only=local_files_only
1421
- )
1422
- except Exception:
1423
- raise ValueError(
1424
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'."
1425
- )
1426
- else:
1427
- return {"text_encoder": text_encoder, "tokenizer": tokenizer}
1428
-
1429
- elif model_type == "FrozenCLIPEmbedder":
1430
- try:
1431
- config_name = "openai/clip-vit-large-patch14"
1432
- text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
1433
- config_name,
1434
- checkpoint,
1435
- local_files_only=local_files_only,
1436
- torch_dtype=torch_dtype,
1437
- )
1438
- tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
1439
-
1440
- except Exception:
1441
- raise ValueError(
1442
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
1443
- )
1444
- else:
1445
- return {"text_encoder": text_encoder, "tokenizer": tokenizer}
1446
-
1447
- elif model_type == "SDXL-Refiner":
1448
- config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1449
- config_kwargs = {"projection_dim": 1280}
1450
- prefix = "conditioner.embedders.0.model."
1451
-
1452
- try:
1453
- tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
1454
- text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
1455
- config_name,
1456
- checkpoint,
1457
- prefix=prefix,
1458
- has_projection=True,
1459
- local_files_only=local_files_only,
1460
- torch_dtype=torch_dtype,
1461
- **config_kwargs,
1462
- )
1463
- except Exception:
1464
- raise ValueError(
1465
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
1466
- )
1467
-
1468
- else:
1469
- return {
1470
- "text_encoder": None,
1471
- "tokenizer": None,
1472
- "tokenizer_2": tokenizer_2,
1473
- "text_encoder_2": text_encoder_2,
1474
- }
1475
-
1476
- elif model_type in ["SDXL", "Playground"]:
1477
- try:
1478
- config_name = "openai/clip-vit-large-patch14"
1479
- tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
1480
- text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
1481
- config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
1482
- )
1483
-
1484
- except Exception:
1485
- raise ValueError(
1486
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
1487
- )
1488
-
1489
- try:
1490
- config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1491
- config_kwargs = {"projection_dim": 1280}
1492
- prefix = "conditioner.embedders.1.model."
1493
- tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
1494
- text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
1495
- config_name,
1496
- checkpoint,
1497
- prefix=prefix,
1498
- has_projection=True,
1499
- local_files_only=local_files_only,
1500
- torch_dtype=torch_dtype,
1501
- **config_kwargs,
1502
- )
1503
- except Exception:
1504
- raise ValueError(
1505
- f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
1506
- )
1507
-
1508
- return {
1509
- "tokenizer": tokenizer,
1510
- "text_encoder": text_encoder,
1511
- "tokenizer_2": tokenizer_2,
1512
- "text_encoder_2": text_encoder_2,
1513
- }
1514
-
1515
- return
1402
+ if scheduler_type is not None:
1403
+ deprecation_message = (
1404
+ "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
1405
+ )
1406
+ deprecate("scheduler_type", "1.0.0", deprecation_message)
1516
1407
 
1408
+ if prediction_type is not None:
1409
+ deprecation_message = (
1410
+ "Please configure an instance of a Scheduler with the appropriate `prediction_type` "
1411
+ "and pass the object directly to the `scheduler` argument in `from_single_file`."
1412
+ )
1413
+ deprecate("prediction_type", "1.0.0", deprecation_message)
1517
1414
 
1518
- def create_scheduler_from_ldm(
1519
- pipeline_class_name,
1520
- original_config,
1521
- checkpoint,
1522
- prediction_type=None,
1523
- scheduler_type="ddim",
1524
- model_type=None,
1525
- ):
1526
- scheduler_config = get_default_scheduler_config()
1527
- model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
1415
+ scheduler_config = SCHEDULER_DEFAULT_CONFIG
1416
+ model_type = infer_diffusers_model_type(checkpoint=checkpoint)
1528
1417
 
1529
1418
  global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
1530
1419
 
1531
- num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
1420
+ if original_config:
1421
+ num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
1422
+ else:
1423
+ num_train_timesteps = 1000
1424
+
1532
1425
  scheduler_config["num_train_timesteps"] = num_train_timesteps
1533
1426
 
1534
- if (
1535
- "parameterization" in original_config["model"]["params"]
1536
- and original_config["model"]["params"]["parameterization"] == "v"
1537
- ):
1427
+ if model_type == "v2":
1538
1428
  if prediction_type is None:
1539
- # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
1540
- # as it relies on a brittle global step parameter here
1429
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
1541
1430
  prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
1542
1431
 
1543
1432
  else:
@@ -1545,20 +1434,44 @@ def create_scheduler_from_ldm(
1545
1434
 
1546
1435
  scheduler_config["prediction_type"] = prediction_type
1547
1436
 
1548
- if model_type in ["SDXL", "SDXL-Refiner"]:
1437
+ if model_type in ["xl_base", "xl_refiner"]:
1549
1438
  scheduler_type = "euler"
1550
- elif model_type == "Playground":
1439
+ elif model_type == "playground":
1551
1440
  scheduler_type = "edm_dpm_solver_multistep"
1552
1441
  else:
1553
- beta_start = original_config["model"]["params"].get("linear_start", 0.02)
1554
- beta_end = original_config["model"]["params"].get("linear_end", 0.085)
1442
+ if original_config:
1443
+ beta_start = original_config["model"]["params"].get("linear_start")
1444
+ beta_end = original_config["model"]["params"].get("linear_end")
1445
+
1446
+ else:
1447
+ beta_start = 0.02
1448
+ beta_end = 0.085
1449
+
1555
1450
  scheduler_config["beta_start"] = beta_start
1556
1451
  scheduler_config["beta_end"] = beta_end
1557
1452
  scheduler_config["beta_schedule"] = "scaled_linear"
1558
1453
  scheduler_config["clip_sample"] = False
1559
1454
  scheduler_config["set_alpha_to_one"] = False
1560
1455
 
1561
- if scheduler_type == "pndm":
1456
+ # to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
1457
+ if component_name == "low_res_scheduler":
1458
+ return cls.from_config(
1459
+ {
1460
+ "beta_end": 0.02,
1461
+ "beta_schedule": "scaled_linear",
1462
+ "beta_start": 0.0001,
1463
+ "clip_sample": True,
1464
+ "num_train_timesteps": 1000,
1465
+ "prediction_type": "epsilon",
1466
+ "trained_betas": None,
1467
+ "variance_type": "fixed_small",
1468
+ }
1469
+ )
1470
+
1471
+ if scheduler_type is None:
1472
+ return cls.from_config(scheduler_config)
1473
+
1474
+ elif scheduler_type == "pndm":
1562
1475
  scheduler_config["skip_prk_steps"] = True
1563
1476
  scheduler = PNDMScheduler.from_config(scheduler_config)
1564
1477
 
@@ -1603,15 +1516,46 @@ def create_scheduler_from_ldm(
1603
1516
  else:
1604
1517
  raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
1605
1518
 
1606
- if pipeline_class_name == "StableDiffusionUpscalePipeline":
1607
- scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
1608
- low_res_scheduler = DDPMScheduler.from_pretrained(
1609
- "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
1610
- )
1519
+ return scheduler
1611
1520
 
1612
- return {
1613
- "scheduler": scheduler,
1614
- "low_res_scheduler": low_res_scheduler,
1615
- }
1616
1521
 
1617
- return {"scheduler": scheduler}
1522
+ def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
1523
+ if config:
1524
+ config = {"pretrained_model_name_or_path": config}
1525
+ else:
1526
+ config = fetch_diffusers_config(checkpoint)
1527
+
1528
+ if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
1529
+ clip_config = "openai/clip-vit-large-patch14"
1530
+ config["pretrained_model_name_or_path"] = clip_config
1531
+ subfolder = ""
1532
+
1533
+ elif is_open_clip_model(checkpoint):
1534
+ clip_config = "stabilityai/stable-diffusion-2"
1535
+ config["pretrained_model_name_or_path"] = clip_config
1536
+ subfolder = "tokenizer"
1537
+
1538
+ else:
1539
+ clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1540
+ config["pretrained_model_name_or_path"] = clip_config
1541
+ subfolder = ""
1542
+
1543
+ tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
1544
+
1545
+ return tokenizer
1546
+
1547
+
1548
+ def _legacy_load_safety_checker(local_files_only, torch_dtype):
1549
+ # Support for loading safety checker components using the deprecated
1550
+ # `load_safety_checker` argument.
1551
+
1552
+ from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1553
+
1554
+ feature_extractor = AutoImageProcessor.from_pretrained(
1555
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1556
+ )
1557
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
1558
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
1559
+ )
1560
+
1561
+ return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}