diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -28,9 +28,12 @@ try:
28
28
  except OptionalDependencyNotAvailable:
29
29
  from ..utils.dummy_pt_objects import * # noqa F403
30
30
  else:
31
+ from .scheduling_consistency_models import CMStochasticIterativeScheduler
31
32
  from .scheduling_ddim import DDIMScheduler
32
33
  from .scheduling_ddim_inverse import DDIMInverseScheduler
34
+ from .scheduling_ddim_parallel import DDIMParallelScheduler
33
35
  from .scheduling_ddpm import DDPMScheduler
36
+ from .scheduling_ddpm_parallel import DDPMParallelScheduler
34
37
  from .scheduling_deis_multistep import DEISMultistepScheduler
35
38
  from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
36
39
  from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
@@ -0,0 +1,380 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging, randn_tensor
23
+ from .scheduling_utils import SchedulerMixin
24
+
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ @dataclass
30
+ class CMStochasticIterativeSchedulerOutput(BaseOutput):
31
+ """
32
+ Output class for the scheduler's step function output.
33
+
34
+ Args:
35
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
37
+ denoising loop.
38
+ """
39
+
40
+ prev_sample: torch.FloatTensor
41
+
42
+
43
+ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
44
+ """
45
+ Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the
46
+ paper [1].
47
+
48
+ [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
49
+ https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based
50
+ Generative Models." https://arxiv.org/abs/2206.00364
51
+
52
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
53
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
54
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
55
+ [`~SchedulerMixin.from_pretrained`] functions.
56
+
57
+ Args:
58
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
59
+ sigma_min (`float`):
60
+ Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation.
61
+ sigma_max (`float`):
62
+ Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation.
63
+ sigma_data (`float`):
64
+ The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the
65
+ original implementation, which is also the original value suggested in the EDM paper.
66
+ s_noise (`float`):
67
+ The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
68
+ 1.011]. This was set to 1.0 in the original implementation.
69
+ rho (`float`):
70
+ The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was
71
+ set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper.
72
+ clip_denoised (`bool`):
73
+ Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`.
74
+ timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
75
+ Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing
76
+ order.
77
+ """
78
+
79
+ order = 1
80
+
81
+ @register_to_config
82
+ def __init__(
83
+ self,
84
+ num_train_timesteps: int = 40,
85
+ sigma_min: float = 0.002,
86
+ sigma_max: float = 80.0,
87
+ sigma_data: float = 0.5,
88
+ s_noise: float = 1.0,
89
+ rho: float = 7.0,
90
+ clip_denoised: bool = True,
91
+ ):
92
+ # standard deviation of the initial noise distribution
93
+ self.init_noise_sigma = sigma_max
94
+
95
+ ramp = np.linspace(0, 1, num_train_timesteps)
96
+ sigmas = self._convert_to_karras(ramp)
97
+ timesteps = self.sigma_to_t(sigmas)
98
+
99
+ # setable values
100
+ self.num_inference_steps = None
101
+ self.sigmas = torch.from_numpy(sigmas)
102
+ self.timesteps = torch.from_numpy(timesteps)
103
+ self.custom_timesteps = False
104
+ self.is_scale_input_called = False
105
+
106
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
107
+ if schedule_timesteps is None:
108
+ schedule_timesteps = self.timesteps
109
+
110
+ indices = (schedule_timesteps == timestep).nonzero()
111
+ return indices.item()
112
+
113
+ def scale_model_input(
114
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
115
+ ) -> torch.FloatTensor:
116
+ """
117
+ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model.
118
+
119
+ Args:
120
+ sample (`torch.FloatTensor`): input sample
121
+ timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
122
+ Returns:
123
+ `torch.FloatTensor`: scaled input sample
124
+ """
125
+ # Get sigma corresponding to timestep
126
+ if isinstance(timestep, torch.Tensor):
127
+ timestep = timestep.to(self.timesteps.device)
128
+ step_idx = self.index_for_timestep(timestep)
129
+ sigma = self.sigmas[step_idx]
130
+
131
+ sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
132
+
133
+ self.is_scale_input_called = True
134
+ return sample
135
+
136
+ def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
137
+ """
138
+ Gets scaled timesteps from the Karras sigmas, for input to the consistency model.
139
+
140
+ Args:
141
+ sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas
142
+ Returns:
143
+ `float` or `np.ndarray`: scaled input timestep or scaled input timestep array
144
+ """
145
+ if not isinstance(sigmas, np.ndarray):
146
+ sigmas = np.array(sigmas, dtype=np.float64)
147
+
148
+ timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44)
149
+
150
+ return timesteps
151
+
152
+ def set_timesteps(
153
+ self,
154
+ num_inference_steps: Optional[int] = None,
155
+ device: Union[str, torch.device] = None,
156
+ timesteps: Optional[List[int]] = None,
157
+ ):
158
+ """
159
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
160
+
161
+ Args:
162
+ num_inference_steps (`int`):
163
+ the number of diffusion steps used when generating samples with a pre-trained model.
164
+ device (`str` or `torch.device`, optional):
165
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
166
+ timesteps (`List[int]`, optional):
167
+ custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
168
+ timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
169
+ must be `None`.
170
+ """
171
+ if num_inference_steps is None and timesteps is None:
172
+ raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
173
+
174
+ if num_inference_steps is not None and timesteps is not None:
175
+ raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.")
176
+
177
+ # Follow DDPMScheduler custom timesteps logic
178
+ if timesteps is not None:
179
+ for i in range(1, len(timesteps)):
180
+ if timesteps[i] >= timesteps[i - 1]:
181
+ raise ValueError("`timesteps` must be in descending order.")
182
+
183
+ if timesteps[0] >= self.config.num_train_timesteps:
184
+ raise ValueError(
185
+ f"`timesteps` must start before `self.config.train_timesteps`:"
186
+ f" {self.config.num_train_timesteps}."
187
+ )
188
+
189
+ timesteps = np.array(timesteps, dtype=np.int64)
190
+ self.custom_timesteps = True
191
+ else:
192
+ if num_inference_steps > self.config.num_train_timesteps:
193
+ raise ValueError(
194
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
195
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
196
+ f" maximal {self.config.num_train_timesteps} timesteps."
197
+ )
198
+
199
+ self.num_inference_steps = num_inference_steps
200
+
201
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
202
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
203
+ self.custom_timesteps = False
204
+
205
+ # Map timesteps to Karras sigmas directly for multistep sampling
206
+ # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
207
+ num_train_timesteps = self.config.num_train_timesteps
208
+ ramp = timesteps[::-1].copy()
209
+ ramp = ramp / (num_train_timesteps - 1)
210
+ sigmas = self._convert_to_karras(ramp)
211
+ timesteps = self.sigma_to_t(sigmas)
212
+
213
+ sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
214
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
215
+
216
+ if str(device).startswith("mps"):
217
+ # mps does not support float64
218
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
219
+ else:
220
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
221
+
222
+ # Modified _convert_to_karras implementation that takes in ramp as argument
223
+ def _convert_to_karras(self, ramp):
224
+ """Constructs the noise schedule of Karras et al. (2022)."""
225
+
226
+ sigma_min: float = self.config.sigma_min
227
+ sigma_max: float = self.config.sigma_max
228
+
229
+ rho = self.config.rho
230
+ min_inv_rho = sigma_min ** (1 / rho)
231
+ max_inv_rho = sigma_max ** (1 / rho)
232
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
233
+ return sigmas
234
+
235
+ def get_scalings(self, sigma):
236
+ sigma_data = self.config.sigma_data
237
+
238
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
239
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
240
+ return c_skip, c_out
241
+
242
+ def get_scalings_for_boundary_condition(self, sigma):
243
+ """
244
+ Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper.
245
+ This enforces the consistency model boundary condition.
246
+
247
+ Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min.
248
+
249
+ Args:
250
+ sigma (`torch.FloatTensor`):
251
+ The current sigma in the Karras sigma schedule.
252
+ Returns:
253
+ `tuple`:
254
+ A two-element tuple where c_skip (which weights the current sample) is the first element and c_out
255
+ (which weights the consistency model output) is the second element.
256
+ """
257
+ sigma_min = self.config.sigma_min
258
+ sigma_data = self.config.sigma_data
259
+
260
+ c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2)
261
+ c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
262
+ return c_skip, c_out
263
+
264
+ def step(
265
+ self,
266
+ model_output: torch.FloatTensor,
267
+ timestep: Union[float, torch.FloatTensor],
268
+ sample: torch.FloatTensor,
269
+ generator: Optional[torch.Generator] = None,
270
+ return_dict: bool = True,
271
+ ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
272
+ """
273
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
274
+ process from the learned model outputs (most often the predicted noise).
275
+
276
+ Args:
277
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
278
+ timestep (`float`): current timestep in the diffusion chain.
279
+ sample (`torch.FloatTensor`):
280
+ current instance of sample being created by diffusion process.
281
+ generator (`torch.Generator`, *optional*): Random number generator.
282
+ return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
283
+ Returns:
284
+ [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`:
285
+ [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a
286
+ `tuple`. When returning a tuple, the first element is the sample tensor.
287
+ """
288
+
289
+ if (
290
+ isinstance(timestep, int)
291
+ or isinstance(timestep, torch.IntTensor)
292
+ or isinstance(timestep, torch.LongTensor)
293
+ ):
294
+ raise ValueError(
295
+ (
296
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
297
+ f" `{self.__class__}.step()` is not supported. Make sure to pass"
298
+ " one of the `scheduler.timesteps` as a timestep."
299
+ ),
300
+ )
301
+
302
+ if not self.is_scale_input_called:
303
+ logger.warning(
304
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
305
+ "See `StableDiffusionPipeline` for a usage example."
306
+ )
307
+
308
+ if isinstance(timestep, torch.Tensor):
309
+ timestep = timestep.to(self.timesteps.device)
310
+
311
+ sigma_min = self.config.sigma_min
312
+ sigma_max = self.config.sigma_max
313
+
314
+ step_index = self.index_for_timestep(timestep)
315
+
316
+ # sigma_next corresponds to next_t in original implementation
317
+ sigma = self.sigmas[step_index]
318
+ if step_index + 1 < self.config.num_train_timesteps:
319
+ sigma_next = self.sigmas[step_index + 1]
320
+ else:
321
+ # Set sigma_next to sigma_min
322
+ sigma_next = self.sigmas[-1]
323
+
324
+ # Get scalings for boundary conditions
325
+ c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
326
+
327
+ # 1. Denoise model output using boundary conditions
328
+ denoised = c_out * model_output + c_skip * sample
329
+ if self.config.clip_denoised:
330
+ denoised = denoised.clamp(-1, 1)
331
+
332
+ # 2. Sample z ~ N(0, s_noise^2 * I)
333
+ # Noise is not used for onestep sampling.
334
+ if len(self.timesteps) > 1:
335
+ noise = randn_tensor(
336
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
337
+ )
338
+ else:
339
+ noise = torch.zeros_like(model_output)
340
+ z = noise * self.config.s_noise
341
+
342
+ sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max)
343
+
344
+ # 3. Return noisy sample
345
+ # tau = sigma_hat, eps = sigma_min
346
+ prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
347
+
348
+ if not return_dict:
349
+ return (prev_sample,)
350
+
351
+ return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
352
+
353
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
354
+ def add_noise(
355
+ self,
356
+ original_samples: torch.FloatTensor,
357
+ noise: torch.FloatTensor,
358
+ timesteps: torch.FloatTensor,
359
+ ) -> torch.FloatTensor:
360
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
361
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
362
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
363
+ # mps does not support float64
364
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
365
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
366
+ else:
367
+ schedule_timesteps = self.timesteps.to(original_samples.device)
368
+ timesteps = timesteps.to(original_samples.device)
369
+
370
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
371
+
372
+ sigma = sigmas[step_indices].flatten()
373
+ while len(sigma.shape) < len(original_samples.shape):
374
+ sigma = sigma.unsqueeze(-1)
375
+
376
+ noisy_samples = original_samples + noise * sigma
377
+ return noisy_samples
378
+
379
+ def __len__(self):
380
+ return self.config.num_train_timesteps
@@ -47,7 +47,11 @@ class DDIMSchedulerOutput(BaseOutput):
47
47
 
48
48
 
49
49
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
50
- def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
50
+ def betas_for_alpha_bar(
51
+ num_diffusion_timesteps,
52
+ max_beta=0.999,
53
+ alpha_transform_type="cosine",
54
+ ):
51
55
  """
52
56
  Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
53
57
  (1-beta) over time from t = [0,1].
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
60
64
  num_diffusion_timesteps (`int`): the number of betas to produce.
61
65
  max_beta (`float`): the maximum beta to use; use values lower than 1 to
62
66
  prevent singularities.
67
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
68
+ Choose from `cosine` or `exp`
63
69
 
64
70
  Returns:
65
71
  betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
66
72
  """
73
+ if alpha_transform_type == "cosine":
67
74
 
68
- def alpha_bar(time_step):
69
- return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
75
+ def alpha_bar_fn(t):
76
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
77
+
78
+ elif alpha_transform_type == "exp":
79
+
80
+ def alpha_bar_fn(t):
81
+ return math.exp(t * -12.0)
82
+
83
+ else:
84
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
70
85
 
71
86
  betas = []
72
87
  for i in range(num_diffusion_timesteps):
73
88
  t1 = i / num_diffusion_timesteps
74
89
  t2 = (i + 1) / num_diffusion_timesteps
75
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
90
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
76
91
  return torch.tensor(betas, dtype=torch.float32)
77
92
 
78
93
 
@@ -302,8 +317,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
302
317
 
303
318
  self.num_inference_steps = num_inference_steps
304
319
 
305
- # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
306
- if self.config.timestep_spacing == "leading":
320
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
321
+ if self.config.timestep_spacing == "linspace":
322
+ timesteps = (
323
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
324
+ .round()[::-1]
325
+ .copy()
326
+ .astype(np.int64)
327
+ )
328
+ elif self.config.timestep_spacing == "leading":
307
329
  step_ratio = self.config.num_train_timesteps // self.num_inference_steps
308
330
  # creates integer timesteps by multiplying by ratio
309
331
  # casting to int to avoid issues when num_inference_step is power of 3
@@ -46,7 +46,11 @@ class DDIMSchedulerOutput(BaseOutput):
46
46
 
47
47
 
48
48
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
49
- def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
49
+ def betas_for_alpha_bar(
50
+ num_diffusion_timesteps,
51
+ max_beta=0.999,
52
+ alpha_transform_type="cosine",
53
+ ):
50
54
  """
51
55
  Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
52
56
  (1-beta) over time from t = [0,1].
@@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
59
63
  num_diffusion_timesteps (`int`): the number of betas to produce.
60
64
  max_beta (`float`): the maximum beta to use; use values lower than 1 to
61
65
  prevent singularities.
66
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
67
+ Choose from `cosine` or `exp`
62
68
 
63
69
  Returns:
64
70
  betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
65
71
  """
72
+ if alpha_transform_type == "cosine":
66
73
 
67
- def alpha_bar(time_step):
68
- return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
74
+ def alpha_bar_fn(t):
75
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
76
+
77
+ elif alpha_transform_type == "exp":
78
+
79
+ def alpha_bar_fn(t):
80
+ return math.exp(t * -12.0)
81
+
82
+ else:
83
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
69
84
 
70
85
  betas = []
71
86
  for i in range(num_diffusion_timesteps):
72
87
  t1 = i / num_diffusion_timesteps
73
88
  t2 = (i + 1) / num_diffusion_timesteps
74
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
89
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
75
90
  return torch.tensor(betas, dtype=torch.float32)
76
91
 
77
92