monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,472 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from collections.abc import Sequence
|
15
|
+
from typing import Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
from monai.networks.blocks import Convolution
|
21
|
+
from monai.networks.layers import Act
|
22
|
+
from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer
|
23
|
+
from monai.utils import ensure_tuple_rep
|
24
|
+
|
25
|
+
__all__ = ["VQVAE"]
|
26
|
+
|
27
|
+
|
28
|
+
class VQVAEResidualUnit(nn.Module):
|
29
|
+
"""
|
30
|
+
Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving
|
31
|
+
Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf).
|
32
|
+
|
33
|
+
The original implementation that can be found at
|
34
|
+
https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
spatial_dims: number of spatial spatial_dims of the input data.
|
38
|
+
in_channels: number of input channels.
|
39
|
+
num_res_channels: number of channels in the residual layers.
|
40
|
+
act: activation type and arguments. Defaults to RELU.
|
41
|
+
dropout: dropout ratio. Defaults to no dropout.
|
42
|
+
bias: whether to have a bias term. Defaults to True.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
spatial_dims: int,
|
48
|
+
in_channels: int,
|
49
|
+
num_res_channels: int,
|
50
|
+
act: tuple | str | None = Act.RELU,
|
51
|
+
dropout: float = 0.0,
|
52
|
+
bias: bool = True,
|
53
|
+
) -> None:
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
self.spatial_dims = spatial_dims
|
57
|
+
self.in_channels = in_channels
|
58
|
+
self.num_res_channels = num_res_channels
|
59
|
+
self.act = act
|
60
|
+
self.dropout = dropout
|
61
|
+
self.bias = bias
|
62
|
+
|
63
|
+
self.conv1 = Convolution(
|
64
|
+
spatial_dims=self.spatial_dims,
|
65
|
+
in_channels=self.in_channels,
|
66
|
+
out_channels=self.num_res_channels,
|
67
|
+
adn_ordering="DA",
|
68
|
+
act=self.act,
|
69
|
+
dropout=self.dropout,
|
70
|
+
bias=self.bias,
|
71
|
+
)
|
72
|
+
|
73
|
+
self.conv2 = Convolution(
|
74
|
+
spatial_dims=self.spatial_dims,
|
75
|
+
in_channels=self.num_res_channels,
|
76
|
+
out_channels=self.in_channels,
|
77
|
+
bias=self.bias,
|
78
|
+
conv_only=True,
|
79
|
+
)
|
80
|
+
|
81
|
+
def forward(self, x):
|
82
|
+
return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True)
|
83
|
+
|
84
|
+
|
85
|
+
class Encoder(nn.Module):
|
86
|
+
"""
|
87
|
+
Encoder module for VQ-VAE.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
spatial_dims: number of spatial spatial_dims.
|
91
|
+
in_channels: number of input channels.
|
92
|
+
out_channels: number of channels in the latent space (embedding_dim).
|
93
|
+
channels: sequence containing the number of channels at each level of the encoder.
|
94
|
+
num_res_layers: number of sequential residual layers at each level.
|
95
|
+
num_res_channels: number of channels in the residual layers at each level.
|
96
|
+
downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
|
97
|
+
following information stride (int), kernel_size (int), dilation (int) and padding (int).
|
98
|
+
dropout: dropout ratio.
|
99
|
+
act: activation type and arguments.
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
spatial_dims: int,
|
105
|
+
in_channels: int,
|
106
|
+
out_channels: int,
|
107
|
+
channels: Sequence[int],
|
108
|
+
num_res_layers: int,
|
109
|
+
num_res_channels: Sequence[int],
|
110
|
+
downsample_parameters: Sequence[Tuple[int, int, int, int]],
|
111
|
+
dropout: float,
|
112
|
+
act: tuple | str | None,
|
113
|
+
) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self.spatial_dims = spatial_dims
|
116
|
+
self.in_channels = in_channels
|
117
|
+
self.out_channels = out_channels
|
118
|
+
self.channels = channels
|
119
|
+
self.num_res_layers = num_res_layers
|
120
|
+
self.num_res_channels = num_res_channels
|
121
|
+
self.downsample_parameters = downsample_parameters
|
122
|
+
self.dropout = dropout
|
123
|
+
self.act = act
|
124
|
+
|
125
|
+
blocks: list[nn.Module] = []
|
126
|
+
|
127
|
+
for i in range(len(self.channels)):
|
128
|
+
blocks.append(
|
129
|
+
Convolution(
|
130
|
+
spatial_dims=self.spatial_dims,
|
131
|
+
in_channels=self.in_channels if i == 0 else self.channels[i - 1],
|
132
|
+
out_channels=self.channels[i],
|
133
|
+
strides=self.downsample_parameters[i][0],
|
134
|
+
kernel_size=self.downsample_parameters[i][1],
|
135
|
+
adn_ordering="DA",
|
136
|
+
act=self.act,
|
137
|
+
dropout=None if i == 0 else self.dropout,
|
138
|
+
dropout_dim=1,
|
139
|
+
dilation=self.downsample_parameters[i][2],
|
140
|
+
padding=self.downsample_parameters[i][3],
|
141
|
+
)
|
142
|
+
)
|
143
|
+
|
144
|
+
for _ in range(self.num_res_layers):
|
145
|
+
blocks.append(
|
146
|
+
VQVAEResidualUnit(
|
147
|
+
spatial_dims=self.spatial_dims,
|
148
|
+
in_channels=self.channels[i],
|
149
|
+
num_res_channels=self.num_res_channels[i],
|
150
|
+
act=self.act,
|
151
|
+
dropout=self.dropout,
|
152
|
+
)
|
153
|
+
)
|
154
|
+
|
155
|
+
blocks.append(
|
156
|
+
Convolution(
|
157
|
+
spatial_dims=self.spatial_dims,
|
158
|
+
in_channels=self.channels[len(self.channels) - 1],
|
159
|
+
out_channels=self.out_channels,
|
160
|
+
strides=1,
|
161
|
+
kernel_size=3,
|
162
|
+
padding=1,
|
163
|
+
conv_only=True,
|
164
|
+
)
|
165
|
+
)
|
166
|
+
|
167
|
+
self.blocks = nn.ModuleList(blocks)
|
168
|
+
|
169
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
170
|
+
for block in self.blocks:
|
171
|
+
x = block(x)
|
172
|
+
return x
|
173
|
+
|
174
|
+
|
175
|
+
class Decoder(nn.Module):
|
176
|
+
"""
|
177
|
+
Decoder module for VQ-VAE.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
spatial_dims: number of spatial spatial_dims.
|
181
|
+
in_channels: number of channels in the latent space (embedding_dim).
|
182
|
+
out_channels: number of output channels.
|
183
|
+
channels: sequence containing the number of channels at each level of the decoder.
|
184
|
+
num_res_layers: number of sequential residual layers at each level.
|
185
|
+
num_res_channels: number of channels in the residual layers at each level.
|
186
|
+
upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
|
187
|
+
following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
|
188
|
+
dropout: dropout ratio.
|
189
|
+
act: activation type and arguments.
|
190
|
+
output_act: activation type and arguments for the output.
|
191
|
+
"""
|
192
|
+
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
spatial_dims: int,
|
196
|
+
in_channels: int,
|
197
|
+
out_channels: int,
|
198
|
+
channels: Sequence[int],
|
199
|
+
num_res_layers: int,
|
200
|
+
num_res_channels: Sequence[int],
|
201
|
+
upsample_parameters: Sequence[Tuple[int, int, int, int, int]],
|
202
|
+
dropout: float,
|
203
|
+
act: tuple | str | None,
|
204
|
+
output_act: tuple | str | None,
|
205
|
+
) -> None:
|
206
|
+
super().__init__()
|
207
|
+
self.spatial_dims = spatial_dims
|
208
|
+
self.in_channels = in_channels
|
209
|
+
self.out_channels = out_channels
|
210
|
+
self.channels = channels
|
211
|
+
self.num_res_layers = num_res_layers
|
212
|
+
self.num_res_channels = num_res_channels
|
213
|
+
self.upsample_parameters = upsample_parameters
|
214
|
+
self.dropout = dropout
|
215
|
+
self.act = act
|
216
|
+
self.output_act = output_act
|
217
|
+
|
218
|
+
reversed_num_channels = list(reversed(self.channels))
|
219
|
+
|
220
|
+
blocks: list[nn.Module] = []
|
221
|
+
blocks.append(
|
222
|
+
Convolution(
|
223
|
+
spatial_dims=self.spatial_dims,
|
224
|
+
in_channels=self.in_channels,
|
225
|
+
out_channels=reversed_num_channels[0],
|
226
|
+
strides=1,
|
227
|
+
kernel_size=3,
|
228
|
+
padding=1,
|
229
|
+
conv_only=True,
|
230
|
+
)
|
231
|
+
)
|
232
|
+
|
233
|
+
reversed_num_res_channels = list(reversed(self.num_res_channels))
|
234
|
+
for i in range(len(self.channels)):
|
235
|
+
for _ in range(self.num_res_layers):
|
236
|
+
blocks.append(
|
237
|
+
VQVAEResidualUnit(
|
238
|
+
spatial_dims=self.spatial_dims,
|
239
|
+
in_channels=reversed_num_channels[i],
|
240
|
+
num_res_channels=reversed_num_res_channels[i],
|
241
|
+
act=self.act,
|
242
|
+
dropout=self.dropout,
|
243
|
+
)
|
244
|
+
)
|
245
|
+
|
246
|
+
blocks.append(
|
247
|
+
Convolution(
|
248
|
+
spatial_dims=self.spatial_dims,
|
249
|
+
in_channels=reversed_num_channels[i],
|
250
|
+
out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1],
|
251
|
+
strides=self.upsample_parameters[i][0],
|
252
|
+
kernel_size=self.upsample_parameters[i][1],
|
253
|
+
adn_ordering="DA",
|
254
|
+
act=self.act,
|
255
|
+
dropout=self.dropout if i != len(self.channels) - 1 else None,
|
256
|
+
norm=None,
|
257
|
+
dilation=self.upsample_parameters[i][2],
|
258
|
+
conv_only=i == len(self.channels) - 1,
|
259
|
+
is_transposed=True,
|
260
|
+
padding=self.upsample_parameters[i][3],
|
261
|
+
output_padding=self.upsample_parameters[i][4],
|
262
|
+
)
|
263
|
+
)
|
264
|
+
|
265
|
+
if self.output_act:
|
266
|
+
blocks.append(Act[self.output_act]())
|
267
|
+
|
268
|
+
self.blocks = nn.ModuleList(blocks)
|
269
|
+
|
270
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
271
|
+
for block in self.blocks:
|
272
|
+
x = block(x)
|
273
|
+
return x
|
274
|
+
|
275
|
+
|
276
|
+
class VQVAE(nn.Module):
|
277
|
+
"""
|
278
|
+
Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative
|
279
|
+
Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf)
|
280
|
+
|
281
|
+
The original implementation can be found at
|
282
|
+
https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/
|
283
|
+
|
284
|
+
Args:
|
285
|
+
spatial_dims: number of spatial spatial_dims.
|
286
|
+
in_channels: number of input channels.
|
287
|
+
out_channels: number of output channels.
|
288
|
+
downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
|
289
|
+
following information stride (int), kernel_size (int), dilation (int) and padding (int).
|
290
|
+
upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
|
291
|
+
following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
|
292
|
+
num_res_layers: number of sequential residual layers at each level.
|
293
|
+
channels: number of channels at each level.
|
294
|
+
num_res_channels: number of channels in the residual layers at each level.
|
295
|
+
num_embeddings: VectorQuantization number of atomic elements in the codebook.
|
296
|
+
embedding_dim: VectorQuantization number of channels of the input and atomic elements.
|
297
|
+
commitment_cost: VectorQuantization commitment_cost.
|
298
|
+
decay: VectorQuantization decay.
|
299
|
+
epsilon: VectorQuantization epsilon.
|
300
|
+
act: activation type and arguments.
|
301
|
+
dropout: dropout ratio.
|
302
|
+
output_act: activation type and arguments for the output.
|
303
|
+
ddp_sync: whether to synchronize the codebook across processes.
|
304
|
+
use_checkpointing if True, use activation checkpointing to save memory.
|
305
|
+
"""
|
306
|
+
|
307
|
+
def __init__(
|
308
|
+
self,
|
309
|
+
spatial_dims: int,
|
310
|
+
in_channels: int,
|
311
|
+
out_channels: int,
|
312
|
+
channels: Sequence[int] = (96, 96, 192),
|
313
|
+
num_res_layers: int = 3,
|
314
|
+
num_res_channels: Sequence[int] | int = (96, 96, 192),
|
315
|
+
downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = (
|
316
|
+
(2, 4, 1, 1),
|
317
|
+
(2, 4, 1, 1),
|
318
|
+
(2, 4, 1, 1),
|
319
|
+
),
|
320
|
+
upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = (
|
321
|
+
(2, 4, 1, 1, 0),
|
322
|
+
(2, 4, 1, 1, 0),
|
323
|
+
(2, 4, 1, 1, 0),
|
324
|
+
),
|
325
|
+
num_embeddings: int = 32,
|
326
|
+
embedding_dim: int = 64,
|
327
|
+
embedding_init: str = "normal",
|
328
|
+
commitment_cost: float = 0.25,
|
329
|
+
decay: float = 0.5,
|
330
|
+
epsilon: float = 1e-5,
|
331
|
+
dropout: float = 0.0,
|
332
|
+
act: tuple | str | None = Act.RELU,
|
333
|
+
output_act: tuple | str | None = None,
|
334
|
+
ddp_sync: bool = True,
|
335
|
+
use_checkpointing: bool = False,
|
336
|
+
):
|
337
|
+
super().__init__()
|
338
|
+
|
339
|
+
self.in_channels = in_channels
|
340
|
+
self.out_channels = out_channels
|
341
|
+
self.spatial_dims = spatial_dims
|
342
|
+
self.channels = channels
|
343
|
+
self.num_embeddings = num_embeddings
|
344
|
+
self.embedding_dim = embedding_dim
|
345
|
+
self.use_checkpointing = use_checkpointing
|
346
|
+
|
347
|
+
if isinstance(num_res_channels, int):
|
348
|
+
num_res_channels = ensure_tuple_rep(num_res_channels, len(channels))
|
349
|
+
|
350
|
+
if len(num_res_channels) != len(channels):
|
351
|
+
raise ValueError(
|
352
|
+
"`num_res_channels` should be a single integer or a tuple of integers with the same length as "
|
353
|
+
"`num_channls`."
|
354
|
+
)
|
355
|
+
if all(isinstance(values, int) for values in upsample_parameters):
|
356
|
+
upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels)
|
357
|
+
else:
|
358
|
+
upsample_parameters_tuple = upsample_parameters
|
359
|
+
|
360
|
+
if all(isinstance(values, int) for values in downsample_parameters):
|
361
|
+
downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels)
|
362
|
+
else:
|
363
|
+
downsample_parameters_tuple = downsample_parameters
|
364
|
+
|
365
|
+
if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple):
|
366
|
+
raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.")
|
367
|
+
|
368
|
+
# check if downsample_parameters is a tuple of ints or a tuple of tuples of ints
|
369
|
+
if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple):
|
370
|
+
raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.")
|
371
|
+
|
372
|
+
for parameter in downsample_parameters_tuple:
|
373
|
+
if len(parameter) != 4:
|
374
|
+
raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.")
|
375
|
+
|
376
|
+
for parameter in upsample_parameters_tuple:
|
377
|
+
if len(parameter) != 5:
|
378
|
+
raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.")
|
379
|
+
|
380
|
+
if len(downsample_parameters_tuple) != len(channels):
|
381
|
+
raise ValueError(
|
382
|
+
"`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
|
383
|
+
)
|
384
|
+
|
385
|
+
if len(upsample_parameters_tuple) != len(channels):
|
386
|
+
raise ValueError(
|
387
|
+
"`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
|
388
|
+
)
|
389
|
+
|
390
|
+
self.num_res_layers = num_res_layers
|
391
|
+
self.num_res_channels = num_res_channels
|
392
|
+
|
393
|
+
self.encoder = Encoder(
|
394
|
+
spatial_dims=spatial_dims,
|
395
|
+
in_channels=in_channels,
|
396
|
+
out_channels=embedding_dim,
|
397
|
+
channels=channels,
|
398
|
+
num_res_layers=num_res_layers,
|
399
|
+
num_res_channels=num_res_channels,
|
400
|
+
downsample_parameters=downsample_parameters_tuple,
|
401
|
+
dropout=dropout,
|
402
|
+
act=act,
|
403
|
+
)
|
404
|
+
|
405
|
+
self.decoder = Decoder(
|
406
|
+
spatial_dims=spatial_dims,
|
407
|
+
in_channels=embedding_dim,
|
408
|
+
out_channels=out_channels,
|
409
|
+
channels=channels,
|
410
|
+
num_res_layers=num_res_layers,
|
411
|
+
num_res_channels=num_res_channels,
|
412
|
+
upsample_parameters=upsample_parameters_tuple,
|
413
|
+
dropout=dropout,
|
414
|
+
act=act,
|
415
|
+
output_act=output_act,
|
416
|
+
)
|
417
|
+
|
418
|
+
self.quantizer = VectorQuantizer(
|
419
|
+
quantizer=EMAQuantizer(
|
420
|
+
spatial_dims=spatial_dims,
|
421
|
+
num_embeddings=num_embeddings,
|
422
|
+
embedding_dim=embedding_dim,
|
423
|
+
commitment_cost=commitment_cost,
|
424
|
+
decay=decay,
|
425
|
+
epsilon=epsilon,
|
426
|
+
embedding_init=embedding_init,
|
427
|
+
ddp_sync=ddp_sync,
|
428
|
+
)
|
429
|
+
)
|
430
|
+
|
431
|
+
def encode(self, images: torch.Tensor) -> torch.Tensor:
|
432
|
+
output: torch.Tensor
|
433
|
+
if self.use_checkpointing:
|
434
|
+
output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False)
|
435
|
+
else:
|
436
|
+
output = self.encoder(images)
|
437
|
+
return output
|
438
|
+
|
439
|
+
def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
440
|
+
x_loss, x = self.quantizer(encodings)
|
441
|
+
return x, x_loss
|
442
|
+
|
443
|
+
def decode(self, quantizations: torch.Tensor) -> torch.Tensor:
|
444
|
+
output: torch.Tensor
|
445
|
+
|
446
|
+
if self.use_checkpointing:
|
447
|
+
output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False)
|
448
|
+
else:
|
449
|
+
output = self.decoder(quantizations)
|
450
|
+
return output
|
451
|
+
|
452
|
+
def index_quantize(self, images: torch.Tensor) -> torch.Tensor:
|
453
|
+
return self.quantizer.quantize(self.encode(images=images))
|
454
|
+
|
455
|
+
def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor:
|
456
|
+
return self.decode(self.quantizer.embed(embedding_indices))
|
457
|
+
|
458
|
+
def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
459
|
+
quantizations, quantization_losses = self.quantize(self.encode(images))
|
460
|
+
reconstruction = self.decode(quantizations)
|
461
|
+
|
462
|
+
return reconstruction, quantization_losses
|
463
|
+
|
464
|
+
def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
|
465
|
+
z = self.encode(x)
|
466
|
+
e, _ = self.quantize(z)
|
467
|
+
return e
|
468
|
+
|
469
|
+
def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
|
470
|
+
e, _ = self.quantize(z)
|
471
|
+
image = self.decode(e)
|
472
|
+
return image
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from .ddim import DDIMScheduler
|
15
|
+
from .ddpm import DDPMScheduler
|
16
|
+
from .pndm import PNDMScheduler
|
17
|
+
from .scheduler import NoiseSchedules, Scheduler
|