monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
from __future__ import annotations
|
33
|
+
|
34
|
+
import numpy as np
|
35
|
+
import torch
|
36
|
+
|
37
|
+
from .ddpm import DDPMPredictionType
|
38
|
+
from .scheduler import Scheduler
|
39
|
+
|
40
|
+
DDIMPredictionType = DDPMPredictionType
|
41
|
+
|
42
|
+
|
43
|
+
class DDIMScheduler(Scheduler):
|
44
|
+
"""
|
45
|
+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
46
|
+
diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
|
47
|
+
Implicit Models" https://arxiv.org/abs/2010.02502
|
48
|
+
|
49
|
+
Args:
|
50
|
+
num_train_timesteps: number of diffusion steps used to train the model.
|
51
|
+
schedule: member of NoiseSchedules, name of noise schedule function in component store
|
52
|
+
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
|
53
|
+
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
|
54
|
+
For the final step there is no previous alpha. When this option is `True` the previous alpha product is
|
55
|
+
fixed to `1`, otherwise it uses the value of alpha at step 0.
|
56
|
+
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
|
57
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
58
|
+
stable diffusion.
|
59
|
+
prediction_type: member of DDPMPredictionType
|
60
|
+
clip_sample_min: minimum clipping value when clip_sample equals True
|
61
|
+
clip_sample_max: maximum clipping value when clip_sample equals True
|
62
|
+
schedule_args: arguments to pass to the schedule function
|
63
|
+
|
64
|
+
"""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
num_train_timesteps: int = 1000,
|
69
|
+
schedule: str = "linear_beta",
|
70
|
+
clip_sample: bool = True,
|
71
|
+
set_alpha_to_one: bool = True,
|
72
|
+
steps_offset: int = 0,
|
73
|
+
prediction_type: str = DDIMPredictionType.EPSILON,
|
74
|
+
clip_sample_min: float = -1.0,
|
75
|
+
clip_sample_max: float = 1.0,
|
76
|
+
**schedule_args,
|
77
|
+
) -> None:
|
78
|
+
super().__init__(num_train_timesteps, schedule, **schedule_args)
|
79
|
+
|
80
|
+
if prediction_type not in DDIMPredictionType.__members__.values():
|
81
|
+
raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")
|
82
|
+
|
83
|
+
self.prediction_type = prediction_type
|
84
|
+
|
85
|
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
86
|
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
87
|
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
88
|
+
# whether we use the final alpha of the "non-previous" one.
|
89
|
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
90
|
+
|
91
|
+
# standard deviation of the initial noise distribution
|
92
|
+
self.init_noise_sigma = 1.0
|
93
|
+
|
94
|
+
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
|
95
|
+
|
96
|
+
self.clip_sample = clip_sample
|
97
|
+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
|
98
|
+
self.steps_offset = steps_offset
|
99
|
+
|
100
|
+
# default the number of inference timesteps to the number of train steps
|
101
|
+
self.num_inference_steps: int
|
102
|
+
self.set_timesteps(self.num_train_timesteps)
|
103
|
+
|
104
|
+
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
105
|
+
"""
|
106
|
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
|
110
|
+
device: target device to put the data.
|
111
|
+
"""
|
112
|
+
if num_inference_steps > self.num_train_timesteps:
|
113
|
+
raise ValueError(
|
114
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
|
115
|
+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
116
|
+
f" maximal {self.num_train_timesteps} timesteps."
|
117
|
+
)
|
118
|
+
|
119
|
+
self.num_inference_steps = num_inference_steps
|
120
|
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
121
|
+
if self.steps_offset >= step_ratio:
|
122
|
+
raise ValueError(
|
123
|
+
f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
|
124
|
+
f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
|
125
|
+
f" the max train timestep."
|
126
|
+
)
|
127
|
+
|
128
|
+
# creates integer timesteps by multiplying by ratio
|
129
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
130
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
131
|
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
132
|
+
self.timesteps += self.steps_offset
|
133
|
+
|
134
|
+
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
135
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
136
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
137
|
+
beta_prod_t = 1 - alpha_prod_t
|
138
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
139
|
+
|
140
|
+
variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
141
|
+
|
142
|
+
return variance
|
143
|
+
|
144
|
+
def step(
|
145
|
+
self,
|
146
|
+
model_output: torch.Tensor,
|
147
|
+
timestep: int,
|
148
|
+
sample: torch.Tensor,
|
149
|
+
eta: float = 0.0,
|
150
|
+
generator: torch.Generator | None = None,
|
151
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
152
|
+
"""
|
153
|
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
154
|
+
process from the learned model outputs (most often the predicted noise).
|
155
|
+
|
156
|
+
Args:
|
157
|
+
model_output: direct output from learned diffusion model.
|
158
|
+
timestep: current discrete timestep in the diffusion chain.
|
159
|
+
sample: current instance of sample being created by diffusion process.
|
160
|
+
eta: weight of noise for added noise in diffusion step.
|
161
|
+
generator: random number generator.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
pred_prev_sample: Predicted previous sample
|
165
|
+
pred_original_sample: Predicted original sample
|
166
|
+
"""
|
167
|
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
168
|
+
# Ideally, read DDIM paper in-detail understanding
|
169
|
+
|
170
|
+
# Notation (<variable name> -> <name in paper>
|
171
|
+
# - model_output -> e_theta(x_t, t)
|
172
|
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
173
|
+
# - std_dev_t -> sigma_t
|
174
|
+
# - eta -> η
|
175
|
+
# - pred_sample_direction -> "direction pointing to x_t"
|
176
|
+
# - pred_prev_sample -> "x_t-1"
|
177
|
+
|
178
|
+
# 1. get previous step value (=t-1)
|
179
|
+
prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
|
180
|
+
|
181
|
+
# 2. compute alphas, betas
|
182
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
183
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
184
|
+
|
185
|
+
beta_prod_t = 1 - alpha_prod_t
|
186
|
+
|
187
|
+
# predefinitions satisfy pylint/mypy, these values won't be ultimately used
|
188
|
+
pred_original_sample = sample
|
189
|
+
pred_epsilon = model_output
|
190
|
+
|
191
|
+
# 3. compute predicted original sample from predicted noise also called
|
192
|
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
193
|
+
if self.prediction_type == DDIMPredictionType.EPSILON:
|
194
|
+
pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
|
195
|
+
pred_epsilon = model_output
|
196
|
+
elif self.prediction_type == DDIMPredictionType.SAMPLE:
|
197
|
+
pred_original_sample = model_output
|
198
|
+
pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)
|
199
|
+
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
|
200
|
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
201
|
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
202
|
+
|
203
|
+
# 4. Clip "predicted x_0"
|
204
|
+
if self.clip_sample:
|
205
|
+
pred_original_sample = torch.clamp(
|
206
|
+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
|
207
|
+
)
|
208
|
+
|
209
|
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
210
|
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
211
|
+
variance = self._get_variance(timestep, prev_timestep)
|
212
|
+
std_dev_t = eta * variance**0.5
|
213
|
+
|
214
|
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
215
|
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
|
216
|
+
|
217
|
+
# 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
218
|
+
pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
|
219
|
+
|
220
|
+
if eta > 0:
|
221
|
+
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
222
|
+
device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu")
|
223
|
+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
224
|
+
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise
|
225
|
+
|
226
|
+
pred_prev_sample = pred_prev_sample + variance
|
227
|
+
|
228
|
+
return pred_prev_sample, pred_original_sample
|
229
|
+
|
230
|
+
def reversed_step(
|
231
|
+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor
|
232
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
233
|
+
"""
|
234
|
+
Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion
|
235
|
+
process from the learned model outputs (most often the predicted noise).
|
236
|
+
|
237
|
+
Args:
|
238
|
+
model_output: direct output from learned diffusion model.
|
239
|
+
timestep: current discrete timestep in the diffusion chain.
|
240
|
+
sample: current instance of sample being created by diffusion process.
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
pred_prev_sample: Predicted previous sample
|
244
|
+
pred_original_sample: Predicted original sample
|
245
|
+
"""
|
246
|
+
# See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
|
247
|
+
|
248
|
+
# Notation (<variable name> -> <name in paper>
|
249
|
+
# - model_output -> e_theta(x_t, t)
|
250
|
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
251
|
+
# - std_dev_t -> sigma_t
|
252
|
+
# - eta -> η
|
253
|
+
# - pred_sample_direction -> "direction pointing to x_t"
|
254
|
+
# - pred_post_sample -> "x_t+1"
|
255
|
+
|
256
|
+
# 1. get previous step value (=t+1)
|
257
|
+
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
|
258
|
+
|
259
|
+
# 2. compute alphas, betas at timestep t+1
|
260
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
261
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
262
|
+
|
263
|
+
beta_prod_t = 1 - alpha_prod_t
|
264
|
+
|
265
|
+
# predefinitions satisfy pylint/mypy, these values won't be ultimately used
|
266
|
+
pred_original_sample = sample
|
267
|
+
pred_epsilon = model_output
|
268
|
+
|
269
|
+
# 3. compute predicted original sample from predicted noise also called
|
270
|
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
271
|
+
|
272
|
+
if self.prediction_type == DDIMPredictionType.EPSILON:
|
273
|
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
274
|
+
pred_epsilon = model_output
|
275
|
+
elif self.prediction_type == DDIMPredictionType.SAMPLE:
|
276
|
+
pred_original_sample = model_output
|
277
|
+
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
278
|
+
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
|
279
|
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
280
|
+
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
281
|
+
|
282
|
+
# 4. Clip "predicted x_0"
|
283
|
+
if self.clip_sample:
|
284
|
+
pred_original_sample = torch.clamp(
|
285
|
+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
|
286
|
+
)
|
287
|
+
|
288
|
+
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
289
|
+
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
|
290
|
+
|
291
|
+
# 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
292
|
+
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
293
|
+
|
294
|
+
return pred_post_sample, pred_original_sample
|
@@ -0,0 +1,250 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
from __future__ import annotations
|
33
|
+
|
34
|
+
import numpy as np
|
35
|
+
import torch
|
36
|
+
|
37
|
+
from monai.utils import StrEnum
|
38
|
+
|
39
|
+
from .scheduler import Scheduler
|
40
|
+
|
41
|
+
|
42
|
+
class DDPMVarianceType(StrEnum):
|
43
|
+
"""
|
44
|
+
Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise
|
45
|
+
to the denoised sample.
|
46
|
+
"""
|
47
|
+
|
48
|
+
FIXED_SMALL = "fixed_small"
|
49
|
+
FIXED_LARGE = "fixed_large"
|
50
|
+
LEARNED = "learned"
|
51
|
+
LEARNED_RANGE = "learned_range"
|
52
|
+
|
53
|
+
|
54
|
+
class DDPMPredictionType(StrEnum):
|
55
|
+
"""
|
56
|
+
Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument.
|
57
|
+
|
58
|
+
epsilon: predicting the noise of the diffusion process
|
59
|
+
sample: directly predicting the noisy sample
|
60
|
+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
|
61
|
+
"""
|
62
|
+
|
63
|
+
EPSILON = "epsilon"
|
64
|
+
SAMPLE = "sample"
|
65
|
+
V_PREDICTION = "v_prediction"
|
66
|
+
|
67
|
+
|
68
|
+
class DDPMScheduler(Scheduler):
|
69
|
+
"""
|
70
|
+
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
|
71
|
+
Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models"
|
72
|
+
https://arxiv.org/abs/2006.11239
|
73
|
+
|
74
|
+
Args:
|
75
|
+
num_train_timesteps: number of diffusion steps used to train the model.
|
76
|
+
schedule: member of NoiseSchedules, name of noise schedule function in component store
|
77
|
+
variance_type: member of DDPMVarianceType
|
78
|
+
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
|
79
|
+
prediction_type: member of DDPMPredictionType
|
80
|
+
clip_sample_min: minimum clipping value when clip_sample equals True
|
81
|
+
clip_sample_max: maximum clipping value when clip_sample equals True
|
82
|
+
schedule_args: arguments to pass to the schedule function
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
num_train_timesteps: int = 1000,
|
88
|
+
schedule: str = "linear_beta",
|
89
|
+
variance_type: str = DDPMVarianceType.FIXED_SMALL,
|
90
|
+
clip_sample: bool = True,
|
91
|
+
prediction_type: str = DDPMPredictionType.EPSILON,
|
92
|
+
clip_sample_min: float = -1.0,
|
93
|
+
clip_sample_max: float = 1.0,
|
94
|
+
**schedule_args,
|
95
|
+
) -> None:
|
96
|
+
super().__init__(num_train_timesteps, schedule, **schedule_args)
|
97
|
+
|
98
|
+
if variance_type not in DDPMVarianceType.__members__.values():
|
99
|
+
raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`")
|
100
|
+
|
101
|
+
if prediction_type not in DDPMPredictionType.__members__.values():
|
102
|
+
raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")
|
103
|
+
|
104
|
+
self.clip_sample = clip_sample
|
105
|
+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
|
106
|
+
self.variance_type = variance_type
|
107
|
+
self.prediction_type = prediction_type
|
108
|
+
|
109
|
+
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
110
|
+
"""
|
111
|
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
|
115
|
+
device: target device to put the data.
|
116
|
+
"""
|
117
|
+
if num_inference_steps > self.num_train_timesteps:
|
118
|
+
raise ValueError(
|
119
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
|
120
|
+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
121
|
+
f" maximal {self.num_train_timesteps} timesteps."
|
122
|
+
)
|
123
|
+
|
124
|
+
self.num_inference_steps = num_inference_steps
|
125
|
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
126
|
+
# creates integer timesteps by multiplying by ratio
|
127
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
128
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
|
129
|
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
130
|
+
|
131
|
+
def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
|
132
|
+
"""
|
133
|
+
Compute the mean of the posterior at timestep t.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
timestep: current timestep.
|
137
|
+
x0: the noise-free input.
|
138
|
+
x_t: the input noised to timestep t.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
Returns the mean
|
142
|
+
"""
|
143
|
+
# these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),
|
144
|
+
# (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)
|
145
|
+
alpha_t = self.alphas[timestep]
|
146
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
147
|
+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
|
148
|
+
|
149
|
+
x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)
|
150
|
+
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
|
151
|
+
|
152
|
+
mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t
|
153
|
+
|
154
|
+
return mean
|
155
|
+
|
156
|
+
def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor:
|
157
|
+
"""
|
158
|
+
Compute the variance of the posterior at timestep t.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
timestep: current timestep.
|
162
|
+
predicted_variance: variance predicted by the model.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
Returns the variance
|
166
|
+
"""
|
167
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
168
|
+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
|
169
|
+
|
170
|
+
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
171
|
+
# and sample from it to get previous sample
|
172
|
+
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
173
|
+
variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
|
174
|
+
# hacks - were probably added for training stability
|
175
|
+
if self.variance_type == DDPMVarianceType.FIXED_SMALL:
|
176
|
+
variance = torch.clamp(variance, min=1e-20)
|
177
|
+
elif self.variance_type == DDPMVarianceType.FIXED_LARGE:
|
178
|
+
variance = self.betas[timestep]
|
179
|
+
elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None:
|
180
|
+
return predicted_variance
|
181
|
+
elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None:
|
182
|
+
min_log = variance
|
183
|
+
max_log = self.betas[timestep]
|
184
|
+
frac = (predicted_variance + 1) / 2
|
185
|
+
variance = frac * max_log + (1 - frac) * min_log
|
186
|
+
|
187
|
+
return variance
|
188
|
+
|
189
|
+
def step(
|
190
|
+
self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None
|
191
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
192
|
+
"""
|
193
|
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
194
|
+
process from the learned model outputs (most often the predicted noise).
|
195
|
+
|
196
|
+
Args:
|
197
|
+
model_output: direct output from learned diffusion model.
|
198
|
+
timestep: current discrete timestep in the diffusion chain.
|
199
|
+
sample: current instance of sample being created by diffusion process.
|
200
|
+
generator: random number generator.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
pred_prev_sample: Predicted previous sample
|
204
|
+
"""
|
205
|
+
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
206
|
+
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
207
|
+
else:
|
208
|
+
predicted_variance = None
|
209
|
+
|
210
|
+
# 1. compute alphas, betas
|
211
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
212
|
+
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
|
213
|
+
beta_prod_t = 1 - alpha_prod_t
|
214
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
215
|
+
|
216
|
+
# 2. compute predicted original sample from predicted noise also called
|
217
|
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
218
|
+
if self.prediction_type == DDPMPredictionType.EPSILON:
|
219
|
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
220
|
+
elif self.prediction_type == DDPMPredictionType.SAMPLE:
|
221
|
+
pred_original_sample = model_output
|
222
|
+
elif self.prediction_type == DDPMPredictionType.V_PREDICTION:
|
223
|
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
224
|
+
|
225
|
+
# 3. Clip "predicted x_0"
|
226
|
+
if self.clip_sample:
|
227
|
+
pred_original_sample = torch.clamp(
|
228
|
+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
|
229
|
+
)
|
230
|
+
|
231
|
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
232
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
233
|
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
|
234
|
+
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
235
|
+
|
236
|
+
# 5. Compute predicted previous sample µ_t
|
237
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
238
|
+
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
239
|
+
|
240
|
+
# 6. Add noise
|
241
|
+
variance = 0
|
242
|
+
if timestep > 0:
|
243
|
+
noise = torch.randn(
|
244
|
+
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
|
245
|
+
).to(model_output.device)
|
246
|
+
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
|
247
|
+
|
248
|
+
pred_prev_sample = pred_prev_sample + variance
|
249
|
+
|
250
|
+
return pred_prev_sample, pred_original_sample
|