diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
61
61
  return math.exp(t * -12.0)
62
62
 
63
63
  else:
64
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
64
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
65
65
 
66
66
  betas = []
67
67
  for i in range(num_diffusion_timesteps):
@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
71
71
  return torch.tensor(betas, dtype=torch.float32)
72
72
 
73
73
 
74
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75
+ def rescale_zero_terminal_snr(betas):
76
+ """
77
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78
+
79
+
80
+ Args:
81
+ betas (`torch.Tensor`):
82
+ the betas that the scheduler is being initialized with.
83
+
84
+ Returns:
85
+ `torch.Tensor`: rescaled betas with zero terminal SNR
86
+ """
87
+ # Convert betas to alphas_bar_sqrt
88
+ alphas = 1.0 - betas
89
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
90
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
91
+
92
+ # Store old values.
93
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
94
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
95
+
96
+ # Shift so the last timestep is zero.
97
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98
+
99
+ # Scale so the first timestep is back to the old value.
100
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
101
+
102
+ # Convert alphas_bar_sqrt to betas
103
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
104
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
105
+ alphas = torch.cat([alphas_bar[0:1], alphas])
106
+ betas = 1 - alphas
107
+
108
+ return betas
109
+
110
+
74
111
  class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
75
112
  """
76
113
  `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
@@ -127,6 +164,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
127
164
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
128
165
  steps_offset (`int`, defaults to 0):
129
166
  An offset added to the inference steps, as required by some model families.
167
+ final_sigmas_type (`str`, defaults to `"zero"`):
168
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
169
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
170
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
171
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
172
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
173
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
130
174
  """
131
175
 
132
176
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -153,6 +197,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
153
197
  use_karras_sigmas: Optional[bool] = False,
154
198
  timestep_spacing: str = "linspace",
155
199
  steps_offset: int = 0,
200
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
201
+ rescale_betas_zero_snr: bool = False,
156
202
  ):
157
203
  if trained_betas is not None:
158
204
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -165,10 +211,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
165
211
  # Glide cosine schedule
166
212
  self.betas = betas_for_alpha_bar(num_train_timesteps)
167
213
  else:
168
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
214
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
215
+
216
+ if rescale_betas_zero_snr:
217
+ self.betas = rescale_zero_terminal_snr(self.betas)
169
218
 
170
219
  self.alphas = 1.0 - self.betas
171
220
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
221
+
222
+ if rescale_betas_zero_snr:
223
+ # Close to 0 without being 0 so first sigma is not inf
224
+ # FP16 smallest positive subnormal works well here
225
+ self.alphas_cumprod[-1] = 2**-24
226
+
172
227
  # Currently we only support VP-type noise schedule
173
228
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
174
229
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
@@ -182,7 +237,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
182
237
  if solver_type in ["midpoint", "heun", "logrho"]:
183
238
  self.register_to_config(solver_type="bh2")
184
239
  else:
185
- raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
240
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
186
241
 
187
242
  self.predict_x0 = predict_x0
188
243
  # setable values
@@ -202,7 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
202
257
  @property
203
258
  def step_index(self):
204
259
  """
205
- The index counter for current timestep. It will increae 1 after each scheduler step.
260
+ The index counter for current timestep. It will increase 1 after each scheduler step.
206
261
  """
207
262
  return self._step_index
208
263
 
@@ -265,10 +320,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
265
320
  sigmas = np.flip(sigmas).copy()
266
321
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
267
322
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
268
- sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
323
+ if self.config.final_sigmas_type == "sigma_min":
324
+ sigma_last = sigmas[-1]
325
+ elif self.config.final_sigmas_type == "zero":
326
+ sigma_last = 0
327
+ else:
328
+ raise ValueError(
329
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
330
+ )
331
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
269
332
  else:
270
333
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
271
- sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
334
+ if self.config.final_sigmas_type == "sigma_min":
335
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
336
+ elif self.config.final_sigmas_type == "zero":
337
+ sigma_last = 0
338
+ else:
339
+ raise ValueError(
340
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
341
+ )
272
342
  sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
273
343
 
274
344
  self.sigmas = torch.from_numpy(sigmas)
@@ -290,7 +360,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
290
360
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
291
361
 
292
362
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
293
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
363
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
294
364
  """
295
365
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
296
366
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -355,7 +425,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
355
425
  return alpha_t, sigma_t
356
426
 
357
427
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
358
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
428
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
359
429
  """Constructs the noise schedule of Karras et al. (2022)."""
360
430
 
361
431
  # Hack to make sure that other schedulers which copy this function don't break
@@ -382,24 +452,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
382
452
 
383
453
  def convert_model_output(
384
454
  self,
385
- model_output: torch.FloatTensor,
455
+ model_output: torch.Tensor,
386
456
  *args,
387
- sample: torch.FloatTensor = None,
457
+ sample: torch.Tensor = None,
388
458
  **kwargs,
389
- ) -> torch.FloatTensor:
459
+ ) -> torch.Tensor:
390
460
  r"""
391
461
  Convert the model output to the corresponding type the UniPC algorithm needs.
392
462
 
393
463
  Args:
394
- model_output (`torch.FloatTensor`):
464
+ model_output (`torch.Tensor`):
395
465
  The direct output from the learned diffusion model.
396
466
  timestep (`int`):
397
467
  The current discrete timestep in the diffusion chain.
398
- sample (`torch.FloatTensor`):
468
+ sample (`torch.Tensor`):
399
469
  A current instance of a sample created by the diffusion process.
400
470
 
401
471
  Returns:
402
- `torch.FloatTensor`:
472
+ `torch.Tensor`:
403
473
  The converted model output.
404
474
  """
405
475
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -452,27 +522,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
452
522
 
453
523
  def multistep_uni_p_bh_update(
454
524
  self,
455
- model_output: torch.FloatTensor,
525
+ model_output: torch.Tensor,
456
526
  *args,
457
- sample: torch.FloatTensor = None,
527
+ sample: torch.Tensor = None,
458
528
  order: int = None,
459
529
  **kwargs,
460
- ) -> torch.FloatTensor:
530
+ ) -> torch.Tensor:
461
531
  """
462
532
  One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
463
533
 
464
534
  Args:
465
- model_output (`torch.FloatTensor`):
535
+ model_output (`torch.Tensor`):
466
536
  The direct output from the learned diffusion model at the current timestep.
467
537
  prev_timestep (`int`):
468
538
  The previous discrete timestep in the diffusion chain.
469
- sample (`torch.FloatTensor`):
539
+ sample (`torch.Tensor`):
470
540
  A current instance of a sample created by the diffusion process.
471
541
  order (`int`):
472
542
  The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
473
543
 
474
544
  Returns:
475
- `torch.FloatTensor`:
545
+ `torch.Tensor`:
476
546
  The sample tensor at the previous timestep.
477
547
  """
478
548
  prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
@@ -557,7 +627,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
557
627
  if order == 2:
558
628
  rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
559
629
  else:
560
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
630
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
561
631
  else:
562
632
  D1s = None
563
633
 
@@ -581,30 +651,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
581
651
 
582
652
  def multistep_uni_c_bh_update(
583
653
  self,
584
- this_model_output: torch.FloatTensor,
654
+ this_model_output: torch.Tensor,
585
655
  *args,
586
- last_sample: torch.FloatTensor = None,
587
- this_sample: torch.FloatTensor = None,
656
+ last_sample: torch.Tensor = None,
657
+ this_sample: torch.Tensor = None,
588
658
  order: int = None,
589
659
  **kwargs,
590
- ) -> torch.FloatTensor:
660
+ ) -> torch.Tensor:
591
661
  """
592
662
  One step for the UniC (B(h) version).
593
663
 
594
664
  Args:
595
- this_model_output (`torch.FloatTensor`):
665
+ this_model_output (`torch.Tensor`):
596
666
  The model outputs at `x_t`.
597
667
  this_timestep (`int`):
598
668
  The current timestep `t`.
599
- last_sample (`torch.FloatTensor`):
669
+ last_sample (`torch.Tensor`):
600
670
  The generated sample before the last predictor `x_{t-1}`.
601
- this_sample (`torch.FloatTensor`):
671
+ this_sample (`torch.Tensor`):
602
672
  The generated sample after the last predictor `x_{t}`.
603
673
  order (`int`):
604
674
  The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
605
675
 
606
676
  Returns:
607
- `torch.FloatTensor`:
677
+ `torch.Tensor`:
608
678
  The corrected sample tensor at the current timestep.
609
679
  """
610
680
  this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
@@ -695,7 +765,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
695
765
  if order == 1:
696
766
  rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
697
767
  else:
698
- rhos_c = torch.linalg.solve(R, b)
768
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
699
769
 
700
770
  if self.predict_x0:
701
771
  x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
@@ -751,9 +821,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
751
821
 
752
822
  def step(
753
823
  self,
754
- model_output: torch.FloatTensor,
824
+ model_output: torch.Tensor,
755
825
  timestep: int,
756
- sample: torch.FloatTensor,
826
+ sample: torch.Tensor,
757
827
  return_dict: bool = True,
758
828
  ) -> Union[SchedulerOutput, Tuple]:
759
829
  """
@@ -761,11 +831,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
761
831
  the multistep UniPC.
762
832
 
763
833
  Args:
764
- model_output (`torch.FloatTensor`):
834
+ model_output (`torch.Tensor`):
765
835
  The direct output from learned diffusion model.
766
836
  timestep (`int`):
767
837
  The current discrete timestep in the diffusion chain.
768
- sample (`torch.FloatTensor`):
838
+ sample (`torch.Tensor`):
769
839
  A current instance of a sample created by the diffusion process.
770
840
  return_dict (`bool`):
771
841
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
@@ -830,17 +900,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
830
900
 
831
901
  return SchedulerOutput(prev_sample=prev_sample)
832
902
 
833
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
903
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
834
904
  """
835
905
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
836
906
  current timestep.
837
907
 
838
908
  Args:
839
- sample (`torch.FloatTensor`):
909
+ sample (`torch.Tensor`):
840
910
  The input sample.
841
911
 
842
912
  Returns:
843
- `torch.FloatTensor`:
913
+ `torch.Tensor`:
844
914
  A scaled input sample.
845
915
  """
846
916
  return sample
@@ -848,10 +918,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
848
918
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
849
919
  def add_noise(
850
920
  self,
851
- original_samples: torch.FloatTensor,
852
- noise: torch.FloatTensor,
921
+ original_samples: torch.Tensor,
922
+ noise: torch.Tensor,
853
923
  timesteps: torch.IntTensor,
854
- ) -> torch.FloatTensor:
924
+ ) -> torch.Tensor:
855
925
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
856
926
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
857
927
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -862,10 +932,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
862
932
  schedule_timesteps = self.timesteps.to(original_samples.device)
863
933
  timesteps = timesteps.to(original_samples.device)
864
934
 
865
- # begin_index is None when the scheduler is used for training
935
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
866
936
  if self.begin_index is None:
867
937
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
938
+ elif self.step_index is not None:
939
+ # add_noise is called after first denoising step (for inpainting)
940
+ step_indices = [self.step_index] * timesteps.shape[0]
868
941
  else:
942
+ # add noise is called before first denoising step to create initial latent(img2img)
869
943
  step_indices = [self.begin_index] * timesteps.shape[0]
870
944
 
871
945
  sigma = sigmas[step_indices].flatten()
@@ -48,18 +48,27 @@ class KarrasDiffusionSchedulers(Enum):
48
48
  EDMEulerScheduler = 15
49
49
 
50
50
 
51
+ AysSchedules = {
52
+ "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
53
+ "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
54
+ "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
55
+ "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
56
+ "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
57
+ }
58
+
59
+
51
60
  @dataclass
52
61
  class SchedulerOutput(BaseOutput):
53
62
  """
54
63
  Base class for the output of a scheduler's `step` function.
55
64
 
56
65
  Args:
57
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
66
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
58
67
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
59
68
  denoising loop.
60
69
  """
61
70
 
62
- prev_sample: torch.FloatTensor
71
+ prev_sample: torch.Tensor
63
72
 
64
73
 
65
74
  class SchedulerMixin(PushToHubMixin):
@@ -112,9 +121,9 @@ class SchedulerMixin(PushToHubMixin):
112
121
  force_download (`bool`, *optional*, defaults to `False`):
113
122
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
114
123
  cached versions if they exist.
115
- resume_download (`bool`, *optional*, defaults to `False`):
116
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
117
- incompletely downloaded files are deleted.
124
+ resume_download:
125
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
126
+ of Diffusers.
118
127
  proxies (`Dict[str, str]`, *optional*):
119
128
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
120
129
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -102,9 +102,9 @@ class FlaxSchedulerMixin(PushToHubMixin):
102
102
  force_download (`bool`, *optional*, defaults to `False`):
103
103
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
104
104
  cached versions if they exist.
105
- resume_download (`bool`, *optional*, defaults to `False`):
106
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
107
- file exists.
105
+ resume_download:
106
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
107
+ of Diffusers.
108
108
  proxies (`Dict[str, str]`, *optional*):
109
109
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
110
110
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -38,7 +38,7 @@ class VQDiffusionSchedulerOutput(BaseOutput):
38
38
  prev_sample: torch.LongTensor
39
39
 
40
40
 
41
- def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
41
+ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor:
42
42
  """
43
43
  Convert batch of vector of class indices into batch of log onehot vectors
44
44
 
@@ -50,7 +50,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
50
50
  number of classes to be used for the onehot vectors
51
51
 
52
52
  Returns:
53
- `torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
53
+ `torch.Tensor` of shape `(batch size, num classes, vector length)`:
54
54
  Log onehot vectors
55
55
  """
56
56
  x_onehot = F.one_hot(x, num_classes)
@@ -59,7 +59,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
59
59
  return log_x
60
60
 
61
61
 
62
- def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
62
+ def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor:
63
63
  """
64
64
  Apply gumbel noise to `logits`
65
65
  """
@@ -199,7 +199,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
199
199
 
200
200
  def step(
201
201
  self,
202
- model_output: torch.FloatTensor,
202
+ model_output: torch.Tensor,
203
203
  timestep: torch.long,
204
204
  sample: torch.LongTensor,
205
205
  generator: Optional[torch.Generator] = None,
@@ -210,7 +210,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
210
210
  [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer.
211
211
 
212
212
  Args:
213
- log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
213
+ log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
214
214
  The log probabilities for the predicted classes of the initial latent pixels. Does not include a
215
215
  prediction for the masked class as the initial unnoised image cannot be masked.
216
216
  t (`torch.long`):
@@ -251,7 +251,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
251
251
  ```
252
252
 
253
253
  Args:
254
- log_p_x_0 (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
254
+ log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
255
255
  The log probabilities for the predicted classes of the initial latent pixels. Does not include a
256
256
  prediction for the masked class as the initial unnoised image cannot be masked.
257
257
  x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
@@ -260,7 +260,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
260
260
  The timestep that determines which transition matrix is used.
261
261
 
262
262
  Returns:
263
- `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
263
+ `torch.Tensor` of shape `(batch size, num classes, num latent pixels)`:
264
264
  The log probabilities for the predicted classes of the image at timestep `t-1`.
265
265
  """
266
266
  log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
@@ -354,7 +354,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
354
354
  return log_p_x_t_min_1
355
355
 
356
356
  def log_Q_t_transitioning_to_known_class(
357
- self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
357
+ self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool
358
358
  ):
359
359
  """
360
360
  Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
@@ -365,14 +365,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
365
365
  The timestep that determines which transition matrix is used.
366
366
  x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
367
367
  The classes of each latent pixel at time `t`.
368
- log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
368
+ log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`):
369
369
  The log one-hot vectors of `x_t`.
370
370
  cumulative (`bool`):
371
371
  If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is
372
372
  `True`, the cumulative transition matrix `0`->`t` is used.
373
373
 
374
374
  Returns:
375
- `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
375
+ `torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`:
376
376
  Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
377
377
  transition matrix.
378
378
 
@@ -1,12 +1,13 @@
1
1
  import contextlib
2
2
  import copy
3
3
  import random
4
- from typing import Any, Dict, Iterable, List, Optional, Union
4
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
 
9
9
  from .models import UNet2DConditionModel
10
+ from .schedulers import SchedulerMixin
10
11
  from .utils import (
11
12
  convert_state_dict_to_diffusers,
12
13
  convert_state_dict_to_peft,
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
117
118
  return interpolation_mode
118
119
 
119
120
 
121
+ def compute_dream_and_update_latents(
122
+ unet: UNet2DConditionModel,
123
+ noise_scheduler: SchedulerMixin,
124
+ timesteps: torch.Tensor,
125
+ noise: torch.Tensor,
126
+ noisy_latents: torch.Tensor,
127
+ target: torch.Tensor,
128
+ encoder_hidden_states: torch.Tensor,
129
+ dream_detail_preservation: float = 1.0,
130
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
131
+ """
132
+ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
133
+ DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
134
+ forward step without gradients.
135
+
136
+ Args:
137
+ `unet`: The state unet to use to make a prediction.
138
+ `noise_scheduler`: The noise scheduler used to add noise for the given timestep.
139
+ `timesteps`: The timesteps for the noise_scheduler to user.
140
+ `noise`: A tensor of noise in the shape of noisy_latents.
141
+ `noisy_latents`: Previously noise latents from the training loop.
142
+ `target`: The ground-truth tensor to predict after eps is removed.
143
+ `encoder_hidden_states`: Text embeddings from the text model.
144
+ `dream_detail_preservation`: A float value that indicates detail preservation level.
145
+ See reference.
146
+
147
+ Returns:
148
+ `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
149
+ """
150
+ alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
151
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
152
+
153
+ # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
154
+ dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
155
+
156
+ pred = None
157
+ with torch.no_grad():
158
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
159
+
160
+ noisy_latents, target = (None, None)
161
+ if noise_scheduler.config.prediction_type == "epsilon":
162
+ predicted_noise = pred
163
+ delta_noise = (noise - predicted_noise).detach()
164
+ delta_noise.mul_(dream_lambda)
165
+ noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
166
+ target = target.add(delta_noise)
167
+ elif noise_scheduler.config.prediction_type == "v_prediction":
168
+ raise NotImplementedError("DREAM has not been implemented for v-prediction")
169
+ else:
170
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
171
+
172
+ return noisy_latents, target
173
+
174
+
120
175
  def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
121
176
  r"""
122
177
  Returns:
@@ -58,19 +58,26 @@ from .import_utils import (
58
58
  get_objects_from_module,
59
59
  is_accelerate_available,
60
60
  is_accelerate_version,
61
+ is_bitsandbytes_available,
61
62
  is_bs4_available,
62
63
  is_flax_available,
63
64
  is_ftfy_available,
65
+ is_google_colab,
64
66
  is_inflect_available,
65
67
  is_invisible_watermark_available,
66
68
  is_k_diffusion_available,
67
69
  is_k_diffusion_version,
68
70
  is_librosa_available,
71
+ is_matplotlib_available,
69
72
  is_note_seq_available,
73
+ is_notebook,
70
74
  is_onnx_available,
71
75
  is_peft_available,
76
+ is_peft_version,
77
+ is_safetensors_available,
72
78
  is_scipy_available,
73
79
  is_tensorboard_available,
80
+ is_timm_available,
74
81
  is_torch_available,
75
82
  is_torch_npu_available,
76
83
  is_torch_version,
@@ -14,6 +14,7 @@
14
14
  """
15
15
  Doc utilities: Utilities related to documentation
16
16
  """
17
+
17
18
  import re
18
19
 
19
20
 
@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
92
92
  requires_backends(cls, ["torch"])
93
93
 
94
94
 
95
+ class ControlNetXSAdapter(metaclass=DummyObject):
96
+ _backends = ["torch"]
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ requires_backends(self, ["torch"])
100
+
101
+ @classmethod
102
+ def from_config(cls, *args, **kwargs):
103
+ requires_backends(cls, ["torch"])
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, *args, **kwargs):
107
+ requires_backends(cls, ["torch"])
108
+
109
+
95
110
  class I2VGenXLUNet(metaclass=DummyObject):
96
111
  _backends = ["torch"]
97
112
 
@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
287
302
  requires_backends(cls, ["torch"])
288
303
 
289
304
 
305
+ class UNetControlNetXSModel(metaclass=DummyObject):
306
+ _backends = ["torch"]
307
+
308
+ def __init__(self, *args, **kwargs):
309
+ requires_backends(self, ["torch"])
310
+
311
+ @classmethod
312
+ def from_config(cls, *args, **kwargs):
313
+ requires_backends(cls, ["torch"])
314
+
315
+ @classmethod
316
+ def from_pretrained(cls, *args, **kwargs):
317
+ requires_backends(cls, ["torch"])
318
+
319
+
290
320
  class UNetMotionModel(metaclass=DummyObject):
291
321
  _backends = ["torch"]
292
322