monai-weekly 1.5.dev2511__py3-none-any.whl → 1.5.dev2513__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/data/utils.py +1 -1
- monai/metrics/meandice.py +132 -76
- monai/networks/blocks/__init__.py +2 -1
- monai/networks/blocks/cablock.py +182 -0
- monai/networks/blocks/downsample.py +241 -2
- monai/networks/nets/restormer.py +337 -0
- monai/networks/utils.py +44 -1
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +13 -0
- monai/utils/misc.py +1 -1
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/METADATA +3 -2
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/RECORD +22 -17
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/WHEEL +1 -1
- tests/metrics/test_compute_meandice.py +3 -3
- tests/networks/blocks/test_CABlock.py +150 -0
- tests/networks/blocks/test_downsample_block.py +184 -0
- tests/networks/nets/test_restormer.py +147 -0
- tests/networks/utils/test_pixelunshuffle.py +51 -0
- tests/integration/test_downsample_block.py +0 -50
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info/licenses}/LICENSE +0 -0
- {monai_weekly-1.5.dev2511.dist-info → monai_weekly-1.5.dev2513.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-03-
|
11
|
+
"date": "2025-03-30T02:32:28+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "ef083a32ccc13ee3937a4bd8acc12b9cdc174e18",
|
15
|
+
"version": "1.5.dev2513"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/data/utils.py
CHANGED
@@ -1473,7 +1473,7 @@ def convert_tables_to_dicts(
|
|
1473
1473
|
# parse row indices
|
1474
1474
|
rows: list[int | str] = []
|
1475
1475
|
if row_indices is None:
|
1476
|
-
rows =
|
1476
|
+
rows = df.index.tolist()
|
1477
1477
|
else:
|
1478
1478
|
for i in row_indices:
|
1479
1479
|
if isinstance(i, (tuple, list)):
|
monai/metrics/meandice.py
CHANGED
@@ -14,7 +14,7 @@ from __future__ import annotations
|
|
14
14
|
import torch
|
15
15
|
|
16
16
|
from monai.metrics.utils import do_metric_reduction
|
17
|
-
from monai.utils import MetricReduction
|
17
|
+
from monai.utils import MetricReduction, deprecated_arg
|
18
18
|
|
19
19
|
from .metric import CumulativeIterationMetric
|
20
20
|
|
@@ -23,35 +23,76 @@ __all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
|
|
23
23
|
|
24
24
|
class DiceMetric(CumulativeIterationMetric):
|
25
25
|
"""
|
26
|
-
|
26
|
+
Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
|
27
|
+
or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
|
28
|
+
and multi-label tasks.
|
27
29
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
30
|
+
If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
|
31
|
+
hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
|
32
|
+
and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
|
33
|
+
this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
|
34
|
+
for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
|
35
|
+
this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
|
36
|
+
into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
|
37
|
+
dimensions to produce a label map.
|
38
|
+
|
39
|
+
The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
|
40
|
+
is by convention assumed to be background. If the non-background segmentations are small compared to the total
|
41
|
+
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
|
42
|
+
and ground truth is BCHW[D].
|
43
|
+
|
44
|
+
The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
45
|
+
|
46
|
+
Further information can be found in the official
|
47
|
+
`MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
|
48
|
+
|
49
|
+
Example:
|
50
|
+
|
51
|
+
.. code-block:: python
|
52
|
+
|
53
|
+
import torch
|
54
|
+
from monai.metrics import DiceMetric
|
55
|
+
from monai.losses import DiceLoss
|
56
|
+
from monai.networks import one_hot
|
57
|
+
|
58
|
+
batch_size, n_classes, h, w = 7, 5, 128, 128
|
59
|
+
|
60
|
+
y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
|
61
|
+
y_pred = torch.argmax(y_pred, 1, True) # convert to label map
|
62
|
+
|
63
|
+
# ground truth as label map
|
64
|
+
y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
|
65
|
+
|
66
|
+
dm = DiceMetric(
|
67
|
+
reduction="mean_batch", return_with_label=True, num_classes=n_classes
|
68
|
+
)
|
69
|
+
|
70
|
+
raw_scores = dm(y_pred, y)
|
71
|
+
print(dm.aggregate())
|
72
|
+
|
73
|
+
# now compute the Dice loss which should be the same as 1 - raw_scores
|
74
|
+
dl = DiceLoss(to_onehot_y=True, reduction="none")
|
75
|
+
loss = dl(one_hot(y_pred, n_classes), y).squeeze()
|
76
|
+
|
77
|
+
print(1.0 - loss) # same as raw_scores
|
36
78
|
|
37
|
-
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
38
79
|
|
39
80
|
Args:
|
40
|
-
include_background: whether to include Dice computation on the first channel of
|
41
|
-
|
42
|
-
reduction:
|
43
|
-
available reduction modes
|
44
|
-
|
45
|
-
get_not_nans: whether to return the `not_nans` count
|
46
|
-
|
47
|
-
ignore_empty: whether to ignore empty ground truth cases during calculation.
|
48
|
-
|
49
|
-
|
50
|
-
num_classes: number of input channels (always including the background). When this is None
|
81
|
+
include_background: whether to include Dice computation on the first channel/category of the prediction and
|
82
|
+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
|
83
|
+
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
|
84
|
+
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
|
85
|
+
selected, the metric will not do reduction.
|
86
|
+
get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
|
87
|
+
`not_nans` counts the number of valid values in the result, and will have the same shape.
|
88
|
+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
|
89
|
+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
|
90
|
+
are also empty.
|
91
|
+
num_classes: number of input channels (always including the background). When this is ``None``,
|
51
92
|
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
|
52
93
|
single-channel class indices and the number of classes is not automatically inferred from data.
|
53
94
|
return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
|
54
|
-
If `True`, use "label_{index}" as the key corresponding to C channels; if
|
95
|
+
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
|
55
96
|
the index begins at "0", otherwise at "1". It can also take a list of label names.
|
56
97
|
The outcome will then be returned as a dictionary.
|
57
98
|
|
@@ -77,22 +118,21 @@ class DiceMetric(CumulativeIterationMetric):
|
|
77
118
|
include_background=self.include_background,
|
78
119
|
reduction=MetricReduction.NONE,
|
79
120
|
get_not_nans=False,
|
80
|
-
|
121
|
+
apply_argmax=False,
|
81
122
|
ignore_empty=self.ignore_empty,
|
82
123
|
num_classes=self.num_classes,
|
83
124
|
)
|
84
125
|
|
85
126
|
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
86
127
|
"""
|
128
|
+
Compute the dice value using ``DiceHelper``.
|
129
|
+
|
87
130
|
Args:
|
88
|
-
y_pred:
|
89
|
-
|
90
|
-
should be binarized.
|
91
|
-
y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
|
92
|
-
in the one-hot format.
|
131
|
+
y_pred: prediction value, see class docstring for format definition.
|
132
|
+
y: ground truth label.
|
93
133
|
|
94
134
|
Raises:
|
95
|
-
ValueError: when `y_pred` has
|
135
|
+
ValueError: when `y_pred` has fewer than three dimensions.
|
96
136
|
"""
|
97
137
|
dims = y_pred.ndimension()
|
98
138
|
if dims < 3:
|
@@ -107,10 +147,8 @@ class DiceMetric(CumulativeIterationMetric):
|
|
107
147
|
Execute reduction and aggregation logic for the output of `compute_dice`.
|
108
148
|
|
109
149
|
Args:
|
110
|
-
reduction:
|
111
|
-
|
112
|
-
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
|
113
|
-
|
150
|
+
reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
|
151
|
+
By default this will do no reduction.
|
114
152
|
"""
|
115
153
|
data = self.get_buffer()
|
116
154
|
if not isinstance(data, torch.Tensor):
|
@@ -138,18 +176,20 @@ def compute_dice(
|
|
138
176
|
ignore_empty: bool = True,
|
139
177
|
num_classes: int | None = None,
|
140
178
|
) -> torch.Tensor:
|
141
|
-
"""
|
179
|
+
"""
|
180
|
+
Computes Dice score metric for a batch of predictions. This performs the same computation as
|
181
|
+
:py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
|
182
|
+
documentation for that class .
|
142
183
|
|
143
184
|
Args:
|
144
185
|
y_pred: input data to compute, typical segmentation model output.
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
num_classes: number of input channels (always including the background). When this is None,
|
186
|
+
y: ground truth to compute mean dice metric.
|
187
|
+
include_background: whether to include Dice computation on the first channel/category of the prediction and
|
188
|
+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
|
189
|
+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
|
190
|
+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
|
191
|
+
are also empty.
|
192
|
+
num_classes: number of input channels (always including the background). When this is ``None``,
|
153
193
|
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
|
154
194
|
single-channel class indices and the number of classes is not automatically inferred from data.
|
155
195
|
|
@@ -161,7 +201,7 @@ def compute_dice(
|
|
161
201
|
include_background=include_background,
|
162
202
|
reduction=MetricReduction.NONE,
|
163
203
|
get_not_nans=False,
|
164
|
-
|
204
|
+
apply_argmax=False,
|
165
205
|
ignore_empty=ignore_empty,
|
166
206
|
num_classes=num_classes,
|
167
207
|
)(y_pred=y_pred, y=y)
|
@@ -169,8 +209,8 @@ def compute_dice(
|
|
169
209
|
|
170
210
|
class DiceHelper:
|
171
211
|
"""
|
172
|
-
Compute Dice score between two tensors
|
173
|
-
|
212
|
+
Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
|
213
|
+
see the documentation for that class for input formats.
|
174
214
|
|
175
215
|
Example:
|
176
216
|
|
@@ -188,49 +228,65 @@ class DiceHelper:
|
|
188
228
|
score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
|
189
229
|
print(score, not_nans)
|
190
230
|
|
231
|
+
Args:
|
232
|
+
include_background: whether to include Dice computation on the first channel/category of the prediction and
|
233
|
+
ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
|
234
|
+
threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
|
235
|
+
apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
|
236
|
+
get the discrete prediction. Defaults to the value of ``not threshold``.
|
237
|
+
activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
|
238
|
+
thresholding. Defaults to False.
|
239
|
+
get_not_nans: whether to return the number of not-nan values.
|
240
|
+
reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
|
241
|
+
available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
|
242
|
+
selected, the metric will not do reduction.
|
243
|
+
ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
|
244
|
+
set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
|
245
|
+
are also empty.
|
246
|
+
num_classes: number of input channels (always including the background). When this is ``None``,
|
247
|
+
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
|
248
|
+
single-channel class indices and the number of classes is not automatically inferred from data.
|
191
249
|
"""
|
192
250
|
|
251
|
+
@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
|
252
|
+
@deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold")
|
193
253
|
def __init__(
|
194
254
|
self,
|
195
255
|
include_background: bool | None = None,
|
196
|
-
|
197
|
-
|
256
|
+
threshold: bool = False,
|
257
|
+
apply_argmax: bool | None = None,
|
198
258
|
activate: bool = False,
|
199
259
|
get_not_nans: bool = True,
|
200
260
|
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
|
201
261
|
ignore_empty: bool = True,
|
202
262
|
num_classes: int | None = None,
|
263
|
+
sigmoid: bool | None = None,
|
264
|
+
softmax: bool | None = None,
|
203
265
|
) -> None:
|
204
|
-
|
266
|
+
# handling deprecated arguments
|
267
|
+
if sigmoid is not None:
|
268
|
+
threshold = sigmoid
|
269
|
+
if softmax is not None:
|
270
|
+
apply_argmax = softmax
|
205
271
|
|
206
|
-
|
207
|
-
include_background: whether to include the score on the first channel
|
208
|
-
(default to the value of `sigmoid`, False).
|
209
|
-
sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
|
210
|
-
will be performed to get the discrete prediction. Defaults to False.
|
211
|
-
softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
|
212
|
-
get the discrete prediction. Defaults to the value of ``not sigmoid``.
|
213
|
-
activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
|
214
|
-
This option is only valid when ``sigmoid`` is True.
|
215
|
-
get_not_nans: whether to return the number of not-nan values.
|
216
|
-
reduction: define mode of reduction to the metrics
|
217
|
-
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
|
218
|
-
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
|
219
|
-
num_classes: number of input channels (always including the background). When this is None,
|
220
|
-
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
|
221
|
-
single-channel class indices and the number of classes is not automatically inferred from data.
|
222
|
-
"""
|
223
|
-
self.sigmoid = sigmoid
|
272
|
+
self.threshold = threshold
|
224
273
|
self.reduction = reduction
|
225
274
|
self.get_not_nans = get_not_nans
|
226
|
-
self.include_background =
|
227
|
-
self.
|
275
|
+
self.include_background = threshold if include_background is None else include_background
|
276
|
+
self.apply_argmax = not threshold if apply_argmax is None else apply_argmax
|
228
277
|
self.activate = activate
|
229
278
|
self.ignore_empty = ignore_empty
|
230
279
|
self.num_classes = num_classes
|
231
280
|
|
232
281
|
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
233
|
-
"""
|
282
|
+
"""
|
283
|
+
Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
|
284
|
+
for each batch item and for each channel of those items.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
y_pred: input predictions with shape HW[D].
|
288
|
+
y: ground truth with shape HW[D].
|
289
|
+
"""
|
234
290
|
y_o = torch.sum(y)
|
235
291
|
if y_o > 0:
|
236
292
|
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
|
@@ -243,25 +299,25 @@ class DiceHelper:
|
|
243
299
|
|
244
300
|
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
245
301
|
"""
|
302
|
+
Compute the metric for the given prediction and ground truth.
|
246
303
|
|
247
304
|
Args:
|
248
305
|
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
|
249
306
|
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
|
250
307
|
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
|
251
308
|
"""
|
252
|
-
|
309
|
+
_apply_argmax, _threshold = self.apply_argmax, self.threshold
|
253
310
|
if self.num_classes is None:
|
254
311
|
n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
|
255
312
|
else:
|
256
313
|
n_pred_ch = self.num_classes
|
257
314
|
if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
|
258
|
-
|
315
|
+
_apply_argmax = _threshold = False
|
259
316
|
|
260
|
-
if
|
261
|
-
|
262
|
-
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
|
317
|
+
if _apply_argmax and n_pred_ch > 1:
|
318
|
+
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
|
263
319
|
|
264
|
-
elif
|
320
|
+
elif _threshold:
|
265
321
|
if self.activate:
|
266
322
|
y_pred = torch.sigmoid(y_pred)
|
267
323
|
y_pred = y_pred > 0.5
|
@@ -15,12 +15,13 @@ from .acti_norm import ADN
|
|
15
15
|
from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish
|
16
16
|
from .aspp import SimpleASPP
|
17
17
|
from .backbone_fpn_utils import BackboneWithFPN
|
18
|
+
from .cablock import CABlock, FeedForward
|
18
19
|
from .convolutions import Convolution, ResidualUnit
|
19
20
|
from .crf import CRF
|
20
21
|
from .crossattention import CrossAttentionBlock
|
21
22
|
from .denseblock import ConvDenseBlock, DenseBlock
|
22
23
|
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
|
23
|
-
from .downsample import MaxAvgPool
|
24
|
+
from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample
|
24
25
|
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
|
25
26
|
from .encoder import BaseEncoder
|
26
27
|
from .fcn import FCN, GCN, MCFCN, Refine
|
@@ -0,0 +1,182 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
from typing import cast
|
14
|
+
|
15
|
+
import torch
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
|
19
|
+
from monai.networks.blocks.convolutions import Convolution
|
20
|
+
from monai.utils import optional_import
|
21
|
+
|
22
|
+
rearrange, _ = optional_import("einops", name="rearrange")
|
23
|
+
|
24
|
+
__all__ = ["FeedForward", "CABlock"]
|
25
|
+
|
26
|
+
|
27
|
+
class FeedForward(nn.Module):
|
28
|
+
"""Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
|
29
|
+
Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
spatial_dims: Number of spatial dimensions (2D or 3D)
|
33
|
+
dim: Number of input channels
|
34
|
+
ffn_expansion_factor: Factor to expand hidden features dimension
|
35
|
+
bias: Whether to use bias in convolution layers
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
|
39
|
+
super().__init__()
|
40
|
+
hidden_features = int(dim * ffn_expansion_factor)
|
41
|
+
|
42
|
+
self.project_in = Convolution(
|
43
|
+
spatial_dims=spatial_dims,
|
44
|
+
in_channels=dim,
|
45
|
+
out_channels=hidden_features * 2,
|
46
|
+
kernel_size=1,
|
47
|
+
bias=bias,
|
48
|
+
conv_only=True,
|
49
|
+
)
|
50
|
+
|
51
|
+
self.dwconv = Convolution(
|
52
|
+
spatial_dims=spatial_dims,
|
53
|
+
in_channels=hidden_features * 2,
|
54
|
+
out_channels=hidden_features * 2,
|
55
|
+
kernel_size=3,
|
56
|
+
strides=1,
|
57
|
+
padding=1,
|
58
|
+
groups=hidden_features * 2,
|
59
|
+
bias=bias,
|
60
|
+
conv_only=True,
|
61
|
+
)
|
62
|
+
|
63
|
+
self.project_out = Convolution(
|
64
|
+
spatial_dims=spatial_dims,
|
65
|
+
in_channels=hidden_features,
|
66
|
+
out_channels=dim,
|
67
|
+
kernel_size=1,
|
68
|
+
bias=bias,
|
69
|
+
conv_only=True,
|
70
|
+
)
|
71
|
+
|
72
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73
|
+
x = self.project_in(x)
|
74
|
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
75
|
+
return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2))
|
76
|
+
|
77
|
+
|
78
|
+
class CABlock(nn.Module):
|
79
|
+
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
|
80
|
+
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
|
81
|
+
convolutions for local mixing before attention, achieving linear complexity vs quadratic
|
82
|
+
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>
|
83
|
+
|
84
|
+
Args:
|
85
|
+
spatial_dims: Number of spatial dimensions (2D or 3D)
|
86
|
+
dim: Number of input channels
|
87
|
+
num_heads: Number of attention heads
|
88
|
+
bias: Whether to use bias in convolution layers
|
89
|
+
flash_attention: Whether to use flash attention optimization. Defaults to False.
|
90
|
+
|
91
|
+
Raises:
|
92
|
+
ValueError: If flash attention is not available in current PyTorch version
|
93
|
+
ValueError: If spatial_dims is greater than 3
|
94
|
+
"""
|
95
|
+
|
96
|
+
def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
|
97
|
+
super().__init__()
|
98
|
+
if flash_attention and not hasattr(F, "scaled_dot_product_attention"):
|
99
|
+
raise ValueError("Flash attention not available")
|
100
|
+
if spatial_dims > 3:
|
101
|
+
raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}")
|
102
|
+
self.spatial_dims = spatial_dims
|
103
|
+
self.num_heads = num_heads
|
104
|
+
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
105
|
+
self.flash_attention = flash_attention
|
106
|
+
|
107
|
+
self.qkv = Convolution(
|
108
|
+
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True
|
109
|
+
)
|
110
|
+
|
111
|
+
self.qkv_dwconv = Convolution(
|
112
|
+
spatial_dims=spatial_dims,
|
113
|
+
in_channels=dim * 3,
|
114
|
+
out_channels=dim * 3,
|
115
|
+
kernel_size=3,
|
116
|
+
strides=1,
|
117
|
+
padding=1,
|
118
|
+
groups=dim * 3,
|
119
|
+
bias=bias,
|
120
|
+
conv_only=True,
|
121
|
+
)
|
122
|
+
|
123
|
+
self.project_out = Convolution(
|
124
|
+
spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True
|
125
|
+
)
|
126
|
+
|
127
|
+
self._attention_fn = self._get_attention_fn()
|
128
|
+
|
129
|
+
def _get_attention_fn(self):
|
130
|
+
if self.flash_attention:
|
131
|
+
return self._flash_attention
|
132
|
+
return self._normal_attention
|
133
|
+
|
134
|
+
def _flash_attention(self, q, k, v):
|
135
|
+
"""Flash attention implementation using scaled dot-product attention."""
|
136
|
+
scale = float(self.temperature.mean())
|
137
|
+
out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False)
|
138
|
+
return out
|
139
|
+
|
140
|
+
def _normal_attention(self, q, k, v):
|
141
|
+
"""Attention matrix multiplication with depth-wise convolutions."""
|
142
|
+
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
143
|
+
attn = attn.softmax(dim=-1)
|
144
|
+
return attn @ v
|
145
|
+
|
146
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
147
|
+
"""Forward pass for MDTA attention.
|
148
|
+
1. Apply depth-wise convolutions to Q, K, V
|
149
|
+
2. Reshape Q, K, V for multi-head attention
|
150
|
+
3. Compute attention matrix using flash or normal attention
|
151
|
+
4. Reshape and project out attention output"""
|
152
|
+
spatial_dims = x.shape[2:]
|
153
|
+
|
154
|
+
# Project and mix
|
155
|
+
qkv = self.qkv_dwconv(self.qkv(x))
|
156
|
+
q, k, v = qkv.chunk(3, dim=1)
|
157
|
+
|
158
|
+
# Select attention
|
159
|
+
if self.spatial_dims == 2:
|
160
|
+
qkv_to_multihead = "b (head c) h w -> b head c (h w)"
|
161
|
+
multihead_to_qkv = "b head c (h w) -> b (head c) h w"
|
162
|
+
else: # dims == 3
|
163
|
+
qkv_to_multihead = "b (head c) d h w -> b head c (d h w)"
|
164
|
+
multihead_to_qkv = "b head c (d h w) -> b (head c) d h w"
|
165
|
+
|
166
|
+
# Reconstruct and project feature map
|
167
|
+
q = rearrange(q, qkv_to_multihead, head=self.num_heads)
|
168
|
+
k = rearrange(k, qkv_to_multihead, head=self.num_heads)
|
169
|
+
v = rearrange(v, qkv_to_multihead, head=self.num_heads)
|
170
|
+
|
171
|
+
q = torch.nn.functional.normalize(q, dim=-1)
|
172
|
+
k = torch.nn.functional.normalize(k, dim=-1)
|
173
|
+
|
174
|
+
out = self._attention_fn(q, k, v)
|
175
|
+
out = rearrange(
|
176
|
+
out,
|
177
|
+
multihead_to_qkv,
|
178
|
+
head=self.num_heads,
|
179
|
+
**dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)),
|
180
|
+
)
|
181
|
+
|
182
|
+
return cast(torch.Tensor, self.project_out(out))
|