diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Adapted from
17
+ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py
18
+ """
19
+
20
+ import importlib
21
+ import types
22
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
23
+
24
+ from packaging import version
25
+
26
+ from ...utils import get_module_from_name, is_torch_available, is_torchao_available, logging
27
+ from ..base import DiffusersQuantizer
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from ...models.modeling_utils import ModelMixin
32
+
33
+
34
+ if is_torch_available():
35
+ import torch
36
+ import torch.nn as nn
37
+
38
+ SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = (
39
+ # At the moment, only int8 is supported for integer quantization dtypes.
40
+ # In Torch 2.6, int1-int7 will be introduced, so this can be visited in the future
41
+ # to support more quantization methods, such as intx_weight_only.
42
+ torch.int8,
43
+ torch.float8_e4m3fn,
44
+ torch.float8_e5m2,
45
+ torch.uint1,
46
+ torch.uint2,
47
+ torch.uint3,
48
+ torch.uint4,
49
+ torch.uint5,
50
+ torch.uint6,
51
+ torch.uint7,
52
+ )
53
+
54
+ if is_torchao_available():
55
+ from torchao.quantization import quantize_
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+
61
+ def _quantization_type(weight):
62
+ from torchao.dtypes import AffineQuantizedTensor
63
+ from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
64
+
65
+ if isinstance(weight, AffineQuantizedTensor):
66
+ return f"{weight.__class__.__name__}({weight._quantization_type()})"
67
+
68
+ if isinstance(weight, LinearActivationQuantizedTensor):
69
+ return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"
70
+
71
+
72
+ def _linear_extra_repr(self):
73
+ weight = _quantization_type(self.weight)
74
+ if weight is None:
75
+ return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
76
+ else:
77
+ return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
78
+
79
+
80
+ class TorchAoHfQuantizer(DiffusersQuantizer):
81
+ r"""
82
+ Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/.
83
+ """
84
+
85
+ requires_calibration = False
86
+ required_packages = ["torchao"]
87
+
88
+ def __init__(self, quantization_config, **kwargs):
89
+ super().__init__(quantization_config, **kwargs)
90
+
91
+ def validate_environment(self, *args, **kwargs):
92
+ if not is_torchao_available():
93
+ raise ImportError(
94
+ "Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
95
+ )
96
+ torchao_version = version.parse(importlib.metadata.version("torch"))
97
+ if torchao_version < version.parse("0.7.0"):
98
+ raise RuntimeError(
99
+ f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
100
+ )
101
+
102
+ self.offload = False
103
+
104
+ device_map = kwargs.get("device_map", None)
105
+ if isinstance(device_map, dict):
106
+ if "cpu" in device_map.values() or "disk" in device_map.values():
107
+ if self.pre_quantized:
108
+ raise ValueError(
109
+ "You are attempting to perform cpu/disk offload with a pre-quantized torchao model "
110
+ "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
111
+ )
112
+ else:
113
+ self.offload = True
114
+
115
+ if self.pre_quantized:
116
+ weights_only = kwargs.get("weights_only", None)
117
+ if weights_only:
118
+ torch_version = version.parse(importlib.metadata.version("torch"))
119
+ if torch_version < version.parse("2.5.0"):
120
+ # TODO(aryan): TorchAO is compatible with Pytorch >= 2.2 for certain quantization types. Try to see if we can support it in future
121
+ raise RuntimeError(
122
+ f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
123
+ )
124
+
125
+ def update_torch_dtype(self, torch_dtype):
126
+ quant_type = self.quantization_config.quant_type
127
+
128
+ if quant_type.startswith("int"):
129
+ if torch_dtype is not None and torch_dtype != torch.bfloat16:
130
+ logger.warning(
131
+ f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
132
+ f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`."
133
+ )
134
+
135
+ if torch_dtype is None:
136
+ # We need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op
137
+ logger.warning(
138
+ "Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` "
139
+ "to enable model loading in different precisions. Pass your own `torch_dtype` to specify the "
140
+ "dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning."
141
+ )
142
+ torch_dtype = torch.bfloat16
143
+
144
+ return torch_dtype
145
+
146
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
147
+ quant_type = self.quantization_config.quant_type
148
+
149
+ if quant_type.startswith("int8") or quant_type.startswith("int4"):
150
+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
151
+ return torch.int8
152
+ elif quant_type == "uintx_weight_only":
153
+ return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
154
+ elif quant_type.startswith("uint"):
155
+ return {
156
+ 1: torch.uint1,
157
+ 2: torch.uint2,
158
+ 3: torch.uint3,
159
+ 4: torch.uint4,
160
+ 5: torch.uint5,
161
+ 6: torch.uint6,
162
+ 7: torch.uint7,
163
+ }[int(quant_type[4])]
164
+ elif quant_type.startswith("float") or quant_type.startswith("fp"):
165
+ return torch.bfloat16
166
+
167
+ if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
168
+ return target_dtype
169
+
170
+ # We need one of the supported dtypes to be selected in order for accelerate to determine
171
+ # the total size of modules/parameters for auto device placement.
172
+ possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"]
173
+ raise ValueError(
174
+ f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype "
175
+ f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the "
176
+ f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
177
+ )
178
+
179
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
180
+ max_memory = {key: val * 0.9 for key, val in max_memory.items()}
181
+ return max_memory
182
+
183
+ def check_if_quantized_param(
184
+ self,
185
+ model: "ModelMixin",
186
+ param_value: "torch.Tensor",
187
+ param_name: str,
188
+ state_dict: Dict[str, Any],
189
+ **kwargs,
190
+ ) -> bool:
191
+ param_device = kwargs.pop("param_device", None)
192
+ # Check if the param_name is not in self.modules_to_not_convert
193
+ if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
194
+ return False
195
+ elif param_device == "cpu" and self.offload:
196
+ # We don't quantize weights that we offload
197
+ return False
198
+ else:
199
+ # We only quantize the weight of nn.Linear
200
+ module, tensor_name = get_module_from_name(model, param_name)
201
+ return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
202
+
203
+ def create_quantized_param(
204
+ self,
205
+ model: "ModelMixin",
206
+ param_value: "torch.Tensor",
207
+ param_name: str,
208
+ target_device: "torch.device",
209
+ state_dict: Dict[str, Any],
210
+ unexpected_keys: List[str],
211
+ ):
212
+ r"""
213
+ Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
214
+ then we move it to the target device. Finally, we quantize the module.
215
+ """
216
+ module, tensor_name = get_module_from_name(model, param_name)
217
+
218
+ if self.pre_quantized:
219
+ # If we're loading pre-quantized weights, replace the repr of linear layers for pretty printing info
220
+ # about AffineQuantizedTensor
221
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
222
+ if isinstance(module, nn.Linear):
223
+ module.extra_repr = types.MethodType(_linear_extra_repr, module)
224
+ else:
225
+ # As we perform quantization here, the repr of linear layers is that of AQT, so we don't have to do it ourselves
226
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
227
+ quantize_(module, self.quantization_config.get_apply_tensor_subclass())
228
+
229
+ def _process_model_before_weight_loading(
230
+ self,
231
+ model: "ModelMixin",
232
+ device_map,
233
+ keep_in_fp32_modules: List[str] = [],
234
+ **kwargs,
235
+ ):
236
+ self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
237
+
238
+ if not isinstance(self.modules_to_not_convert, list):
239
+ self.modules_to_not_convert = [self.modules_to_not_convert]
240
+
241
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
242
+
243
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
244
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
245
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
246
+ self.modules_to_not_convert.extend(keys_on_cpu)
247
+
248
+ # Purge `None`.
249
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
250
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
251
+ # and tied modules are usually kept in FP32.
252
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
253
+
254
+ model.config.quantization_config = self.quantization_config
255
+
256
+ def _process_model_after_weight_loading(self, model: "ModelMixin"):
257
+ return model
258
+
259
+ def is_serializable(self, safe_serialization=None):
260
+ # TODO(aryan): needs to be tested
261
+ if safe_serialization:
262
+ logger.warning(
263
+ "torchao quantized model does not support safe serialization, please set `safe_serialization` to False."
264
+ )
265
+ return False
266
+
267
+ _is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
268
+ "0.25.0"
269
+ )
270
+
271
+ if not _is_torchao_serializable:
272
+ logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
273
+
274
+ if self.offload and self.quantization_config.modules_to_not_convert is None:
275
+ logger.warning(
276
+ "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them."
277
+ "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config."
278
+ )
279
+ return False
280
+
281
+ return _is_torchao_serializable
282
+
283
+ @property
284
+ def is_trainable(self):
285
+ return self.quantization_config.quant_type.startswith("int8")
@@ -463,7 +463,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
463
463
  prev_sample = prev_sample + variance
464
464
 
465
465
  if not return_dict:
466
- return (prev_sample,)
466
+ return (
467
+ prev_sample,
468
+ pred_original_sample,
469
+ )
467
470
 
468
471
  return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
469
472
 
@@ -394,7 +394,10 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
394
394
  prev_sample = a_t * sample + b_t * pred_original_sample
395
395
 
396
396
  if not return_dict:
397
- return (prev_sample,)
397
+ return (
398
+ prev_sample,
399
+ pred_original_sample,
400
+ )
398
401
 
399
402
  return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
400
403
 
@@ -480,7 +480,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
480
480
  prev_sample = prev_sample + variance
481
481
 
482
482
  if not return_dict:
483
- return (prev_sample,)
483
+ return (
484
+ prev_sample,
485
+ pred_original_sample,
486
+ )
484
487
 
485
488
  return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
486
489
 
@@ -492,7 +492,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
492
492
  pred_prev_sample = pred_prev_sample + variance
493
493
 
494
494
  if not return_dict:
495
- return (pred_prev_sample,)
495
+ return (
496
+ pred_prev_sample,
497
+ pred_original_sample,
498
+ )
496
499
 
497
500
  return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
498
501
 
@@ -545,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
545
548
  return self.config.num_train_timesteps
546
549
 
547
550
  def previous_timestep(self, timestep):
548
- if self.custom_timesteps:
551
+ if self.custom_timesteps or self.num_inference_steps:
549
552
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
550
553
  if index == self.timesteps.shape[0] - 1:
551
554
  prev_t = torch.tensor(-1)
552
555
  else:
553
556
  prev_t = self.timesteps[index + 1]
554
557
  else:
555
- num_inference_steps = (
556
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
557
- )
558
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
559
-
558
+ prev_t = timestep - 1
560
559
  return prev_t
@@ -500,7 +500,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
500
500
  pred_prev_sample = pred_prev_sample + variance
501
501
 
502
502
  if not return_dict:
503
- return (pred_prev_sample,)
503
+ return (
504
+ pred_prev_sample,
505
+ pred_original_sample,
506
+ )
504
507
 
505
508
  return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
506
509
 
@@ -636,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
636
639
 
637
640
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
638
641
  def previous_timestep(self, timestep):
639
- if self.custom_timesteps:
642
+ if self.custom_timesteps or self.num_inference_steps:
640
643
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
641
644
  if index == self.timesteps.shape[0] - 1:
642
645
  prev_t = torch.tensor(-1)
643
646
  else:
644
647
  prev_t = self.timesteps[index + 1]
645
648
  else:
646
- num_inference_steps = (
647
- self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
648
- )
649
- prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
650
-
649
+ prev_t = timestep - 1
651
650
  return prev_t
@@ -22,10 +22,14 @@ import numpy as np
22
22
  import torch
23
23
 
24
24
  from ..configuration_utils import ConfigMixin, register_to_config
25
- from ..utils import deprecate
25
+ from ..utils import deprecate, is_scipy_available
26
26
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
27
27
 
28
28
 
29
+ if is_scipy_available():
30
+ import scipy.stats
31
+
32
+
29
33
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
  def betas_for_alpha_bar(
31
35
  num_diffusion_timesteps,
@@ -111,6 +115,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
111
115
  use_karras_sigmas (`bool`, *optional*, defaults to `False`):
112
116
  Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
113
117
  the sigmas are determined according to a sequence of noise levels {σi}.
118
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
119
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
121
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
122
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
114
123
  timestep_spacing (`str`, defaults to `"linspace"`):
115
124
  The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
116
125
  Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -138,9 +147,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
138
147
  solver_type: str = "logrho",
139
148
  lower_order_final: bool = True,
140
149
  use_karras_sigmas: Optional[bool] = False,
150
+ use_exponential_sigmas: Optional[bool] = False,
151
+ use_beta_sigmas: Optional[bool] = False,
152
+ use_flow_sigmas: Optional[bool] = False,
153
+ flow_shift: Optional[float] = 1.0,
141
154
  timestep_spacing: str = "linspace",
142
155
  steps_offset: int = 0,
143
156
  ):
157
+ if self.config.use_beta_sigmas and not is_scipy_available():
158
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
159
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
160
+ raise ValueError(
161
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162
+ )
144
163
  if trained_betas is not None:
145
164
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
146
165
  elif beta_schedule == "linear":
@@ -249,12 +268,28 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
249
268
  )
250
269
 
251
270
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
271
+ log_sigmas = np.log(sigmas)
252
272
  if self.config.use_karras_sigmas:
253
- log_sigmas = np.log(sigmas)
254
273
  sigmas = np.flip(sigmas).copy()
255
274
  sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
256
275
  timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
257
276
  sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
277
+ elif self.config.use_exponential_sigmas:
278
+ sigmas = np.flip(sigmas).copy()
279
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
280
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
281
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
282
+ elif self.config.use_beta_sigmas:
283
+ sigmas = np.flip(sigmas).copy()
284
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
285
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
286
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
287
+ elif self.config.use_flow_sigmas:
288
+ alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
289
+ sigmas = 1.0 - alphas
290
+ sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
291
+ timesteps = (sigmas * self.config.num_train_timesteps).copy()
292
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
258
293
  else:
259
294
  sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
260
295
  sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -335,8 +370,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
335
370
 
336
371
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
337
372
  def _sigma_to_alpha_sigma_t(self, sigma):
338
- alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
339
- sigma_t = sigma * alpha_t
373
+ if self.config.use_flow_sigmas:
374
+ alpha_t = 1 - sigma
375
+ sigma_t = sigma
376
+ else:
377
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
378
+ sigma_t = sigma * alpha_t
340
379
 
341
380
  return alpha_t, sigma_t
342
381
 
@@ -366,6 +405,60 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
366
405
  sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
367
406
  return sigmas
368
407
 
408
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
409
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
410
+ """Constructs an exponential noise schedule."""
411
+
412
+ # Hack to make sure that other schedulers which copy this function don't break
413
+ # TODO: Add this logic to the other schedulers
414
+ if hasattr(self.config, "sigma_min"):
415
+ sigma_min = self.config.sigma_min
416
+ else:
417
+ sigma_min = None
418
+
419
+ if hasattr(self.config, "sigma_max"):
420
+ sigma_max = self.config.sigma_max
421
+ else:
422
+ sigma_max = None
423
+
424
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
425
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
426
+
427
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
428
+ return sigmas
429
+
430
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
431
+ def _convert_to_beta(
432
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
433
+ ) -> torch.Tensor:
434
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
435
+
436
+ # Hack to make sure that other schedulers which copy this function don't break
437
+ # TODO: Add this logic to the other schedulers
438
+ if hasattr(self.config, "sigma_min"):
439
+ sigma_min = self.config.sigma_min
440
+ else:
441
+ sigma_min = None
442
+
443
+ if hasattr(self.config, "sigma_max"):
444
+ sigma_max = self.config.sigma_max
445
+ else:
446
+ sigma_max = None
447
+
448
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
449
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
450
+
451
+ sigmas = np.array(
452
+ [
453
+ sigma_min + (ppf * (sigma_max - sigma_min))
454
+ for ppf in [
455
+ scipy.stats.beta.ppf(timestep, alpha, beta)
456
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
457
+ ]
458
+ ]
459
+ )
460
+ return sigmas
461
+
369
462
  def convert_model_output(
370
463
  self,
371
464
  model_output: torch.Tensor,
@@ -409,10 +502,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
409
502
  x0_pred = model_output
410
503
  elif self.config.prediction_type == "v_prediction":
411
504
  x0_pred = alpha_t * sample - sigma_t * model_output
505
+ elif self.config.prediction_type == "flow_prediction":
506
+ sigma_t = self.sigmas[self.step_index]
507
+ x0_pred = sample - sigma_t * model_output
412
508
  else:
413
509
  raise ValueError(
414
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
415
- " `v_prediction` for the DEISMultistepScheduler."
510
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
511
+ "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
416
512
  )
417
513
 
418
514
  if self.config.thresholding: