monai-weekly 1.5.dev2512__py3-none-any.whl → 1.5.dev2514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "5c5ca232a1ccf0e7eae99dd386b8b37472d8aa4c"
139
+ __commit_id__ = "4986d7ffd2d351c9d66de0e0329884b1a26d5500"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-03-23T02:31:42+0000",
11
+ "date": "2025-04-06T02:31:51+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e4701e24c97d1f8c7ba40777c238cdfe14b04581",
15
- "version": "1.5.dev2512"
14
+ "full-revisionid": "a3ea49fc4e600d131daadad61ea340df25fcfdaa",
15
+ "version": "1.5.dev2514"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
monai/metrics/meandice.py CHANGED
@@ -14,7 +14,7 @@ from __future__ import annotations
14
14
  import torch
15
15
 
16
16
  from monai.metrics.utils import do_metric_reduction
17
- from monai.utils import MetricReduction
17
+ from monai.utils import MetricReduction, deprecated_arg
18
18
 
19
19
  from .metric import CumulativeIterationMetric
20
20
 
@@ -23,35 +23,76 @@ __all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
23
23
 
24
24
  class DiceMetric(CumulativeIterationMetric):
25
25
  """
26
- Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
26
+ Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
27
+ or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
28
+ and multi-label tasks.
27
29
 
28
- It supports both multi-classes and multi-labels tasks.
29
- Input `y_pred` is compared with ground truth `y`.
30
- `y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
31
- one-hot format. The `include_background` parameter can be set to ``False`` to exclude
32
- the first category (channel index 0) which is by convention assumed to be background. If the non-background
33
- segmentations are small compared to the total image size they can get overwhelmed by the signal from the
34
- background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
35
- `y` can also be in the format of `B1HW[D]`.
30
+ If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
31
+ hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
32
+ and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
33
+ this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
34
+ for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
35
+ this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
36
+ into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
37
+ dimensions to produce a label map.
38
+
39
+ The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
40
+ is by convention assumed to be background. If the non-background segmentations are small compared to the total
41
+ image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
42
+ and ground truth is BCHW[D].
43
+
44
+ The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
45
+
46
+ Further information can be found in the official
47
+ `MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.
48
+
49
+ Example:
50
+
51
+ .. code-block:: python
52
+
53
+ import torch
54
+ from monai.metrics import DiceMetric
55
+ from monai.losses import DiceLoss
56
+ from monai.networks import one_hot
57
+
58
+ batch_size, n_classes, h, w = 7, 5, 128, 128
59
+
60
+ y_pred = torch.rand(batch_size, n_classes, h, w) # network predictions
61
+ y_pred = torch.argmax(y_pred, 1, True) # convert to label map
62
+
63
+ # ground truth as label map
64
+ y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))
65
+
66
+ dm = DiceMetric(
67
+ reduction="mean_batch", return_with_label=True, num_classes=n_classes
68
+ )
69
+
70
+ raw_scores = dm(y_pred, y)
71
+ print(dm.aggregate())
72
+
73
+ # now compute the Dice loss which should be the same as 1 - raw_scores
74
+ dl = DiceLoss(to_onehot_y=True, reduction="none")
75
+ loss = dl(one_hot(y_pred, n_classes), y).squeeze()
76
+
77
+ print(1.0 - loss) # same as raw_scores
36
78
 
37
- Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
38
79
 
39
80
  Args:
40
- include_background: whether to include Dice computation on the first channel of
41
- the predicted output. Defaults to ``True``.
42
- reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
43
- available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
44
- ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
45
- get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
46
- Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.
47
- ignore_empty: whether to ignore empty ground truth cases during calculation.
48
- If `True`, NaN value will be set for empty ground truth cases.
49
- If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
50
- num_classes: number of input channels (always including the background). When this is None,
81
+ include_background: whether to include Dice computation on the first channel/category of the prediction and
82
+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
83
+ reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
84
+ available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
85
+ selected, the metric will not do reduction.
86
+ get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
87
+ `not_nans` counts the number of valid values in the result, and will have the same shape.
88
+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
89
+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
90
+ are also empty.
91
+ num_classes: number of input channels (always including the background). When this is ``None``,
51
92
  ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
52
93
  single-channel class indices and the number of classes is not automatically inferred from data.
53
94
  return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
54
- If `True`, use "label_{index}" as the key corresponding to C channels; if 'include_background' is True,
95
+ If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
55
96
  the index begins at "0", otherwise at "1". It can also take a list of label names.
56
97
  The outcome will then be returned as a dictionary.
57
98
 
@@ -77,22 +118,21 @@ class DiceMetric(CumulativeIterationMetric):
77
118
  include_background=self.include_background,
78
119
  reduction=MetricReduction.NONE,
79
120
  get_not_nans=False,
80
- softmax=False,
121
+ apply_argmax=False,
81
122
  ignore_empty=self.ignore_empty,
82
123
  num_classes=self.num_classes,
83
124
  )
84
125
 
85
126
  def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
86
127
  """
128
+ Compute the dice value using ``DiceHelper``.
129
+
87
130
  Args:
88
- y_pred: input data to compute, typical segmentation model output.
89
- It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
90
- should be binarized.
91
- y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
92
- in the one-hot format.
131
+ y_pred: prediction value, see class docstring for format definition.
132
+ y: ground truth label.
93
133
 
94
134
  Raises:
95
- ValueError: when `y_pred` has less than three dimensions.
135
+ ValueError: when `y_pred` has fewer than three dimensions.
96
136
  """
97
137
  dims = y_pred.ndimension()
98
138
  if dims < 3:
@@ -107,10 +147,8 @@ class DiceMetric(CumulativeIterationMetric):
107
147
  Execute reduction and aggregation logic for the output of `compute_dice`.
108
148
 
109
149
  Args:
110
- reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
111
- available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
112
- ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
113
-
150
+ reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
151
+ By default this will do no reduction.
114
152
  """
115
153
  data = self.get_buffer()
116
154
  if not isinstance(data, torch.Tensor):
@@ -138,18 +176,20 @@ def compute_dice(
138
176
  ignore_empty: bool = True,
139
177
  num_classes: int | None = None,
140
178
  ) -> torch.Tensor:
141
- """Computes Dice score metric for a batch of predictions.
179
+ """
180
+ Computes Dice score metric for a batch of predictions. This performs the same computation as
181
+ :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
182
+ documentation for that class .
142
183
 
143
184
  Args:
144
185
  y_pred: input data to compute, typical segmentation model output.
145
- `y_pred` can be single-channel class indices or in the one-hot format.
146
- y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
147
- include_background: whether to include Dice computation on the first channel of
148
- the predicted output. Defaults to True.
149
- ignore_empty: whether to ignore empty ground truth cases during calculation.
150
- If `True`, NaN value will be set for empty ground truth cases.
151
- If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
152
- num_classes: number of input channels (always including the background). When this is None,
186
+ y: ground truth to compute mean dice metric.
187
+ include_background: whether to include Dice computation on the first channel/category of the prediction and
188
+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
189
+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
190
+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
191
+ are also empty.
192
+ num_classes: number of input channels (always including the background). When this is ``None``,
153
193
  ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
154
194
  single-channel class indices and the number of classes is not automatically inferred from data.
155
195
 
@@ -161,7 +201,7 @@ def compute_dice(
161
201
  include_background=include_background,
162
202
  reduction=MetricReduction.NONE,
163
203
  get_not_nans=False,
164
- softmax=False,
204
+ apply_argmax=False,
165
205
  ignore_empty=ignore_empty,
166
206
  num_classes=num_classes,
167
207
  )(y_pred=y_pred, y=y)
@@ -169,8 +209,8 @@ def compute_dice(
169
209
 
170
210
  class DiceHelper:
171
211
  """
172
- Compute Dice score between two tensors `y_pred` and `y`.
173
- `y_pred` and `y` can be single-channel class indices or in the one-hot format.
212
+ Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
213
+ see the documentation for that class for input formats.
174
214
 
175
215
  Example:
176
216
 
@@ -188,49 +228,65 @@ class DiceHelper:
188
228
  score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
189
229
  print(score, not_nans)
190
230
 
231
+ Args:
232
+ include_background: whether to include Dice computation on the first channel/category of the prediction and
233
+ ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
234
+ threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
235
+ apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
236
+ get the discrete prediction. Defaults to the value of ``not threshold``.
237
+ activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
238
+ thresholding. Defaults to False.
239
+ get_not_nans: whether to return the number of not-nan values.
240
+ reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
241
+ available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
242
+ selected, the metric will not do reduction.
243
+ ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
244
+ set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
245
+ are also empty.
246
+ num_classes: number of input channels (always including the background). When this is ``None``,
247
+ ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248
+ single-channel class indices and the number of classes is not automatically inferred from data.
191
249
  """
192
250
 
251
+ @deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
252
+ @deprecated_arg("sigmoid", "1.5", "1.7", "Use `threshold` instead.", new_name="threshold")
193
253
  def __init__(
194
254
  self,
195
255
  include_background: bool | None = None,
196
- sigmoid: bool = False,
197
- softmax: bool | None = None,
256
+ threshold: bool = False,
257
+ apply_argmax: bool | None = None,
198
258
  activate: bool = False,
199
259
  get_not_nans: bool = True,
200
260
  reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
201
261
  ignore_empty: bool = True,
202
262
  num_classes: int | None = None,
263
+ sigmoid: bool | None = None,
264
+ softmax: bool | None = None,
203
265
  ) -> None:
204
- """
266
+ # handling deprecated arguments
267
+ if sigmoid is not None:
268
+ threshold = sigmoid
269
+ if softmax is not None:
270
+ apply_argmax = softmax
205
271
 
206
- Args:
207
- include_background: whether to include the score on the first channel
208
- (default to the value of `sigmoid`, False).
209
- sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
210
- will be performed to get the discrete prediction. Defaults to False.
211
- softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
212
- get the discrete prediction. Defaults to the value of ``not sigmoid``.
213
- activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
214
- This option is only valid when ``sigmoid`` is True.
215
- get_not_nans: whether to return the number of not-nan values.
216
- reduction: define mode of reduction to the metrics
217
- ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
218
- If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
219
- num_classes: number of input channels (always including the background). When this is None,
220
- ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
221
- single-channel class indices and the number of classes is not automatically inferred from data.
222
- """
223
- self.sigmoid = sigmoid
272
+ self.threshold = threshold
224
273
  self.reduction = reduction
225
274
  self.get_not_nans = get_not_nans
226
- self.include_background = sigmoid if include_background is None else include_background
227
- self.softmax = not sigmoid if softmax is None else softmax
275
+ self.include_background = threshold if include_background is None else include_background
276
+ self.apply_argmax = not threshold if apply_argmax is None else apply_argmax
228
277
  self.activate = activate
229
278
  self.ignore_empty = ignore_empty
230
279
  self.num_classes = num_classes
231
280
 
232
281
  def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
233
- """"""
282
+ """
283
+ Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
284
+ for each batch item and for each channel of those items.
285
+
286
+ Args:
287
+ y_pred: input predictions with shape HW[D].
288
+ y: ground truth with shape HW[D].
289
+ """
234
290
  y_o = torch.sum(y)
235
291
  if y_o > 0:
236
292
  return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
@@ -243,25 +299,25 @@ class DiceHelper:
243
299
 
244
300
  def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
245
301
  """
302
+ Compute the metric for the given prediction and ground truth.
246
303
 
247
304
  Args:
248
305
  y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
249
306
  the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
250
307
  y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
251
308
  """
252
- _softmax, _sigmoid = self.softmax, self.sigmoid
309
+ _apply_argmax, _threshold = self.apply_argmax, self.threshold
253
310
  if self.num_classes is None:
254
311
  n_pred_ch = y_pred.shape[1] # y_pred is in one-hot format or multi-channel scores
255
312
  else:
256
313
  n_pred_ch = self.num_classes
257
314
  if y_pred.shape[1] == 1 and self.num_classes > 1: # y_pred is single-channel class indices
258
- _softmax = _sigmoid = False
315
+ _apply_argmax = _threshold = False
259
316
 
260
- if _softmax:
261
- if n_pred_ch > 1:
262
- y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
317
+ if _apply_argmax and n_pred_ch > 1:
318
+ y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
263
319
 
264
- elif _sigmoid:
320
+ elif _threshold:
265
321
  if self.activate:
266
322
  y_pred = torch.sigmoid(y_pred)
267
323
  y_pred = y_pred > 0.5
@@ -15,12 +15,13 @@ from .acti_norm import ADN
15
15
  from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish
16
16
  from .aspp import SimpleASPP
17
17
  from .backbone_fpn_utils import BackboneWithFPN
18
+ from .cablock import CABlock, FeedForward
18
19
  from .convolutions import Convolution, ResidualUnit
19
20
  from .crf import CRF
20
21
  from .crossattention import CrossAttentionBlock
21
22
  from .denseblock import ConvDenseBlock, DenseBlock
22
23
  from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
23
- from .downsample import MaxAvgPool
24
+ from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample
24
25
  from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
25
26
  from .encoder import BaseEncoder
26
27
  from .fcn import FCN, GCN, MCFCN, Refine
@@ -0,0 +1,182 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+ from __future__ import annotations
12
+
13
+ from typing import cast
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from monai.networks.blocks.convolutions import Convolution
20
+ from monai.utils import optional_import
21
+
22
+ rearrange, _ = optional_import("einops", name="rearrange")
23
+
24
+ __all__ = ["FeedForward", "CABlock"]
25
+
26
+
27
+ class FeedForward(nn.Module):
28
+ """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
29
+ Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.
30
+
31
+ Args:
32
+ spatial_dims: Number of spatial dimensions (2D or 3D)
33
+ dim: Number of input channels
34
+ ffn_expansion_factor: Factor to expand hidden features dimension
35
+ bias: Whether to use bias in convolution layers
36
+ """
37
+
38
+ def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool):
39
+ super().__init__()
40
+ hidden_features = int(dim * ffn_expansion_factor)
41
+
42
+ self.project_in = Convolution(
43
+ spatial_dims=spatial_dims,
44
+ in_channels=dim,
45
+ out_channels=hidden_features * 2,
46
+ kernel_size=1,
47
+ bias=bias,
48
+ conv_only=True,
49
+ )
50
+
51
+ self.dwconv = Convolution(
52
+ spatial_dims=spatial_dims,
53
+ in_channels=hidden_features * 2,
54
+ out_channels=hidden_features * 2,
55
+ kernel_size=3,
56
+ strides=1,
57
+ padding=1,
58
+ groups=hidden_features * 2,
59
+ bias=bias,
60
+ conv_only=True,
61
+ )
62
+
63
+ self.project_out = Convolution(
64
+ spatial_dims=spatial_dims,
65
+ in_channels=hidden_features,
66
+ out_channels=dim,
67
+ kernel_size=1,
68
+ bias=bias,
69
+ conv_only=True,
70
+ )
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ x = self.project_in(x)
74
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
75
+ return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2))
76
+
77
+
78
+ class CABlock(nn.Module):
79
+ """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
80
+ by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
81
+ convolutions for local mixing before attention, achieving linear complexity vs quadratic
82
+ in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>
83
+
84
+ Args:
85
+ spatial_dims: Number of spatial dimensions (2D or 3D)
86
+ dim: Number of input channels
87
+ num_heads: Number of attention heads
88
+ bias: Whether to use bias in convolution layers
89
+ flash_attention: Whether to use flash attention optimization. Defaults to False.
90
+
91
+ Raises:
92
+ ValueError: If flash attention is not available in current PyTorch version
93
+ ValueError: If spatial_dims is greater than 3
94
+ """
95
+
96
+ def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False):
97
+ super().__init__()
98
+ if flash_attention and not hasattr(F, "scaled_dot_product_attention"):
99
+ raise ValueError("Flash attention not available")
100
+ if spatial_dims > 3:
101
+ raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}")
102
+ self.spatial_dims = spatial_dims
103
+ self.num_heads = num_heads
104
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
105
+ self.flash_attention = flash_attention
106
+
107
+ self.qkv = Convolution(
108
+ spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True
109
+ )
110
+
111
+ self.qkv_dwconv = Convolution(
112
+ spatial_dims=spatial_dims,
113
+ in_channels=dim * 3,
114
+ out_channels=dim * 3,
115
+ kernel_size=3,
116
+ strides=1,
117
+ padding=1,
118
+ groups=dim * 3,
119
+ bias=bias,
120
+ conv_only=True,
121
+ )
122
+
123
+ self.project_out = Convolution(
124
+ spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True
125
+ )
126
+
127
+ self._attention_fn = self._get_attention_fn()
128
+
129
+ def _get_attention_fn(self):
130
+ if self.flash_attention:
131
+ return self._flash_attention
132
+ return self._normal_attention
133
+
134
+ def _flash_attention(self, q, k, v):
135
+ """Flash attention implementation using scaled dot-product attention."""
136
+ scale = float(self.temperature.mean())
137
+ out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False)
138
+ return out
139
+
140
+ def _normal_attention(self, q, k, v):
141
+ """Attention matrix multiplication with depth-wise convolutions."""
142
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
143
+ attn = attn.softmax(dim=-1)
144
+ return attn @ v
145
+
146
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
147
+ """Forward pass for MDTA attention.
148
+ 1. Apply depth-wise convolutions to Q, K, V
149
+ 2. Reshape Q, K, V for multi-head attention
150
+ 3. Compute attention matrix using flash or normal attention
151
+ 4. Reshape and project out attention output"""
152
+ spatial_dims = x.shape[2:]
153
+
154
+ # Project and mix
155
+ qkv = self.qkv_dwconv(self.qkv(x))
156
+ q, k, v = qkv.chunk(3, dim=1)
157
+
158
+ # Select attention
159
+ if self.spatial_dims == 2:
160
+ qkv_to_multihead = "b (head c) h w -> b head c (h w)"
161
+ multihead_to_qkv = "b head c (h w) -> b (head c) h w"
162
+ else: # dims == 3
163
+ qkv_to_multihead = "b (head c) d h w -> b head c (d h w)"
164
+ multihead_to_qkv = "b head c (d h w) -> b (head c) d h w"
165
+
166
+ # Reconstruct and project feature map
167
+ q = rearrange(q, qkv_to_multihead, head=self.num_heads)
168
+ k = rearrange(k, qkv_to_multihead, head=self.num_heads)
169
+ v = rearrange(v, qkv_to_multihead, head=self.num_heads)
170
+
171
+ q = torch.nn.functional.normalize(q, dim=-1)
172
+ k = torch.nn.functional.normalize(k, dim=-1)
173
+
174
+ out = self._attention_fn(q, k, v)
175
+ out = rearrange(
176
+ out,
177
+ multihead_to_qkv,
178
+ head=self.num_heads,
179
+ **dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)),
180
+ )
181
+
182
+ return cast(torch.Tensor, self.project_out(out))