monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1913 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
#
|
12
|
+
# =========================================================================
|
13
|
+
# Adapted from https://github.com/huggingface/diffusers
|
14
|
+
# which has the following license:
|
15
|
+
# https://github.com/huggingface/diffusers/blob/main/LICENSE
|
16
|
+
#
|
17
|
+
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
18
|
+
#
|
19
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
20
|
+
# you may not use this file except in compliance with the License.
|
21
|
+
# You may obtain a copy of the License at
|
22
|
+
#
|
23
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
24
|
+
#
|
25
|
+
# Unless required by applicable law or agreed to in writing, software
|
26
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
27
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
28
|
+
# See the License for the specific language governing permissions and
|
29
|
+
# limitations under the License.
|
30
|
+
# =========================================================================
|
31
|
+
|
32
|
+
from __future__ import annotations
|
33
|
+
|
34
|
+
import math
|
35
|
+
from collections.abc import Sequence
|
36
|
+
|
37
|
+
import torch
|
38
|
+
from torch import nn
|
39
|
+
|
40
|
+
from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample
|
41
|
+
from monai.networks.layers.factories import Pool
|
42
|
+
from monai.utils import ensure_tuple_rep, optional_import
|
43
|
+
|
44
|
+
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
|
45
|
+
|
46
|
+
__all__ = ["DiffusionModelUNet"]
|
47
|
+
|
48
|
+
|
49
|
+
def zero_module(module: nn.Module) -> nn.Module:
|
50
|
+
"""
|
51
|
+
Zero out the parameters of a module and return it.
|
52
|
+
"""
|
53
|
+
for p in module.parameters():
|
54
|
+
p.detach().zero_()
|
55
|
+
return module
|
56
|
+
|
57
|
+
|
58
|
+
class DiffusionUNetTransformerBlock(nn.Module):
|
59
|
+
"""
|
60
|
+
A Transformer block that allows for the input dimension to differ from the hidden dimension.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
num_channels: number of channels in the input and output.
|
64
|
+
num_attention_heads: number of heads to use for multi-head attention.
|
65
|
+
num_head_channels: number of channels in each attention head.
|
66
|
+
dropout: dropout probability to use.
|
67
|
+
cross_attention_dim: size of the context vector for cross attention.
|
68
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
69
|
+
|
70
|
+
"""
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
num_channels: int,
|
75
|
+
num_attention_heads: int,
|
76
|
+
num_head_channels: int,
|
77
|
+
dropout: float = 0.0,
|
78
|
+
cross_attention_dim: int | None = None,
|
79
|
+
upcast_attention: bool = False,
|
80
|
+
) -> None:
|
81
|
+
super().__init__()
|
82
|
+
self.attn1 = SABlock(
|
83
|
+
hidden_size=num_attention_heads * num_head_channels,
|
84
|
+
hidden_input_size=num_channels,
|
85
|
+
num_heads=num_attention_heads,
|
86
|
+
dim_head=num_head_channels,
|
87
|
+
dropout_rate=dropout,
|
88
|
+
attention_dtype=torch.float if upcast_attention else None,
|
89
|
+
)
|
90
|
+
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
|
91
|
+
self.attn2 = CrossAttentionBlock(
|
92
|
+
hidden_size=num_attention_heads * num_head_channels,
|
93
|
+
num_heads=num_attention_heads,
|
94
|
+
hidden_input_size=num_channels,
|
95
|
+
context_input_size=cross_attention_dim,
|
96
|
+
dim_head=num_head_channels,
|
97
|
+
dropout_rate=dropout,
|
98
|
+
attention_dtype=torch.float if upcast_attention else None,
|
99
|
+
)
|
100
|
+
self.norm1 = nn.LayerNorm(num_channels)
|
101
|
+
self.norm2 = nn.LayerNorm(num_channels)
|
102
|
+
self.norm3 = nn.LayerNorm(num_channels)
|
103
|
+
|
104
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
105
|
+
# 1. Self-Attention
|
106
|
+
x = self.attn1(self.norm1(x)) + x
|
107
|
+
|
108
|
+
# 2. Cross-Attention
|
109
|
+
x = self.attn2(self.norm2(x), context=context) + x
|
110
|
+
|
111
|
+
# 3. Feed-forward
|
112
|
+
x = self.ff(self.norm3(x)) + x
|
113
|
+
return x
|
114
|
+
|
115
|
+
|
116
|
+
class SpatialTransformer(nn.Module):
|
117
|
+
"""
|
118
|
+
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
119
|
+
standard transformer action. Finally, reshape to image.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
spatial_dims: number of spatial dimensions.
|
123
|
+
in_channels: number of channels in the input and output.
|
124
|
+
num_attention_heads: number of heads to use for multi-head attention.
|
125
|
+
num_head_channels: number of channels in each attention head.
|
126
|
+
num_layers: number of layers of Transformer blocks to use.
|
127
|
+
dropout: dropout probability to use.
|
128
|
+
norm_num_groups: number of groups for the normalization.
|
129
|
+
norm_eps: epsilon for the normalization.
|
130
|
+
cross_attention_dim: number of context dimensions to use.
|
131
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
132
|
+
"""
|
133
|
+
|
134
|
+
def __init__(
|
135
|
+
self,
|
136
|
+
spatial_dims: int,
|
137
|
+
in_channels: int,
|
138
|
+
num_attention_heads: int,
|
139
|
+
num_head_channels: int,
|
140
|
+
num_layers: int = 1,
|
141
|
+
dropout: float = 0.0,
|
142
|
+
norm_num_groups: int = 32,
|
143
|
+
norm_eps: float = 1e-6,
|
144
|
+
cross_attention_dim: int | None = None,
|
145
|
+
upcast_attention: bool = False,
|
146
|
+
) -> None:
|
147
|
+
super().__init__()
|
148
|
+
self.spatial_dims = spatial_dims
|
149
|
+
self.in_channels = in_channels
|
150
|
+
inner_dim = num_attention_heads * num_head_channels
|
151
|
+
|
152
|
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
|
153
|
+
|
154
|
+
self.proj_in = Convolution(
|
155
|
+
spatial_dims=spatial_dims,
|
156
|
+
in_channels=in_channels,
|
157
|
+
out_channels=inner_dim,
|
158
|
+
strides=1,
|
159
|
+
kernel_size=1,
|
160
|
+
padding=0,
|
161
|
+
conv_only=True,
|
162
|
+
)
|
163
|
+
|
164
|
+
self.transformer_blocks = nn.ModuleList(
|
165
|
+
[
|
166
|
+
DiffusionUNetTransformerBlock(
|
167
|
+
num_channels=inner_dim,
|
168
|
+
num_attention_heads=num_attention_heads,
|
169
|
+
num_head_channels=num_head_channels,
|
170
|
+
dropout=dropout,
|
171
|
+
cross_attention_dim=cross_attention_dim,
|
172
|
+
upcast_attention=upcast_attention,
|
173
|
+
)
|
174
|
+
for _ in range(num_layers)
|
175
|
+
]
|
176
|
+
)
|
177
|
+
|
178
|
+
self.proj_out = zero_module(
|
179
|
+
Convolution(
|
180
|
+
spatial_dims=spatial_dims,
|
181
|
+
in_channels=inner_dim,
|
182
|
+
out_channels=in_channels,
|
183
|
+
strides=1,
|
184
|
+
kernel_size=1,
|
185
|
+
padding=0,
|
186
|
+
conv_only=True,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
|
190
|
+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
|
191
|
+
# note: if no context is given, cross-attention defaults to self-attention
|
192
|
+
batch = channel = height = width = depth = -1
|
193
|
+
if self.spatial_dims == 2:
|
194
|
+
batch, channel, height, width = x.shape
|
195
|
+
if self.spatial_dims == 3:
|
196
|
+
batch, channel, height, width, depth = x.shape
|
197
|
+
|
198
|
+
residual = x
|
199
|
+
x = self.norm(x)
|
200
|
+
x = self.proj_in(x)
|
201
|
+
|
202
|
+
inner_dim = x.shape[1]
|
203
|
+
|
204
|
+
if self.spatial_dims == 2:
|
205
|
+
x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
206
|
+
if self.spatial_dims == 3:
|
207
|
+
x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)
|
208
|
+
|
209
|
+
for block in self.transformer_blocks:
|
210
|
+
x = block(x, context=context)
|
211
|
+
|
212
|
+
if self.spatial_dims == 2:
|
213
|
+
x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
214
|
+
if self.spatial_dims == 3:
|
215
|
+
x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
|
216
|
+
|
217
|
+
x = self.proj_out(x)
|
218
|
+
return x + residual
|
219
|
+
|
220
|
+
|
221
|
+
def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
|
222
|
+
"""
|
223
|
+
Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
|
224
|
+
Models" https://arxiv.org/abs/2006.11239.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
timesteps: a 1-D Tensor of N indices, one per batch element.
|
228
|
+
embedding_dim: the dimension of the output.
|
229
|
+
max_period: controls the minimum frequency of the embeddings.
|
230
|
+
"""
|
231
|
+
if timesteps.ndim != 1:
|
232
|
+
raise ValueError("Timesteps should be a 1d-array")
|
233
|
+
|
234
|
+
half_dim = embedding_dim // 2
|
235
|
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
236
|
+
freqs = torch.exp(exponent / half_dim)
|
237
|
+
|
238
|
+
args = timesteps[:, None].float() * freqs[None, :]
|
239
|
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
240
|
+
|
241
|
+
# zero pad
|
242
|
+
if embedding_dim % 2 == 1:
|
243
|
+
embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))
|
244
|
+
|
245
|
+
return embedding
|
246
|
+
|
247
|
+
|
248
|
+
class DiffusionUnetDownsample(nn.Module):
|
249
|
+
"""
|
250
|
+
Downsampling layer.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
spatial_dims: number of spatial dimensions.
|
254
|
+
num_channels: number of input channels.
|
255
|
+
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
|
256
|
+
False, the number of output channels must be the same as the number of input channels.
|
257
|
+
out_channels: number of output channels.
|
258
|
+
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
|
259
|
+
for each dimension.
|
260
|
+
"""
|
261
|
+
|
262
|
+
def __init__(
|
263
|
+
self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
|
264
|
+
) -> None:
|
265
|
+
super().__init__()
|
266
|
+
self.num_channels = num_channels
|
267
|
+
self.out_channels = out_channels or num_channels
|
268
|
+
self.use_conv = use_conv
|
269
|
+
if use_conv:
|
270
|
+
self.op = Convolution(
|
271
|
+
spatial_dims=spatial_dims,
|
272
|
+
in_channels=self.num_channels,
|
273
|
+
out_channels=self.out_channels,
|
274
|
+
strides=2,
|
275
|
+
kernel_size=3,
|
276
|
+
padding=padding,
|
277
|
+
conv_only=True,
|
278
|
+
)
|
279
|
+
else:
|
280
|
+
if self.num_channels != self.out_channels:
|
281
|
+
raise ValueError("num_channels and out_channels must be equal when use_conv=False")
|
282
|
+
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
|
283
|
+
|
284
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
285
|
+
del emb
|
286
|
+
if x.shape[1] != self.num_channels:
|
287
|
+
raise ValueError(
|
288
|
+
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
|
289
|
+
f"({self.num_channels})"
|
290
|
+
)
|
291
|
+
output: torch.Tensor = self.op(x)
|
292
|
+
return output
|
293
|
+
|
294
|
+
|
295
|
+
class WrappedUpsample(Upsample):
|
296
|
+
"""
|
297
|
+
Wraps MONAI upsample block to allow for calling with timestep embeddings.
|
298
|
+
"""
|
299
|
+
|
300
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
|
301
|
+
del emb
|
302
|
+
upsampled: torch.Tensor = super().forward(x)
|
303
|
+
return upsampled
|
304
|
+
|
305
|
+
|
306
|
+
class DiffusionUNetResnetBlock(nn.Module):
|
307
|
+
"""
|
308
|
+
Residual block with timestep conditioning.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
spatial_dims: The number of spatial dimensions.
|
312
|
+
in_channels: number of input channels.
|
313
|
+
temb_channels: number of timestep embedding channels.
|
314
|
+
out_channels: number of output channels.
|
315
|
+
up: if True, performs upsampling.
|
316
|
+
down: if True, performs downsampling.
|
317
|
+
norm_num_groups: number of groups for the group normalization.
|
318
|
+
norm_eps: epsilon for the group normalization.
|
319
|
+
"""
|
320
|
+
|
321
|
+
def __init__(
|
322
|
+
self,
|
323
|
+
spatial_dims: int,
|
324
|
+
in_channels: int,
|
325
|
+
temb_channels: int,
|
326
|
+
out_channels: int | None = None,
|
327
|
+
up: bool = False,
|
328
|
+
down: bool = False,
|
329
|
+
norm_num_groups: int = 32,
|
330
|
+
norm_eps: float = 1e-6,
|
331
|
+
) -> None:
|
332
|
+
super().__init__()
|
333
|
+
self.spatial_dims = spatial_dims
|
334
|
+
self.channels = in_channels
|
335
|
+
self.emb_channels = temb_channels
|
336
|
+
self.out_channels = out_channels or in_channels
|
337
|
+
self.up = up
|
338
|
+
self.down = down
|
339
|
+
|
340
|
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
|
341
|
+
self.nonlinearity = nn.SiLU()
|
342
|
+
self.conv1 = Convolution(
|
343
|
+
spatial_dims=spatial_dims,
|
344
|
+
in_channels=in_channels,
|
345
|
+
out_channels=self.out_channels,
|
346
|
+
strides=1,
|
347
|
+
kernel_size=3,
|
348
|
+
padding=1,
|
349
|
+
conv_only=True,
|
350
|
+
)
|
351
|
+
|
352
|
+
self.upsample = self.downsample = None
|
353
|
+
if self.up:
|
354
|
+
self.upsample = WrappedUpsample(
|
355
|
+
spatial_dims=spatial_dims,
|
356
|
+
mode="nontrainable",
|
357
|
+
in_channels=in_channels,
|
358
|
+
out_channels=in_channels,
|
359
|
+
interp_mode="nearest",
|
360
|
+
scale_factor=2.0,
|
361
|
+
align_corners=None,
|
362
|
+
)
|
363
|
+
elif down:
|
364
|
+
self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)
|
365
|
+
|
366
|
+
self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
|
367
|
+
|
368
|
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)
|
369
|
+
self.conv2 = zero_module(
|
370
|
+
Convolution(
|
371
|
+
spatial_dims=spatial_dims,
|
372
|
+
in_channels=self.out_channels,
|
373
|
+
out_channels=self.out_channels,
|
374
|
+
strides=1,
|
375
|
+
kernel_size=3,
|
376
|
+
padding=1,
|
377
|
+
conv_only=True,
|
378
|
+
)
|
379
|
+
)
|
380
|
+
self.skip_connection: nn.Module
|
381
|
+
if self.out_channels == in_channels:
|
382
|
+
self.skip_connection = nn.Identity()
|
383
|
+
else:
|
384
|
+
self.skip_connection = Convolution(
|
385
|
+
spatial_dims=spatial_dims,
|
386
|
+
in_channels=in_channels,
|
387
|
+
out_channels=self.out_channels,
|
388
|
+
strides=1,
|
389
|
+
kernel_size=1,
|
390
|
+
padding=0,
|
391
|
+
conv_only=True,
|
392
|
+
)
|
393
|
+
|
394
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
395
|
+
h = x
|
396
|
+
h = self.norm1(h)
|
397
|
+
h = self.nonlinearity(h)
|
398
|
+
|
399
|
+
if self.upsample is not None:
|
400
|
+
x = self.upsample(x)
|
401
|
+
h = self.upsample(h)
|
402
|
+
elif self.downsample is not None:
|
403
|
+
x = self.downsample(x)
|
404
|
+
h = self.downsample(h)
|
405
|
+
|
406
|
+
h = self.conv1(h)
|
407
|
+
|
408
|
+
if self.spatial_dims == 2:
|
409
|
+
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
|
410
|
+
else:
|
411
|
+
temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
|
412
|
+
h = h + temb
|
413
|
+
|
414
|
+
h = self.norm2(h)
|
415
|
+
h = self.nonlinearity(h)
|
416
|
+
h = self.conv2(h)
|
417
|
+
output: torch.Tensor = self.skip_connection(x) + h
|
418
|
+
return output
|
419
|
+
|
420
|
+
|
421
|
+
class DownBlock(nn.Module):
|
422
|
+
"""
|
423
|
+
Unet's down block containing resnet and downsamplers blocks.
|
424
|
+
|
425
|
+
Args:
|
426
|
+
spatial_dims: The number of spatial dimensions.
|
427
|
+
in_channels: number of input channels.
|
428
|
+
out_channels: number of output channels.
|
429
|
+
temb_channels: number of timestep embedding channels.
|
430
|
+
num_res_blocks: number of residual blocks.
|
431
|
+
norm_num_groups: number of groups for the group normalization.
|
432
|
+
norm_eps: epsilon for the group normalization.
|
433
|
+
add_downsample: if True add downsample block.
|
434
|
+
resblock_updown: if True use residual blocks for downsampling.
|
435
|
+
downsample_padding: padding used in the downsampling block.
|
436
|
+
"""
|
437
|
+
|
438
|
+
def __init__(
|
439
|
+
self,
|
440
|
+
spatial_dims: int,
|
441
|
+
in_channels: int,
|
442
|
+
out_channels: int,
|
443
|
+
temb_channels: int,
|
444
|
+
num_res_blocks: int = 1,
|
445
|
+
norm_num_groups: int = 32,
|
446
|
+
norm_eps: float = 1e-6,
|
447
|
+
add_downsample: bool = True,
|
448
|
+
resblock_updown: bool = False,
|
449
|
+
downsample_padding: int = 1,
|
450
|
+
) -> None:
|
451
|
+
super().__init__()
|
452
|
+
self.resblock_updown = resblock_updown
|
453
|
+
|
454
|
+
resnets = []
|
455
|
+
|
456
|
+
for i in range(num_res_blocks):
|
457
|
+
in_channels = in_channels if i == 0 else out_channels
|
458
|
+
resnets.append(
|
459
|
+
DiffusionUNetResnetBlock(
|
460
|
+
spatial_dims=spatial_dims,
|
461
|
+
in_channels=in_channels,
|
462
|
+
out_channels=out_channels,
|
463
|
+
temb_channels=temb_channels,
|
464
|
+
norm_num_groups=norm_num_groups,
|
465
|
+
norm_eps=norm_eps,
|
466
|
+
)
|
467
|
+
)
|
468
|
+
|
469
|
+
self.resnets = nn.ModuleList(resnets)
|
470
|
+
|
471
|
+
if add_downsample:
|
472
|
+
self.downsampler: nn.Module | None
|
473
|
+
if resblock_updown:
|
474
|
+
self.downsampler = DiffusionUNetResnetBlock(
|
475
|
+
spatial_dims=spatial_dims,
|
476
|
+
in_channels=out_channels,
|
477
|
+
out_channels=out_channels,
|
478
|
+
temb_channels=temb_channels,
|
479
|
+
norm_num_groups=norm_num_groups,
|
480
|
+
norm_eps=norm_eps,
|
481
|
+
down=True,
|
482
|
+
)
|
483
|
+
else:
|
484
|
+
self.downsampler = DiffusionUnetDownsample(
|
485
|
+
spatial_dims=spatial_dims,
|
486
|
+
num_channels=out_channels,
|
487
|
+
use_conv=True,
|
488
|
+
out_channels=out_channels,
|
489
|
+
padding=downsample_padding,
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
self.downsampler = None
|
493
|
+
|
494
|
+
def forward(
|
495
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
496
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
497
|
+
del context
|
498
|
+
output_states = []
|
499
|
+
|
500
|
+
for resnet in self.resnets:
|
501
|
+
hidden_states = resnet(hidden_states, temb)
|
502
|
+
output_states.append(hidden_states)
|
503
|
+
|
504
|
+
if self.downsampler is not None:
|
505
|
+
hidden_states = self.downsampler(hidden_states, temb)
|
506
|
+
output_states.append(hidden_states)
|
507
|
+
|
508
|
+
return hidden_states, output_states
|
509
|
+
|
510
|
+
|
511
|
+
class AttnDownBlock(nn.Module):
|
512
|
+
"""
|
513
|
+
Unet's down block containing resnet, downsamplers and self-attention blocks.
|
514
|
+
|
515
|
+
Args:
|
516
|
+
spatial_dims: The number of spatial dimensions.
|
517
|
+
in_channels: number of input channels.
|
518
|
+
out_channels: number of output channels.
|
519
|
+
temb_channels: number of timestep embedding channels.
|
520
|
+
num_res_blocks: number of residual blocks.
|
521
|
+
norm_num_groups: number of groups for the group normalization.
|
522
|
+
norm_eps: epsilon for the group normalization.
|
523
|
+
add_downsample: if True add downsample block.
|
524
|
+
resblock_updown: if True use residual blocks for downsampling.
|
525
|
+
downsample_padding: padding used in the downsampling block.
|
526
|
+
num_head_channels: number of channels in each attention head.
|
527
|
+
"""
|
528
|
+
|
529
|
+
def __init__(
|
530
|
+
self,
|
531
|
+
spatial_dims: int,
|
532
|
+
in_channels: int,
|
533
|
+
out_channels: int,
|
534
|
+
temb_channels: int,
|
535
|
+
num_res_blocks: int = 1,
|
536
|
+
norm_num_groups: int = 32,
|
537
|
+
norm_eps: float = 1e-6,
|
538
|
+
add_downsample: bool = True,
|
539
|
+
resblock_updown: bool = False,
|
540
|
+
downsample_padding: int = 1,
|
541
|
+
num_head_channels: int = 1,
|
542
|
+
) -> None:
|
543
|
+
super().__init__()
|
544
|
+
self.resblock_updown = resblock_updown
|
545
|
+
|
546
|
+
resnets = []
|
547
|
+
attentions = []
|
548
|
+
|
549
|
+
for i in range(num_res_blocks):
|
550
|
+
in_channels = in_channels if i == 0 else out_channels
|
551
|
+
resnets.append(
|
552
|
+
DiffusionUNetResnetBlock(
|
553
|
+
spatial_dims=spatial_dims,
|
554
|
+
in_channels=in_channels,
|
555
|
+
out_channels=out_channels,
|
556
|
+
temb_channels=temb_channels,
|
557
|
+
norm_num_groups=norm_num_groups,
|
558
|
+
norm_eps=norm_eps,
|
559
|
+
)
|
560
|
+
)
|
561
|
+
attentions.append(
|
562
|
+
SpatialAttentionBlock(
|
563
|
+
spatial_dims=spatial_dims,
|
564
|
+
num_channels=out_channels,
|
565
|
+
num_head_channels=num_head_channels,
|
566
|
+
norm_num_groups=norm_num_groups,
|
567
|
+
norm_eps=norm_eps,
|
568
|
+
)
|
569
|
+
)
|
570
|
+
|
571
|
+
self.attentions = nn.ModuleList(attentions)
|
572
|
+
self.resnets = nn.ModuleList(resnets)
|
573
|
+
|
574
|
+
self.downsampler: nn.Module | None
|
575
|
+
if add_downsample:
|
576
|
+
if resblock_updown:
|
577
|
+
self.downsampler = DiffusionUNetResnetBlock(
|
578
|
+
spatial_dims=spatial_dims,
|
579
|
+
in_channels=out_channels,
|
580
|
+
out_channels=out_channels,
|
581
|
+
temb_channels=temb_channels,
|
582
|
+
norm_num_groups=norm_num_groups,
|
583
|
+
norm_eps=norm_eps,
|
584
|
+
down=True,
|
585
|
+
)
|
586
|
+
else:
|
587
|
+
self.downsampler = DiffusionUnetDownsample(
|
588
|
+
spatial_dims=spatial_dims,
|
589
|
+
num_channels=out_channels,
|
590
|
+
use_conv=True,
|
591
|
+
out_channels=out_channels,
|
592
|
+
padding=downsample_padding,
|
593
|
+
)
|
594
|
+
else:
|
595
|
+
self.downsampler = None
|
596
|
+
|
597
|
+
def forward(
|
598
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
599
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
600
|
+
del context
|
601
|
+
output_states = []
|
602
|
+
|
603
|
+
for resnet, attn in zip(self.resnets, self.attentions):
|
604
|
+
hidden_states = resnet(hidden_states, temb)
|
605
|
+
hidden_states = attn(hidden_states).contiguous()
|
606
|
+
output_states.append(hidden_states)
|
607
|
+
|
608
|
+
if self.downsampler is not None:
|
609
|
+
hidden_states = self.downsampler(hidden_states, temb)
|
610
|
+
output_states.append(hidden_states)
|
611
|
+
|
612
|
+
return hidden_states, output_states
|
613
|
+
|
614
|
+
|
615
|
+
class CrossAttnDownBlock(nn.Module):
|
616
|
+
"""
|
617
|
+
Unet's down block containing resnet, downsamplers and cross-attention blocks.
|
618
|
+
|
619
|
+
Args:
|
620
|
+
spatial_dims: number of spatial dimensions.
|
621
|
+
in_channels: number of input channels.
|
622
|
+
out_channels: number of output channels.
|
623
|
+
temb_channels: number of timestep embedding channels.
|
624
|
+
num_res_blocks: number of residual blocks.
|
625
|
+
norm_num_groups: number of groups for the group normalization.
|
626
|
+
norm_eps: epsilon for the group normalization.
|
627
|
+
add_downsample: if True add downsample block.
|
628
|
+
resblock_updown: if True use residual blocks for downsampling.
|
629
|
+
downsample_padding: padding used in the downsampling block.
|
630
|
+
num_head_channels: number of channels in each attention head.
|
631
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
632
|
+
cross_attention_dim: number of context dimensions to use.
|
633
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
634
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
635
|
+
"""
|
636
|
+
|
637
|
+
def __init__(
|
638
|
+
self,
|
639
|
+
spatial_dims: int,
|
640
|
+
in_channels: int,
|
641
|
+
out_channels: int,
|
642
|
+
temb_channels: int,
|
643
|
+
num_res_blocks: int = 1,
|
644
|
+
norm_num_groups: int = 32,
|
645
|
+
norm_eps: float = 1e-6,
|
646
|
+
add_downsample: bool = True,
|
647
|
+
resblock_updown: bool = False,
|
648
|
+
downsample_padding: int = 1,
|
649
|
+
num_head_channels: int = 1,
|
650
|
+
transformer_num_layers: int = 1,
|
651
|
+
cross_attention_dim: int | None = None,
|
652
|
+
upcast_attention: bool = False,
|
653
|
+
dropout_cattn: float = 0.0,
|
654
|
+
) -> None:
|
655
|
+
super().__init__()
|
656
|
+
self.resblock_updown = resblock_updown
|
657
|
+
|
658
|
+
resnets = []
|
659
|
+
attentions = []
|
660
|
+
|
661
|
+
for i in range(num_res_blocks):
|
662
|
+
in_channels = in_channels if i == 0 else out_channels
|
663
|
+
resnets.append(
|
664
|
+
DiffusionUNetResnetBlock(
|
665
|
+
spatial_dims=spatial_dims,
|
666
|
+
in_channels=in_channels,
|
667
|
+
out_channels=out_channels,
|
668
|
+
temb_channels=temb_channels,
|
669
|
+
norm_num_groups=norm_num_groups,
|
670
|
+
norm_eps=norm_eps,
|
671
|
+
)
|
672
|
+
)
|
673
|
+
|
674
|
+
attentions.append(
|
675
|
+
SpatialTransformer(
|
676
|
+
spatial_dims=spatial_dims,
|
677
|
+
in_channels=out_channels,
|
678
|
+
num_attention_heads=out_channels // num_head_channels,
|
679
|
+
num_head_channels=num_head_channels,
|
680
|
+
num_layers=transformer_num_layers,
|
681
|
+
norm_num_groups=norm_num_groups,
|
682
|
+
norm_eps=norm_eps,
|
683
|
+
cross_attention_dim=cross_attention_dim,
|
684
|
+
upcast_attention=upcast_attention,
|
685
|
+
dropout=dropout_cattn,
|
686
|
+
)
|
687
|
+
)
|
688
|
+
|
689
|
+
self.attentions = nn.ModuleList(attentions)
|
690
|
+
self.resnets = nn.ModuleList(resnets)
|
691
|
+
|
692
|
+
self.downsampler: nn.Module | None
|
693
|
+
if add_downsample:
|
694
|
+
if resblock_updown:
|
695
|
+
self.downsampler = DiffusionUNetResnetBlock(
|
696
|
+
spatial_dims=spatial_dims,
|
697
|
+
in_channels=out_channels,
|
698
|
+
out_channels=out_channels,
|
699
|
+
temb_channels=temb_channels,
|
700
|
+
norm_num_groups=norm_num_groups,
|
701
|
+
norm_eps=norm_eps,
|
702
|
+
down=True,
|
703
|
+
)
|
704
|
+
else:
|
705
|
+
self.downsampler = DiffusionUnetDownsample(
|
706
|
+
spatial_dims=spatial_dims,
|
707
|
+
num_channels=out_channels,
|
708
|
+
use_conv=True,
|
709
|
+
out_channels=out_channels,
|
710
|
+
padding=downsample_padding,
|
711
|
+
)
|
712
|
+
else:
|
713
|
+
self.downsampler = None
|
714
|
+
|
715
|
+
def forward(
|
716
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
717
|
+
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
718
|
+
output_states = []
|
719
|
+
|
720
|
+
for resnet, attn in zip(self.resnets, self.attentions):
|
721
|
+
hidden_states = resnet(hidden_states, temb)
|
722
|
+
hidden_states = attn(hidden_states, context=context).contiguous()
|
723
|
+
output_states.append(hidden_states)
|
724
|
+
|
725
|
+
if self.downsampler is not None:
|
726
|
+
hidden_states = self.downsampler(hidden_states, temb)
|
727
|
+
output_states.append(hidden_states)
|
728
|
+
|
729
|
+
return hidden_states, output_states
|
730
|
+
|
731
|
+
|
732
|
+
class AttnMidBlock(nn.Module):
|
733
|
+
"""
|
734
|
+
Unet's mid block containing resnet and self-attention blocks.
|
735
|
+
|
736
|
+
Args:
|
737
|
+
spatial_dims: The number of spatial dimensions.
|
738
|
+
in_channels: number of input channels.
|
739
|
+
temb_channels: number of timestep embedding channels.
|
740
|
+
norm_num_groups: number of groups for the group normalization.
|
741
|
+
norm_eps: epsilon for the group normalization.
|
742
|
+
num_head_channels: number of channels in each attention head.
|
743
|
+
"""
|
744
|
+
|
745
|
+
def __init__(
|
746
|
+
self,
|
747
|
+
spatial_dims: int,
|
748
|
+
in_channels: int,
|
749
|
+
temb_channels: int,
|
750
|
+
norm_num_groups: int = 32,
|
751
|
+
norm_eps: float = 1e-6,
|
752
|
+
num_head_channels: int = 1,
|
753
|
+
) -> None:
|
754
|
+
super().__init__()
|
755
|
+
|
756
|
+
self.resnet_1 = DiffusionUNetResnetBlock(
|
757
|
+
spatial_dims=spatial_dims,
|
758
|
+
in_channels=in_channels,
|
759
|
+
out_channels=in_channels,
|
760
|
+
temb_channels=temb_channels,
|
761
|
+
norm_num_groups=norm_num_groups,
|
762
|
+
norm_eps=norm_eps,
|
763
|
+
)
|
764
|
+
self.attention = SpatialAttentionBlock(
|
765
|
+
spatial_dims=spatial_dims,
|
766
|
+
num_channels=in_channels,
|
767
|
+
num_head_channels=num_head_channels,
|
768
|
+
norm_num_groups=norm_num_groups,
|
769
|
+
norm_eps=norm_eps,
|
770
|
+
)
|
771
|
+
|
772
|
+
self.resnet_2 = DiffusionUNetResnetBlock(
|
773
|
+
spatial_dims=spatial_dims,
|
774
|
+
in_channels=in_channels,
|
775
|
+
out_channels=in_channels,
|
776
|
+
temb_channels=temb_channels,
|
777
|
+
norm_num_groups=norm_num_groups,
|
778
|
+
norm_eps=norm_eps,
|
779
|
+
)
|
780
|
+
|
781
|
+
def forward(
|
782
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
783
|
+
) -> torch.Tensor:
|
784
|
+
del context
|
785
|
+
hidden_states = self.resnet_1(hidden_states, temb)
|
786
|
+
hidden_states = self.attention(hidden_states).contiguous()
|
787
|
+
hidden_states = self.resnet_2(hidden_states, temb)
|
788
|
+
|
789
|
+
return hidden_states
|
790
|
+
|
791
|
+
|
792
|
+
class CrossAttnMidBlock(nn.Module):
|
793
|
+
"""
|
794
|
+
Unet's mid block containing resnet and cross-attention blocks.
|
795
|
+
|
796
|
+
Args:
|
797
|
+
spatial_dims: The number of spatial dimensions.
|
798
|
+
in_channels: number of input channels.
|
799
|
+
temb_channels: number of timestep embedding channels
|
800
|
+
norm_num_groups: number of groups for the group normalization.
|
801
|
+
norm_eps: epsilon for the group normalization.
|
802
|
+
num_head_channels: number of channels in each attention head.
|
803
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
804
|
+
cross_attention_dim: number of context dimensions to use.
|
805
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
806
|
+
"""
|
807
|
+
|
808
|
+
def __init__(
|
809
|
+
self,
|
810
|
+
spatial_dims: int,
|
811
|
+
in_channels: int,
|
812
|
+
temb_channels: int,
|
813
|
+
norm_num_groups: int = 32,
|
814
|
+
norm_eps: float = 1e-6,
|
815
|
+
num_head_channels: int = 1,
|
816
|
+
transformer_num_layers: int = 1,
|
817
|
+
cross_attention_dim: int | None = None,
|
818
|
+
upcast_attention: bool = False,
|
819
|
+
dropout_cattn: float = 0.0,
|
820
|
+
) -> None:
|
821
|
+
super().__init__()
|
822
|
+
|
823
|
+
self.resnet_1 = DiffusionUNetResnetBlock(
|
824
|
+
spatial_dims=spatial_dims,
|
825
|
+
in_channels=in_channels,
|
826
|
+
out_channels=in_channels,
|
827
|
+
temb_channels=temb_channels,
|
828
|
+
norm_num_groups=norm_num_groups,
|
829
|
+
norm_eps=norm_eps,
|
830
|
+
)
|
831
|
+
self.attention = SpatialTransformer(
|
832
|
+
spatial_dims=spatial_dims,
|
833
|
+
in_channels=in_channels,
|
834
|
+
num_attention_heads=in_channels // num_head_channels,
|
835
|
+
num_head_channels=num_head_channels,
|
836
|
+
num_layers=transformer_num_layers,
|
837
|
+
norm_num_groups=norm_num_groups,
|
838
|
+
norm_eps=norm_eps,
|
839
|
+
cross_attention_dim=cross_attention_dim,
|
840
|
+
upcast_attention=upcast_attention,
|
841
|
+
dropout=dropout_cattn,
|
842
|
+
)
|
843
|
+
self.resnet_2 = DiffusionUNetResnetBlock(
|
844
|
+
spatial_dims=spatial_dims,
|
845
|
+
in_channels=in_channels,
|
846
|
+
out_channels=in_channels,
|
847
|
+
temb_channels=temb_channels,
|
848
|
+
norm_num_groups=norm_num_groups,
|
849
|
+
norm_eps=norm_eps,
|
850
|
+
)
|
851
|
+
|
852
|
+
def forward(
|
853
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
|
854
|
+
) -> torch.Tensor:
|
855
|
+
hidden_states = self.resnet_1(hidden_states, temb)
|
856
|
+
hidden_states = self.attention(hidden_states, context=context)
|
857
|
+
hidden_states = self.resnet_2(hidden_states, temb)
|
858
|
+
|
859
|
+
return hidden_states
|
860
|
+
|
861
|
+
|
862
|
+
class UpBlock(nn.Module):
|
863
|
+
"""
|
864
|
+
Unet's up block containing resnet and upsamplers blocks.
|
865
|
+
|
866
|
+
Args:
|
867
|
+
spatial_dims: The number of spatial dimensions.
|
868
|
+
in_channels: number of input channels.
|
869
|
+
prev_output_channel: number of channels from residual connection.
|
870
|
+
out_channels: number of output channels.
|
871
|
+
temb_channels: number of timestep embedding channels.
|
872
|
+
num_res_blocks: number of residual blocks.
|
873
|
+
norm_num_groups: number of groups for the group normalization.
|
874
|
+
norm_eps: epsilon for the group normalization.
|
875
|
+
add_upsample: if True add downsample block.
|
876
|
+
resblock_updown: if True use residual blocks for upsampling.
|
877
|
+
"""
|
878
|
+
|
879
|
+
def __init__(
|
880
|
+
self,
|
881
|
+
spatial_dims: int,
|
882
|
+
in_channels: int,
|
883
|
+
prev_output_channel: int,
|
884
|
+
out_channels: int,
|
885
|
+
temb_channels: int,
|
886
|
+
num_res_blocks: int = 1,
|
887
|
+
norm_num_groups: int = 32,
|
888
|
+
norm_eps: float = 1e-6,
|
889
|
+
add_upsample: bool = True,
|
890
|
+
resblock_updown: bool = False,
|
891
|
+
) -> None:
|
892
|
+
super().__init__()
|
893
|
+
self.resblock_updown = resblock_updown
|
894
|
+
resnets = []
|
895
|
+
|
896
|
+
for i in range(num_res_blocks):
|
897
|
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
898
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
899
|
+
|
900
|
+
resnets.append(
|
901
|
+
DiffusionUNetResnetBlock(
|
902
|
+
spatial_dims=spatial_dims,
|
903
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
904
|
+
out_channels=out_channels,
|
905
|
+
temb_channels=temb_channels,
|
906
|
+
norm_num_groups=norm_num_groups,
|
907
|
+
norm_eps=norm_eps,
|
908
|
+
)
|
909
|
+
)
|
910
|
+
|
911
|
+
self.resnets = nn.ModuleList(resnets)
|
912
|
+
|
913
|
+
self.upsampler: nn.Module | None
|
914
|
+
if add_upsample:
|
915
|
+
if resblock_updown:
|
916
|
+
self.upsampler = DiffusionUNetResnetBlock(
|
917
|
+
spatial_dims=spatial_dims,
|
918
|
+
in_channels=out_channels,
|
919
|
+
out_channels=out_channels,
|
920
|
+
temb_channels=temb_channels,
|
921
|
+
norm_num_groups=norm_num_groups,
|
922
|
+
norm_eps=norm_eps,
|
923
|
+
up=True,
|
924
|
+
)
|
925
|
+
else:
|
926
|
+
post_conv = Convolution(
|
927
|
+
spatial_dims=spatial_dims,
|
928
|
+
in_channels=out_channels,
|
929
|
+
out_channels=out_channels,
|
930
|
+
strides=1,
|
931
|
+
kernel_size=3,
|
932
|
+
padding=1,
|
933
|
+
conv_only=True,
|
934
|
+
)
|
935
|
+
self.upsampler = WrappedUpsample(
|
936
|
+
spatial_dims=spatial_dims,
|
937
|
+
mode="nontrainable",
|
938
|
+
in_channels=out_channels,
|
939
|
+
out_channels=out_channels,
|
940
|
+
interp_mode="nearest",
|
941
|
+
scale_factor=2.0,
|
942
|
+
post_conv=post_conv,
|
943
|
+
align_corners=None,
|
944
|
+
)
|
945
|
+
|
946
|
+
else:
|
947
|
+
self.upsampler = None
|
948
|
+
|
949
|
+
def forward(
|
950
|
+
self,
|
951
|
+
hidden_states: torch.Tensor,
|
952
|
+
res_hidden_states_list: list[torch.Tensor],
|
953
|
+
temb: torch.Tensor,
|
954
|
+
context: torch.Tensor | None = None,
|
955
|
+
) -> torch.Tensor:
|
956
|
+
del context
|
957
|
+
for resnet in self.resnets:
|
958
|
+
# pop res hidden states
|
959
|
+
res_hidden_states = res_hidden_states_list[-1]
|
960
|
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
961
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
962
|
+
|
963
|
+
hidden_states = resnet(hidden_states, temb)
|
964
|
+
|
965
|
+
if self.upsampler is not None:
|
966
|
+
hidden_states = self.upsampler(hidden_states, temb)
|
967
|
+
|
968
|
+
return hidden_states
|
969
|
+
|
970
|
+
|
971
|
+
class AttnUpBlock(nn.Module):
|
972
|
+
"""
|
973
|
+
Unet's up block containing resnet, upsamplers, and self-attention blocks.
|
974
|
+
|
975
|
+
Args:
|
976
|
+
spatial_dims: The number of spatial dimensions.
|
977
|
+
in_channels: number of input channels.
|
978
|
+
prev_output_channel: number of channels from residual connection.
|
979
|
+
out_channels: number of output channels.
|
980
|
+
temb_channels: number of timestep embedding channels.
|
981
|
+
num_res_blocks: number of residual blocks.
|
982
|
+
norm_num_groups: number of groups for the group normalization.
|
983
|
+
norm_eps: epsilon for the group normalization.
|
984
|
+
add_upsample: if True add downsample block.
|
985
|
+
resblock_updown: if True use residual blocks for upsampling.
|
986
|
+
num_head_channels: number of channels in each attention head.
|
987
|
+
"""
|
988
|
+
|
989
|
+
def __init__(
|
990
|
+
self,
|
991
|
+
spatial_dims: int,
|
992
|
+
in_channels: int,
|
993
|
+
prev_output_channel: int,
|
994
|
+
out_channels: int,
|
995
|
+
temb_channels: int,
|
996
|
+
num_res_blocks: int = 1,
|
997
|
+
norm_num_groups: int = 32,
|
998
|
+
norm_eps: float = 1e-6,
|
999
|
+
add_upsample: bool = True,
|
1000
|
+
resblock_updown: bool = False,
|
1001
|
+
num_head_channels: int = 1,
|
1002
|
+
) -> None:
|
1003
|
+
super().__init__()
|
1004
|
+
self.resblock_updown = resblock_updown
|
1005
|
+
|
1006
|
+
resnets = []
|
1007
|
+
attentions = []
|
1008
|
+
|
1009
|
+
for i in range(num_res_blocks):
|
1010
|
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
1011
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1012
|
+
|
1013
|
+
resnets.append(
|
1014
|
+
DiffusionUNetResnetBlock(
|
1015
|
+
spatial_dims=spatial_dims,
|
1016
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
1017
|
+
out_channels=out_channels,
|
1018
|
+
temb_channels=temb_channels,
|
1019
|
+
norm_num_groups=norm_num_groups,
|
1020
|
+
norm_eps=norm_eps,
|
1021
|
+
)
|
1022
|
+
)
|
1023
|
+
attentions.append(
|
1024
|
+
SpatialAttentionBlock(
|
1025
|
+
spatial_dims=spatial_dims,
|
1026
|
+
num_channels=out_channels,
|
1027
|
+
num_head_channels=num_head_channels,
|
1028
|
+
norm_num_groups=norm_num_groups,
|
1029
|
+
norm_eps=norm_eps,
|
1030
|
+
)
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
self.resnets = nn.ModuleList(resnets)
|
1034
|
+
self.attentions = nn.ModuleList(attentions)
|
1035
|
+
|
1036
|
+
self.upsampler: nn.Module | None
|
1037
|
+
if add_upsample:
|
1038
|
+
if resblock_updown:
|
1039
|
+
self.upsampler = DiffusionUNetResnetBlock(
|
1040
|
+
spatial_dims=spatial_dims,
|
1041
|
+
in_channels=out_channels,
|
1042
|
+
out_channels=out_channels,
|
1043
|
+
temb_channels=temb_channels,
|
1044
|
+
norm_num_groups=norm_num_groups,
|
1045
|
+
norm_eps=norm_eps,
|
1046
|
+
up=True,
|
1047
|
+
)
|
1048
|
+
else:
|
1049
|
+
|
1050
|
+
post_conv = Convolution(
|
1051
|
+
spatial_dims=spatial_dims,
|
1052
|
+
in_channels=out_channels,
|
1053
|
+
out_channels=out_channels,
|
1054
|
+
strides=1,
|
1055
|
+
kernel_size=3,
|
1056
|
+
padding=1,
|
1057
|
+
conv_only=True,
|
1058
|
+
)
|
1059
|
+
self.upsampler = WrappedUpsample(
|
1060
|
+
spatial_dims=spatial_dims,
|
1061
|
+
mode="nontrainable",
|
1062
|
+
in_channels=out_channels,
|
1063
|
+
out_channels=out_channels,
|
1064
|
+
interp_mode="nearest",
|
1065
|
+
scale_factor=2.0,
|
1066
|
+
post_conv=post_conv,
|
1067
|
+
align_corners=None,
|
1068
|
+
)
|
1069
|
+
else:
|
1070
|
+
self.upsampler = None
|
1071
|
+
|
1072
|
+
def forward(
|
1073
|
+
self,
|
1074
|
+
hidden_states: torch.Tensor,
|
1075
|
+
res_hidden_states_list: list[torch.Tensor],
|
1076
|
+
temb: torch.Tensor,
|
1077
|
+
context: torch.Tensor | None = None,
|
1078
|
+
) -> torch.Tensor:
|
1079
|
+
del context
|
1080
|
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1081
|
+
# pop res hidden states
|
1082
|
+
res_hidden_states = res_hidden_states_list[-1]
|
1083
|
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
1084
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1085
|
+
|
1086
|
+
hidden_states = resnet(hidden_states, temb)
|
1087
|
+
hidden_states = attn(hidden_states).contiguous()
|
1088
|
+
|
1089
|
+
if self.upsampler is not None:
|
1090
|
+
hidden_states = self.upsampler(hidden_states, temb)
|
1091
|
+
|
1092
|
+
return hidden_states
|
1093
|
+
|
1094
|
+
|
1095
|
+
class CrossAttnUpBlock(nn.Module):
|
1096
|
+
"""
|
1097
|
+
Unet's up block containing resnet, upsamplers, and self-attention blocks.
|
1098
|
+
|
1099
|
+
Args:
|
1100
|
+
spatial_dims: The number of spatial dimensions.
|
1101
|
+
in_channels: number of input channels.
|
1102
|
+
prev_output_channel: number of channels from residual connection.
|
1103
|
+
out_channels: number of output channels.
|
1104
|
+
temb_channels: number of timestep embedding channels.
|
1105
|
+
num_res_blocks: number of residual blocks.
|
1106
|
+
norm_num_groups: number of groups for the group normalization.
|
1107
|
+
norm_eps: epsilon for the group normalization.
|
1108
|
+
add_upsample: if True add downsample block.
|
1109
|
+
resblock_updown: if True use residual blocks for upsampling.
|
1110
|
+
num_head_channels: number of channels in each attention head.
|
1111
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
1112
|
+
cross_attention_dim: number of context dimensions to use.
|
1113
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
1114
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
1115
|
+
"""
|
1116
|
+
|
1117
|
+
def __init__(
|
1118
|
+
self,
|
1119
|
+
spatial_dims: int,
|
1120
|
+
in_channels: int,
|
1121
|
+
prev_output_channel: int,
|
1122
|
+
out_channels: int,
|
1123
|
+
temb_channels: int,
|
1124
|
+
num_res_blocks: int = 1,
|
1125
|
+
norm_num_groups: int = 32,
|
1126
|
+
norm_eps: float = 1e-6,
|
1127
|
+
add_upsample: bool = True,
|
1128
|
+
resblock_updown: bool = False,
|
1129
|
+
num_head_channels: int = 1,
|
1130
|
+
transformer_num_layers: int = 1,
|
1131
|
+
cross_attention_dim: int | None = None,
|
1132
|
+
upcast_attention: bool = False,
|
1133
|
+
dropout_cattn: float = 0.0,
|
1134
|
+
) -> None:
|
1135
|
+
super().__init__()
|
1136
|
+
self.resblock_updown = resblock_updown
|
1137
|
+
|
1138
|
+
resnets = []
|
1139
|
+
attentions = []
|
1140
|
+
|
1141
|
+
for i in range(num_res_blocks):
|
1142
|
+
res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
|
1143
|
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
1144
|
+
|
1145
|
+
resnets.append(
|
1146
|
+
DiffusionUNetResnetBlock(
|
1147
|
+
spatial_dims=spatial_dims,
|
1148
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
1149
|
+
out_channels=out_channels,
|
1150
|
+
temb_channels=temb_channels,
|
1151
|
+
norm_num_groups=norm_num_groups,
|
1152
|
+
norm_eps=norm_eps,
|
1153
|
+
)
|
1154
|
+
)
|
1155
|
+
attentions.append(
|
1156
|
+
SpatialTransformer(
|
1157
|
+
spatial_dims=spatial_dims,
|
1158
|
+
in_channels=out_channels,
|
1159
|
+
num_attention_heads=out_channels // num_head_channels,
|
1160
|
+
num_head_channels=num_head_channels,
|
1161
|
+
norm_num_groups=norm_num_groups,
|
1162
|
+
norm_eps=norm_eps,
|
1163
|
+
num_layers=transformer_num_layers,
|
1164
|
+
cross_attention_dim=cross_attention_dim,
|
1165
|
+
upcast_attention=upcast_attention,
|
1166
|
+
dropout=dropout_cattn,
|
1167
|
+
)
|
1168
|
+
)
|
1169
|
+
|
1170
|
+
self.attentions = nn.ModuleList(attentions)
|
1171
|
+
self.resnets = nn.ModuleList(resnets)
|
1172
|
+
|
1173
|
+
self.upsampler: nn.Module | None
|
1174
|
+
if add_upsample:
|
1175
|
+
if resblock_updown:
|
1176
|
+
self.upsampler = DiffusionUNetResnetBlock(
|
1177
|
+
spatial_dims=spatial_dims,
|
1178
|
+
in_channels=out_channels,
|
1179
|
+
out_channels=out_channels,
|
1180
|
+
temb_channels=temb_channels,
|
1181
|
+
norm_num_groups=norm_num_groups,
|
1182
|
+
norm_eps=norm_eps,
|
1183
|
+
up=True,
|
1184
|
+
)
|
1185
|
+
else:
|
1186
|
+
|
1187
|
+
post_conv = Convolution(
|
1188
|
+
spatial_dims=spatial_dims,
|
1189
|
+
in_channels=out_channels,
|
1190
|
+
out_channels=out_channels,
|
1191
|
+
strides=1,
|
1192
|
+
kernel_size=3,
|
1193
|
+
padding=1,
|
1194
|
+
conv_only=True,
|
1195
|
+
)
|
1196
|
+
self.upsampler = WrappedUpsample(
|
1197
|
+
spatial_dims=spatial_dims,
|
1198
|
+
mode="nontrainable",
|
1199
|
+
in_channels=out_channels,
|
1200
|
+
out_channels=out_channels,
|
1201
|
+
interp_mode="nearest",
|
1202
|
+
scale_factor=2.0,
|
1203
|
+
post_conv=post_conv,
|
1204
|
+
align_corners=None,
|
1205
|
+
)
|
1206
|
+
else:
|
1207
|
+
self.upsampler = None
|
1208
|
+
|
1209
|
+
def forward(
|
1210
|
+
self,
|
1211
|
+
hidden_states: torch.Tensor,
|
1212
|
+
res_hidden_states_list: list[torch.Tensor],
|
1213
|
+
temb: torch.Tensor,
|
1214
|
+
context: torch.Tensor | None = None,
|
1215
|
+
) -> torch.Tensor:
|
1216
|
+
for resnet, attn in zip(self.resnets, self.attentions):
|
1217
|
+
# pop res hidden states
|
1218
|
+
res_hidden_states = res_hidden_states_list[-1]
|
1219
|
+
res_hidden_states_list = res_hidden_states_list[:-1]
|
1220
|
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1221
|
+
|
1222
|
+
hidden_states = resnet(hidden_states, temb)
|
1223
|
+
hidden_states = attn(hidden_states, context=context)
|
1224
|
+
|
1225
|
+
if self.upsampler is not None:
|
1226
|
+
hidden_states = self.upsampler(hidden_states, temb)
|
1227
|
+
|
1228
|
+
return hidden_states
|
1229
|
+
|
1230
|
+
|
1231
|
+
def get_down_block(
|
1232
|
+
spatial_dims: int,
|
1233
|
+
in_channels: int,
|
1234
|
+
out_channels: int,
|
1235
|
+
temb_channels: int,
|
1236
|
+
num_res_blocks: int,
|
1237
|
+
norm_num_groups: int,
|
1238
|
+
norm_eps: float,
|
1239
|
+
add_downsample: bool,
|
1240
|
+
resblock_updown: bool,
|
1241
|
+
with_attn: bool,
|
1242
|
+
with_cross_attn: bool,
|
1243
|
+
num_head_channels: int,
|
1244
|
+
transformer_num_layers: int,
|
1245
|
+
cross_attention_dim: int | None,
|
1246
|
+
upcast_attention: bool = False,
|
1247
|
+
dropout_cattn: float = 0.0,
|
1248
|
+
) -> nn.Module:
|
1249
|
+
if with_attn:
|
1250
|
+
return AttnDownBlock(
|
1251
|
+
spatial_dims=spatial_dims,
|
1252
|
+
in_channels=in_channels,
|
1253
|
+
out_channels=out_channels,
|
1254
|
+
temb_channels=temb_channels,
|
1255
|
+
num_res_blocks=num_res_blocks,
|
1256
|
+
norm_num_groups=norm_num_groups,
|
1257
|
+
norm_eps=norm_eps,
|
1258
|
+
add_downsample=add_downsample,
|
1259
|
+
resblock_updown=resblock_updown,
|
1260
|
+
num_head_channels=num_head_channels,
|
1261
|
+
)
|
1262
|
+
elif with_cross_attn:
|
1263
|
+
return CrossAttnDownBlock(
|
1264
|
+
spatial_dims=spatial_dims,
|
1265
|
+
in_channels=in_channels,
|
1266
|
+
out_channels=out_channels,
|
1267
|
+
temb_channels=temb_channels,
|
1268
|
+
num_res_blocks=num_res_blocks,
|
1269
|
+
norm_num_groups=norm_num_groups,
|
1270
|
+
norm_eps=norm_eps,
|
1271
|
+
add_downsample=add_downsample,
|
1272
|
+
resblock_updown=resblock_updown,
|
1273
|
+
num_head_channels=num_head_channels,
|
1274
|
+
transformer_num_layers=transformer_num_layers,
|
1275
|
+
cross_attention_dim=cross_attention_dim,
|
1276
|
+
upcast_attention=upcast_attention,
|
1277
|
+
dropout_cattn=dropout_cattn,
|
1278
|
+
)
|
1279
|
+
else:
|
1280
|
+
return DownBlock(
|
1281
|
+
spatial_dims=spatial_dims,
|
1282
|
+
in_channels=in_channels,
|
1283
|
+
out_channels=out_channels,
|
1284
|
+
temb_channels=temb_channels,
|
1285
|
+
num_res_blocks=num_res_blocks,
|
1286
|
+
norm_num_groups=norm_num_groups,
|
1287
|
+
norm_eps=norm_eps,
|
1288
|
+
add_downsample=add_downsample,
|
1289
|
+
resblock_updown=resblock_updown,
|
1290
|
+
)
|
1291
|
+
|
1292
|
+
|
1293
|
+
def get_mid_block(
|
1294
|
+
spatial_dims: int,
|
1295
|
+
in_channels: int,
|
1296
|
+
temb_channels: int,
|
1297
|
+
norm_num_groups: int,
|
1298
|
+
norm_eps: float,
|
1299
|
+
with_conditioning: bool,
|
1300
|
+
num_head_channels: int,
|
1301
|
+
transformer_num_layers: int,
|
1302
|
+
cross_attention_dim: int | None,
|
1303
|
+
upcast_attention: bool = False,
|
1304
|
+
dropout_cattn: float = 0.0,
|
1305
|
+
) -> nn.Module:
|
1306
|
+
if with_conditioning:
|
1307
|
+
return CrossAttnMidBlock(
|
1308
|
+
spatial_dims=spatial_dims,
|
1309
|
+
in_channels=in_channels,
|
1310
|
+
temb_channels=temb_channels,
|
1311
|
+
norm_num_groups=norm_num_groups,
|
1312
|
+
norm_eps=norm_eps,
|
1313
|
+
num_head_channels=num_head_channels,
|
1314
|
+
transformer_num_layers=transformer_num_layers,
|
1315
|
+
cross_attention_dim=cross_attention_dim,
|
1316
|
+
upcast_attention=upcast_attention,
|
1317
|
+
dropout_cattn=dropout_cattn,
|
1318
|
+
)
|
1319
|
+
else:
|
1320
|
+
return AttnMidBlock(
|
1321
|
+
spatial_dims=spatial_dims,
|
1322
|
+
in_channels=in_channels,
|
1323
|
+
temb_channels=temb_channels,
|
1324
|
+
norm_num_groups=norm_num_groups,
|
1325
|
+
norm_eps=norm_eps,
|
1326
|
+
num_head_channels=num_head_channels,
|
1327
|
+
)
|
1328
|
+
|
1329
|
+
|
1330
|
+
def get_up_block(
|
1331
|
+
spatial_dims: int,
|
1332
|
+
in_channels: int,
|
1333
|
+
prev_output_channel: int,
|
1334
|
+
out_channels: int,
|
1335
|
+
temb_channels: int,
|
1336
|
+
num_res_blocks: int,
|
1337
|
+
norm_num_groups: int,
|
1338
|
+
norm_eps: float,
|
1339
|
+
add_upsample: bool,
|
1340
|
+
resblock_updown: bool,
|
1341
|
+
with_attn: bool,
|
1342
|
+
with_cross_attn: bool,
|
1343
|
+
num_head_channels: int,
|
1344
|
+
transformer_num_layers: int,
|
1345
|
+
cross_attention_dim: int | None,
|
1346
|
+
upcast_attention: bool = False,
|
1347
|
+
dropout_cattn: float = 0.0,
|
1348
|
+
) -> nn.Module:
|
1349
|
+
if with_attn:
|
1350
|
+
return AttnUpBlock(
|
1351
|
+
spatial_dims=spatial_dims,
|
1352
|
+
in_channels=in_channels,
|
1353
|
+
prev_output_channel=prev_output_channel,
|
1354
|
+
out_channels=out_channels,
|
1355
|
+
temb_channels=temb_channels,
|
1356
|
+
num_res_blocks=num_res_blocks,
|
1357
|
+
norm_num_groups=norm_num_groups,
|
1358
|
+
norm_eps=norm_eps,
|
1359
|
+
add_upsample=add_upsample,
|
1360
|
+
resblock_updown=resblock_updown,
|
1361
|
+
num_head_channels=num_head_channels,
|
1362
|
+
)
|
1363
|
+
elif with_cross_attn:
|
1364
|
+
return CrossAttnUpBlock(
|
1365
|
+
spatial_dims=spatial_dims,
|
1366
|
+
in_channels=in_channels,
|
1367
|
+
prev_output_channel=prev_output_channel,
|
1368
|
+
out_channels=out_channels,
|
1369
|
+
temb_channels=temb_channels,
|
1370
|
+
num_res_blocks=num_res_blocks,
|
1371
|
+
norm_num_groups=norm_num_groups,
|
1372
|
+
norm_eps=norm_eps,
|
1373
|
+
add_upsample=add_upsample,
|
1374
|
+
resblock_updown=resblock_updown,
|
1375
|
+
num_head_channels=num_head_channels,
|
1376
|
+
transformer_num_layers=transformer_num_layers,
|
1377
|
+
cross_attention_dim=cross_attention_dim,
|
1378
|
+
upcast_attention=upcast_attention,
|
1379
|
+
dropout_cattn=dropout_cattn,
|
1380
|
+
)
|
1381
|
+
else:
|
1382
|
+
return UpBlock(
|
1383
|
+
spatial_dims=spatial_dims,
|
1384
|
+
in_channels=in_channels,
|
1385
|
+
prev_output_channel=prev_output_channel,
|
1386
|
+
out_channels=out_channels,
|
1387
|
+
temb_channels=temb_channels,
|
1388
|
+
num_res_blocks=num_res_blocks,
|
1389
|
+
norm_num_groups=norm_num_groups,
|
1390
|
+
norm_eps=norm_eps,
|
1391
|
+
add_upsample=add_upsample,
|
1392
|
+
resblock_updown=resblock_updown,
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
|
1396
|
+
class DiffusionModelUNet(nn.Module):
|
1397
|
+
"""
|
1398
|
+
Unet network with timestep embedding and attention mechanisms for conditioning based on
|
1399
|
+
Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
|
1400
|
+
and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
|
1401
|
+
|
1402
|
+
Args:
|
1403
|
+
spatial_dims: number of spatial dimensions.
|
1404
|
+
in_channels: number of input channels.
|
1405
|
+
out_channels: number of output channels.
|
1406
|
+
num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
|
1407
|
+
channels: tuple of block output channels.
|
1408
|
+
attention_levels: list of levels to add attention.
|
1409
|
+
norm_num_groups: number of groups for the normalization.
|
1410
|
+
norm_eps: epsilon for the normalization.
|
1411
|
+
resblock_updown: if True use residual blocks for up/downsampling.
|
1412
|
+
num_head_channels: number of channels in each attention head.
|
1413
|
+
with_conditioning: if True add spatial transformers to perform conditioning.
|
1414
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
1415
|
+
cross_attention_dim: number of context dimensions to use.
|
1416
|
+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
|
1417
|
+
classes.
|
1418
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
1419
|
+
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
|
1420
|
+
"""
|
1421
|
+
|
1422
|
+
def __init__(
|
1423
|
+
self,
|
1424
|
+
spatial_dims: int,
|
1425
|
+
in_channels: int,
|
1426
|
+
out_channels: int,
|
1427
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
1428
|
+
channels: Sequence[int] = (32, 64, 64, 64),
|
1429
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
1430
|
+
norm_num_groups: int = 32,
|
1431
|
+
norm_eps: float = 1e-6,
|
1432
|
+
resblock_updown: bool = False,
|
1433
|
+
num_head_channels: int | Sequence[int] = 8,
|
1434
|
+
with_conditioning: bool = False,
|
1435
|
+
transformer_num_layers: int = 1,
|
1436
|
+
cross_attention_dim: int | None = None,
|
1437
|
+
num_class_embeds: int | None = None,
|
1438
|
+
upcast_attention: bool = False,
|
1439
|
+
dropout_cattn: float = 0.0,
|
1440
|
+
) -> None:
|
1441
|
+
super().__init__()
|
1442
|
+
if with_conditioning is True and cross_attention_dim is None:
|
1443
|
+
raise ValueError(
|
1444
|
+
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
1445
|
+
"when using with_conditioning."
|
1446
|
+
)
|
1447
|
+
if cross_attention_dim is not None and with_conditioning is False:
|
1448
|
+
raise ValueError(
|
1449
|
+
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
|
1450
|
+
)
|
1451
|
+
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
|
1452
|
+
raise ValueError("Dropout cannot be negative or >1.0!")
|
1453
|
+
|
1454
|
+
# All number of channels should be multiple of num_groups
|
1455
|
+
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
|
1456
|
+
raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
|
1457
|
+
|
1458
|
+
if len(channels) != len(attention_levels):
|
1459
|
+
raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
|
1460
|
+
|
1461
|
+
if isinstance(num_head_channels, int):
|
1462
|
+
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
|
1463
|
+
|
1464
|
+
if len(num_head_channels) != len(attention_levels):
|
1465
|
+
raise ValueError(
|
1466
|
+
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
|
1467
|
+
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
|
1468
|
+
)
|
1469
|
+
|
1470
|
+
if isinstance(num_res_blocks, int):
|
1471
|
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
|
1472
|
+
|
1473
|
+
if len(num_res_blocks) != len(channels):
|
1474
|
+
raise ValueError(
|
1475
|
+
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
|
1476
|
+
"`num_channels`."
|
1477
|
+
)
|
1478
|
+
|
1479
|
+
self.in_channels = in_channels
|
1480
|
+
self.block_out_channels = channels
|
1481
|
+
self.out_channels = out_channels
|
1482
|
+
self.num_res_blocks = num_res_blocks
|
1483
|
+
self.attention_levels = attention_levels
|
1484
|
+
self.num_head_channels = num_head_channels
|
1485
|
+
self.with_conditioning = with_conditioning
|
1486
|
+
|
1487
|
+
# input
|
1488
|
+
self.conv_in = Convolution(
|
1489
|
+
spatial_dims=spatial_dims,
|
1490
|
+
in_channels=in_channels,
|
1491
|
+
out_channels=channels[0],
|
1492
|
+
strides=1,
|
1493
|
+
kernel_size=3,
|
1494
|
+
padding=1,
|
1495
|
+
conv_only=True,
|
1496
|
+
)
|
1497
|
+
|
1498
|
+
# time
|
1499
|
+
time_embed_dim = channels[0] * 4
|
1500
|
+
self.time_embed = nn.Sequential(
|
1501
|
+
nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
|
1502
|
+
)
|
1503
|
+
|
1504
|
+
# class embedding
|
1505
|
+
self.num_class_embeds = num_class_embeds
|
1506
|
+
if num_class_embeds is not None:
|
1507
|
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
1508
|
+
|
1509
|
+
# down
|
1510
|
+
self.down_blocks = nn.ModuleList([])
|
1511
|
+
output_channel = channels[0]
|
1512
|
+
for i in range(len(channels)):
|
1513
|
+
input_channel = output_channel
|
1514
|
+
output_channel = channels[i]
|
1515
|
+
is_final_block = i == len(channels) - 1
|
1516
|
+
|
1517
|
+
down_block = get_down_block(
|
1518
|
+
spatial_dims=spatial_dims,
|
1519
|
+
in_channels=input_channel,
|
1520
|
+
out_channels=output_channel,
|
1521
|
+
temb_channels=time_embed_dim,
|
1522
|
+
num_res_blocks=num_res_blocks[i],
|
1523
|
+
norm_num_groups=norm_num_groups,
|
1524
|
+
norm_eps=norm_eps,
|
1525
|
+
add_downsample=not is_final_block,
|
1526
|
+
resblock_updown=resblock_updown,
|
1527
|
+
with_attn=(attention_levels[i] and not with_conditioning),
|
1528
|
+
with_cross_attn=(attention_levels[i] and with_conditioning),
|
1529
|
+
num_head_channels=num_head_channels[i],
|
1530
|
+
transformer_num_layers=transformer_num_layers,
|
1531
|
+
cross_attention_dim=cross_attention_dim,
|
1532
|
+
upcast_attention=upcast_attention,
|
1533
|
+
dropout_cattn=dropout_cattn,
|
1534
|
+
)
|
1535
|
+
|
1536
|
+
self.down_blocks.append(down_block)
|
1537
|
+
|
1538
|
+
# mid
|
1539
|
+
self.middle_block = get_mid_block(
|
1540
|
+
spatial_dims=spatial_dims,
|
1541
|
+
in_channels=channels[-1],
|
1542
|
+
temb_channels=time_embed_dim,
|
1543
|
+
norm_num_groups=norm_num_groups,
|
1544
|
+
norm_eps=norm_eps,
|
1545
|
+
with_conditioning=with_conditioning,
|
1546
|
+
num_head_channels=num_head_channels[-1],
|
1547
|
+
transformer_num_layers=transformer_num_layers,
|
1548
|
+
cross_attention_dim=cross_attention_dim,
|
1549
|
+
upcast_attention=upcast_attention,
|
1550
|
+
dropout_cattn=dropout_cattn,
|
1551
|
+
)
|
1552
|
+
|
1553
|
+
# up
|
1554
|
+
self.up_blocks = nn.ModuleList([])
|
1555
|
+
reversed_block_out_channels = list(reversed(channels))
|
1556
|
+
reversed_num_res_blocks = list(reversed(num_res_blocks))
|
1557
|
+
reversed_attention_levels = list(reversed(attention_levels))
|
1558
|
+
reversed_num_head_channels = list(reversed(num_head_channels))
|
1559
|
+
output_channel = reversed_block_out_channels[0]
|
1560
|
+
for i in range(len(reversed_block_out_channels)):
|
1561
|
+
prev_output_channel = output_channel
|
1562
|
+
output_channel = reversed_block_out_channels[i]
|
1563
|
+
input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]
|
1564
|
+
|
1565
|
+
is_final_block = i == len(channels) - 1
|
1566
|
+
|
1567
|
+
up_block = get_up_block(
|
1568
|
+
spatial_dims=spatial_dims,
|
1569
|
+
in_channels=input_channel,
|
1570
|
+
prev_output_channel=prev_output_channel,
|
1571
|
+
out_channels=output_channel,
|
1572
|
+
temb_channels=time_embed_dim,
|
1573
|
+
num_res_blocks=reversed_num_res_blocks[i] + 1,
|
1574
|
+
norm_num_groups=norm_num_groups,
|
1575
|
+
norm_eps=norm_eps,
|
1576
|
+
add_upsample=not is_final_block,
|
1577
|
+
resblock_updown=resblock_updown,
|
1578
|
+
with_attn=(reversed_attention_levels[i] and not with_conditioning),
|
1579
|
+
with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
|
1580
|
+
num_head_channels=reversed_num_head_channels[i],
|
1581
|
+
transformer_num_layers=transformer_num_layers,
|
1582
|
+
cross_attention_dim=cross_attention_dim,
|
1583
|
+
upcast_attention=upcast_attention,
|
1584
|
+
dropout_cattn=dropout_cattn,
|
1585
|
+
)
|
1586
|
+
|
1587
|
+
self.up_blocks.append(up_block)
|
1588
|
+
|
1589
|
+
# out
|
1590
|
+
self.out = nn.Sequential(
|
1591
|
+
nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
|
1592
|
+
nn.SiLU(),
|
1593
|
+
zero_module(
|
1594
|
+
Convolution(
|
1595
|
+
spatial_dims=spatial_dims,
|
1596
|
+
in_channels=channels[0],
|
1597
|
+
out_channels=out_channels,
|
1598
|
+
strides=1,
|
1599
|
+
kernel_size=3,
|
1600
|
+
padding=1,
|
1601
|
+
conv_only=True,
|
1602
|
+
)
|
1603
|
+
),
|
1604
|
+
)
|
1605
|
+
|
1606
|
+
def forward(
|
1607
|
+
self,
|
1608
|
+
x: torch.Tensor,
|
1609
|
+
timesteps: torch.Tensor,
|
1610
|
+
context: torch.Tensor | None = None,
|
1611
|
+
class_labels: torch.Tensor | None = None,
|
1612
|
+
down_block_additional_residuals: tuple[torch.Tensor] | None = None,
|
1613
|
+
mid_block_additional_residual: torch.Tensor | None = None,
|
1614
|
+
) -> torch.Tensor:
|
1615
|
+
"""
|
1616
|
+
Args:
|
1617
|
+
x: input tensor (N, C, SpatialDims).
|
1618
|
+
timesteps: timestep tensor (N,).
|
1619
|
+
context: context tensor (N, 1, ContextDim).
|
1620
|
+
class_labels: context tensor (N, ).
|
1621
|
+
down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
|
1622
|
+
mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
|
1623
|
+
"""
|
1624
|
+
# 1. time
|
1625
|
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
1626
|
+
|
1627
|
+
# timesteps does not contain any weights and will always return f32 tensors
|
1628
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1629
|
+
# there might be better ways to encapsulate this.
|
1630
|
+
t_emb = t_emb.to(dtype=x.dtype)
|
1631
|
+
emb = self.time_embed(t_emb)
|
1632
|
+
|
1633
|
+
# 2. class
|
1634
|
+
if self.num_class_embeds is not None:
|
1635
|
+
if class_labels is None:
|
1636
|
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
1637
|
+
class_emb = self.class_embedding(class_labels)
|
1638
|
+
class_emb = class_emb.to(dtype=x.dtype)
|
1639
|
+
emb = emb + class_emb
|
1640
|
+
|
1641
|
+
# 3. initial convolution
|
1642
|
+
h = self.conv_in(x)
|
1643
|
+
|
1644
|
+
# 4. down
|
1645
|
+
if context is not None and self.with_conditioning is False:
|
1646
|
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
1647
|
+
down_block_res_samples: list[torch.Tensor] = [h]
|
1648
|
+
for downsample_block in self.down_blocks:
|
1649
|
+
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
|
1650
|
+
for residual in res_samples:
|
1651
|
+
down_block_res_samples.append(residual)
|
1652
|
+
|
1653
|
+
# Additional residual conections for Controlnets
|
1654
|
+
if down_block_additional_residuals is not None:
|
1655
|
+
new_down_block_res_samples: list[torch.Tensor] = []
|
1656
|
+
for down_block_res_sample, down_block_additional_residual in zip(
|
1657
|
+
down_block_res_samples, down_block_additional_residuals
|
1658
|
+
):
|
1659
|
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
1660
|
+
new_down_block_res_samples += [down_block_res_sample]
|
1661
|
+
|
1662
|
+
down_block_res_samples = new_down_block_res_samples
|
1663
|
+
|
1664
|
+
# 5. mid
|
1665
|
+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
|
1666
|
+
|
1667
|
+
# Additional residual conections for Controlnets
|
1668
|
+
if mid_block_additional_residual is not None:
|
1669
|
+
h = h + mid_block_additional_residual
|
1670
|
+
|
1671
|
+
# 6. up
|
1672
|
+
for upsample_block in self.up_blocks:
|
1673
|
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
1674
|
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
1675
|
+
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
|
1676
|
+
|
1677
|
+
# 7. output block
|
1678
|
+
output: torch.Tensor = self.out(h)
|
1679
|
+
|
1680
|
+
return output
|
1681
|
+
|
1682
|
+
def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
|
1683
|
+
"""
|
1684
|
+
Load a state dict from a DiffusionModelUNet trained with
|
1685
|
+
[MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
|
1686
|
+
|
1687
|
+
Args:
|
1688
|
+
old_state_dict: state dict from the old DecoderOnlyTransformer model.
|
1689
|
+
"""
|
1690
|
+
|
1691
|
+
new_state_dict = self.state_dict()
|
1692
|
+
# if all keys match, just load the state dict
|
1693
|
+
if all(k in new_state_dict for k in old_state_dict):
|
1694
|
+
print("All keys match, loading state dict.")
|
1695
|
+
self.load_state_dict(old_state_dict)
|
1696
|
+
return
|
1697
|
+
|
1698
|
+
if verbose:
|
1699
|
+
# print all new_state_dict keys that are not in old_state_dict
|
1700
|
+
for k in new_state_dict:
|
1701
|
+
if k not in old_state_dict:
|
1702
|
+
print(f"key {k} not found in old state dict")
|
1703
|
+
# and vice versa
|
1704
|
+
print("----------------------------------------------")
|
1705
|
+
for k in old_state_dict:
|
1706
|
+
if k not in new_state_dict:
|
1707
|
+
print(f"key {k} not found in new state dict")
|
1708
|
+
|
1709
|
+
# copy over all matching keys
|
1710
|
+
for k in new_state_dict:
|
1711
|
+
if k in old_state_dict:
|
1712
|
+
new_state_dict[k] = old_state_dict[k]
|
1713
|
+
|
1714
|
+
# fix the attention blocks
|
1715
|
+
attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k]
|
1716
|
+
for block in attention_blocks:
|
1717
|
+
new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat(
|
1718
|
+
[
|
1719
|
+
old_state_dict[f"{block}.attn1.to_q.weight"],
|
1720
|
+
old_state_dict[f"{block}.attn1.to_k.weight"],
|
1721
|
+
old_state_dict[f"{block}.attn1.to_v.weight"],
|
1722
|
+
],
|
1723
|
+
dim=0,
|
1724
|
+
)
|
1725
|
+
|
1726
|
+
# projection
|
1727
|
+
new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"]
|
1728
|
+
new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"]
|
1729
|
+
|
1730
|
+
new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"]
|
1731
|
+
new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"]
|
1732
|
+
# fix the upsample conv blocks which were renamed postconv
|
1733
|
+
for k in new_state_dict:
|
1734
|
+
if "postconv" in k:
|
1735
|
+
old_name = k.replace("postconv", "conv")
|
1736
|
+
new_state_dict[k] = old_state_dict[old_name]
|
1737
|
+
self.load_state_dict(new_state_dict)
|
1738
|
+
|
1739
|
+
|
1740
|
+
class DiffusionModelEncoder(nn.Module):
|
1741
|
+
"""
|
1742
|
+
Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
|
1743
|
+
Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).
|
1744
|
+
|
1745
|
+
Args:
|
1746
|
+
spatial_dims: number of spatial dimensions.
|
1747
|
+
in_channels: number of input channels.
|
1748
|
+
out_channels: number of output channels.
|
1749
|
+
num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
|
1750
|
+
channels: tuple of block output channels.
|
1751
|
+
attention_levels: list of levels to add attention.
|
1752
|
+
norm_num_groups: number of groups for the normalization.
|
1753
|
+
norm_eps: epsilon for the normalization.
|
1754
|
+
resblock_updown: if True use residual blocks for downsampling.
|
1755
|
+
num_head_channels: number of channels in each attention head.
|
1756
|
+
with_conditioning: if True add spatial transformers to perform conditioning.
|
1757
|
+
transformer_num_layers: number of layers of Transformer blocks to use.
|
1758
|
+
cross_attention_dim: number of context dimensions to use.
|
1759
|
+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
|
1760
|
+
upcast_attention: if True, upcast attention operations to full precision.
|
1761
|
+
"""
|
1762
|
+
|
1763
|
+
def __init__(
|
1764
|
+
self,
|
1765
|
+
spatial_dims: int,
|
1766
|
+
in_channels: int,
|
1767
|
+
out_channels: int,
|
1768
|
+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
|
1769
|
+
channels: Sequence[int] = (32, 64, 64, 64),
|
1770
|
+
attention_levels: Sequence[bool] = (False, False, True, True),
|
1771
|
+
norm_num_groups: int = 32,
|
1772
|
+
norm_eps: float = 1e-6,
|
1773
|
+
resblock_updown: bool = False,
|
1774
|
+
num_head_channels: int | Sequence[int] = 8,
|
1775
|
+
with_conditioning: bool = False,
|
1776
|
+
transformer_num_layers: int = 1,
|
1777
|
+
cross_attention_dim: int | None = None,
|
1778
|
+
num_class_embeds: int | None = None,
|
1779
|
+
upcast_attention: bool = False,
|
1780
|
+
) -> None:
|
1781
|
+
super().__init__()
|
1782
|
+
if with_conditioning is True and cross_attention_dim is None:
|
1783
|
+
raise ValueError(
|
1784
|
+
"DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
|
1785
|
+
"when using with_conditioning."
|
1786
|
+
)
|
1787
|
+
if cross_attention_dim is not None and with_conditioning is False:
|
1788
|
+
raise ValueError(
|
1789
|
+
"DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
|
1790
|
+
)
|
1791
|
+
|
1792
|
+
# All number of channels should be multiple of num_groups
|
1793
|
+
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
|
1794
|
+
raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
|
1795
|
+
if len(channels) != len(attention_levels):
|
1796
|
+
raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")
|
1797
|
+
|
1798
|
+
if isinstance(num_head_channels, int):
|
1799
|
+
num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
|
1800
|
+
|
1801
|
+
if isinstance(num_res_blocks, int):
|
1802
|
+
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
|
1803
|
+
|
1804
|
+
if len(num_head_channels) != len(attention_levels):
|
1805
|
+
raise ValueError(
|
1806
|
+
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
|
1807
|
+
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
|
1808
|
+
)
|
1809
|
+
|
1810
|
+
self.in_channels = in_channels
|
1811
|
+
self.block_out_channels = channels
|
1812
|
+
self.out_channels = out_channels
|
1813
|
+
self.num_res_blocks = num_res_blocks
|
1814
|
+
self.attention_levels = attention_levels
|
1815
|
+
self.num_head_channels = num_head_channels
|
1816
|
+
self.with_conditioning = with_conditioning
|
1817
|
+
|
1818
|
+
# input
|
1819
|
+
self.conv_in = Convolution(
|
1820
|
+
spatial_dims=spatial_dims,
|
1821
|
+
in_channels=in_channels,
|
1822
|
+
out_channels=channels[0],
|
1823
|
+
strides=1,
|
1824
|
+
kernel_size=3,
|
1825
|
+
padding=1,
|
1826
|
+
conv_only=True,
|
1827
|
+
)
|
1828
|
+
|
1829
|
+
# time
|
1830
|
+
time_embed_dim = channels[0] * 4
|
1831
|
+
self.time_embed = nn.Sequential(
|
1832
|
+
nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
|
1833
|
+
)
|
1834
|
+
|
1835
|
+
# class embedding
|
1836
|
+
self.num_class_embeds = num_class_embeds
|
1837
|
+
if num_class_embeds is not None:
|
1838
|
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
1839
|
+
|
1840
|
+
# down
|
1841
|
+
self.down_blocks = nn.ModuleList([])
|
1842
|
+
output_channel = channels[0]
|
1843
|
+
for i in range(len(channels)):
|
1844
|
+
input_channel = output_channel
|
1845
|
+
output_channel = channels[i]
|
1846
|
+
is_final_block = i == len(channels) # - 1
|
1847
|
+
|
1848
|
+
down_block = get_down_block(
|
1849
|
+
spatial_dims=spatial_dims,
|
1850
|
+
in_channels=input_channel,
|
1851
|
+
out_channels=output_channel,
|
1852
|
+
temb_channels=time_embed_dim,
|
1853
|
+
num_res_blocks=num_res_blocks[i],
|
1854
|
+
norm_num_groups=norm_num_groups,
|
1855
|
+
norm_eps=norm_eps,
|
1856
|
+
add_downsample=not is_final_block,
|
1857
|
+
resblock_updown=resblock_updown,
|
1858
|
+
with_attn=(attention_levels[i] and not with_conditioning),
|
1859
|
+
with_cross_attn=(attention_levels[i] and with_conditioning),
|
1860
|
+
num_head_channels=num_head_channels[i],
|
1861
|
+
transformer_num_layers=transformer_num_layers,
|
1862
|
+
cross_attention_dim=cross_attention_dim,
|
1863
|
+
upcast_attention=upcast_attention,
|
1864
|
+
)
|
1865
|
+
|
1866
|
+
self.down_blocks.append(down_block)
|
1867
|
+
|
1868
|
+
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
|
1869
|
+
|
1870
|
+
def forward(
|
1871
|
+
self,
|
1872
|
+
x: torch.Tensor,
|
1873
|
+
timesteps: torch.Tensor,
|
1874
|
+
context: torch.Tensor | None = None,
|
1875
|
+
class_labels: torch.Tensor | None = None,
|
1876
|
+
) -> torch.Tensor:
|
1877
|
+
"""
|
1878
|
+
Args:
|
1879
|
+
x: input tensor (N, C, SpatialDims).
|
1880
|
+
timesteps: timestep tensor (N,).
|
1881
|
+
context: context tensor (N, 1, ContextDim).
|
1882
|
+
class_labels: context tensor (N, ).
|
1883
|
+
"""
|
1884
|
+
# 1. time
|
1885
|
+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
|
1886
|
+
|
1887
|
+
# timesteps does not contain any weights and will always return f32 tensors
|
1888
|
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
1889
|
+
# there might be better ways to encapsulate this.
|
1890
|
+
t_emb = t_emb.to(dtype=x.dtype)
|
1891
|
+
emb = self.time_embed(t_emb)
|
1892
|
+
|
1893
|
+
# 2. class
|
1894
|
+
if self.num_class_embeds is not None:
|
1895
|
+
if class_labels is None:
|
1896
|
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
1897
|
+
class_emb = self.class_embedding(class_labels)
|
1898
|
+
class_emb = class_emb.to(dtype=x.dtype)
|
1899
|
+
emb = emb + class_emb
|
1900
|
+
|
1901
|
+
# 3. initial convolution
|
1902
|
+
h = self.conv_in(x)
|
1903
|
+
|
1904
|
+
# 4. down
|
1905
|
+
if context is not None and self.with_conditioning is False:
|
1906
|
+
raise ValueError("model should have with_conditioning = True if context is provided")
|
1907
|
+
for downsample_block in self.down_blocks:
|
1908
|
+
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
|
1909
|
+
|
1910
|
+
h = h.reshape(h.shape[0], -1)
|
1911
|
+
output: torch.Tensor = self.out(h)
|
1912
|
+
|
1913
|
+
return output
|