diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,686 @@
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ..configuration_utils import ConfigMixin, register_to_config
26
+ from ..schedulers.scheduling_utils import SchedulerMixin
27
+ from ..utils import BaseOutput, logging
28
+ from ..utils.torch_utils import randn_tensor
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class TCDSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted noised sample `(x_{s})` based on the model output from the current timestep.
45
+ """
46
+
47
+ prev_sample: torch.FloatTensor
48
+ pred_noised_sample: Optional[torch.FloatTensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
98
+ """
99
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
100
+
101
+
102
+ Args:
103
+ betas (`torch.FloatTensor`):
104
+ the betas that the scheduler is being initialized with.
105
+
106
+ Returns:
107
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
108
+ """
109
+ # Convert betas to alphas_bar_sqrt
110
+ alphas = 1.0 - betas
111
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
112
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
113
+
114
+ # Store old values.
115
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
116
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
117
+
118
+ # Shift so the last timestep is zero.
119
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
120
+
121
+ # Scale so the first timestep is back to the old value.
122
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
123
+
124
+ # Convert alphas_bar_sqrt to betas
125
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
126
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
127
+ alphas = torch.cat([alphas_bar[0:1], alphas])
128
+ betas = 1 - alphas
129
+
130
+ return betas
131
+
132
+
133
+ class TCDScheduler(SchedulerMixin, ConfigMixin):
134
+ """
135
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
136
+ extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
137
+
138
+ This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
139
+
140
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
141
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
142
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
143
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
144
+
145
+ Args:
146
+ num_train_timesteps (`int`, defaults to 1000):
147
+ The number of diffusion steps to train the model.
148
+ beta_start (`float`, defaults to 0.0001):
149
+ The starting `beta` value of inference.
150
+ beta_end (`float`, defaults to 0.02):
151
+ The final `beta` value.
152
+ beta_schedule (`str`, defaults to `"linear"`):
153
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
154
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
155
+ trained_betas (`np.ndarray`, *optional*):
156
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
157
+ original_inference_steps (`int`, *optional*, defaults to 50):
158
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
159
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
160
+ clip_sample (`bool`, defaults to `True`):
161
+ Clip the predicted sample for numerical stability.
162
+ clip_sample_range (`float`, defaults to 1.0):
163
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
164
+ set_alpha_to_one (`bool`, defaults to `True`):
165
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
166
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
167
+ otherwise it uses the alpha value at step 0.
168
+ steps_offset (`int`, defaults to 0):
169
+ An offset added to the inference steps, as required by some model families.
170
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
171
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
172
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
173
+ Video](https://imagen.research.google/video/paper.pdf) paper).
174
+ thresholding (`bool`, defaults to `False`):
175
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
176
+ as Stable Diffusion.
177
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
178
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
179
+ sample_max_value (`float`, defaults to 1.0):
180
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
181
+ timestep_spacing (`str`, defaults to `"leading"`):
182
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
184
+ timestep_scaling (`float`, defaults to 10.0):
185
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
186
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
187
+ error at the default of `10.0` is already pretty small).
188
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
189
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
190
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
191
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
192
+ """
193
+
194
+ order = 1
195
+
196
+ @register_to_config
197
+ def __init__(
198
+ self,
199
+ num_train_timesteps: int = 1000,
200
+ beta_start: float = 0.00085,
201
+ beta_end: float = 0.012,
202
+ beta_schedule: str = "scaled_linear",
203
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
204
+ original_inference_steps: int = 50,
205
+ clip_sample: bool = False,
206
+ clip_sample_range: float = 1.0,
207
+ set_alpha_to_one: bool = True,
208
+ steps_offset: int = 0,
209
+ prediction_type: str = "epsilon",
210
+ thresholding: bool = False,
211
+ dynamic_thresholding_ratio: float = 0.995,
212
+ sample_max_value: float = 1.0,
213
+ timestep_spacing: str = "leading",
214
+ timestep_scaling: float = 10.0,
215
+ rescale_betas_zero_snr: bool = False,
216
+ ):
217
+ if trained_betas is not None:
218
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
219
+ elif beta_schedule == "linear":
220
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
221
+ elif beta_schedule == "scaled_linear":
222
+ # this schedule is very specific to the latent diffusion model.
223
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
224
+ elif beta_schedule == "squaredcos_cap_v2":
225
+ # Glide cosine schedule
226
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
227
+ else:
228
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
229
+
230
+ # Rescale for zero SNR
231
+ if rescale_betas_zero_snr:
232
+ self.betas = rescale_zero_terminal_snr(self.betas)
233
+
234
+ self.alphas = 1.0 - self.betas
235
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
236
+
237
+ # At every step in ddim, we are looking into the previous alphas_cumprod
238
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
239
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
240
+ # whether we use the final alpha of the "non-previous" one.
241
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
242
+
243
+ # standard deviation of the initial noise distribution
244
+ self.init_noise_sigma = 1.0
245
+
246
+ # setable values
247
+ self.num_inference_steps = None
248
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
249
+ self.custom_timesteps = False
250
+
251
+ self._step_index = None
252
+ self._begin_index = None
253
+
254
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
255
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
256
+ if schedule_timesteps is None:
257
+ schedule_timesteps = self.timesteps
258
+
259
+ indices = (schedule_timesteps == timestep).nonzero()
260
+
261
+ # The sigma index that is taken for the **very** first `step`
262
+ # is always the second index (or the last index if there is only 1)
263
+ # This way we can ensure we don't accidentally skip a sigma in
264
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
265
+ pos = 1 if len(indices) > 1 else 0
266
+
267
+ return indices[pos].item()
268
+
269
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
270
+ def _init_step_index(self, timestep):
271
+ if self.begin_index is None:
272
+ if isinstance(timestep, torch.Tensor):
273
+ timestep = timestep.to(self.timesteps.device)
274
+ self._step_index = self.index_for_timestep(timestep)
275
+ else:
276
+ self._step_index = self._begin_index
277
+
278
+ @property
279
+ def step_index(self):
280
+ return self._step_index
281
+
282
+ @property
283
+ def begin_index(self):
284
+ """
285
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
286
+ """
287
+ return self._begin_index
288
+
289
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
290
+ def set_begin_index(self, begin_index: int = 0):
291
+ """
292
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
293
+
294
+ Args:
295
+ begin_index (`int`):
296
+ The begin index for the scheduler.
297
+ """
298
+ self._begin_index = begin_index
299
+
300
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
301
+ """
302
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
303
+ current timestep.
304
+
305
+ Args:
306
+ sample (`torch.FloatTensor`):
307
+ The input sample.
308
+ timestep (`int`, *optional*):
309
+ The current timestep in the diffusion chain.
310
+ Returns:
311
+ `torch.FloatTensor`:
312
+ A scaled input sample.
313
+ """
314
+ return sample
315
+
316
+ # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
317
+ def _get_variance(self, timestep, prev_timestep):
318
+ alpha_prod_t = self.alphas_cumprod[timestep]
319
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
320
+ beta_prod_t = 1 - alpha_prod_t
321
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
322
+
323
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
324
+
325
+ return variance
326
+
327
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
328
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
329
+ """
330
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
331
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
332
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
333
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
334
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
335
+
336
+ https://arxiv.org/abs/2205.11487
337
+ """
338
+ dtype = sample.dtype
339
+ batch_size, channels, *remaining_dims = sample.shape
340
+
341
+ if dtype not in (torch.float32, torch.float64):
342
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
343
+
344
+ # Flatten sample for doing quantile calculation along each image
345
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
346
+
347
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
348
+
349
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
350
+ s = torch.clamp(
351
+ s, min=1, max=self.config.sample_max_value
352
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
353
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
354
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
355
+
356
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
357
+ sample = sample.to(dtype)
358
+
359
+ return sample
360
+
361
+ def set_timesteps(
362
+ self,
363
+ num_inference_steps: Optional[int] = None,
364
+ device: Union[str, torch.device] = None,
365
+ original_inference_steps: Optional[int] = None,
366
+ timesteps: Optional[List[int]] = None,
367
+ strength: int = 1.0,
368
+ ):
369
+ """
370
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
371
+
372
+ Args:
373
+ num_inference_steps (`int`, *optional*):
374
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
375
+ `timesteps` must be `None`.
376
+ device (`str` or `torch.device`, *optional*):
377
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
378
+ original_inference_steps (`int`, *optional*):
379
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
380
+ schedule (which is different from the standard `diffusers` implementation). We will then take
381
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
382
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
383
+ timesteps (`List[int]`, *optional*):
384
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
385
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
386
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
387
+ """
388
+ # 0. Check inputs
389
+ if num_inference_steps is None and timesteps is None:
390
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
391
+
392
+ if num_inference_steps is not None and timesteps is not None:
393
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
394
+
395
+ # 1. Calculate the TCD original training/distillation timestep schedule.
396
+ original_steps = (
397
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
398
+ )
399
+
400
+ if original_inference_steps is None:
401
+ # default option, timesteps align with discrete inference steps
402
+ if original_steps > self.config.num_train_timesteps:
403
+ raise ValueError(
404
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
405
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
406
+ f" maximal {self.config.num_train_timesteps} timesteps."
407
+ )
408
+ # TCD Timesteps Setting
409
+ # The skipping step parameter k from the paper.
410
+ k = self.config.num_train_timesteps // original_steps
411
+ # TCD Training/Distillation Steps Schedule
412
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
413
+ else:
414
+ # customised option, sampled timesteps can be any arbitrary value
415
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength))))
416
+
417
+ # 2. Calculate the TCD inference timestep schedule.
418
+ if timesteps is not None:
419
+ # 2.1 Handle custom timestep schedules.
420
+ train_timesteps = set(tcd_origin_timesteps)
421
+ non_train_timesteps = []
422
+ for i in range(1, len(timesteps)):
423
+ if timesteps[i] >= timesteps[i - 1]:
424
+ raise ValueError("`custom_timesteps` must be in descending order.")
425
+
426
+ if timesteps[i] not in train_timesteps:
427
+ non_train_timesteps.append(timesteps[i])
428
+
429
+ if timesteps[0] >= self.config.num_train_timesteps:
430
+ raise ValueError(
431
+ f"`timesteps` must start before `self.config.train_timesteps`:"
432
+ f" {self.config.num_train_timesteps}."
433
+ )
434
+
435
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
436
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
437
+ logger.warning(
438
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
439
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
440
+ f" unexpected results when using this timestep schedule."
441
+ )
442
+
443
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
444
+ if non_train_timesteps:
445
+ logger.warning(
446
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
447
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
448
+ f" when using this timestep schedule."
449
+ )
450
+
451
+ # Raise warning if custom timestep schedule is longer than original_steps
452
+ if original_steps is not None:
453
+ if len(timesteps) > original_steps:
454
+ logger.warning(
455
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
456
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
457
+ f" unexpected results when using this timestep schedule."
458
+ )
459
+ else:
460
+ if len(timesteps) > self.config.num_train_timesteps:
461
+ logger.warning(
462
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
463
+ f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some"
464
+ f" unexpected results when using this timestep schedule."
465
+ )
466
+
467
+ timesteps = np.array(timesteps, dtype=np.int64)
468
+ self.num_inference_steps = len(timesteps)
469
+ self.custom_timesteps = True
470
+
471
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
472
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
473
+ t_start = max(self.num_inference_steps - init_timestep, 0)
474
+ timesteps = timesteps[t_start * self.order :]
475
+ # TODO: also reset self.num_inference_steps?
476
+ else:
477
+ # 2.2 Create the "standard" TCD inference timestep schedule.
478
+ if num_inference_steps > self.config.num_train_timesteps:
479
+ raise ValueError(
480
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
481
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
482
+ f" maximal {self.config.num_train_timesteps} timesteps."
483
+ )
484
+
485
+ if original_steps is not None:
486
+ skipping_step = len(tcd_origin_timesteps) // num_inference_steps
487
+
488
+ if skipping_step < 1:
489
+ raise ValueError(
490
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
491
+ )
492
+
493
+ self.num_inference_steps = num_inference_steps
494
+
495
+ if original_steps is not None:
496
+ if num_inference_steps > original_steps:
497
+ raise ValueError(
498
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
499
+ f" {original_steps} because the final timestep schedule will be a subset of the"
500
+ f" `original_inference_steps`-sized initial timestep schedule."
501
+ )
502
+ else:
503
+ if num_inference_steps > self.config.num_train_timesteps:
504
+ raise ValueError(
505
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:"
506
+ f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the"
507
+ f" `num_train_timesteps`-sized initial timestep schedule."
508
+ )
509
+
510
+ # TCD Inference Steps Schedule
511
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
512
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
513
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
514
+ inference_indices = np.floor(inference_indices).astype(np.int64)
515
+ timesteps = tcd_origin_timesteps[inference_indices]
516
+
517
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
518
+
519
+ self._step_index = None
520
+ self._begin_index = None
521
+
522
+ def step(
523
+ self,
524
+ model_output: torch.FloatTensor,
525
+ timestep: int,
526
+ sample: torch.FloatTensor,
527
+ eta: float = 0.3,
528
+ generator: Optional[torch.Generator] = None,
529
+ return_dict: bool = True,
530
+ ) -> Union[TCDSchedulerOutput, Tuple]:
531
+ """
532
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
533
+ process from the learned model outputs (most often the predicted noise).
534
+
535
+ Args:
536
+ model_output (`torch.FloatTensor`):
537
+ The direct output from learned diffusion model.
538
+ timestep (`int`):
539
+ The current discrete timestep in the diffusion chain.
540
+ sample (`torch.FloatTensor`):
541
+ A current instance of a sample created by the diffusion process.
542
+ eta (`float`):
543
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
544
+ When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
545
+ generator (`torch.Generator`, *optional*):
546
+ A random number generator.
547
+ return_dict (`bool`, *optional*, defaults to `True`):
548
+ Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`.
549
+ Returns:
550
+ [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`:
551
+ If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a
552
+ tuple is returned where the first element is the sample tensor.
553
+ """
554
+ if self.num_inference_steps is None:
555
+ raise ValueError(
556
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
557
+ )
558
+
559
+ if self.step_index is None:
560
+ self._init_step_index(timestep)
561
+
562
+ assert 0 <= eta <= 1.0, "gamma must be less than or equal to 1.0"
563
+
564
+ # 1. get previous step value
565
+ prev_step_index = self.step_index + 1
566
+ if prev_step_index < len(self.timesteps):
567
+ prev_timestep = self.timesteps[prev_step_index]
568
+ else:
569
+ prev_timestep = torch.tensor(0)
570
+
571
+ timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long)
572
+
573
+ # 2. compute alphas, betas
574
+ alpha_prod_t = self.alphas_cumprod[timestep]
575
+ beta_prod_t = 1 - alpha_prod_t
576
+
577
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
578
+
579
+ alpha_prod_s = self.alphas_cumprod[timestep_s]
580
+ beta_prod_s = 1 - alpha_prod_s
581
+
582
+ # 3. Compute the predicted noised sample x_s based on the model parameterization
583
+ if self.config.prediction_type == "epsilon": # noise-prediction
584
+ pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
585
+ pred_epsilon = model_output
586
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
587
+ elif self.config.prediction_type == "sample": # x-prediction
588
+ pred_original_sample = model_output
589
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
590
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
591
+ elif self.config.prediction_type == "v_prediction": # v-prediction
592
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
593
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
594
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
595
+ else:
596
+ raise ValueError(
597
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
598
+ " `v_prediction` for `TCDScheduler`."
599
+ )
600
+
601
+ # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference
602
+ # Noise is not used on the final timestep of the timestep schedule.
603
+ # This also means that noise is not used for one-step sampling.
604
+ # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step.
605
+ # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
606
+ if eta > 0:
607
+ if self.step_index != self.num_inference_steps - 1:
608
+ noise = randn_tensor(
609
+ model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype
610
+ )
611
+ prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (
612
+ 1 - alpha_prod_t_prev / alpha_prod_s
613
+ ).sqrt() * noise
614
+ else:
615
+ prev_sample = pred_noised_sample
616
+ else:
617
+ prev_sample = pred_noised_sample
618
+
619
+ # upon completion increase step index by one
620
+ self._step_index += 1
621
+
622
+ if not return_dict:
623
+ return (prev_sample, pred_noised_sample)
624
+
625
+ return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
626
+
627
+ def add_noise(
628
+ self,
629
+ original_samples: torch.FloatTensor,
630
+ noise: torch.FloatTensor,
631
+ timesteps: torch.IntTensor,
632
+ ) -> torch.FloatTensor:
633
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
634
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
635
+ timesteps = timesteps.to(original_samples.device)
636
+
637
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
638
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
639
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
640
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
641
+
642
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
643
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
644
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
645
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
646
+
647
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
648
+ return noisy_samples
649
+
650
+ def get_velocity(
651
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
652
+ ) -> torch.FloatTensor:
653
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
654
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
655
+ timesteps = timesteps.to(sample.device)
656
+
657
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
658
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
659
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
660
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
661
+
662
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
663
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
664
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
665
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
666
+
667
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
668
+ return velocity
669
+
670
+ def __len__(self):
671
+ return self.config.num_train_timesteps
672
+
673
+ def previous_timestep(self, timestep):
674
+ if self.custom_timesteps:
675
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
676
+ if index == self.timesteps.shape[0] - 1:
677
+ prev_t = torch.tensor(-1)
678
+ else:
679
+ prev_t = self.timesteps[index + 1]
680
+ else:
681
+ num_inference_steps = (
682
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
683
+ )
684
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
685
+
686
+ return prev_t