monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ import numpy as np
35
+ import torch
36
+
37
+ from .ddpm import DDPMPredictionType
38
+ from .scheduler import Scheduler
39
+
40
+ DDIMPredictionType = DDPMPredictionType
41
+
42
+
43
+ class DDIMScheduler(Scheduler):
44
+ """
45
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
46
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
47
+ Implicit Models" https://arxiv.org/abs/2010.02502
48
+
49
+ Args:
50
+ num_train_timesteps: number of diffusion steps used to train the model.
51
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
52
+ clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
53
+ set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
54
+ For the final step there is no previous alpha. When this option is `True` the previous alpha product is
55
+ fixed to `1`, otherwise it uses the value of alpha at step 0.
56
+ steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
57
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
58
+ stable diffusion.
59
+ prediction_type: member of DDPMPredictionType
60
+ clip_sample_min: minimum clipping value when clip_sample equals True
61
+ clip_sample_max: maximum clipping value when clip_sample equals True
62
+ schedule_args: arguments to pass to the schedule function
63
+
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ num_train_timesteps: int = 1000,
69
+ schedule: str = "linear_beta",
70
+ clip_sample: bool = True,
71
+ set_alpha_to_one: bool = True,
72
+ steps_offset: int = 0,
73
+ prediction_type: str = DDIMPredictionType.EPSILON,
74
+ clip_sample_min: float = -1.0,
75
+ clip_sample_max: float = 1.0,
76
+ **schedule_args,
77
+ ) -> None:
78
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
79
+
80
+ if prediction_type not in DDIMPredictionType.__members__.values():
81
+ raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")
82
+
83
+ self.prediction_type = prediction_type
84
+
85
+ # At every step in ddim, we are looking into the previous alphas_cumprod
86
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
87
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
88
+ # whether we use the final alpha of the "non-previous" one.
89
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
90
+
91
+ # standard deviation of the initial noise distribution
92
+ self.init_noise_sigma = 1.0
93
+
94
+ self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
95
+
96
+ self.clip_sample = clip_sample
97
+ self.clip_sample_values = [clip_sample_min, clip_sample_max]
98
+ self.steps_offset = steps_offset
99
+
100
+ # default the number of inference timesteps to the number of train steps
101
+ self.num_inference_steps: int
102
+ self.set_timesteps(self.num_train_timesteps)
103
+
104
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
105
+ """
106
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
107
+
108
+ Args:
109
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
110
+ device: target device to put the data.
111
+ """
112
+ if num_inference_steps > self.num_train_timesteps:
113
+ raise ValueError(
114
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
115
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
116
+ f" maximal {self.num_train_timesteps} timesteps."
117
+ )
118
+
119
+ self.num_inference_steps = num_inference_steps
120
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
121
+ if self.steps_offset >= step_ratio:
122
+ raise ValueError(
123
+ f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
124
+ f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
125
+ f" the max train timestep."
126
+ )
127
+
128
+ # creates integer timesteps by multiplying by ratio
129
+ # casting to int to avoid issues when num_inference_step is power of 3
130
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
131
+ self.timesteps = torch.from_numpy(timesteps).to(device)
132
+ self.timesteps += self.steps_offset
133
+
134
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
135
+ alpha_prod_t = self.alphas_cumprod[timestep]
136
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
137
+ beta_prod_t = 1 - alpha_prod_t
138
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
139
+
140
+ variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
141
+
142
+ return variance
143
+
144
+ def step(
145
+ self,
146
+ model_output: torch.Tensor,
147
+ timestep: int,
148
+ sample: torch.Tensor,
149
+ eta: float = 0.0,
150
+ generator: torch.Generator | None = None,
151
+ ) -> tuple[torch.Tensor, torch.Tensor]:
152
+ """
153
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
154
+ process from the learned model outputs (most often the predicted noise).
155
+
156
+ Args:
157
+ model_output: direct output from learned diffusion model.
158
+ timestep: current discrete timestep in the diffusion chain.
159
+ sample: current instance of sample being created by diffusion process.
160
+ eta: weight of noise for added noise in diffusion step.
161
+ generator: random number generator.
162
+
163
+ Returns:
164
+ pred_prev_sample: Predicted previous sample
165
+ pred_original_sample: Predicted original sample
166
+ """
167
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
168
+ # Ideally, read DDIM paper in-detail understanding
169
+
170
+ # Notation (<variable name> -> <name in paper>
171
+ # - model_output -> e_theta(x_t, t)
172
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
173
+ # - std_dev_t -> sigma_t
174
+ # - eta -> η
175
+ # - pred_sample_direction -> "direction pointing to x_t"
176
+ # - pred_prev_sample -> "x_t-1"
177
+
178
+ # 1. get previous step value (=t-1)
179
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
180
+
181
+ # 2. compute alphas, betas
182
+ alpha_prod_t = self.alphas_cumprod[timestep]
183
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
184
+
185
+ beta_prod_t = 1 - alpha_prod_t
186
+
187
+ # predefinitions satisfy pylint/mypy, these values won't be ultimately used
188
+ pred_original_sample = sample
189
+ pred_epsilon = model_output
190
+
191
+ # 3. compute predicted original sample from predicted noise also called
192
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
193
+ if self.prediction_type == DDIMPredictionType.EPSILON:
194
+ pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
195
+ pred_epsilon = model_output
196
+ elif self.prediction_type == DDIMPredictionType.SAMPLE:
197
+ pred_original_sample = model_output
198
+ pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)
199
+ elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
200
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
201
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
202
+
203
+ # 4. Clip "predicted x_0"
204
+ if self.clip_sample:
205
+ pred_original_sample = torch.clamp(
206
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
207
+ )
208
+
209
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
210
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
211
+ variance = self._get_variance(timestep, prev_timestep)
212
+ std_dev_t = eta * variance**0.5
213
+
214
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
215
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
216
+
217
+ # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
218
+ pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
219
+
220
+ if eta > 0:
221
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
222
+ device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu")
223
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
224
+ variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise
225
+
226
+ pred_prev_sample = pred_prev_sample + variance
227
+
228
+ return pred_prev_sample, pred_original_sample
229
+
230
+ def reversed_step(
231
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor
232
+ ) -> tuple[torch.Tensor, torch.Tensor]:
233
+ """
234
+ Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion
235
+ process from the learned model outputs (most often the predicted noise).
236
+
237
+ Args:
238
+ model_output: direct output from learned diffusion model.
239
+ timestep: current discrete timestep in the diffusion chain.
240
+ sample: current instance of sample being created by diffusion process.
241
+
242
+ Returns:
243
+ pred_prev_sample: Predicted previous sample
244
+ pred_original_sample: Predicted original sample
245
+ """
246
+ # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
247
+
248
+ # Notation (<variable name> -> <name in paper>
249
+ # - model_output -> e_theta(x_t, t)
250
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
251
+ # - std_dev_t -> sigma_t
252
+ # - eta -> η
253
+ # - pred_sample_direction -> "direction pointing to x_t"
254
+ # - pred_post_sample -> "x_t+1"
255
+
256
+ # 1. get previous step value (=t+1)
257
+ prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
258
+
259
+ # 2. compute alphas, betas at timestep t+1
260
+ alpha_prod_t = self.alphas_cumprod[timestep]
261
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
262
+
263
+ beta_prod_t = 1 - alpha_prod_t
264
+
265
+ # predefinitions satisfy pylint/mypy, these values won't be ultimately used
266
+ pred_original_sample = sample
267
+ pred_epsilon = model_output
268
+
269
+ # 3. compute predicted original sample from predicted noise also called
270
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
271
+
272
+ if self.prediction_type == DDIMPredictionType.EPSILON:
273
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
274
+ pred_epsilon = model_output
275
+ elif self.prediction_type == DDIMPredictionType.SAMPLE:
276
+ pred_original_sample = model_output
277
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
278
+ elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
279
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
280
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
281
+
282
+ # 4. Clip "predicted x_0"
283
+ if self.clip_sample:
284
+ pred_original_sample = torch.clamp(
285
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
286
+ )
287
+
288
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
289
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
290
+
291
+ # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
292
+ pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
293
+
294
+ return pred_post_sample, pred_original_sample
@@ -0,0 +1,250 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ import numpy as np
35
+ import torch
36
+
37
+ from monai.utils import StrEnum
38
+
39
+ from .scheduler import Scheduler
40
+
41
+
42
+ class DDPMVarianceType(StrEnum):
43
+ """
44
+ Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise
45
+ to the denoised sample.
46
+ """
47
+
48
+ FIXED_SMALL = "fixed_small"
49
+ FIXED_LARGE = "fixed_large"
50
+ LEARNED = "learned"
51
+ LEARNED_RANGE = "learned_range"
52
+
53
+
54
+ class DDPMPredictionType(StrEnum):
55
+ """
56
+ Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument.
57
+
58
+ epsilon: predicting the noise of the diffusion process
59
+ sample: directly predicting the noisy sample
60
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
61
+ """
62
+
63
+ EPSILON = "epsilon"
64
+ SAMPLE = "sample"
65
+ V_PREDICTION = "v_prediction"
66
+
67
+
68
+ class DDPMScheduler(Scheduler):
69
+ """
70
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
71
+ Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models"
72
+ https://arxiv.org/abs/2006.11239
73
+
74
+ Args:
75
+ num_train_timesteps: number of diffusion steps used to train the model.
76
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
77
+ variance_type: member of DDPMVarianceType
78
+ clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
79
+ prediction_type: member of DDPMPredictionType
80
+ clip_sample_min: minimum clipping value when clip_sample equals True
81
+ clip_sample_max: maximum clipping value when clip_sample equals True
82
+ schedule_args: arguments to pass to the schedule function
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ num_train_timesteps: int = 1000,
88
+ schedule: str = "linear_beta",
89
+ variance_type: str = DDPMVarianceType.FIXED_SMALL,
90
+ clip_sample: bool = True,
91
+ prediction_type: str = DDPMPredictionType.EPSILON,
92
+ clip_sample_min: float = -1.0,
93
+ clip_sample_max: float = 1.0,
94
+ **schedule_args,
95
+ ) -> None:
96
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
97
+
98
+ if variance_type not in DDPMVarianceType.__members__.values():
99
+ raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`")
100
+
101
+ if prediction_type not in DDPMPredictionType.__members__.values():
102
+ raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")
103
+
104
+ self.clip_sample = clip_sample
105
+ self.clip_sample_values = [clip_sample_min, clip_sample_max]
106
+ self.variance_type = variance_type
107
+ self.prediction_type = prediction_type
108
+
109
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
110
+ """
111
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
112
+
113
+ Args:
114
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
115
+ device: target device to put the data.
116
+ """
117
+ if num_inference_steps > self.num_train_timesteps:
118
+ raise ValueError(
119
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
120
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
121
+ f" maximal {self.num_train_timesteps} timesteps."
122
+ )
123
+
124
+ self.num_inference_steps = num_inference_steps
125
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
126
+ # creates integer timesteps by multiplying by ratio
127
+ # casting to int to avoid issues when num_inference_step is power of 3
128
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
129
+ self.timesteps = torch.from_numpy(timesteps).to(device)
130
+
131
+ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Compute the mean of the posterior at timestep t.
134
+
135
+ Args:
136
+ timestep: current timestep.
137
+ x0: the noise-free input.
138
+ x_t: the input noised to timestep t.
139
+
140
+ Returns:
141
+ Returns the mean
142
+ """
143
+ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),
144
+ # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)
145
+ alpha_t = self.alphas[timestep]
146
+ alpha_prod_t = self.alphas_cumprod[timestep]
147
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
148
+
149
+ x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)
150
+ x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
151
+
152
+ mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t
153
+
154
+ return mean
155
+
156
+ def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor:
157
+ """
158
+ Compute the variance of the posterior at timestep t.
159
+
160
+ Args:
161
+ timestep: current timestep.
162
+ predicted_variance: variance predicted by the model.
163
+
164
+ Returns:
165
+ Returns the variance
166
+ """
167
+ alpha_prod_t = self.alphas_cumprod[timestep]
168
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
169
+
170
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
171
+ # and sample from it to get previous sample
172
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
173
+ variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
174
+ # hacks - were probably added for training stability
175
+ if self.variance_type == DDPMVarianceType.FIXED_SMALL:
176
+ variance = torch.clamp(variance, min=1e-20)
177
+ elif self.variance_type == DDPMVarianceType.FIXED_LARGE:
178
+ variance = self.betas[timestep]
179
+ elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None:
180
+ return predicted_variance
181
+ elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None:
182
+ min_log = variance
183
+ max_log = self.betas[timestep]
184
+ frac = (predicted_variance + 1) / 2
185
+ variance = frac * max_log + (1 - frac) * min_log
186
+
187
+ return variance
188
+
189
+ def step(
190
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None
191
+ ) -> tuple[torch.Tensor, torch.Tensor]:
192
+ """
193
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
194
+ process from the learned model outputs (most often the predicted noise).
195
+
196
+ Args:
197
+ model_output: direct output from learned diffusion model.
198
+ timestep: current discrete timestep in the diffusion chain.
199
+ sample: current instance of sample being created by diffusion process.
200
+ generator: random number generator.
201
+
202
+ Returns:
203
+ pred_prev_sample: Predicted previous sample
204
+ """
205
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
206
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
207
+ else:
208
+ predicted_variance = None
209
+
210
+ # 1. compute alphas, betas
211
+ alpha_prod_t = self.alphas_cumprod[timestep]
212
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
213
+ beta_prod_t = 1 - alpha_prod_t
214
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
215
+
216
+ # 2. compute predicted original sample from predicted noise also called
217
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
218
+ if self.prediction_type == DDPMPredictionType.EPSILON:
219
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
220
+ elif self.prediction_type == DDPMPredictionType.SAMPLE:
221
+ pred_original_sample = model_output
222
+ elif self.prediction_type == DDPMPredictionType.V_PREDICTION:
223
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
224
+
225
+ # 3. Clip "predicted x_0"
226
+ if self.clip_sample:
227
+ pred_original_sample = torch.clamp(
228
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
229
+ )
230
+
231
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
232
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
233
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
234
+ current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
235
+
236
+ # 5. Compute predicted previous sample µ_t
237
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
238
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
239
+
240
+ # 6. Add noise
241
+ variance = 0
242
+ if timestep > 0:
243
+ noise = torch.randn(
244
+ model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
245
+ ).to(model_output.device)
246
+ variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
247
+
248
+ pred_prev_sample = pred_prev_sample + variance
249
+
250
+ return pred_prev_sample, pred_original_sample