diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -47,7 +47,11 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
|
|
47
47
|
|
48
48
|
|
49
49
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
50
|
-
def betas_for_alpha_bar(
|
50
|
+
def betas_for_alpha_bar(
|
51
|
+
num_diffusion_timesteps,
|
52
|
+
max_beta=0.999,
|
53
|
+
alpha_transform_type="cosine",
|
54
|
+
):
|
51
55
|
"""
|
52
56
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
53
57
|
(1-beta) over time from t = [0,1].
|
@@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
|
60
64
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
61
65
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
62
66
|
prevent singularities.
|
67
|
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
68
|
+
Choose from `cosine` or `exp`
|
63
69
|
|
64
70
|
Returns:
|
65
71
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
66
72
|
"""
|
73
|
+
if alpha_transform_type == "cosine":
|
67
74
|
|
68
|
-
|
69
|
-
|
75
|
+
def alpha_bar_fn(t):
|
76
|
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
77
|
+
|
78
|
+
elif alpha_transform_type == "exp":
|
79
|
+
|
80
|
+
def alpha_bar_fn(t):
|
81
|
+
return math.exp(t * -12.0)
|
82
|
+
|
83
|
+
else:
|
84
|
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
70
85
|
|
71
86
|
betas = []
|
72
87
|
for i in range(num_diffusion_timesteps):
|
73
88
|
t1 = i / num_diffusion_timesteps
|
74
89
|
t2 = (i + 1) / num_diffusion_timesteps
|
75
|
-
betas.append(min(1 -
|
90
|
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
76
91
|
return torch.tensor(betas, dtype=torch.float32)
|
77
92
|
|
78
93
|
|
@@ -107,6 +122,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
107
122
|
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
108
123
|
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
109
124
|
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
125
|
+
timestep_spacing (`str`, default `"linspace"`):
|
126
|
+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
127
|
+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
128
|
+
steps_offset (`int`, default `0`):
|
129
|
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
130
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
131
|
+
stable diffusion.
|
110
132
|
"""
|
111
133
|
|
112
134
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -123,6 +145,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
123
145
|
prediction_type: str = "epsilon",
|
124
146
|
interpolation_type: str = "linear",
|
125
147
|
use_karras_sigmas: Optional[bool] = False,
|
148
|
+
timestep_spacing: str = "linspace",
|
149
|
+
steps_offset: int = 0,
|
126
150
|
):
|
127
151
|
if trained_betas is not None:
|
128
152
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -146,9 +170,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
146
170
|
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
147
171
|
self.sigmas = torch.from_numpy(sigmas)
|
148
172
|
|
149
|
-
# standard deviation of the initial noise distribution
|
150
|
-
self.init_noise_sigma = self.sigmas.max()
|
151
|
-
|
152
173
|
# setable values
|
153
174
|
self.num_inference_steps = None
|
154
175
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
@@ -156,6 +177,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
156
177
|
self.is_scale_input_called = False
|
157
178
|
self.use_karras_sigmas = use_karras_sigmas
|
158
179
|
|
180
|
+
@property
|
181
|
+
def init_noise_sigma(self):
|
182
|
+
# standard deviation of the initial noise distribution
|
183
|
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
184
|
+
return self.sigmas.max()
|
185
|
+
|
186
|
+
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
187
|
+
|
159
188
|
def scale_model_input(
|
160
189
|
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
161
190
|
) -> torch.FloatTensor:
|
@@ -191,7 +220,28 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
191
220
|
"""
|
192
221
|
self.num_inference_steps = num_inference_steps
|
193
222
|
|
194
|
-
|
223
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
224
|
+
if self.config.timestep_spacing == "linspace":
|
225
|
+
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
226
|
+
::-1
|
227
|
+
].copy()
|
228
|
+
elif self.config.timestep_spacing == "leading":
|
229
|
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
230
|
+
# creates integer timesteps by multiplying by ratio
|
231
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
232
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
233
|
+
timesteps += self.config.steps_offset
|
234
|
+
elif self.config.timestep_spacing == "trailing":
|
235
|
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
236
|
+
# creates integer timesteps by multiplying by ratio
|
237
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
238
|
+
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
239
|
+
timesteps -= 1
|
240
|
+
else:
|
241
|
+
raise ValueError(
|
242
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
243
|
+
)
|
244
|
+
|
195
245
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
196
246
|
log_sigmas = np.log(sigmas)
|
197
247
|
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from collections import defaultdict
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
|
23
24
|
|
24
25
|
|
25
26
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
26
|
-
def betas_for_alpha_bar(
|
27
|
+
def betas_for_alpha_bar(
|
28
|
+
num_diffusion_timesteps,
|
29
|
+
max_beta=0.999,
|
30
|
+
alpha_transform_type="cosine",
|
31
|
+
):
|
27
32
|
"""
|
28
33
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
29
34
|
(1-beta) over time from t = [0,1].
|
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
|
36
41
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
37
42
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
38
43
|
prevent singularities.
|
44
|
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
45
|
+
Choose from `cosine` or `exp`
|
39
46
|
|
40
47
|
Returns:
|
41
48
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
42
49
|
"""
|
50
|
+
if alpha_transform_type == "cosine":
|
43
51
|
|
44
|
-
|
45
|
-
|
52
|
+
def alpha_bar_fn(t):
|
53
|
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
54
|
+
|
55
|
+
elif alpha_transform_type == "exp":
|
56
|
+
|
57
|
+
def alpha_bar_fn(t):
|
58
|
+
return math.exp(t * -12.0)
|
59
|
+
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
46
62
|
|
47
63
|
betas = []
|
48
64
|
for i in range(num_diffusion_timesteps):
|
49
65
|
t1 = i / num_diffusion_timesteps
|
50
66
|
t2 = (i + 1) / num_diffusion_timesteps
|
51
|
-
betas.append(min(1 -
|
67
|
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
52
68
|
return torch.tensor(betas, dtype=torch.float32)
|
53
69
|
|
54
70
|
|
@@ -74,10 +90,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
74
90
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
75
91
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
76
92
|
https://imagen.research.google/video/paper.pdf).
|
93
|
+
clip_sample (`bool`, default `True`):
|
94
|
+
option to clip predicted sample for numerical stability.
|
95
|
+
clip_sample_range (`float`, default `1.0`):
|
96
|
+
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
77
97
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
78
98
|
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
79
99
|
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
80
100
|
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
101
|
+
timestep_spacing (`str`, default `"linspace"`):
|
102
|
+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
103
|
+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
104
|
+
steps_offset (`int`, default `0`):
|
105
|
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
106
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
107
|
+
stable diffusion.
|
81
108
|
"""
|
82
109
|
|
83
110
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -93,6 +120,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
93
120
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
94
121
|
prediction_type: str = "epsilon",
|
95
122
|
use_karras_sigmas: Optional[bool] = False,
|
123
|
+
clip_sample: Optional[bool] = False,
|
124
|
+
clip_sample_range: float = 1.0,
|
125
|
+
timestep_spacing: str = "linspace",
|
126
|
+
steps_offset: int = 0,
|
96
127
|
):
|
97
128
|
if trained_betas is not None:
|
98
129
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -105,7 +136,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
105
136
|
)
|
106
137
|
elif beta_schedule == "squaredcos_cap_v2":
|
107
138
|
# Glide cosine schedule
|
108
|
-
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
139
|
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
140
|
+
elif beta_schedule == "exp":
|
141
|
+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
|
109
142
|
else:
|
110
143
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
111
144
|
|
@@ -122,12 +155,26 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
122
155
|
|
123
156
|
indices = (schedule_timesteps == timestep).nonzero()
|
124
157
|
|
125
|
-
|
126
|
-
|
158
|
+
# The sigma index that is taken for the **very** first `step`
|
159
|
+
# is always the second index (or the last index if there is only 1)
|
160
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
161
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
162
|
+
if len(self._index_counter) == 0:
|
163
|
+
pos = 1 if len(indices) > 1 else 0
|
127
164
|
else:
|
128
|
-
|
165
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
166
|
+
pos = self._index_counter[timestep_int]
|
167
|
+
|
129
168
|
return indices[pos].item()
|
130
169
|
|
170
|
+
@property
|
171
|
+
def init_noise_sigma(self):
|
172
|
+
# standard deviation of the initial noise distribution
|
173
|
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
174
|
+
return self.sigmas.max()
|
175
|
+
|
176
|
+
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
177
|
+
|
131
178
|
def scale_model_input(
|
132
179
|
self,
|
133
180
|
sample: torch.FloatTensor,
|
@@ -166,13 +213,31 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
166
213
|
|
167
214
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
168
215
|
|
169
|
-
|
216
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
217
|
+
if self.config.timestep_spacing == "linspace":
|
218
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
219
|
+
elif self.config.timestep_spacing == "leading":
|
220
|
+
step_ratio = num_train_timesteps // self.num_inference_steps
|
221
|
+
# creates integer timesteps by multiplying by ratio
|
222
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
223
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
224
|
+
timesteps += self.config.steps_offset
|
225
|
+
elif self.config.timestep_spacing == "trailing":
|
226
|
+
step_ratio = num_train_timesteps / self.num_inference_steps
|
227
|
+
# creates integer timesteps by multiplying by ratio
|
228
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
229
|
+
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
230
|
+
timesteps -= 1
|
231
|
+
else:
|
232
|
+
raise ValueError(
|
233
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
234
|
+
)
|
170
235
|
|
171
236
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
172
237
|
log_sigmas = np.log(sigmas)
|
173
238
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
174
239
|
|
175
|
-
if self.use_karras_sigmas:
|
240
|
+
if self.config.use_karras_sigmas:
|
176
241
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
177
242
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
178
243
|
|
@@ -180,9 +245,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
180
245
|
sigmas = torch.from_numpy(sigmas).to(device=device)
|
181
246
|
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
182
247
|
|
183
|
-
# standard deviation of the initial noise distribution
|
184
|
-
self.init_noise_sigma = self.sigmas.max()
|
185
|
-
|
186
248
|
timesteps = torch.from_numpy(timesteps)
|
187
249
|
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
188
250
|
|
@@ -196,6 +258,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
196
258
|
self.prev_derivative = None
|
197
259
|
self.dt = None
|
198
260
|
|
261
|
+
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
262
|
+
# we need an index counter
|
263
|
+
self._index_counter = defaultdict(int)
|
264
|
+
|
199
265
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
200
266
|
def _sigma_to_t(self, sigma, log_sigmas):
|
201
267
|
# get log sigma
|
@@ -260,6 +326,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
260
326
|
"""
|
261
327
|
step_index = self.index_for_timestep(timestep)
|
262
328
|
|
329
|
+
# advance index counter by 1
|
330
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
331
|
+
self._index_counter[timestep_int] += 1
|
332
|
+
|
263
333
|
if self.state_in_first_order:
|
264
334
|
sigma = self.sigmas[step_index]
|
265
335
|
sigma_next = self.sigmas[step_index + 1]
|
@@ -284,12 +354,17 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
284
354
|
sample / (sigma_input**2 + 1)
|
285
355
|
)
|
286
356
|
elif self.config.prediction_type == "sample":
|
287
|
-
|
357
|
+
pred_original_sample = model_output
|
288
358
|
else:
|
289
359
|
raise ValueError(
|
290
360
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
291
361
|
)
|
292
362
|
|
363
|
+
if self.config.clip_sample:
|
364
|
+
pred_original_sample = pred_original_sample.clamp(
|
365
|
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
366
|
+
)
|
367
|
+
|
293
368
|
if self.state_in_first_order:
|
294
369
|
# 2. Convert to an ODE derivative for 1st order
|
295
370
|
derivative = (sample - pred_original_sample) / sigma_hat
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from collections import defaultdict
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
@@ -24,7 +25,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
|
24
25
|
|
25
26
|
|
26
27
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
27
|
-
def betas_for_alpha_bar(
|
28
|
+
def betas_for_alpha_bar(
|
29
|
+
num_diffusion_timesteps,
|
30
|
+
max_beta=0.999,
|
31
|
+
alpha_transform_type="cosine",
|
32
|
+
):
|
28
33
|
"""
|
29
34
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
30
35
|
(1-beta) over time from t = [0,1].
|
@@ -37,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
|
37
42
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
38
43
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
39
44
|
prevent singularities.
|
45
|
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
46
|
+
Choose from `cosine` or `exp`
|
40
47
|
|
41
48
|
Returns:
|
42
49
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
43
50
|
"""
|
51
|
+
if alpha_transform_type == "cosine":
|
44
52
|
|
45
|
-
|
46
|
-
|
53
|
+
def alpha_bar_fn(t):
|
54
|
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
55
|
+
|
56
|
+
elif alpha_transform_type == "exp":
|
57
|
+
|
58
|
+
def alpha_bar_fn(t):
|
59
|
+
return math.exp(t * -12.0)
|
60
|
+
|
61
|
+
else:
|
62
|
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
47
63
|
|
48
64
|
betas = []
|
49
65
|
for i in range(num_diffusion_timesteps):
|
50
66
|
t1 = i / num_diffusion_timesteps
|
51
67
|
t2 = (i + 1) / num_diffusion_timesteps
|
52
|
-
betas.append(min(1 -
|
68
|
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
53
69
|
return torch.tensor(betas, dtype=torch.float32)
|
54
70
|
|
55
71
|
|
@@ -78,6 +94,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
78
94
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
79
95
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
80
96
|
https://imagen.research.google/video/paper.pdf)
|
97
|
+
timestep_spacing (`str`, default `"linspace"`):
|
98
|
+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
99
|
+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
100
|
+
steps_offset (`int`, default `0`):
|
101
|
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
102
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
103
|
+
stable diffusion.
|
81
104
|
"""
|
82
105
|
|
83
106
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -92,6 +115,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
92
115
|
beta_schedule: str = "linear",
|
93
116
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
94
117
|
prediction_type: str = "epsilon",
|
118
|
+
timestep_spacing: str = "linspace",
|
119
|
+
steps_offset: int = 0,
|
95
120
|
):
|
96
121
|
if trained_betas is not None:
|
97
122
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -121,12 +146,26 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
121
146
|
|
122
147
|
indices = (schedule_timesteps == timestep).nonzero()
|
123
148
|
|
124
|
-
|
125
|
-
|
149
|
+
# The sigma index that is taken for the **very** first `step`
|
150
|
+
# is always the second index (or the last index if there is only 1)
|
151
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
152
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
153
|
+
if len(self._index_counter) == 0:
|
154
|
+
pos = 1 if len(indices) > 1 else 0
|
126
155
|
else:
|
127
|
-
|
156
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
157
|
+
pos = self._index_counter[timestep_int]
|
158
|
+
|
128
159
|
return indices[pos].item()
|
129
160
|
|
161
|
+
@property
|
162
|
+
def init_noise_sigma(self):
|
163
|
+
# standard deviation of the initial noise distribution
|
164
|
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
165
|
+
return self.sigmas.max()
|
166
|
+
|
167
|
+
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
168
|
+
|
130
169
|
def scale_model_input(
|
131
170
|
self,
|
132
171
|
sample: torch.FloatTensor,
|
@@ -169,7 +208,25 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
169
208
|
|
170
209
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
171
210
|
|
172
|
-
|
211
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
212
|
+
if self.config.timestep_spacing == "linspace":
|
213
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
214
|
+
elif self.config.timestep_spacing == "leading":
|
215
|
+
step_ratio = num_train_timesteps // self.num_inference_steps
|
216
|
+
# creates integer timesteps by multiplying by ratio
|
217
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
218
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
219
|
+
timesteps += self.config.steps_offset
|
220
|
+
elif self.config.timestep_spacing == "trailing":
|
221
|
+
step_ratio = num_train_timesteps / self.num_inference_steps
|
222
|
+
# creates integer timesteps by multiplying by ratio
|
223
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
224
|
+
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
225
|
+
timesteps -= 1
|
226
|
+
else:
|
227
|
+
raise ValueError(
|
228
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
229
|
+
)
|
173
230
|
|
174
231
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
175
232
|
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
@@ -197,9 +254,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
197
254
|
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
|
198
255
|
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
199
256
|
|
200
|
-
# standard deviation of the initial noise distribution
|
201
|
-
self.init_noise_sigma = self.sigmas.max()
|
202
|
-
|
203
257
|
if str(device).startswith("mps"):
|
204
258
|
# mps does not support float64
|
205
259
|
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
@@ -213,6 +267,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
213
267
|
|
214
268
|
self.sample = None
|
215
269
|
|
270
|
+
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
271
|
+
# we need an index counter
|
272
|
+
self._index_counter = defaultdict(int)
|
273
|
+
|
216
274
|
def sigma_to_t(self, sigma):
|
217
275
|
# get log sigma
|
218
276
|
log_sigma = sigma.log()
|
@@ -263,6 +321,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
263
321
|
"""
|
264
322
|
step_index = self.index_for_timestep(timestep)
|
265
323
|
|
324
|
+
# advance index counter by 1
|
325
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
326
|
+
self._index_counter[timestep_int] += 1
|
327
|
+
|
266
328
|
if self.state_in_first_order:
|
267
329
|
sigma = self.sigmas[step_index]
|
268
330
|
sigma_interpol = self.sigmas_interpol[step_index]
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
+
from collections import defaultdict
|
16
17
|
from typing import List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
@@ -23,7 +24,11 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
|
|
23
24
|
|
24
25
|
|
25
26
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
26
|
-
def betas_for_alpha_bar(
|
27
|
+
def betas_for_alpha_bar(
|
28
|
+
num_diffusion_timesteps,
|
29
|
+
max_beta=0.999,
|
30
|
+
alpha_transform_type="cosine",
|
31
|
+
):
|
27
32
|
"""
|
28
33
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
29
34
|
(1-beta) over time from t = [0,1].
|
@@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
|
36
41
|
num_diffusion_timesteps (`int`): the number of betas to produce.
|
37
42
|
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
38
43
|
prevent singularities.
|
44
|
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
45
|
+
Choose from `cosine` or `exp`
|
39
46
|
|
40
47
|
Returns:
|
41
48
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
42
49
|
"""
|
50
|
+
if alpha_transform_type == "cosine":
|
43
51
|
|
44
|
-
|
45
|
-
|
52
|
+
def alpha_bar_fn(t):
|
53
|
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
54
|
+
|
55
|
+
elif alpha_transform_type == "exp":
|
56
|
+
|
57
|
+
def alpha_bar_fn(t):
|
58
|
+
return math.exp(t * -12.0)
|
59
|
+
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
46
62
|
|
47
63
|
betas = []
|
48
64
|
for i in range(num_diffusion_timesteps):
|
49
65
|
t1 = i / num_diffusion_timesteps
|
50
66
|
t2 = (i + 1) / num_diffusion_timesteps
|
51
|
-
betas.append(min(1 -
|
67
|
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
52
68
|
return torch.tensor(betas, dtype=torch.float32)
|
53
69
|
|
54
70
|
|
@@ -77,6 +93,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
77
93
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
78
94
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
79
95
|
https://imagen.research.google/video/paper.pdf)
|
96
|
+
timestep_spacing (`str`, default `"linspace"`):
|
97
|
+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
98
|
+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
99
|
+
steps_offset (`int`, default `0`):
|
100
|
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
101
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
102
|
+
stable diffusion.
|
80
103
|
"""
|
81
104
|
|
82
105
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -91,6 +114,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
91
114
|
beta_schedule: str = "linear",
|
92
115
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
93
116
|
prediction_type: str = "epsilon",
|
117
|
+
timestep_spacing: str = "linspace",
|
118
|
+
steps_offset: int = 0,
|
94
119
|
):
|
95
120
|
if trained_betas is not None:
|
96
121
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -120,12 +145,26 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
120
145
|
|
121
146
|
indices = (schedule_timesteps == timestep).nonzero()
|
122
147
|
|
123
|
-
|
124
|
-
|
148
|
+
# The sigma index that is taken for the **very** first `step`
|
149
|
+
# is always the second index (or the last index if there is only 1)
|
150
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
151
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
152
|
+
if len(self._index_counter) == 0:
|
153
|
+
pos = 1 if len(indices) > 1 else 0
|
125
154
|
else:
|
126
|
-
|
155
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
156
|
+
pos = self._index_counter[timestep_int]
|
157
|
+
|
127
158
|
return indices[pos].item()
|
128
159
|
|
160
|
+
@property
|
161
|
+
def init_noise_sigma(self):
|
162
|
+
# standard deviation of the initial noise distribution
|
163
|
+
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
164
|
+
return self.sigmas.max()
|
165
|
+
|
166
|
+
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
167
|
+
|
129
168
|
def scale_model_input(
|
130
169
|
self,
|
131
170
|
sample: torch.FloatTensor,
|
@@ -168,7 +207,25 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
168
207
|
|
169
208
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
170
209
|
|
171
|
-
|
210
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
211
|
+
if self.config.timestep_spacing == "linspace":
|
212
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
213
|
+
elif self.config.timestep_spacing == "leading":
|
214
|
+
step_ratio = num_train_timesteps // self.num_inference_steps
|
215
|
+
# creates integer timesteps by multiplying by ratio
|
216
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
217
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
218
|
+
timesteps += self.config.steps_offset
|
219
|
+
elif self.config.timestep_spacing == "trailing":
|
220
|
+
step_ratio = num_train_timesteps / self.num_inference_steps
|
221
|
+
# creates integer timesteps by multiplying by ratio
|
222
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
223
|
+
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
224
|
+
timesteps -= 1
|
225
|
+
else:
|
226
|
+
raise ValueError(
|
227
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
228
|
+
)
|
172
229
|
|
173
230
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
174
231
|
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
@@ -185,9 +242,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
185
242
|
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
|
186
243
|
)
|
187
244
|
|
188
|
-
# standard deviation of the initial noise distribution
|
189
|
-
self.init_noise_sigma = self.sigmas.max()
|
190
|
-
|
191
245
|
if str(device).startswith("mps"):
|
192
246
|
# mps does not support float64
|
193
247
|
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
|
@@ -202,6 +256,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
202
256
|
|
203
257
|
self.sample = None
|
204
258
|
|
259
|
+
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
260
|
+
# we need an index counter
|
261
|
+
self._index_counter = defaultdict(int)
|
262
|
+
|
205
263
|
def sigma_to_t(self, sigma):
|
206
264
|
# get log sigma
|
207
265
|
log_sigma = sigma.log()
|
@@ -251,6 +309,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
251
309
|
"""
|
252
310
|
step_index = self.index_for_timestep(timestep)
|
253
311
|
|
312
|
+
# advance index counter by 1
|
313
|
+
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
314
|
+
self._index_counter[timestep_int] += 1
|
315
|
+
|
254
316
|
if self.state_in_first_order:
|
255
317
|
sigma = self.sigmas[step_index]
|
256
318
|
sigma_interpol = self.sigmas_interpol[step_index + 1]
|