monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
from __future__ import annotations
|
33
|
+
|
34
|
+
from typing import Any
|
35
|
+
|
36
|
+
import numpy as np
|
37
|
+
import torch
|
38
|
+
|
39
|
+
from monai.utils import StrEnum
|
40
|
+
|
41
|
+
from .scheduler import Scheduler
|
42
|
+
|
43
|
+
|
44
|
+
class PNDMPredictionType(StrEnum):
|
45
|
+
"""
|
46
|
+
Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument.
|
47
|
+
|
48
|
+
epsilon: predicting the noise of the diffusion process
|
49
|
+
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
|
50
|
+
"""
|
51
|
+
|
52
|
+
EPSILON = "epsilon"
|
53
|
+
V_PREDICTION = "v_prediction"
|
54
|
+
|
55
|
+
|
56
|
+
class PNDMScheduler(Scheduler):
|
57
|
+
"""
|
58
|
+
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
|
59
|
+
namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al.,
|
60
|
+
"Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778
|
61
|
+
|
62
|
+
Args:
|
63
|
+
num_train_timesteps: number of diffusion steps used to train the model.
|
64
|
+
schedule: member of NoiseSchedules, name of noise schedule function in component store
|
65
|
+
skip_prk_steps:
|
66
|
+
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
|
67
|
+
before plms step.
|
68
|
+
set_alpha_to_one:
|
69
|
+
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
70
|
+
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
71
|
+
otherwise it uses the value of alpha at step 0.
|
72
|
+
prediction_type: member of DDPMPredictionType
|
73
|
+
steps_offset:
|
74
|
+
an offset added to the inference steps. You can use a combination of `offset=1` and
|
75
|
+
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
76
|
+
stable diffusion.
|
77
|
+
schedule_args: arguments to pass to the schedule function
|
78
|
+
"""
|
79
|
+
|
80
|
+
def __init__(
|
81
|
+
self,
|
82
|
+
num_train_timesteps: int = 1000,
|
83
|
+
schedule: str = "linear_beta",
|
84
|
+
skip_prk_steps: bool = False,
|
85
|
+
set_alpha_to_one: bool = False,
|
86
|
+
prediction_type: str = PNDMPredictionType.EPSILON,
|
87
|
+
steps_offset: int = 0,
|
88
|
+
**schedule_args,
|
89
|
+
) -> None:
|
90
|
+
super().__init__(num_train_timesteps, schedule, **schedule_args)
|
91
|
+
|
92
|
+
if prediction_type not in PNDMPredictionType.__members__.values():
|
93
|
+
raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType")
|
94
|
+
|
95
|
+
self.prediction_type = prediction_type
|
96
|
+
|
97
|
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
98
|
+
|
99
|
+
# standard deviation of the initial noise distribution
|
100
|
+
self.init_noise_sigma = 1.0
|
101
|
+
|
102
|
+
# For now we only support F-PNDM, i.e. the runge-kutta method
|
103
|
+
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
104
|
+
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
105
|
+
self.pndm_order = 4
|
106
|
+
|
107
|
+
self.skip_prk_steps = skip_prk_steps
|
108
|
+
self.steps_offset = steps_offset
|
109
|
+
|
110
|
+
# running values
|
111
|
+
self.cur_model_output = torch.Tensor()
|
112
|
+
self.counter = 0
|
113
|
+
self.cur_sample = torch.Tensor()
|
114
|
+
self.ets: list = []
|
115
|
+
|
116
|
+
# default the number of inference timesteps to the number of train steps
|
117
|
+
self.set_timesteps(num_train_timesteps)
|
118
|
+
|
119
|
+
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
120
|
+
"""
|
121
|
+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
|
125
|
+
device: target device to put the data.
|
126
|
+
"""
|
127
|
+
if num_inference_steps > self.num_train_timesteps:
|
128
|
+
raise ValueError(
|
129
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
|
130
|
+
f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
131
|
+
f" maximal {self.num_train_timesteps} timesteps."
|
132
|
+
)
|
133
|
+
|
134
|
+
self.num_inference_steps = num_inference_steps
|
135
|
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
136
|
+
# creates integer timesteps by multiplying by ratio
|
137
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
138
|
+
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
|
139
|
+
self._timesteps += self.steps_offset
|
140
|
+
|
141
|
+
if self.skip_prk_steps:
|
142
|
+
# for some models like stable diffusion the prk steps can/should be skipped to
|
143
|
+
# produce better results. When using PNDM with `self.skip_prk_steps` the implementation
|
144
|
+
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
|
145
|
+
self.prk_timesteps = np.array([])
|
146
|
+
self.plms_timesteps = self._timesteps[::-1]
|
147
|
+
|
148
|
+
else:
|
149
|
+
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
150
|
+
np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
151
|
+
)
|
152
|
+
self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
|
153
|
+
self.plms_timesteps = self._timesteps[:-3][
|
154
|
+
::-1
|
155
|
+
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
|
156
|
+
|
157
|
+
timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
|
158
|
+
self.timesteps = torch.from_numpy(timesteps).to(device)
|
159
|
+
# update num_inference_steps - necessary if we use prk steps
|
160
|
+
self.num_inference_steps = len(self.timesteps)
|
161
|
+
|
162
|
+
self.ets = []
|
163
|
+
self.counter = 0
|
164
|
+
|
165
|
+
def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]:
|
166
|
+
"""
|
167
|
+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
168
|
+
process from the learned model outputs (most often the predicted noise).
|
169
|
+
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
model_output: direct output from learned diffusion model.
|
173
|
+
timestep: current discrete timestep in the diffusion chain.
|
174
|
+
sample: current instance of sample being created by diffusion process.
|
175
|
+
Returns:
|
176
|
+
pred_prev_sample: Predicted previous sample
|
177
|
+
"""
|
178
|
+
# return a tuple for consistency with samplers that return (previous pred, original sample pred)
|
179
|
+
|
180
|
+
if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps:
|
181
|
+
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None
|
182
|
+
else:
|
183
|
+
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None
|
184
|
+
|
185
|
+
def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor:
|
186
|
+
"""
|
187
|
+
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
188
|
+
solution to the differential equation.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
model_output: direct output from learned diffusion model.
|
192
|
+
timestep: current discrete timestep in the diffusion chain.
|
193
|
+
sample: current instance of sample being created by diffusion process.
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
pred_prev_sample: Predicted previous sample
|
197
|
+
"""
|
198
|
+
if self.num_inference_steps is None:
|
199
|
+
raise ValueError(
|
200
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
201
|
+
)
|
202
|
+
|
203
|
+
diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2
|
204
|
+
prev_timestep = timestep - diff_to_prev
|
205
|
+
timestep = self.prk_timesteps[self.counter // 4 * 4]
|
206
|
+
|
207
|
+
if self.counter % 4 == 0:
|
208
|
+
self.cur_model_output = 1 / 6 * model_output
|
209
|
+
self.ets.append(model_output)
|
210
|
+
self.cur_sample = sample
|
211
|
+
elif (self.counter - 1) % 4 == 0:
|
212
|
+
self.cur_model_output += 1 / 3 * model_output
|
213
|
+
elif (self.counter - 2) % 4 == 0:
|
214
|
+
self.cur_model_output += 1 / 3 * model_output
|
215
|
+
elif (self.counter - 3) % 4 == 0:
|
216
|
+
model_output = self.cur_model_output + 1 / 6 * model_output
|
217
|
+
self.cur_model_output = torch.Tensor()
|
218
|
+
|
219
|
+
# cur_sample should not be an empty torch.Tensor()
|
220
|
+
cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample
|
221
|
+
|
222
|
+
prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
223
|
+
self.counter += 1
|
224
|
+
|
225
|
+
return prev_sample
|
226
|
+
|
227
|
+
def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any:
|
228
|
+
"""
|
229
|
+
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
230
|
+
times to approximate the solution.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
model_output: direct output from learned diffusion model.
|
234
|
+
timestep: current discrete timestep in the diffusion chain.
|
235
|
+
sample: current instance of sample being created by diffusion process.
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
pred_prev_sample: Predicted previous sample
|
239
|
+
"""
|
240
|
+
if self.num_inference_steps is None:
|
241
|
+
raise ValueError(
|
242
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
243
|
+
)
|
244
|
+
|
245
|
+
if not self.skip_prk_steps and len(self.ets) < 3:
|
246
|
+
raise ValueError(
|
247
|
+
f"{self.__class__} can only be run AFTER scheduler has been run "
|
248
|
+
"in 'prk' mode for at least 12 iterations "
|
249
|
+
)
|
250
|
+
|
251
|
+
prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
|
252
|
+
|
253
|
+
if self.counter != 1:
|
254
|
+
self.ets = self.ets[-3:]
|
255
|
+
self.ets.append(model_output)
|
256
|
+
else:
|
257
|
+
prev_timestep = timestep
|
258
|
+
timestep = timestep + self.num_train_timesteps // self.num_inference_steps
|
259
|
+
|
260
|
+
if len(self.ets) == 1 and self.counter == 0:
|
261
|
+
model_output = model_output
|
262
|
+
self.cur_sample = sample
|
263
|
+
elif len(self.ets) == 1 and self.counter == 1:
|
264
|
+
model_output = (model_output + self.ets[-1]) / 2
|
265
|
+
sample = self.cur_sample
|
266
|
+
self.cur_sample = torch.Tensor()
|
267
|
+
elif len(self.ets) == 2:
|
268
|
+
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
|
269
|
+
elif len(self.ets) == 3:
|
270
|
+
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
|
271
|
+
else:
|
272
|
+
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
273
|
+
|
274
|
+
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
275
|
+
self.counter += 1
|
276
|
+
|
277
|
+
return prev_sample
|
278
|
+
|
279
|
+
def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor):
|
280
|
+
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
281
|
+
# this function computes x_(t−δ) using the formula of (9)
|
282
|
+
# Note that x_t needs to be added to both sides of the equation
|
283
|
+
|
284
|
+
# Notation (<variable name> -> <name in paper>
|
285
|
+
# alpha_prod_t -> α_t
|
286
|
+
# alpha_prod_t_prev -> α_(t−δ)
|
287
|
+
# beta_prod_t -> (1 - α_t)
|
288
|
+
# beta_prod_t_prev -> (1 - α_(t−δ))
|
289
|
+
# sample -> x_t
|
290
|
+
# model_output -> e_θ(x_t, t)
|
291
|
+
# prev_sample -> x_(t−δ)
|
292
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
293
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
294
|
+
beta_prod_t = 1 - alpha_prod_t
|
295
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
296
|
+
|
297
|
+
if self.prediction_type == PNDMPredictionType.V_PREDICTION:
|
298
|
+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
299
|
+
|
300
|
+
# corresponds to (α_(t−δ) - α_t) divided by
|
301
|
+
# denominator of x_t in formula (9) and plus 1
|
302
|
+
# Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
|
303
|
+
# sqrt(α_(t−δ)) / sqrt(α_t))
|
304
|
+
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
|
305
|
+
|
306
|
+
# corresponds to denominator of e_θ(x_t, t) in formula (9)
|
307
|
+
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
|
308
|
+
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
|
309
|
+
) ** (0.5)
|
310
|
+
|
311
|
+
# full formula (9)
|
312
|
+
prev_sample = (
|
313
|
+
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
|
314
|
+
)
|
315
|
+
|
316
|
+
return prev_sample
|
@@ -0,0 +1,205 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
|
33
|
+
from __future__ import annotations
|
34
|
+
|
35
|
+
import torch
|
36
|
+
import torch.nn as nn
|
37
|
+
|
38
|
+
from monai.utils import ComponentStore, unsqueeze_right
|
39
|
+
|
40
|
+
NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules")
|
41
|
+
|
42
|
+
|
43
|
+
@NoiseSchedules.add_def("linear_beta", "Linear beta schedule")
|
44
|
+
def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
|
45
|
+
"""
|
46
|
+
Linear beta noise schedule function.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
num_train_timesteps: number of timesteps
|
50
|
+
beta_start: start of beta range, default 1e-4
|
51
|
+
beta_end: end of beta range, default 2e-2
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
betas: beta schedule tensor
|
55
|
+
"""
|
56
|
+
return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
57
|
+
|
58
|
+
|
59
|
+
@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule")
|
60
|
+
def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
|
61
|
+
"""
|
62
|
+
Scaled linear beta noise schedule function.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
num_train_timesteps: number of timesteps
|
66
|
+
beta_start: start of beta range, default 1e-4
|
67
|
+
beta_end: end of beta range, default 2e-2
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
betas: beta schedule tensor
|
71
|
+
"""
|
72
|
+
return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
73
|
+
|
74
|
+
|
75
|
+
@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule")
|
76
|
+
def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6):
|
77
|
+
"""
|
78
|
+
Sigmoid beta noise schedule function.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
num_train_timesteps: number of timesteps
|
82
|
+
beta_start: start of beta range, default 1e-4
|
83
|
+
beta_end: end of beta range, default 2e-2
|
84
|
+
sig_range: pos/neg range of sigmoid input, default 6
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
betas: beta schedule tensor
|
88
|
+
"""
|
89
|
+
betas = torch.linspace(-sig_range, sig_range, num_train_timesteps)
|
90
|
+
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
91
|
+
|
92
|
+
|
93
|
+
@NoiseSchedules.add_def("cosine", "Cosine schedule")
|
94
|
+
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
|
95
|
+
"""
|
96
|
+
Cosine noise schedule, see https://arxiv.org/abs/2102.09672
|
97
|
+
|
98
|
+
Args:
|
99
|
+
num_train_timesteps: number of timesteps
|
100
|
+
s: smoothing factor, default 8e-3 (see referenced paper)
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
(betas, alphas, alpha_cumprod) values
|
104
|
+
"""
|
105
|
+
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
|
106
|
+
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
107
|
+
alphas_cumprod /= alphas_cumprod[0].item()
|
108
|
+
alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
|
109
|
+
betas = 1.0 - alphas
|
110
|
+
return betas, alphas, alphas_cumprod[:-1]
|
111
|
+
|
112
|
+
|
113
|
+
class Scheduler(nn.Module):
|
114
|
+
"""
|
115
|
+
Base class for other schedulers based on a noise schedule function.
|
116
|
+
|
117
|
+
This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here
|
118
|
+
the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`,
|
119
|
+
which is the name of a component in NoiseSchedules. These components must all be callables which return either
|
120
|
+
the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions
|
121
|
+
can be provided by using the NoiseSchedules.add_def, for example:
|
122
|
+
|
123
|
+
.. code-block:: python
|
124
|
+
|
125
|
+
from monai.networks.schedulers import NoiseSchedules, DDPMScheduler
|
126
|
+
|
127
|
+
@NoiseSchedules.add_def("my_beta_schedule", "Some description of your function")
|
128
|
+
def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2):
|
129
|
+
return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
130
|
+
|
131
|
+
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule")
|
132
|
+
|
133
|
+
All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of
|
134
|
+
timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through
|
135
|
+
the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules
|
136
|
+
to get a listing of stored objects with their docstring descriptions.
|
137
|
+
|
138
|
+
Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule
|
139
|
+
type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended
|
140
|
+
to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are
|
141
|
+
still used for some schedules but these are provided as keyword arguments now.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
num_train_timesteps: number of diffusion steps used to train the model.
|
145
|
+
schedule: member of NoiseSchedules,
|
146
|
+
a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple
|
147
|
+
schedule_args: arguments to pass to the schedule function
|
148
|
+
"""
|
149
|
+
|
150
|
+
def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None:
|
151
|
+
super().__init__()
|
152
|
+
schedule_args["num_train_timesteps"] = num_train_timesteps
|
153
|
+
noise_sched = NoiseSchedules[schedule](**schedule_args)
|
154
|
+
|
155
|
+
# set betas, alphas, alphas_cumprod based off return value from noise function
|
156
|
+
if isinstance(noise_sched, tuple):
|
157
|
+
self.betas, self.alphas, self.alphas_cumprod = noise_sched
|
158
|
+
else:
|
159
|
+
self.betas = noise_sched
|
160
|
+
self.alphas = 1.0 - self.betas
|
161
|
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
162
|
+
|
163
|
+
self.num_train_timesteps = num_train_timesteps
|
164
|
+
self.one = torch.tensor(1.0)
|
165
|
+
|
166
|
+
# settable values
|
167
|
+
self.num_inference_steps: int | None = None
|
168
|
+
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
|
169
|
+
|
170
|
+
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
171
|
+
"""
|
172
|
+
Add noise to the original samples.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
original_samples: original samples
|
176
|
+
noise: noise to add to samples
|
177
|
+
timesteps: timesteps tensor indicating the timestep to be computed for each sample.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
noisy_samples: sample with added noise
|
181
|
+
"""
|
182
|
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
183
|
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
184
|
+
timesteps = timesteps.to(original_samples.device)
|
185
|
+
|
186
|
+
sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim)
|
187
|
+
sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
|
188
|
+
(1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim
|
189
|
+
)
|
190
|
+
|
191
|
+
noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
|
192
|
+
return noisy_samples
|
193
|
+
|
194
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
195
|
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
196
|
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
197
|
+
timesteps = timesteps.to(sample.device)
|
198
|
+
|
199
|
+
sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim)
|
200
|
+
sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
|
201
|
+
(1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim
|
202
|
+
)
|
203
|
+
|
204
|
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
205
|
+
return velocity
|
monai/networks/utils.py
CHANGED
@@ -42,6 +42,7 @@ __all__ = [
|
|
42
42
|
"predict_segmentation",
|
43
43
|
"normalize_transform",
|
44
44
|
"to_norm_affine",
|
45
|
+
"CastTempType",
|
45
46
|
"normal_init",
|
46
47
|
"icnr_init",
|
47
48
|
"pixelshuffle",
|
@@ -1167,3 +1168,24 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None):
|
|
1167
1168
|
warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.")
|
1168
1169
|
|
1169
1170
|
logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.")
|
1171
|
+
|
1172
|
+
|
1173
|
+
class CastTempType(nn.Module):
|
1174
|
+
"""
|
1175
|
+
Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type.
|
1176
|
+
"""
|
1177
|
+
|
1178
|
+
def __init__(self, initial_type, temporary_type, submodule):
|
1179
|
+
super().__init__()
|
1180
|
+
self.initial_type = initial_type
|
1181
|
+
self.temporary_type = temporary_type
|
1182
|
+
self.submodule = submodule
|
1183
|
+
|
1184
|
+
def forward(self, x):
|
1185
|
+
dtype = x.dtype
|
1186
|
+
if dtype == self.initial_type:
|
1187
|
+
x = x.to(self.temporary_type)
|
1188
|
+
x = self.submodule(x)
|
1189
|
+
if dtype == self.initial_type:
|
1190
|
+
x = x.to(self.initial_type)
|
1191
|
+
return x
|
@@ -362,10 +362,10 @@ class Crop(InvertibleTransform, LazyTransform):
|
|
362
362
|
|
363
363
|
@staticmethod
|
364
364
|
def compute_slices(
|
365
|
-
roi_center: Sequence[int] | NdarrayOrTensor | None = None,
|
366
|
-
roi_size: Sequence[int] | NdarrayOrTensor | None = None,
|
367
|
-
roi_start: Sequence[int] | NdarrayOrTensor | None = None,
|
368
|
-
roi_end: Sequence[int] | NdarrayOrTensor | None = None,
|
365
|
+
roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
|
366
|
+
roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
|
367
|
+
roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
|
368
|
+
roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
|
369
369
|
roi_slices: Sequence[slice] | None = None,
|
370
370
|
) -> tuple[slice]:
|
371
371
|
"""
|
@@ -459,10 +459,10 @@ class SpatialCrop(Crop):
|
|
459
459
|
|
460
460
|
def __init__(
|
461
461
|
self,
|
462
|
-
roi_center: Sequence[int] | NdarrayOrTensor | None = None,
|
463
|
-
roi_size: Sequence[int] | NdarrayOrTensor | None = None,
|
464
|
-
roi_start: Sequence[int] | NdarrayOrTensor | None = None,
|
465
|
-
roi_end: Sequence[int] | NdarrayOrTensor | None = None,
|
462
|
+
roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
|
463
|
+
roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
|
464
|
+
roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
|
465
|
+
roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
|
466
466
|
roi_slices: Sequence[slice] | None = None,
|
467
467
|
lazy: bool = False,
|
468
468
|
) -> None:
|
@@ -438,10 +438,10 @@ class SpatialCropd(Cropd):
|
|
438
438
|
def __init__(
|
439
439
|
self,
|
440
440
|
keys: KeysCollection,
|
441
|
-
roi_center: Sequence[int] | None = None,
|
442
|
-
roi_size: Sequence[int] | None = None,
|
443
|
-
roi_start: Sequence[int] | None = None,
|
444
|
-
roi_end: Sequence[int] | None = None,
|
441
|
+
roi_center: Sequence[int] | int | None = None,
|
442
|
+
roi_size: Sequence[int] | int | None = None,
|
443
|
+
roi_start: Sequence[int] | int | None = None,
|
444
|
+
roi_end: Sequence[int] | int | None = None,
|
445
445
|
roi_slices: Sequence[slice] | None = None,
|
446
446
|
allow_missing_keys: bool = False,
|
447
447
|
lazy: bool = False,
|
@@ -48,7 +48,7 @@ def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k
|
|
48
48
|
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
|
49
49
|
img_np = img.detach().cpu().numpy()
|
50
50
|
else:
|
51
|
-
img_np = img
|
51
|
+
img_np = np.asarray(img)
|
52
52
|
mode = convert_pad_mode(dst=img_np, mode=mode).value
|
53
53
|
if mode == "constant" and "value" in kwargs:
|
54
54
|
kwargs["constant_values"] = kwargs.pop("value")
|
@@ -87,12 +87,14 @@ class MixUp(Mixer):
|
|
87
87
|
|
88
88
|
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
|
89
89
|
data_t = convert_to_tensor(data, track_meta=get_track_meta())
|
90
|
+
labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy
|
90
91
|
if labels is not None:
|
91
92
|
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
|
92
93
|
if randomize:
|
93
94
|
self.randomize()
|
94
95
|
if labels is None:
|
95
96
|
return convert_to_dst_type(self.apply(data_t), dst=data)[0]
|
97
|
+
|
96
98
|
return (
|
97
99
|
convert_to_dst_type(self.apply(data_t), dst=data)[0],
|
98
100
|
convert_to_dst_type(self.apply(labels_t), dst=labels)[0],
|
@@ -149,11 +151,13 @@ class CutMix(Mixer):
|
|
149
151
|
|
150
152
|
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
|
151
153
|
data_t = convert_to_tensor(data, track_meta=get_track_meta())
|
154
|
+
augmented_label = None
|
152
155
|
if labels is not None:
|
153
156
|
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
|
154
157
|
if randomize:
|
155
158
|
self.randomize(data)
|
156
159
|
augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]
|
160
|
+
|
157
161
|
if labels is not None:
|
158
162
|
augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]
|
159
163
|
return (augmented, augmented_label) if labels is not None else augmented
|
@@ -3441,7 +3441,7 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
|
|
3441
3441
|
idx = self.R.permutation(image_np.shape[0])
|
3442
3442
|
idx = idx[: self.num_patches]
|
3443
3443
|
idx_np = convert_data_type(idx, np.ndarray)[0]
|
3444
|
-
image_np = image_np[idx]
|
3444
|
+
image_np = image_np[idx] # type: ignore[index]
|
3445
3445
|
locations = locations[idx_np]
|
3446
3446
|
return image_np, locations
|
3447
3447
|
elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):
|
@@ -269,11 +269,9 @@ def update_docstring(code_path, transform_name):
|
|
269
269
|
|
270
270
|
|
271
271
|
def pre_process_data(data, ndim, is_map, is_post):
|
272
|
-
"""If transform requires 2D data, then convert to 2D"""
|
272
|
+
"""If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension."""
|
273
273
|
if ndim == 2:
|
274
|
-
for k in
|
275
|
-
data[k] = data[k][..., data[k].shape[-1] // 2]
|
276
|
-
|
274
|
+
data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()}
|
277
275
|
if is_map:
|
278
276
|
return data
|
279
277
|
return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE]
|