monai-weekly 1.4.dev2425__py3-none-any.whl → 1.4.dev2427__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/deepedit/transforms.py +1 -1
- monai/apps/deepgrow/transforms.py +1 -1
- monai/apps/generation/__init__.py +10 -0
- monai/apps/generation/maisi/__init__.py +10 -0
- monai/apps/generation/maisi/networks/__init__.py +10 -0
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +975 -0
- monai/apps/generation/maisi/networks/controlnet_maisi.py +178 -0
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +410 -0
- monai/apps/generation/maisi/utils/__init__.py +10 -0
- monai/apps/generation/maisi/utils/morphological_ops.py +170 -0
- monai/apps/nuclick/transforms.py +1 -1
- monai/apps/pathology/transforms/post/array.py +1 -1
- monai/apps/pathology/utils.py +2 -2
- monai/data/torchscript_utils.py +1 -1
- monai/data/ultrasound_confidence_map.py +41 -10
- monai/losses/dice.py +10 -3
- monai/metrics/utils.py +3 -3
- monai/optimizers/lr_finder.py +1 -1
- monai/transforms/intensity/array.py +25 -2
- monai/transforms/signal/array.py +1 -1
- monai/utils/misc.py +20 -2
- monai/utils/module.py +6 -3
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/METADATA +6 -3
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/RECORD +29 -21
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2425.dist-info → monai_weekly-1.4.dev2427.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import TYPE_CHECKING, Sequence, cast
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from monai.utils import optional_import
|
19
|
+
|
20
|
+
ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
|
21
|
+
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
22
|
+
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
23
|
+
)
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from generative.networks.nets.controlnet import ControlNet as ControlNetType
|
27
|
+
else:
|
28
|
+
ControlNetType = cast(type, ControlNet)
|
29
|
+
|
30
|
+
|
31
|
+
class ControlNetMaisi(ControlNetType):
|
32
|
+
"""
|
33
|
+
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
|
34
|
+
Diffusion Models" (https://arxiv.org/abs/2302.05543)
|
35
|
+
|
36
|
+
Args:
|
37
|
+
spatial_dims: number of spatial dimensions.
|
38
|
+
in_channels: number of input channels.
|
39
|
+
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
|
40
|
+
num_channels: tuple of block output channels.
|
41
|
+
attention_levels: list of levels to add attention.
|
42
|
+
norm_num_groups: number of groups for the normalization.
|
43
|
+
norm_eps: epsilon for the normalization.
|
44
|
+
resblock_updown: if True use residual blocks for up/downsampling.
|
45
|
+
num_head_channels: number of channels in each attention head.
|
46
|
+
with_conditioning: if True add spatial transformers to perform conditioning.
|
47
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
48
|
+
cross_attention_dim: number of context dimensions to use.
|
49
|
+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
50
|
+
classes.
|
51
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
52
|
+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
|
53
|
+
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
|
54
|
+
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
|
55
|
+
use_checkpointing: if True, use activation checkpointing to save memory.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
spatial_dims: int,
|
61
|
+
in_channels: int,
|
62
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
63
|
+
num_channels: Sequence[int] = (32, 64, 64, 64),
|
64
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
65
|
+
norm_num_groups: int = 32,
|
66
|
+
norm_eps: float = 1e-6,
|
67
|
+
resblock_updown: bool = False,
|
68
|
+
num_head_channels: int | Sequence[int] = 8,
|
69
|
+
with_conditioning: bool = False,
|
70
|
+
transformer_num_layers: int = 1,
|
71
|
+
cross_attention_dim: int | None = None,
|
72
|
+
num_class_embeds: int | None = None,
|
73
|
+
upcast_attention: bool = False,
|
74
|
+
use_flash_attention: bool = False,
|
75
|
+
conditioning_embedding_in_channels: int = 1,
|
76
|
+
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
|
77
|
+
use_checkpointing: bool = True,
|
78
|
+
) -> None:
|
79
|
+
super().__init__(
|
80
|
+
spatial_dims,
|
81
|
+
in_channels,
|
82
|
+
num_res_blocks,
|
83
|
+
num_channels,
|
84
|
+
attention_levels,
|
85
|
+
norm_num_groups,
|
86
|
+
norm_eps,
|
87
|
+
resblock_updown,
|
88
|
+
num_head_channels,
|
89
|
+
with_conditioning,
|
90
|
+
transformer_num_layers,
|
91
|
+
cross_attention_dim,
|
92
|
+
num_class_embeds,
|
93
|
+
upcast_attention,
|
94
|
+
use_flash_attention,
|
95
|
+
conditioning_embedding_in_channels,
|
96
|
+
conditioning_embedding_num_channels,
|
97
|
+
)
|
98
|
+
self.use_checkpointing = use_checkpointing
|
99
|
+
|
100
|
+
def forward(
|
101
|
+
self,
|
102
|
+
x: torch.Tensor,
|
103
|
+
timesteps: torch.Tensor,
|
104
|
+
controlnet_cond: torch.Tensor,
|
105
|
+
conditioning_scale: float = 1.0,
|
106
|
+
context: torch.Tensor | None = None,
|
107
|
+
class_labels: torch.Tensor | None = None,
|
108
|
+
) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
|
109
|
+
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
|
110
|
+
h = self._apply_initial_convolution(x)
|
111
|
+
if self.use_checkpointing:
|
112
|
+
controlnet_cond = torch.utils.checkpoint.checkpoint(
|
113
|
+
self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
117
|
+
h += controlnet_cond
|
118
|
+
down_block_res_samples, h = self._apply_down_blocks(emb, context, h)
|
119
|
+
h = self._apply_mid_block(emb, context, h)
|
120
|
+
down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)
|
121
|
+
# scaling
|
122
|
+
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
|
123
|
+
mid_block_res_sample *= conditioning_scale
|
124
|
+
|
125
|
+
return down_block_res_samples, mid_block_res_sample
|
126
|
+
|
127
|
+
def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
|
128
|
+
# 1. time
|
129
|
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
130
|
+
|
131
|
+
# timesteps does not contain any weights and will always return f32 tensors
|
132
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
133
|
+
# there might be better ways to encapsulate this.
|
134
|
+
t_emb = t_emb.to(dtype=x.dtype)
|
135
|
+
emb = self.time_embed(t_emb)
|
136
|
+
|
137
|
+
# 2. class
|
138
|
+
if self.num_class_embeds is not None:
|
139
|
+
if class_labels is None:
|
140
|
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
141
|
+
class_emb = self.class_embedding(class_labels)
|
142
|
+
class_emb = class_emb.to(dtype=x.dtype)
|
143
|
+
emb = emb + class_emb
|
144
|
+
|
145
|
+
return emb
|
146
|
+
|
147
|
+
def _apply_initial_convolution(self, x):
|
148
|
+
# 3. initial convolution
|
149
|
+
h = self.conv_in(x)
|
150
|
+
return h
|
151
|
+
|
152
|
+
def _apply_down_blocks(self, emb, context, h):
|
153
|
+
# 4. down
|
154
|
+
if context is not None and self.with_conditioning is False:
|
155
|
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
156
|
+
down_block_res_samples: list[torch.Tensor] = [h]
|
157
|
+
for downsample_block in self.down_blocks:
|
158
|
+
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
|
159
|
+
for residual in res_samples:
|
160
|
+
down_block_res_samples.append(residual)
|
161
|
+
|
162
|
+
return down_block_res_samples, h
|
163
|
+
|
164
|
+
def _apply_mid_block(self, emb, context, h):
|
165
|
+
# 5. mid
|
166
|
+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
|
167
|
+
return h
|
168
|
+
|
169
|
+
def _apply_controlnet_blocks(self, h, down_block_res_samples):
|
170
|
+
# 6. Control net blocks
|
171
|
+
controlnet_down_block_res_samples = []
|
172
|
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
173
|
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
174
|
+
controlnet_down_block_res_samples.append(down_block_res_sample)
|
175
|
+
|
176
|
+
mid_block_res_sample = self.controlnet_mid_block(h)
|
177
|
+
|
178
|
+
return controlnet_down_block_res_samples, mid_block_res_sample
|
@@ -0,0 +1,410 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
from __future__ import annotations
|
33
|
+
|
34
|
+
from collections.abc import Sequence
|
35
|
+
|
36
|
+
import torch
|
37
|
+
from torch import nn
|
38
|
+
|
39
|
+
from monai.networks.blocks import Convolution
|
40
|
+
from monai.utils import ensure_tuple_rep, optional_import
|
41
|
+
from monai.utils.type_conversion import convert_to_tensor
|
42
|
+
|
43
|
+
get_down_block, has_get_down_block = optional_import(
|
44
|
+
"generative.networks.nets.diffusion_model_unet", name="get_down_block"
|
45
|
+
)
|
46
|
+
get_mid_block, has_get_mid_block = optional_import(
|
47
|
+
"generative.networks.nets.diffusion_model_unet", name="get_mid_block"
|
48
|
+
)
|
49
|
+
get_timestep_embedding, has_get_timestep_embedding = optional_import(
|
50
|
+
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
|
51
|
+
)
|
52
|
+
get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block")
|
53
|
+
xformers, has_xformers = optional_import("xformers")
|
54
|
+
zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module")
|
55
|
+
|
56
|
+
__all__ = ["DiffusionModelUNetMaisi"]
|
57
|
+
|
58
|
+
|
59
|
+
class DiffusionModelUNetMaisi(nn.Module):
|
60
|
+
"""
|
61
|
+
U-Net network with timestep embedding and attention mechanisms for conditioning based on
|
62
|
+
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
|
63
|
+
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
|
64
|
+
|
65
|
+
Args:
|
66
|
+
spatial_dims: Number of spatial dimensions.
|
67
|
+
in_channels: Number of input channels.
|
68
|
+
out_channels: Number of output channels.
|
69
|
+
num_res_blocks: Number of residual blocks (see ResnetBlock) per level. Can be a single integer or a sequence of integers.
|
70
|
+
num_channels: Tuple of block output channels.
|
71
|
+
attention_levels: List of levels to add attention.
|
72
|
+
norm_num_groups: Number of groups for the normalization.
|
73
|
+
norm_eps: Epsilon for the normalization.
|
74
|
+
resblock_updown: If True, use residual blocks for up/downsampling.
|
75
|
+
num_head_channels: Number of channels in each attention head. Can be a single integer or a sequence of integers.
|
76
|
+
with_conditioning: If True, add spatial transformers to perform conditioning.
|
77
|
+
transformer_num_layers: Number of layers of Transformer blocks to use.
|
78
|
+
cross_attention_dim: Number of context dimensions to use.
|
79
|
+
num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
80
|
+
upcast_attention: If True, upcast attention operations to full precision.
|
81
|
+
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
|
82
|
+
dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
|
83
|
+
include_top_region_index_input: If True, use top region index input.
|
84
|
+
include_bottom_region_index_input: If True, use bottom region index input.
|
85
|
+
include_spacing_input: If True, use spacing input.
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
spatial_dims: int,
|
91
|
+
in_channels: int,
|
92
|
+
out_channels: int,
|
93
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
94
|
+
num_channels: Sequence[int] = (32, 64, 64, 64),
|
95
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
96
|
+
norm_num_groups: int = 32,
|
97
|
+
norm_eps: float = 1e-6,
|
98
|
+
resblock_updown: bool = False,
|
99
|
+
num_head_channels: int | Sequence[int] = 8,
|
100
|
+
with_conditioning: bool = False,
|
101
|
+
transformer_num_layers: int = 1,
|
102
|
+
cross_attention_dim: int | None = None,
|
103
|
+
num_class_embeds: int | None = None,
|
104
|
+
upcast_attention: bool = False,
|
105
|
+
use_flash_attention: bool = False,
|
106
|
+
dropout_cattn: float = 0.0,
|
107
|
+
include_top_region_index_input: bool = False,
|
108
|
+
include_bottom_region_index_input: bool = False,
|
109
|
+
include_spacing_input: bool = False,
|
110
|
+
) -> None:
|
111
|
+
super().__init__()
|
112
|
+
if with_conditioning is True and cross_attention_dim is None:
|
113
|
+
raise ValueError(
|
114
|
+
"DiffusionModelUNetMaisi expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
115
|
+
"when using with_conditioning."
|
116
|
+
)
|
117
|
+
if cross_attention_dim is not None and with_conditioning is False:
|
118
|
+
raise ValueError(
|
119
|
+
"DiffusionModelUNetMaisi expects with_conditioning=True when specifying the cross_attention_dim."
|
120
|
+
)
|
121
|
+
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
|
122
|
+
raise ValueError("Dropout cannot be negative or >1.0!")
|
123
|
+
|
124
|
+
# All number of channels should be multiple of num_groups
|
125
|
+
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
|
126
|
+
raise ValueError(
|
127
|
+
f"DiffusionModelUNetMaisi expects all num_channels being multiple of norm_num_groups, "
|
128
|
+
f"but get num_channels: {num_channels} and norm_num_groups: {norm_num_groups}"
|
129
|
+
)
|
130
|
+
|
131
|
+
if len(num_channels) != len(attention_levels):
|
132
|
+
raise ValueError(
|
133
|
+
f"DiffusionModelUNetMaisi expects num_channels being same size of attention_levels, "
|
134
|
+
f"but get num_channels: {len(num_channels)} and attention_levels: {len(attention_levels)}"
|
135
|
+
)
|
136
|
+
|
137
|
+
if isinstance(num_head_channels, int):
|
138
|
+
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
|
139
|
+
|
140
|
+
if len(num_head_channels) != len(attention_levels):
|
141
|
+
raise ValueError(
|
142
|
+
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
|
143
|
+
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
|
144
|
+
)
|
145
|
+
|
146
|
+
if isinstance(num_res_blocks, int):
|
147
|
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))
|
148
|
+
|
149
|
+
if len(num_res_blocks) != len(num_channels):
|
150
|
+
raise ValueError(
|
151
|
+
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
|
152
|
+
"`num_channels`."
|
153
|
+
)
|
154
|
+
|
155
|
+
if use_flash_attention and not has_xformers:
|
156
|
+
raise ValueError("use_flash_attention is True but xformers is not installed.")
|
157
|
+
|
158
|
+
if use_flash_attention is True and not torch.cuda.is_available():
|
159
|
+
raise ValueError(
|
160
|
+
"torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
|
161
|
+
)
|
162
|
+
|
163
|
+
self.in_channels = in_channels
|
164
|
+
self.block_out_channels = num_channels
|
165
|
+
self.out_channels = out_channels
|
166
|
+
self.num_res_blocks = num_res_blocks
|
167
|
+
self.attention_levels = attention_levels
|
168
|
+
self.num_head_channels = num_head_channels
|
169
|
+
self.with_conditioning = with_conditioning
|
170
|
+
|
171
|
+
# input
|
172
|
+
self.conv_in = Convolution(
|
173
|
+
spatial_dims=spatial_dims,
|
174
|
+
in_channels=in_channels,
|
175
|
+
out_channels=num_channels[0],
|
176
|
+
strides=1,
|
177
|
+
kernel_size=3,
|
178
|
+
padding=1,
|
179
|
+
conv_only=True,
|
180
|
+
)
|
181
|
+
|
182
|
+
# time
|
183
|
+
time_embed_dim = num_channels[0] * 4
|
184
|
+
self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim)
|
185
|
+
|
186
|
+
# class embedding
|
187
|
+
self.num_class_embeds = num_class_embeds
|
188
|
+
if num_class_embeds is not None:
|
189
|
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
190
|
+
|
191
|
+
self.include_top_region_index_input = include_top_region_index_input
|
192
|
+
self.include_bottom_region_index_input = include_bottom_region_index_input
|
193
|
+
self.include_spacing_input = include_spacing_input
|
194
|
+
|
195
|
+
new_time_embed_dim = time_embed_dim
|
196
|
+
if self.include_top_region_index_input:
|
197
|
+
self.top_region_index_layer = self._create_embedding_module(4, time_embed_dim)
|
198
|
+
new_time_embed_dim += time_embed_dim
|
199
|
+
if self.include_bottom_region_index_input:
|
200
|
+
self.bottom_region_index_layer = self._create_embedding_module(4, time_embed_dim)
|
201
|
+
new_time_embed_dim += time_embed_dim
|
202
|
+
if self.include_spacing_input:
|
203
|
+
self.spacing_layer = self._create_embedding_module(3, time_embed_dim)
|
204
|
+
new_time_embed_dim += time_embed_dim
|
205
|
+
|
206
|
+
# down
|
207
|
+
self.down_blocks = nn.ModuleList([])
|
208
|
+
output_channel = num_channels[0]
|
209
|
+
for i in range(len(num_channels)):
|
210
|
+
input_channel = output_channel
|
211
|
+
output_channel = num_channels[i]
|
212
|
+
is_final_block = i == len(num_channels) - 1
|
213
|
+
|
214
|
+
down_block = get_down_block(
|
215
|
+
spatial_dims=spatial_dims,
|
216
|
+
in_channels=input_channel,
|
217
|
+
out_channels=output_channel,
|
218
|
+
temb_channels=new_time_embed_dim,
|
219
|
+
num_res_blocks=num_res_blocks[i],
|
220
|
+
norm_num_groups=norm_num_groups,
|
221
|
+
norm_eps=norm_eps,
|
222
|
+
add_downsample=not is_final_block,
|
223
|
+
resblock_updown=resblock_updown,
|
224
|
+
with_attn=(attention_levels[i] and not with_conditioning),
|
225
|
+
with_cross_attn=(attention_levels[i] and with_conditioning),
|
226
|
+
num_head_channels=num_head_channels[i],
|
227
|
+
transformer_num_layers=transformer_num_layers,
|
228
|
+
cross_attention_dim=cross_attention_dim,
|
229
|
+
upcast_attention=upcast_attention,
|
230
|
+
use_flash_attention=use_flash_attention,
|
231
|
+
dropout_cattn=dropout_cattn,
|
232
|
+
)
|
233
|
+
|
234
|
+
self.down_blocks.append(down_block)
|
235
|
+
|
236
|
+
# mid
|
237
|
+
self.middle_block = get_mid_block(
|
238
|
+
spatial_dims=spatial_dims,
|
239
|
+
in_channels=num_channels[-1],
|
240
|
+
temb_channels=new_time_embed_dim,
|
241
|
+
norm_num_groups=norm_num_groups,
|
242
|
+
norm_eps=norm_eps,
|
243
|
+
with_conditioning=with_conditioning,
|
244
|
+
num_head_channels=num_head_channels[-1],
|
245
|
+
transformer_num_layers=transformer_num_layers,
|
246
|
+
cross_attention_dim=cross_attention_dim,
|
247
|
+
upcast_attention=upcast_attention,
|
248
|
+
use_flash_attention=use_flash_attention,
|
249
|
+
dropout_cattn=dropout_cattn,
|
250
|
+
)
|
251
|
+
|
252
|
+
# up
|
253
|
+
self.up_blocks = nn.ModuleList([])
|
254
|
+
reversed_block_out_channels = list(reversed(num_channels))
|
255
|
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
256
|
+
reversed_attention_levels = list(reversed(attention_levels))
|
257
|
+
reversed_num_head_channels = list(reversed(num_head_channels))
|
258
|
+
output_channel = reversed_block_out_channels[0]
|
259
|
+
for i in range(len(reversed_block_out_channels)):
|
260
|
+
prev_output_channel = output_channel
|
261
|
+
output_channel = reversed_block_out_channels[i]
|
262
|
+
input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]
|
263
|
+
|
264
|
+
is_final_block = i == len(num_channels) - 1
|
265
|
+
|
266
|
+
up_block = get_up_block(
|
267
|
+
spatial_dims=spatial_dims,
|
268
|
+
in_channels=input_channel,
|
269
|
+
prev_output_channel=prev_output_channel,
|
270
|
+
out_channels=output_channel,
|
271
|
+
temb_channels=new_time_embed_dim,
|
272
|
+
num_res_blocks=reversed_num_res_blocks[i] + 1,
|
273
|
+
norm_num_groups=norm_num_groups,
|
274
|
+
norm_eps=norm_eps,
|
275
|
+
add_upsample=not is_final_block,
|
276
|
+
resblock_updown=resblock_updown,
|
277
|
+
with_attn=(reversed_attention_levels[i] and not with_conditioning),
|
278
|
+
with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
|
279
|
+
num_head_channels=reversed_num_head_channels[i],
|
280
|
+
transformer_num_layers=transformer_num_layers,
|
281
|
+
cross_attention_dim=cross_attention_dim,
|
282
|
+
upcast_attention=upcast_attention,
|
283
|
+
use_flash_attention=use_flash_attention,
|
284
|
+
dropout_cattn=dropout_cattn,
|
285
|
+
)
|
286
|
+
|
287
|
+
self.up_blocks.append(up_block)
|
288
|
+
|
289
|
+
# out
|
290
|
+
self.out = nn.Sequential(
|
291
|
+
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),
|
292
|
+
nn.SiLU(),
|
293
|
+
zero_module(
|
294
|
+
Convolution(
|
295
|
+
spatial_dims=spatial_dims,
|
296
|
+
in_channels=num_channels[0],
|
297
|
+
out_channels=out_channels,
|
298
|
+
strides=1,
|
299
|
+
kernel_size=3,
|
300
|
+
padding=1,
|
301
|
+
conv_only=True,
|
302
|
+
)
|
303
|
+
),
|
304
|
+
)
|
305
|
+
|
306
|
+
def _create_embedding_module(self, input_dim, embed_dim):
|
307
|
+
model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim))
|
308
|
+
return model
|
309
|
+
|
310
|
+
def _get_time_and_class_embedding(self, x, timesteps, class_labels):
|
311
|
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
312
|
+
|
313
|
+
# timesteps does not contain any weights and will always return f32 tensors
|
314
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
315
|
+
# there might be better ways to encapsulate this.
|
316
|
+
t_emb = t_emb.to(dtype=x.dtype)
|
317
|
+
emb = self.time_embed(t_emb)
|
318
|
+
|
319
|
+
if self.num_class_embeds is not None:
|
320
|
+
if class_labels is None:
|
321
|
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
322
|
+
class_emb = self.class_embedding(class_labels)
|
323
|
+
class_emb = class_emb.to(dtype=x.dtype)
|
324
|
+
emb += class_emb
|
325
|
+
return emb
|
326
|
+
|
327
|
+
def _get_input_embeddings(self, emb, top_index, bottom_index, spacing):
|
328
|
+
if self.include_top_region_index_input:
|
329
|
+
_emb = self.top_region_index_layer(top_index)
|
330
|
+
emb = torch.cat((emb, _emb), dim=1)
|
331
|
+
if self.include_bottom_region_index_input:
|
332
|
+
_emb = self.bottom_region_index_layer(bottom_index)
|
333
|
+
emb = torch.cat((emb, _emb), dim=1)
|
334
|
+
if self.include_spacing_input:
|
335
|
+
_emb = self.spacing_layer(spacing)
|
336
|
+
emb = torch.cat((emb, _emb), dim=1)
|
337
|
+
return emb
|
338
|
+
|
339
|
+
def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals):
|
340
|
+
if context is not None and self.with_conditioning is False:
|
341
|
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
342
|
+
down_block_res_samples: list[torch.Tensor] = [h]
|
343
|
+
for downsample_block in self.down_blocks:
|
344
|
+
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
|
345
|
+
down_block_res_samples.extend(res_samples)
|
346
|
+
|
347
|
+
# Additional residual conections for Controlnets
|
348
|
+
if down_block_additional_residuals is not None:
|
349
|
+
new_down_block_res_samples: list[torch.Tensor] = []
|
350
|
+
for down_block_res_sample, down_block_additional_residual in zip(
|
351
|
+
down_block_res_samples, down_block_additional_residuals
|
352
|
+
):
|
353
|
+
down_block_res_sample += down_block_additional_residual
|
354
|
+
new_down_block_res_samples.append(down_block_res_sample)
|
355
|
+
|
356
|
+
down_block_res_samples = new_down_block_res_samples
|
357
|
+
return h, down_block_res_samples
|
358
|
+
|
359
|
+
def _apply_up_blocks(self, h, emb, context, down_block_res_samples):
|
360
|
+
for upsample_block in self.up_blocks:
|
361
|
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
362
|
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
363
|
+
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
|
364
|
+
|
365
|
+
return h
|
366
|
+
|
367
|
+
def forward(
|
368
|
+
self,
|
369
|
+
x: torch.Tensor,
|
370
|
+
timesteps: torch.Tensor,
|
371
|
+
context: torch.Tensor | None = None,
|
372
|
+
class_labels: torch.Tensor | None = None,
|
373
|
+
down_block_additional_residuals: tuple[torch.Tensor] | None = None,
|
374
|
+
mid_block_additional_residual: torch.Tensor | None = None,
|
375
|
+
top_region_index_tensor: torch.Tensor | None = None,
|
376
|
+
bottom_region_index_tensor: torch.Tensor | None = None,
|
377
|
+
spacing_tensor: torch.Tensor | None = None,
|
378
|
+
) -> torch.Tensor:
|
379
|
+
"""
|
380
|
+
Forward pass through the UNet model.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
x: Input tensor of shape (N, C, SpatialDims).
|
384
|
+
timesteps: Timestep tensor of shape (N,).
|
385
|
+
context: Context tensor of shape (N, 1, ContextDim).
|
386
|
+
class_labels: Class labels tensor of shape (N,).
|
387
|
+
down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims).
|
388
|
+
mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims).
|
389
|
+
top_region_index_tensor: Tensor representing top region index of shape (N, 4).
|
390
|
+
bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4).
|
391
|
+
spacing_tensor: Tensor representing spacing of shape (N, 3).
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
A tensor representing the output of the UNet model.
|
395
|
+
"""
|
396
|
+
|
397
|
+
emb = self._get_time_and_class_embedding(x, timesteps, class_labels)
|
398
|
+
emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor)
|
399
|
+
h = self.conv_in(x)
|
400
|
+
h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals)
|
401
|
+
h = self.middle_block(h, emb, context)
|
402
|
+
|
403
|
+
# Additional residual conections for Controlnets
|
404
|
+
if mid_block_additional_residual is not None:
|
405
|
+
h += mid_block_additional_residual
|
406
|
+
|
407
|
+
h = self._apply_up_blocks(h, emb, context, _updated_down_block_res_samples)
|
408
|
+
h = self.out(h)
|
409
|
+
h_tensor: torch.Tensor = convert_to_tensor(h)
|
410
|
+
return h_tensor
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|