monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+ from __future__ import annotations
33
+
34
+ from typing import Any
35
+
36
+ import numpy as np
37
+ import torch
38
+
39
+ from monai.utils import StrEnum
40
+
41
+ from .scheduler import Scheduler
42
+
43
+
44
+ class PNDMPredictionType(StrEnum):
45
+ """
46
+ Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument.
47
+
48
+ epsilon: predicting the noise of the diffusion process
49
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
50
+ """
51
+
52
+ EPSILON = "epsilon"
53
+ V_PREDICTION = "v_prediction"
54
+
55
+
56
+ class PNDMScheduler(Scheduler):
57
+ """
58
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
59
+ namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al.,
60
+ "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778
61
+
62
+ Args:
63
+ num_train_timesteps: number of diffusion steps used to train the model.
64
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
65
+ skip_prk_steps:
66
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
67
+ before plms step.
68
+ set_alpha_to_one:
69
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
70
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
71
+ otherwise it uses the value of alpha at step 0.
72
+ prediction_type: member of DDPMPredictionType
73
+ steps_offset:
74
+ an offset added to the inference steps. You can use a combination of `offset=1` and
75
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
76
+ stable diffusion.
77
+ schedule_args: arguments to pass to the schedule function
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ num_train_timesteps: int = 1000,
83
+ schedule: str = "linear_beta",
84
+ skip_prk_steps: bool = False,
85
+ set_alpha_to_one: bool = False,
86
+ prediction_type: str = PNDMPredictionType.EPSILON,
87
+ steps_offset: int = 0,
88
+ **schedule_args,
89
+ ) -> None:
90
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
91
+
92
+ if prediction_type not in PNDMPredictionType.__members__.values():
93
+ raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType")
94
+
95
+ self.prediction_type = prediction_type
96
+
97
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
98
+
99
+ # standard deviation of the initial noise distribution
100
+ self.init_noise_sigma = 1.0
101
+
102
+ # For now we only support F-PNDM, i.e. the runge-kutta method
103
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
104
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
105
+ self.pndm_order = 4
106
+
107
+ self.skip_prk_steps = skip_prk_steps
108
+ self.steps_offset = steps_offset
109
+
110
+ # running values
111
+ self.cur_model_output = torch.Tensor()
112
+ self.counter = 0
113
+ self.cur_sample = torch.Tensor()
114
+ self.ets: list = []
115
+
116
+ # default the number of inference timesteps to the number of train steps
117
+ self.set_timesteps(num_train_timesteps)
118
+
119
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
120
+ """
121
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
122
+
123
+ Args:
124
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
125
+ device: target device to put the data.
126
+ """
127
+ if num_inference_steps > self.num_train_timesteps:
128
+ raise ValueError(
129
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
130
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
131
+ f" maximal {self.num_train_timesteps} timesteps."
132
+ )
133
+
134
+ self.num_inference_steps = num_inference_steps
135
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
136
+ # creates integer timesteps by multiplying by ratio
137
+ # casting to int to avoid issues when num_inference_step is power of 3
138
+ self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
139
+ self._timesteps += self.steps_offset
140
+
141
+ if self.skip_prk_steps:
142
+ # for some models like stable diffusion the prk steps can/should be skipped to
143
+ # produce better results. When using PNDM with `self.skip_prk_steps` the implementation
144
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
145
+ self.prk_timesteps = np.array([])
146
+ self.plms_timesteps = self._timesteps[::-1]
147
+
148
+ else:
149
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
150
+ np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
151
+ )
152
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
153
+ self.plms_timesteps = self._timesteps[:-3][
154
+ ::-1
155
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
156
+
157
+ timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
158
+ self.timesteps = torch.from_numpy(timesteps).to(device)
159
+ # update num_inference_steps - necessary if we use prk steps
160
+ self.num_inference_steps = len(self.timesteps)
161
+
162
+ self.ets = []
163
+ self.counter = 0
164
+
165
+ def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]:
166
+ """
167
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
168
+ process from the learned model outputs (most often the predicted noise).
169
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
170
+
171
+ Args:
172
+ model_output: direct output from learned diffusion model.
173
+ timestep: current discrete timestep in the diffusion chain.
174
+ sample: current instance of sample being created by diffusion process.
175
+ Returns:
176
+ pred_prev_sample: Predicted previous sample
177
+ """
178
+ # return a tuple for consistency with samplers that return (previous pred, original sample pred)
179
+
180
+ if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps:
181
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None
182
+ else:
183
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None
184
+
185
+ def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor:
186
+ """
187
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
188
+ solution to the differential equation.
189
+
190
+ Args:
191
+ model_output: direct output from learned diffusion model.
192
+ timestep: current discrete timestep in the diffusion chain.
193
+ sample: current instance of sample being created by diffusion process.
194
+
195
+ Returns:
196
+ pred_prev_sample: Predicted previous sample
197
+ """
198
+ if self.num_inference_steps is None:
199
+ raise ValueError(
200
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
201
+ )
202
+
203
+ diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2
204
+ prev_timestep = timestep - diff_to_prev
205
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
206
+
207
+ if self.counter % 4 == 0:
208
+ self.cur_model_output = 1 / 6 * model_output
209
+ self.ets.append(model_output)
210
+ self.cur_sample = sample
211
+ elif (self.counter - 1) % 4 == 0:
212
+ self.cur_model_output += 1 / 3 * model_output
213
+ elif (self.counter - 2) % 4 == 0:
214
+ self.cur_model_output += 1 / 3 * model_output
215
+ elif (self.counter - 3) % 4 == 0:
216
+ model_output = self.cur_model_output + 1 / 6 * model_output
217
+ self.cur_model_output = torch.Tensor()
218
+
219
+ # cur_sample should not be an empty torch.Tensor()
220
+ cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample
221
+
222
+ prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
223
+ self.counter += 1
224
+
225
+ return prev_sample
226
+
227
+ def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any:
228
+ """
229
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
230
+ times to approximate the solution.
231
+
232
+ Args:
233
+ model_output: direct output from learned diffusion model.
234
+ timestep: current discrete timestep in the diffusion chain.
235
+ sample: current instance of sample being created by diffusion process.
236
+
237
+ Returns:
238
+ pred_prev_sample: Predicted previous sample
239
+ """
240
+ if self.num_inference_steps is None:
241
+ raise ValueError(
242
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
243
+ )
244
+
245
+ if not self.skip_prk_steps and len(self.ets) < 3:
246
+ raise ValueError(
247
+ f"{self.__class__} can only be run AFTER scheduler has been run "
248
+ "in 'prk' mode for at least 12 iterations "
249
+ )
250
+
251
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
252
+
253
+ if self.counter != 1:
254
+ self.ets = self.ets[-3:]
255
+ self.ets.append(model_output)
256
+ else:
257
+ prev_timestep = timestep
258
+ timestep = timestep + self.num_train_timesteps // self.num_inference_steps
259
+
260
+ if len(self.ets) == 1 and self.counter == 0:
261
+ model_output = model_output
262
+ self.cur_sample = sample
263
+ elif len(self.ets) == 1 and self.counter == 1:
264
+ model_output = (model_output + self.ets[-1]) / 2
265
+ sample = self.cur_sample
266
+ self.cur_sample = torch.Tensor()
267
+ elif len(self.ets) == 2:
268
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
269
+ elif len(self.ets) == 3:
270
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
271
+ else:
272
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
273
+
274
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
275
+ self.counter += 1
276
+
277
+ return prev_sample
278
+
279
+ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor):
280
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
281
+ # this function computes x_(t−δ) using the formula of (9)
282
+ # Note that x_t needs to be added to both sides of the equation
283
+
284
+ # Notation (<variable name> -> <name in paper>
285
+ # alpha_prod_t -> α_t
286
+ # alpha_prod_t_prev -> α_(t−δ)
287
+ # beta_prod_t -> (1 - α_t)
288
+ # beta_prod_t_prev -> (1 - α_(t−δ))
289
+ # sample -> x_t
290
+ # model_output -> e_θ(x_t, t)
291
+ # prev_sample -> x_(t−δ)
292
+ alpha_prod_t = self.alphas_cumprod[timestep]
293
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
294
+ beta_prod_t = 1 - alpha_prod_t
295
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
296
+
297
+ if self.prediction_type == PNDMPredictionType.V_PREDICTION:
298
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
299
+
300
+ # corresponds to (α_(t−δ) - α_t) divided by
301
+ # denominator of x_t in formula (9) and plus 1
302
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
303
+ # sqrt(α_(t−δ)) / sqrt(α_t))
304
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
305
+
306
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
307
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
308
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
309
+ ) ** (0.5)
310
+
311
+ # full formula (9)
312
+ prev_sample = (
313
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
314
+ )
315
+
316
+ return prev_sample
@@ -0,0 +1,205 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ #
12
+ # =========================================================================
13
+ # Adapted from https://github.com/huggingface/diffusers
14
+ # which has the following license:
15
+ # https://github.com/huggingface/diffusers/blob/main/LICENSE
16
+ #
17
+ # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
18
+ #
19
+ # Licensed under the Apache License, Version 2.0 (the "License");
20
+ # you may not use this file except in compliance with the License.
21
+ # You may obtain a copy of the License at
22
+ #
23
+ # http://www.apache.org/licenses/LICENSE-2.0
24
+ #
25
+ # Unless required by applicable law or agreed to in writing, software
26
+ # distributed under the License is distributed on an "AS IS" BASIS,
27
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
+ # See the License for the specific language governing permissions and
29
+ # limitations under the License.
30
+ # =========================================================================
31
+
32
+
33
+ from __future__ import annotations
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+
38
+ from monai.utils import ComponentStore, unsqueeze_right
39
+
40
+ NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules")
41
+
42
+
43
+ @NoiseSchedules.add_def("linear_beta", "Linear beta schedule")
44
+ def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
45
+ """
46
+ Linear beta noise schedule function.
47
+
48
+ Args:
49
+ num_train_timesteps: number of timesteps
50
+ beta_start: start of beta range, default 1e-4
51
+ beta_end: end of beta range, default 2e-2
52
+
53
+ Returns:
54
+ betas: beta schedule tensor
55
+ """
56
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
57
+
58
+
59
+ @NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule")
60
+ def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
61
+ """
62
+ Scaled linear beta noise schedule function.
63
+
64
+ Args:
65
+ num_train_timesteps: number of timesteps
66
+ beta_start: start of beta range, default 1e-4
67
+ beta_end: end of beta range, default 2e-2
68
+
69
+ Returns:
70
+ betas: beta schedule tensor
71
+ """
72
+ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
73
+
74
+
75
+ @NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule")
76
+ def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6):
77
+ """
78
+ Sigmoid beta noise schedule function.
79
+
80
+ Args:
81
+ num_train_timesteps: number of timesteps
82
+ beta_start: start of beta range, default 1e-4
83
+ beta_end: end of beta range, default 2e-2
84
+ sig_range: pos/neg range of sigmoid input, default 6
85
+
86
+ Returns:
87
+ betas: beta schedule tensor
88
+ """
89
+ betas = torch.linspace(-sig_range, sig_range, num_train_timesteps)
90
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
91
+
92
+
93
+ @NoiseSchedules.add_def("cosine", "Cosine schedule")
94
+ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
95
+ """
96
+ Cosine noise schedule, see https://arxiv.org/abs/2102.09672
97
+
98
+ Args:
99
+ num_train_timesteps: number of timesteps
100
+ s: smoothing factor, default 8e-3 (see referenced paper)
101
+
102
+ Returns:
103
+ (betas, alphas, alpha_cumprod) values
104
+ """
105
+ x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
106
+ alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
107
+ alphas_cumprod /= alphas_cumprod[0].item()
108
+ alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
109
+ betas = 1.0 - alphas
110
+ return betas, alphas, alphas_cumprod[:-1]
111
+
112
+
113
+ class Scheduler(nn.Module):
114
+ """
115
+ Base class for other schedulers based on a noise schedule function.
116
+
117
+ This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here
118
+ the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`,
119
+ which is the name of a component in NoiseSchedules. These components must all be callables which return either
120
+ the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions
121
+ can be provided by using the NoiseSchedules.add_def, for example:
122
+
123
+ .. code-block:: python
124
+
125
+ from monai.networks.schedulers import NoiseSchedules, DDPMScheduler
126
+
127
+ @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function")
128
+ def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2):
129
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
130
+
131
+ scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule")
132
+
133
+ All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of
134
+ timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through
135
+ the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules
136
+ to get a listing of stored objects with their docstring descriptions.
137
+
138
+ Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule
139
+ type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended
140
+ to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are
141
+ still used for some schedules but these are provided as keyword arguments now.
142
+
143
+ Args:
144
+ num_train_timesteps: number of diffusion steps used to train the model.
145
+ schedule: member of NoiseSchedules,
146
+ a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple
147
+ schedule_args: arguments to pass to the schedule function
148
+ """
149
+
150
+ def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None:
151
+ super().__init__()
152
+ schedule_args["num_train_timesteps"] = num_train_timesteps
153
+ noise_sched = NoiseSchedules[schedule](**schedule_args)
154
+
155
+ # set betas, alphas, alphas_cumprod based off return value from noise function
156
+ if isinstance(noise_sched, tuple):
157
+ self.betas, self.alphas, self.alphas_cumprod = noise_sched
158
+ else:
159
+ self.betas = noise_sched
160
+ self.alphas = 1.0 - self.betas
161
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
162
+
163
+ self.num_train_timesteps = num_train_timesteps
164
+ self.one = torch.tensor(1.0)
165
+
166
+ # settable values
167
+ self.num_inference_steps: int | None = None
168
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
169
+
170
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Add noise to the original samples.
173
+
174
+ Args:
175
+ original_samples: original samples
176
+ noise: noise to add to samples
177
+ timesteps: timesteps tensor indicating the timestep to be computed for each sample.
178
+
179
+ Returns:
180
+ noisy_samples: sample with added noise
181
+ """
182
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
183
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
184
+ timesteps = timesteps.to(original_samples.device)
185
+
186
+ sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim)
187
+ sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
188
+ (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim
189
+ )
190
+
191
+ noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
192
+ return noisy_samples
193
+
194
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
195
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
196
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
197
+ timesteps = timesteps.to(sample.device)
198
+
199
+ sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim)
200
+ sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
201
+ (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim
202
+ )
203
+
204
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
205
+ return velocity
monai/networks/utils.py CHANGED
@@ -42,6 +42,7 @@ __all__ = [
42
42
  "predict_segmentation",
43
43
  "normalize_transform",
44
44
  "to_norm_affine",
45
+ "CastTempType",
45
46
  "normal_init",
46
47
  "icnr_init",
47
48
  "pixelshuffle",
@@ -1167,3 +1168,24 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None):
1167
1168
  warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.")
1168
1169
 
1169
1170
  logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.")
1171
+
1172
+
1173
+ class CastTempType(nn.Module):
1174
+ """
1175
+ Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type.
1176
+ """
1177
+
1178
+ def __init__(self, initial_type, temporary_type, submodule):
1179
+ super().__init__()
1180
+ self.initial_type = initial_type
1181
+ self.temporary_type = temporary_type
1182
+ self.submodule = submodule
1183
+
1184
+ def forward(self, x):
1185
+ dtype = x.dtype
1186
+ if dtype == self.initial_type:
1187
+ x = x.to(self.temporary_type)
1188
+ x = self.submodule(x)
1189
+ if dtype == self.initial_type:
1190
+ x = x.to(self.initial_type)
1191
+ return x
@@ -362,10 +362,10 @@ class Crop(InvertibleTransform, LazyTransform):
362
362
 
363
363
  @staticmethod
364
364
  def compute_slices(
365
- roi_center: Sequence[int] | NdarrayOrTensor | None = None,
366
- roi_size: Sequence[int] | NdarrayOrTensor | None = None,
367
- roi_start: Sequence[int] | NdarrayOrTensor | None = None,
368
- roi_end: Sequence[int] | NdarrayOrTensor | None = None,
365
+ roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
366
+ roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
367
+ roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
368
+ roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
369
369
  roi_slices: Sequence[slice] | None = None,
370
370
  ) -> tuple[slice]:
371
371
  """
@@ -459,10 +459,10 @@ class SpatialCrop(Crop):
459
459
 
460
460
  def __init__(
461
461
  self,
462
- roi_center: Sequence[int] | NdarrayOrTensor | None = None,
463
- roi_size: Sequence[int] | NdarrayOrTensor | None = None,
464
- roi_start: Sequence[int] | NdarrayOrTensor | None = None,
465
- roi_end: Sequence[int] | NdarrayOrTensor | None = None,
462
+ roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
463
+ roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
464
+ roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
465
+ roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
466
466
  roi_slices: Sequence[slice] | None = None,
467
467
  lazy: bool = False,
468
468
  ) -> None:
@@ -438,10 +438,10 @@ class SpatialCropd(Cropd):
438
438
  def __init__(
439
439
  self,
440
440
  keys: KeysCollection,
441
- roi_center: Sequence[int] | None = None,
442
- roi_size: Sequence[int] | None = None,
443
- roi_start: Sequence[int] | None = None,
444
- roi_end: Sequence[int] | None = None,
441
+ roi_center: Sequence[int] | int | None = None,
442
+ roi_size: Sequence[int] | int | None = None,
443
+ roi_start: Sequence[int] | int | None = None,
444
+ roi_end: Sequence[int] | int | None = None,
445
445
  roi_slices: Sequence[slice] | None = None,
446
446
  allow_missing_keys: bool = False,
447
447
  lazy: bool = False,
@@ -48,7 +48,7 @@ def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k
48
48
  warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
49
49
  img_np = img.detach().cpu().numpy()
50
50
  else:
51
- img_np = img
51
+ img_np = np.asarray(img)
52
52
  mode = convert_pad_mode(dst=img_np, mode=mode).value
53
53
  if mode == "constant" and "value" in kwargs:
54
54
  kwargs["constant_values"] = kwargs.pop("value")
@@ -87,12 +87,14 @@ class MixUp(Mixer):
87
87
 
88
88
  def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
89
89
  data_t = convert_to_tensor(data, track_meta=get_track_meta())
90
+ labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy
90
91
  if labels is not None:
91
92
  labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
92
93
  if randomize:
93
94
  self.randomize()
94
95
  if labels is None:
95
96
  return convert_to_dst_type(self.apply(data_t), dst=data)[0]
97
+
96
98
  return (
97
99
  convert_to_dst_type(self.apply(data_t), dst=data)[0],
98
100
  convert_to_dst_type(self.apply(labels_t), dst=labels)[0],
@@ -149,11 +151,13 @@ class CutMix(Mixer):
149
151
 
150
152
  def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
151
153
  data_t = convert_to_tensor(data, track_meta=get_track_meta())
154
+ augmented_label = None
152
155
  if labels is not None:
153
156
  labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
154
157
  if randomize:
155
158
  self.randomize(data)
156
159
  augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]
160
+
157
161
  if labels is not None:
158
162
  augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]
159
163
  return (augmented, augmented_label) if labels is not None else augmented
@@ -3441,7 +3441,7 @@ class RandGridPatch(GridPatch, RandomizableTransform, MultiSampleTrait):
3441
3441
  idx = self.R.permutation(image_np.shape[0])
3442
3442
  idx = idx[: self.num_patches]
3443
3443
  idx_np = convert_data_type(idx, np.ndarray)[0]
3444
- image_np = image_np[idx]
3444
+ image_np = image_np[idx] # type: ignore[index]
3445
3445
  locations = locations[idx_np]
3446
3446
  return image_np, locations
3447
3447
  elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):
@@ -269,11 +269,9 @@ def update_docstring(code_path, transform_name):
269
269
 
270
270
 
271
271
  def pre_process_data(data, ndim, is_map, is_post):
272
- """If transform requires 2D data, then convert to 2D"""
272
+ """If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension."""
273
273
  if ndim == 2:
274
- for k in keys:
275
- data[k] = data[k][..., data[k].shape[-1] // 2]
276
-
274
+ data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()}
277
275
  if is_map:
278
276
  return data
279
277
  return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE]
monai/utils/__init__.py CHANGED
@@ -126,6 +126,7 @@ from .module import (
126
126
  version_leq,
127
127
  )
128
128
  from .nvtx import Range
129
+ from .ordering import Ordering
129
130
  from .profiling import (
130
131
  PerfContext,
131
132
  ProfileHandler,