diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import torch
22
22
 
23
23
  from ..configuration_utils import ConfigMixin, register_to_config
24
24
  from ..utils import deprecate, logging
25
+ from ..utils.torch_utils import randn_tensor
25
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
26
27
 
27
28
 
@@ -108,11 +109,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
108
109
  The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
109
110
  `algorithm_type="dpmsolver++"`.
110
111
  algorithm_type (`str`, defaults to `dpmsolver++`):
111
- Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
112
- algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
113
- implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
114
- recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
115
- Stable Diffusion.
112
+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
113
+ type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
114
+ `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
115
+ paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
116
+ sampling like in Stable Diffusion.
116
117
  solver_type (`str`, defaults to `midpoint`):
117
118
  Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
118
119
  sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -186,7 +187,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
186
187
  self.init_noise_sigma = 1.0
187
188
 
188
189
  # settings for DPM-Solver
189
- if algorithm_type not in ["dpmsolver", "dpmsolver++"]:
190
+ if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver++"]:
190
191
  if algorithm_type == "deis":
191
192
  self.register_to_config(algorithm_type="dpmsolver++")
192
193
  else:
@@ -197,7 +198,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
197
198
  else:
198
199
  raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
199
200
 
200
- if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
201
+ if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
201
202
  raise ValueError(
202
203
  f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
203
204
  )
@@ -493,10 +494,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
493
494
  "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
494
495
  )
495
496
  # DPM-Solver++ needs to solve an integral of the data prediction model.
496
- if self.config.algorithm_type == "dpmsolver++":
497
+ if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
497
498
  if self.config.prediction_type == "epsilon":
498
499
  # DPM-Solver and DPM-Solver++ only need the "mean" output.
499
- if self.config.variance_type in ["learned_range"]:
500
+ if self.config.variance_type in ["learned", "learned_range"]:
500
501
  model_output = model_output[:, :3]
501
502
  sigma = self.sigmas[self.step_index]
502
503
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
@@ -517,34 +518,43 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
517
518
  x0_pred = self._threshold_sample(x0_pred)
518
519
 
519
520
  return x0_pred
521
+
520
522
  # DPM-Solver needs to solve an integral of the noise prediction model.
521
523
  elif self.config.algorithm_type == "dpmsolver":
522
524
  if self.config.prediction_type == "epsilon":
523
525
  # DPM-Solver and DPM-Solver++ only need the "mean" output.
524
- if self.config.variance_type in ["learned_range"]:
525
- model_output = model_output[:, :3]
526
- return model_output
526
+ if self.config.variance_type in ["learned", "learned_range"]:
527
+ epsilon = model_output[:, :3]
528
+ else:
529
+ epsilon = model_output
527
530
  elif self.config.prediction_type == "sample":
528
531
  sigma = self.sigmas[self.step_index]
529
532
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
530
533
  epsilon = (sample - alpha_t * model_output) / sigma_t
531
- return epsilon
532
534
  elif self.config.prediction_type == "v_prediction":
533
535
  sigma = self.sigmas[self.step_index]
534
536
  alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
535
537
  epsilon = alpha_t * model_output + sigma_t * sample
536
- return epsilon
537
538
  else:
538
539
  raise ValueError(
539
540
  f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
540
541
  " `v_prediction` for the DPMSolverSinglestepScheduler."
541
542
  )
542
543
 
544
+ if self.config.thresholding:
545
+ alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
546
+ x0_pred = (sample - sigma_t * epsilon) / alpha_t
547
+ x0_pred = self._threshold_sample(x0_pred)
548
+ epsilon = (sample - alpha_t * x0_pred) / sigma_t
549
+
550
+ return epsilon
551
+
543
552
  def dpm_solver_first_order_update(
544
553
  self,
545
554
  model_output: torch.Tensor,
546
555
  *args,
547
556
  sample: torch.Tensor = None,
557
+ noise: Optional[torch.Tensor] = None,
548
558
  **kwargs,
549
559
  ) -> torch.Tensor:
550
560
  """
@@ -594,6 +604,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
594
604
  x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
595
605
  elif self.config.algorithm_type == "dpmsolver":
596
606
  x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
607
+ elif self.config.algorithm_type == "sde-dpmsolver++":
608
+ assert noise is not None
609
+ x_t = (
610
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
611
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
612
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
613
+ )
597
614
  return x_t
598
615
 
599
616
  def singlestep_dpm_solver_second_order_update(
@@ -601,6 +618,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
601
618
  model_output_list: List[torch.Tensor],
602
619
  *args,
603
620
  sample: torch.Tensor = None,
621
+ noise: Optional[torch.Tensor] = None,
604
622
  **kwargs,
605
623
  ) -> torch.Tensor:
606
624
  """
@@ -688,6 +706,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
688
706
  - (sigma_t * (torch.exp(h) - 1.0)) * D0
689
707
  - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1
690
708
  )
709
+ elif self.config.algorithm_type == "sde-dpmsolver++":
710
+ assert noise is not None
711
+ if self.config.solver_type == "midpoint":
712
+ x_t = (
713
+ (sigma_t / sigma_s1 * torch.exp(-h)) * sample
714
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
715
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
716
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
717
+ )
718
+ elif self.config.solver_type == "heun":
719
+ x_t = (
720
+ (sigma_t / sigma_s1 * torch.exp(-h)) * sample
721
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
722
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
723
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
724
+ )
691
725
  return x_t
692
726
 
693
727
  def singlestep_dpm_solver_third_order_update(
@@ -800,6 +834,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
800
834
  *args,
801
835
  sample: torch.Tensor = None,
802
836
  order: int = None,
837
+ noise: Optional[torch.Tensor] = None,
803
838
  **kwargs,
804
839
  ) -> torch.Tensor:
805
840
  """
@@ -848,9 +883,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
848
883
  )
849
884
 
850
885
  if order == 1:
851
- return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample)
886
+ return self.dpm_solver_first_order_update(model_output_list[-1], sample=sample, noise=noise)
852
887
  elif order == 2:
853
- return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample)
888
+ return self.singlestep_dpm_solver_second_order_update(model_output_list, sample=sample, noise=noise)
854
889
  elif order == 3:
855
890
  return self.singlestep_dpm_solver_third_order_update(model_output_list, sample=sample)
856
891
  else:
@@ -892,8 +927,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
892
927
  def step(
893
928
  self,
894
929
  model_output: torch.Tensor,
895
- timestep: int,
930
+ timestep: Union[int, torch.Tensor],
896
931
  sample: torch.Tensor,
932
+ generator=None,
897
933
  return_dict: bool = True,
898
934
  ) -> Union[SchedulerOutput, Tuple]:
899
935
  """
@@ -929,6 +965,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
929
965
  self.model_outputs[i] = self.model_outputs[i + 1]
930
966
  self.model_outputs[-1] = model_output
931
967
 
968
+ if self.config.algorithm_type == "sde-dpmsolver++":
969
+ noise = randn_tensor(
970
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
971
+ )
972
+ else:
973
+ noise = None
974
+
932
975
  order = self.order_list[self.step_index]
933
976
 
934
977
  # For img2img denoising might start with order>1 which is not possible
@@ -940,9 +983,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
940
983
  if order == 1:
941
984
  self.sample = sample
942
985
 
943
- prev_sample = self.singlestep_dpm_solver_update(self.model_outputs, sample=self.sample, order=order)
986
+ prev_sample = self.singlestep_dpm_solver_update(
987
+ self.model_outputs, sample=self.sample, order=order, noise=noise
988
+ )
944
989
 
945
- # upon completion increase step index by one
990
+ # upon completion increase step index by one, noise=noise
946
991
  self._step_index += 1
947
992
 
948
993
  if not return_dict:
@@ -134,7 +134,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
134
134
 
135
135
  self.timesteps = self.precondition_noise(sigmas)
136
136
 
137
- self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
137
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
138
138
 
139
139
  # setable values
140
140
  self.num_inference_steps = None
@@ -594,7 +594,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
594
594
  def step(
595
595
  self,
596
596
  model_output: torch.Tensor,
597
- timestep: int,
597
+ timestep: Union[int, torch.Tensor],
598
598
  sample: torch.Tensor,
599
599
  generator=None,
600
600
  return_dict: bool = True,
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import math
15
16
  from dataclasses import dataclass
16
- from typing import Optional, Tuple, Union
17
+ from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import torch
20
21
 
21
22
  from ..configuration_utils import ConfigMixin, register_to_config
22
23
  from ..utils import BaseOutput, logging
23
- from ..utils.torch_utils import randn_tensor
24
24
  from .scheduling_utils import SchedulerMixin
25
25
 
26
26
 
@@ -66,12 +66,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
66
66
  self,
67
67
  num_train_timesteps: int = 1000,
68
68
  shift: float = 1.0,
69
+ use_dynamic_shifting=False,
70
+ base_shift: Optional[float] = 0.5,
71
+ max_shift: Optional[float] = 1.15,
72
+ base_image_seq_len: Optional[int] = 256,
73
+ max_image_seq_len: Optional[int] = 4096,
69
74
  ):
70
75
  timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
71
76
  timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
72
77
 
73
78
  sigmas = timesteps / num_train_timesteps
74
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
79
+ if not use_dynamic_shifting:
80
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
81
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
75
82
 
76
83
  self.timesteps = sigmas * num_train_timesteps
77
84
 
@@ -114,7 +121,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
114
121
  noise: Optional[torch.FloatTensor] = None,
115
122
  ) -> torch.FloatTensor:
116
123
  """
117
- Foward process in flow-matching
124
+ Forward process in flow-matching
118
125
 
119
126
  Args:
120
127
  sample (`torch.FloatTensor`):
@@ -126,10 +133,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
126
133
  `torch.FloatTensor`:
127
134
  A scaled input sample.
128
135
  """
129
- if self.step_index is None:
130
- self._init_step_index(timestep)
136
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
137
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
138
+
139
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
140
+ # mps does not support float64
141
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
142
+ timestep = timestep.to(sample.device, dtype=torch.float32)
143
+ else:
144
+ schedule_timesteps = self.timesteps.to(sample.device)
145
+ timestep = timestep.to(sample.device)
146
+
147
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
148
+ if self.begin_index is None:
149
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
150
+ elif self.step_index is not None:
151
+ # add_noise is called after first denoising step (for inpainting)
152
+ step_indices = [self.step_index] * timestep.shape[0]
153
+ else:
154
+ # add noise is called before first denoising step to create initial latent(img2img)
155
+ step_indices = [self.begin_index] * timestep.shape[0]
156
+
157
+ sigma = sigmas[step_indices].flatten()
158
+ while len(sigma.shape) < len(sample.shape):
159
+ sigma = sigma.unsqueeze(-1)
131
160
 
132
- sigma = self.sigmas[self.step_index]
133
161
  sample = sigma * noise + (1.0 - sigma) * sample
134
162
 
135
163
  return sample
@@ -137,7 +165,16 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
137
165
  def _sigma_to_t(self, sigma):
138
166
  return sigma * self.config.num_train_timesteps
139
167
 
140
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
168
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
169
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
170
+
171
+ def set_timesteps(
172
+ self,
173
+ num_inference_steps: int = None,
174
+ device: Union[str, torch.device] = None,
175
+ sigmas: Optional[List[float]] = None,
176
+ mu: Optional[float] = None,
177
+ ):
141
178
  """
142
179
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
143
180
 
@@ -147,17 +184,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
147
184
  device (`str` or `torch.device`, *optional*):
148
185
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
149
186
  """
150
- self.num_inference_steps = num_inference_steps
151
187
 
152
- timesteps = np.linspace(
153
- self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
154
- )
188
+ if self.config.use_dynamic_shifting and mu is None:
189
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
155
190
 
156
- sigmas = timesteps / self.config.num_train_timesteps
157
- sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
158
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
191
+ if sigmas is None:
192
+ self.num_inference_steps = num_inference_steps
193
+ timesteps = np.linspace(
194
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
195
+ )
196
+
197
+ sigmas = timesteps / self.config.num_train_timesteps
198
+
199
+ if self.config.use_dynamic_shifting:
200
+ sigmas = self.time_shift(mu, 1.0, sigmas)
201
+ else:
202
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
159
203
 
204
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
160
205
  timesteps = sigmas * self.config.num_train_timesteps
206
+
161
207
  self.timesteps = timesteps.to(device=device)
162
208
  self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
163
209
 
@@ -246,32 +292,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
246
292
  sample = sample.to(torch.float32)
247
293
 
248
294
  sigma = self.sigmas[self.step_index]
295
+ sigma_next = self.sigmas[self.step_index + 1]
249
296
 
250
- gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
251
-
252
- noise = randn_tensor(
253
- model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
254
- )
255
-
256
- eps = noise * s_noise
257
- sigma_hat = sigma * (gamma + 1)
258
-
259
- if gamma > 0:
260
- sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
261
-
262
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
263
- # NOTE: "original_sample" should not be an expected prediction_type but is left in for
264
- # backwards compatibility
265
-
266
- # if self.config.prediction_type == "vector_field":
267
-
268
- denoised = sample - model_output * sigma
269
- # 2. Convert to an ODE derivative
270
- derivative = (sample - denoised) / sigma_hat
271
-
272
- dt = self.sigmas[self.step_index + 1] - sigma_hat
297
+ prev_sample = sample + (sigma_next - sigma) * model_output
273
298
 
274
- prev_sample = sample + derivative * dt
275
299
  # Cast sample back to model compatible dtype
276
300
  prev_sample = prev_sample.to(model_output.dtype)
277
301