monai-weekly 1.5.dev2447__py3-none-any.whl → 1.5.dev2449__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,6 +53,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
53
53
  from .generator import Generator
54
54
  from .highresnet import HighResBlock, HighResNet
55
55
  from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56
+ from .masked_autoencoder_vit import MaskedAutoEncoderViT
56
57
  from .mednext import (
57
58
  MedNeXt,
58
59
  MedNext,
@@ -0,0 +1,211 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Sequence
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
21
+ from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
22
+ from monai.networks.blocks.transformerblock import TransformerBlock
23
+ from monai.networks.layers import trunc_normal_
24
+ from monai.utils import ensure_tuple_rep
25
+ from monai.utils.module import look_up_option
26
+
27
+ SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
28
+
29
+ __all__ = ["MaskedAutoEncoderViT"]
30
+
31
+
32
+ class MaskedAutoEncoderViT(nn.Module):
33
+ """
34
+ Masked Autoencoder (ViT), based on: "Kaiming et al.,
35
+ Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
36
+ Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
37
+ the masked patches, resulting in improved training speed.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ in_channels: int,
43
+ img_size: Sequence[int] | int,
44
+ patch_size: Sequence[int] | int,
45
+ hidden_size: int = 768,
46
+ mlp_dim: int = 512,
47
+ num_layers: int = 12,
48
+ num_heads: int = 12,
49
+ masking_ratio: float = 0.75,
50
+ decoder_hidden_size: int = 384,
51
+ decoder_mlp_dim: int = 512,
52
+ decoder_num_layers: int = 4,
53
+ decoder_num_heads: int = 12,
54
+ proj_type: str = "conv",
55
+ pos_embed_type: str = "sincos",
56
+ decoder_pos_embed_type: str = "sincos",
57
+ dropout_rate: float = 0.0,
58
+ spatial_dims: int = 3,
59
+ qkv_bias: bool = False,
60
+ save_attn: bool = False,
61
+ ) -> None:
62
+ """
63
+ Args:
64
+ in_channels: dimension of input channels or the number of channels for input.
65
+ img_size: dimension of input image.
66
+ patch_size: dimension of patch size
67
+ hidden_size: dimension of hidden layer. Defaults to 768.
68
+ mlp_dim: dimension of feedforward layer. Defaults to 512.
69
+ num_layers: number of transformer blocks. Defaults to 12.
70
+ num_heads: number of attention heads. Defaults to 12.
71
+ masking_ratio: ratio of patches to be masked. Defaults to 0.75.
72
+ decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
73
+ decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
74
+ decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
75
+ decoder_num_heads: number of attention heads for decoder. Defaults to 12.
76
+ proj_type: position embedding layer type. Defaults to "conv".
77
+ pos_embed_type: position embedding layer type. Defaults to "sincos".
78
+ decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
79
+ dropout_rate: fraction of the input units to drop. Defaults to 0.0.
80
+ spatial_dims: number of spatial dimensions. Defaults to 3.
81
+ qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
82
+ save_attn: to make accessible the attention in self attention block. Defaults to False.
83
+ Examples::
84
+ # for single channel input with image size of (96,96,96), and sin-cos positional encoding
85
+ >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
86
+ pos_embed_type='sincos')
87
+ # for 3-channel with image size of (128,128,128) and a learnable positional encoding
88
+ >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
89
+ # for 3-channel with image size of (224,224) and a masking ratio of 0.25
90
+ >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
91
+ spatial_dims=2)
92
+ """
93
+
94
+ super().__init__()
95
+
96
+ if not (0 <= dropout_rate <= 1):
97
+ raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")
98
+
99
+ if hidden_size % num_heads != 0:
100
+ raise ValueError("hidden_size should be divisible by num_heads.")
101
+
102
+ if decoder_hidden_size % decoder_num_heads != 0:
103
+ raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")
104
+
105
+ self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
106
+ self.img_size = ensure_tuple_rep(img_size, spatial_dims)
107
+ self.spatial_dims = spatial_dims
108
+ for m, p in zip(self.img_size, self.patch_size):
109
+ if m % p != 0:
110
+ raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")
111
+
112
+ self.decoder_hidden_size = decoder_hidden_size
113
+
114
+ if masking_ratio <= 0 or masking_ratio >= 1:
115
+ raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")
116
+
117
+ self.masking_ratio = masking_ratio
118
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
119
+
120
+ self.patch_embedding = PatchEmbeddingBlock(
121
+ in_channels=in_channels,
122
+ img_size=img_size,
123
+ patch_size=patch_size,
124
+ hidden_size=hidden_size,
125
+ num_heads=num_heads,
126
+ proj_type=proj_type,
127
+ pos_embed_type=pos_embed_type,
128
+ dropout_rate=dropout_rate,
129
+ spatial_dims=self.spatial_dims,
130
+ )
131
+ blocks = [
132
+ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
133
+ for _ in range(num_layers)
134
+ ]
135
+ self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))
136
+
137
+ # decoder
138
+ self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)
139
+
140
+ self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
141
+
142
+ self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
143
+ self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))
144
+
145
+ decoder_blocks = [
146
+ TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
147
+ for _ in range(decoder_num_layers)
148
+ ]
149
+ self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
150
+ self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)
151
+
152
+ self._init_weights()
153
+
154
+ def _init_weights(self):
155
+ """
156
+ similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
157
+ classification tokens
158
+ """
159
+ if self.decoder_pos_embed_type == "none":
160
+ pass
161
+ elif self.decoder_pos_embed_type == "learnable":
162
+ trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
163
+ elif self.decoder_pos_embed_type == "sincos":
164
+ grid_size = []
165
+ for in_size, pa_size in zip(self.img_size, self.patch_size):
166
+ grid_size.append(in_size // pa_size)
167
+
168
+ self.decoder_pos_embedding = build_sincos_position_embedding(
169
+ grid_size, self.decoder_hidden_size, self.spatial_dims
170
+ )
171
+
172
+ else:
173
+ raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")
174
+
175
+ # initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
176
+ trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
177
+ trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)
178
+
179
+ def _masking(self, x, masking_ratio: float | None = None):
180
+ batch_size, num_tokens, _ = x.shape
181
+ percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
182
+ selected_indices = torch.multinomial(
183
+ torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
184
+ )
185
+ x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
186
+ mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
187
+ mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0
188
+
189
+ return x_masked, selected_indices, mask
190
+
191
+ def forward(self, x, masking_ratio: float | None = None):
192
+ x = self.patch_embedding(x)
193
+ x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)
194
+
195
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
196
+ x = torch.cat((cls_tokens, x), dim=1)
197
+
198
+ x = self.blocks(x)
199
+
200
+ # decoder
201
+ x = self.decoder_embed(x)
202
+
203
+ x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
204
+ x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
205
+ x_ = x_ + self.decoder_pos_embedding
206
+ x = torch.cat([x[:, :1, :], x_], dim=1)
207
+ x = self.decoder_blocks(x)
208
+ x = self.decoder_pred(x)
209
+
210
+ x = x[:, 1:, :]
211
+ return x, mask
@@ -13,7 +13,6 @@ from __future__ import annotations
13
13
 
14
14
  import itertools
15
15
  from collections.abc import Sequence
16
- from typing import Final
17
16
 
18
17
  import numpy as np
19
18
  import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
51
50
  <https://arxiv.org/abs/2201.01266>"
52
51
  """
53
52
 
54
- patch_size: Final[int] = 2
55
-
56
53
  @deprecated_arg(
57
54
  name="img_size",
58
55
  since="1.3",
@@ -65,18 +62,24 @@ class SwinUNETR(nn.Module):
65
62
  img_size: Sequence[int] | int,
66
63
  in_channels: int,
67
64
  out_channels: int,
65
+ patch_size: int = 2,
68
66
  depths: Sequence[int] = (2, 2, 2, 2),
69
67
  num_heads: Sequence[int] = (3, 6, 12, 24),
68
+ window_size: Sequence[int] | int = 7,
69
+ qkv_bias: bool = True,
70
+ mlp_ratio: float = 4.0,
70
71
  feature_size: int = 24,
71
72
  norm_name: tuple | str = "instance",
72
73
  drop_rate: float = 0.0,
73
74
  attn_drop_rate: float = 0.0,
74
75
  dropout_path_rate: float = 0.0,
75
76
  normalize: bool = True,
77
+ norm_layer: type[LayerNorm] = nn.LayerNorm,
78
+ patch_norm: bool = False,
76
79
  use_checkpoint: bool = False,
77
80
  spatial_dims: int = 3,
78
- downsample="merging",
79
- use_v2=False,
81
+ downsample: str | nn.Module = "merging",
82
+ use_v2: bool = False,
80
83
  ) -> None:
81
84
  """
82
85
  Args:
@@ -86,14 +89,20 @@ class SwinUNETR(nn.Module):
86
89
  It will be removed in an upcoming version.
87
90
  in_channels: dimension of input channels.
88
91
  out_channels: dimension of output channels.
92
+ patch_size: size of the patch token.
89
93
  feature_size: dimension of network feature size.
90
94
  depths: number of layers in each stage.
91
95
  num_heads: number of attention heads.
96
+ window_size: local window size.
97
+ qkv_bias: add a learnable bias to query, key, value.
98
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
92
99
  norm_name: feature normalization type and arguments.
93
100
  drop_rate: dropout rate.
94
101
  attn_drop_rate: attention dropout rate.
95
102
  dropout_path_rate: drop path rate.
96
103
  normalize: normalize output intermediate features in each stage.
104
+ norm_layer: normalization layer.
105
+ patch_norm: whether to apply normalization to the patch embedding. Default is False.
97
106
  use_checkpoint: use gradient checkpointing for reduced memory usage.
98
107
  spatial_dims: number of spatial dims.
99
108
  downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ class SwinUNETR(nn.Module):
116
125
 
117
126
  super().__init__()
118
127
 
119
- img_size = ensure_tuple_rep(img_size, spatial_dims)
120
- patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
121
- window_size = ensure_tuple_rep(7, spatial_dims)
122
-
123
128
  if spatial_dims not in (2, 3):
124
129
  raise ValueError("spatial dimension should be 2 or 3.")
125
130
 
131
+ self.patch_size = patch_size
132
+
133
+ img_size = ensure_tuple_rep(img_size, spatial_dims)
134
+ patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
135
+ window_size = ensure_tuple_rep(window_size, spatial_dims)
136
+
126
137
  self._check_input_size(img_size)
127
138
 
128
139
  if not (0 <= drop_rate <= 1):
@@ -146,12 +157,13 @@ class SwinUNETR(nn.Module):
146
157
  patch_size=patch_sizes,
147
158
  depths=depths,
148
159
  num_heads=num_heads,
149
- mlp_ratio=4.0,
150
- qkv_bias=True,
160
+ mlp_ratio=mlp_ratio,
161
+ qkv_bias=qkv_bias,
151
162
  drop_rate=drop_rate,
152
163
  attn_drop_rate=attn_drop_rate,
153
164
  drop_path_rate=dropout_path_rate,
154
- norm_layer=nn.LayerNorm,
165
+ norm_layer=norm_layer,
166
+ patch_norm=patch_norm,
155
167
  use_checkpoint=use_checkpoint,
156
168
  spatial_dims=spatial_dims,
157
169
  downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
@@ -531,6 +531,8 @@ from .utility.array import (
531
531
  RandIdentity,
532
532
  RandImageFilter,
533
533
  RandLambda,
534
+ RandTorchIO,
535
+ RandTorchVision,
534
536
  RemoveRepeatedChannel,
535
537
  RepeatChannel,
536
538
  SimulateDelay,
@@ -540,6 +542,7 @@ from .utility.array import (
540
542
  ToDevice,
541
543
  ToNumpy,
542
544
  ToPIL,
545
+ TorchIO,
543
546
  TorchVision,
544
547
  ToTensor,
545
548
  Transpose,
@@ -620,6 +623,9 @@ from .utility.dictionary import (
620
623
  RandLambdad,
621
624
  RandLambdaD,
622
625
  RandLambdaDict,
626
+ RandTorchIOd,
627
+ RandTorchIOD,
628
+ RandTorchIODict,
623
629
  RandTorchVisiond,
624
630
  RandTorchVisionD,
625
631
  RandTorchVisionDict,
@@ -653,6 +659,9 @@ from .utility.dictionary import (
653
659
  ToPILd,
654
660
  ToPILD,
655
661
  ToPILDict,
662
+ TorchIOd,
663
+ TorchIOD,
664
+ TorchIODict,
656
665
  TorchVisiond,
657
666
  TorchVisionD,
658
667
  TorchVisionDict,
@@ -18,10 +18,10 @@ import logging
18
18
  import sys
19
19
  import time
20
20
  import warnings
21
- from collections.abc import Mapping, Sequence
21
+ from collections.abc import Hashable, Mapping, Sequence
22
22
  from copy import deepcopy
23
23
  from functools import partial
24
- from typing import Any, Callable
24
+ from typing import Any, Callable, Union
25
25
 
26
26
  import numpy as np
27
27
  import torch
@@ -99,11 +99,14 @@ __all__ = [
99
99
  "ConvertToMultiChannelBasedOnBratsClasses",
100
100
  "AddExtremePointsChannel",
101
101
  "TorchVision",
102
+ "TorchIO",
102
103
  "MapLabelValue",
103
104
  "IntensityStats",
104
105
  "ToDevice",
105
106
  "CuCIM",
106
107
  "RandCuCIM",
108
+ "RandTorchIO",
109
+ "RandTorchVision",
107
110
  "ToCupy",
108
111
  "ImageFilter",
109
112
  "RandImageFilter",
@@ -1051,12 +1054,11 @@ class ClassesToIndices(Transform, MultiSampleTrait):
1051
1054
 
1052
1055
  class ConvertToMultiChannelBasedOnBratsClasses(Transform):
1053
1056
  """
1054
- Convert labels to multi channels based on brats18 classes:
1055
- label 1 is the necrotic and non-enhancing tumor core
1056
- label 2 is the peritumoral edema
1057
- label 4 is the GD-enhancing tumor
1058
- The possible classes are TC (Tumor core), WT (Whole tumor)
1059
- and ET (Enhancing tumor).
1057
+ Convert labels to multi channels based on `brats18 <https://www.med.upenn.edu/sbia/brats2018/data.html>`_ classes,
1058
+ which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
1059
+ label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
1060
+ label 2 is the peritumoral edema, which is counted only under WT subregion,
1061
+ label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1060
1062
  """
1061
1063
 
1062
1064
  backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
@@ -1136,12 +1138,44 @@ class AddExtremePointsChannel(Randomizable, Transform):
1136
1138
  return concatenate((img, points_image), axis=0)
1137
1139
 
1138
1140
 
1139
- class TorchVision:
1141
+ class TorchVision(Transform):
1140
1142
  """
1141
- This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
1142
- As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
1143
- data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
1143
+ This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
1144
+ Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
1145
+ """
1146
+
1147
+ backend = [TransformBackends.TORCH]
1148
+
1149
+ def __init__(self, name: str, *args, **kwargs) -> None:
1150
+ """
1151
+ Args:
1152
+ name: The transform name in TorchVision package.
1153
+ args: parameters for the TorchVision transform.
1154
+ kwargs: parameters for the TorchVision transform.
1155
+
1156
+ """
1157
+ super().__init__()
1158
+ self.name = name
1159
+ transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
1160
+ self.trans = transform(*args, **kwargs)
1161
+
1162
+ def __call__(self, img: NdarrayOrTensor):
1163
+ """
1164
+ Args:
1165
+ img: PyTorch Tensor data for the TorchVision transform.
1144
1166
 
1167
+ """
1168
+ img_t, *_ = convert_data_type(img, torch.Tensor)
1169
+
1170
+ out = self.trans(img_t)
1171
+ out, *_ = convert_to_dst_type(src=out, dst=img)
1172
+ return out
1173
+
1174
+
1175
+ class RandTorchVision(Transform, RandomizableTrait):
1176
+ """
1177
+ This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
1178
+ Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
1145
1179
  """
1146
1180
 
1147
1181
  backend = [TransformBackends.TORCH]
@@ -1172,6 +1206,68 @@ class TorchVision:
1172
1206
  return out
1173
1207
 
1174
1208
 
1209
+ class TorchIO(Transform):
1210
+ """
1211
+ This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
1212
+ See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1213
+ """
1214
+
1215
+ backend = [TransformBackends.TORCH]
1216
+
1217
+ def __init__(self, name: str, *args, **kwargs) -> None:
1218
+ """
1219
+ Args:
1220
+ name: The transform name in TorchIO package.
1221
+ args: parameters for the TorchIO transform.
1222
+ kwargs: parameters for the TorchIO transform.
1223
+ """
1224
+ super().__init__()
1225
+ self.name = name
1226
+ transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1227
+ self.trans = transform(*args, **kwargs)
1228
+
1229
+ def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1230
+ """
1231
+ Args:
1232
+ img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1233
+ or dict containing 4D tensors as values
1234
+
1235
+ """
1236
+ return self.trans(img)
1237
+
1238
+
1239
+ class RandTorchIO(Transform, RandomizableTrait):
1240
+ """
1241
+ This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
1242
+ See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1243
+ Use this wrapper for all TorchIO transform inheriting from RandomTransform:
1244
+ https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
1245
+ """
1246
+
1247
+ backend = [TransformBackends.TORCH]
1248
+
1249
+ def __init__(self, name: str, *args, **kwargs) -> None:
1250
+ """
1251
+ Args:
1252
+ name: The transform name in TorchIO package.
1253
+ args: parameters for the TorchIO transform.
1254
+ kwargs: parameters for the TorchIO transform.
1255
+ """
1256
+ super().__init__()
1257
+ self.name = name
1258
+ transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1259
+ self.trans = transform(*args, **kwargs)
1260
+
1261
+ def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1262
+ """
1263
+ Args:
1264
+ img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1265
+ or dict containing 4D tensors as values
1266
+
1267
+ """
1268
+ return self.trans(img)
1269
+
1270
+
1175
1271
  class MapLabelValue:
1176
1272
  """
1177
1273
  Utility to map label values to another set of values.
@@ -60,6 +60,7 @@ from monai.transforms.utility.array import (
60
60
  ToDevice,
61
61
  ToNumpy,
62
62
  ToPIL,
63
+ TorchIO,
63
64
  TorchVision,
64
65
  ToTensor,
65
66
  Transpose,
@@ -136,6 +137,9 @@ __all__ = [
136
137
  "RandLambdaD",
137
138
  "RandLambdaDict",
138
139
  "RandLambdad",
140
+ "RandTorchIOd",
141
+ "RandTorchIOD",
142
+ "RandTorchIODict",
139
143
  "RandTorchVisionD",
140
144
  "RandTorchVisionDict",
141
145
  "RandTorchVisiond",
@@ -172,6 +176,9 @@ __all__ = [
172
176
  "ToTensorD",
173
177
  "ToTensorDict",
174
178
  "ToTensord",
179
+ "TorchIOD",
180
+ "TorchIODict",
181
+ "TorchIOd",
175
182
  "TorchVisionD",
176
183
  "TorchVisionDict",
177
184
  "TorchVisiond",
@@ -1445,6 +1452,64 @@ class RandTorchVisiond(MapTransform, RandomizableTrait):
1445
1452
  return d
1446
1453
 
1447
1454
 
1455
+ class TorchIOd(MapTransform):
1456
+ """
1457
+ Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
1458
+ For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
1459
+ """
1460
+
1461
+ backend = TorchIO.backend
1462
+
1463
+ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1464
+ """
1465
+ Args:
1466
+ keys: keys of the corresponding items to be transformed.
1467
+ See also: :py:class:`monai.transforms.compose.MapTransform`
1468
+ name: The transform name in TorchIO package.
1469
+ allow_missing_keys: don't raise exception if key is missing.
1470
+ args: parameters for the TorchIO transform.
1471
+ kwargs: parameters for the TorchIO transform.
1472
+
1473
+ """
1474
+ super().__init__(keys, allow_missing_keys)
1475
+ self.name = name
1476
+ kwargs["include"] = self.keys
1477
+
1478
+ self.trans = TorchIO(name, *args, **kwargs)
1479
+
1480
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1481
+ return dict(self.trans(data))
1482
+
1483
+
1484
+ class RandTorchIOd(MapTransform, RandomizableTrait):
1485
+ """
1486
+ Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
1487
+ For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
1488
+ """
1489
+
1490
+ backend = TorchIO.backend
1491
+
1492
+ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1493
+ """
1494
+ Args:
1495
+ keys: keys of the corresponding items to be transformed.
1496
+ See also: :py:class:`monai.transforms.compose.MapTransform`
1497
+ name: The transform name in TorchIO package.
1498
+ allow_missing_keys: don't raise exception if key is missing.
1499
+ args: parameters for the TorchIO transform.
1500
+ kwargs: parameters for the TorchIO transform.
1501
+
1502
+ """
1503
+ super().__init__(keys, allow_missing_keys)
1504
+ self.name = name
1505
+ kwargs["include"] = self.keys
1506
+
1507
+ self.trans = TorchIO(name, *args, **kwargs)
1508
+
1509
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1510
+ return dict(self.trans(data))
1511
+
1512
+
1448
1513
  class MapLabelValued(MapTransform):
1449
1514
  """
1450
1515
  Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
@@ -1871,8 +1936,10 @@ ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsCla
1871
1936
  ConvertToMultiChannelBasedOnBratsClassesd
1872
1937
  )
1873
1938
  AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
1939
+ TorchIOD = TorchIODict = TorchIOd
1874
1940
  TorchVisionD = TorchVisionDict = TorchVisiond
1875
1941
  RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
1942
+ RandTorchIOD = RandTorchIODict = RandTorchIOd
1876
1943
  RandLambdaD = RandLambdaDict = RandLambdad
1877
1944
  MapLabelValueD = MapLabelValueDict = MapLabelValued
1878
1945
  IntensityStatsD = IntensityStatsDict = IntensityStatsd
monai/utils/module.py CHANGED
@@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
649
649
  current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650
650
 
651
651
  Returns:
652
- True if the current system GPU CUDA compute capability is greater than the specified version.
652
+ True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653
653
  """
654
654
  if current_ver_string is None:
655
655
  cuda_available = torch.cuda.is_available()
@@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
667
667
 
668
668
  ver, has_ver = optional_import("packaging.version", name="parse")
669
669
  if has_ver:
670
- return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
670
+ return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
671
671
  parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
672
672
  while len(parts) < 2:
673
673
  parts += ["0"]
674
674
  c_major, c_minor = parts[:2]
675
675
  c_mn = int(c_major), int(c_minor)
676
676
  mn = int(major), int(minor)
677
- return c_mn >= mn
677
+ return c_mn > mn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: monai-weekly
3
- Version: 1.5.dev2447
3
+ Version: 1.5.dev2449
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -40,6 +40,7 @@ Requires-Dist: pillow; extra == "all"
40
40
  Requires-Dist: tensorboard; extra == "all"
41
41
  Requires-Dist: gdown>=4.7.3; extra == "all"
42
42
  Requires-Dist: pytorch-ignite==0.4.11; extra == "all"
43
+ Requires-Dist: torchio; extra == "all"
43
44
  Requires-Dist: torchvision; extra == "all"
44
45
  Requires-Dist: itk>=5.2; extra == "all"
45
46
  Requires-Dist: tqdm>=4.47.0; extra == "all"
@@ -87,6 +88,8 @@ Provides-Extra: gdown
87
88
  Requires-Dist: gdown>=4.7.3; extra == "gdown"
88
89
  Provides-Extra: ignite
89
90
  Requires-Dist: pytorch-ignite==0.4.11; extra == "ignite"
91
+ Provides-Extra: torchio
92
+ Requires-Dist: torchio; extra == "torchio"
90
93
  Provides-Extra: torchvision
91
94
  Requires-Dist: torchvision; extra == "torchvision"
92
95
  Provides-Extra: itk