monai-weekly 1.5.dev2447__py3-none-any.whl → 1.5.dev2448__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "b1e915c323a8065cfe9e92de3013476f2f67c1b2"
139
+ __commit_id__ = "44e249d7d492d858199acfca1c948faa5aa33763"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-24T02:30:43+0000",
11
+ "date": "2024-12-01T02:35:43+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2d6751b01bf78fddabe03b2c53645c6bc9808ed8",
15
- "version": "1.5.dev2447"
14
+ "full-revisionid": "d4ff1455cf46b35e4dcfb6f57d54b0738b39f738",
15
+ "version": "1.5.dev2448"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
monai/bundle/__init__.py CHANGED
@@ -43,4 +43,4 @@ from .utils import (
43
43
  MACRO_KEY,
44
44
  load_bundle_config,
45
45
  )
46
- from .workflows import BundleWorkflow, ConfigWorkflow
46
+ from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow
@@ -192,6 +192,16 @@ class ReferenceResolver:
192
192
  """
193
193
  return self._resolve_one_item(id=id, **kwargs)
194
194
 
195
+ def remove_resolved_content(self, id: str) -> Any | None:
196
+ """
197
+ Remove the resolved ``ConfigItem`` by id.
198
+
199
+ Args:
200
+ id: id name of the expected item.
201
+
202
+ """
203
+ return self.resolved_content.pop(id) if id in self.resolved_content else None
204
+
195
205
  @classmethod
196
206
  def normalize_id(cls, id: str | int) -> str:
197
207
  """
monai/bundle/workflows.py CHANGED
@@ -44,12 +44,18 @@ class BundleWorkflow(ABC):
44
44
  workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
45
45
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
46
46
  other unsupported string will raise a ValueError.
47
- default to `train` for train workflow.
47
+ default to `None` for only using meta properties.
48
48
  workflow: specifies the workflow type: "train" or "training" for a training workflow,
49
49
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
50
50
  other unsupported string will raise a ValueError.
51
51
  default to `None` for common workflow.
52
- properties_path: the path to the JSON file of properties.
52
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
53
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
54
+ properties will default to loading from "meta". If `properties_path` is None, default properties
55
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
56
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
57
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
58
+ For workflow_type = None : only `MetaProperties` will be loaded.
53
59
  meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
54
60
  logging_file: config file for `logging` module in the program. for more details:
55
61
  https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
@@ -97,29 +103,50 @@ class BundleWorkflow(ABC):
97
103
  meta_file = None
98
104
 
99
105
  workflow_type = workflow if workflow is not None else workflow_type
100
- if workflow_type is None and properties_path is None:
101
- self.properties = copy(MetaProperties)
102
- self.workflow_type = None
103
- self.meta_file = meta_file
104
- return
106
+ if workflow_type is not None:
107
+ if workflow_type.lower() in self.supported_train_type:
108
+ workflow_type = "train"
109
+ elif workflow_type.lower() in self.supported_infer_type:
110
+ workflow_type = "infer"
111
+ else:
112
+ raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
113
+
105
114
  if properties_path is not None:
106
115
  properties_path = Path(properties_path)
107
116
  if not properties_path.is_file():
108
117
  raise ValueError(f"Property file {properties_path} does not exist.")
109
118
  with open(properties_path) as json_file:
110
- self.properties = json.load(json_file)
111
- self.workflow_type = None
112
- self.meta_file = meta_file
113
- return
114
- if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
115
- self.properties = {**TrainProperties, **MetaProperties}
116
- self.workflow_type = "train"
117
- elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
118
- self.properties = {**InferProperties, **MetaProperties}
119
- self.workflow_type = "infer"
119
+ try:
120
+ properties = json.load(json_file)
121
+ self.properties: dict = {}
122
+ if workflow_type is not None and workflow_type in properties:
123
+ self.properties = properties[workflow_type]
124
+ if "meta" in properties:
125
+ self.properties.update(properties["meta"])
126
+ elif workflow_type is None:
127
+ if "meta" in properties:
128
+ self.properties = properties["meta"]
129
+ logger.info(
130
+ "No workflow type specified, default to load meta properties from property file."
131
+ )
132
+ else:
133
+ logger.warning("No 'meta' key found in properties while workflow_type is None.")
134
+ except KeyError as e:
135
+ raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e
136
+ except json.JSONDecodeError as e:
137
+ raise ValueError(f"Error decoding JSON from property file {properties_path}") from e
120
138
  else:
121
- raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
139
+ if workflow_type == "train":
140
+ self.properties = {**TrainProperties, **MetaProperties}
141
+ elif workflow_type == "infer":
142
+ self.properties = {**InferProperties, **MetaProperties}
143
+ elif workflow_type is None:
144
+ self.properties = copy(MetaProperties)
145
+ logger.info("No workflow type and property file specified, default to 'meta' properties.")
146
+ else:
147
+ raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
122
148
 
149
+ self.workflow_type = workflow_type
123
150
  self.meta_file = meta_file
124
151
 
125
152
  @abstractmethod
@@ -226,6 +253,124 @@ class BundleWorkflow(ABC):
226
253
  return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]
227
254
 
228
255
 
256
+ class PythonicWorkflow(BundleWorkflow):
257
+ """
258
+ Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.
259
+ It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.
260
+ This also provides the interface to get / set public properties to interact with a bundle workflow through
261
+ defined `get_<property>` accessor methods or directly defining members of the object.
262
+ For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.
263
+ The `initialize` method is called to set up the workflow before running. This method sets up internal state
264
+ and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`
265
+ is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is
266
+ properly set up with the new property values.
267
+
268
+ Args:
269
+ workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
270
+ or "infer", "inference", "eval", "evaluation" for a inference workflow,
271
+ other unsupported string will raise a ValueError.
272
+ default to `None` for only using meta properties.
273
+ workflow: specifies the workflow type: "train" or "training" for a training workflow,
274
+ or "infer", "inference", "eval", "evaluation" for a inference workflow,
275
+ other unsupported string will raise a ValueError.
276
+ default to `None` for common workflow.
277
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
278
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
279
+ properties will default to loading from "meta". If `properties_path` is None, default properties
280
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
281
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
282
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
283
+ For workflow_type = None : only `MetaProperties` will be loaded.
284
+ config_file: path to the config file, typically used to store hyperparameters.
285
+ meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
286
+ logging_file: config file for `logging` module in the program. for more details:
287
+ https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
288
+
289
+ """
290
+
291
+ supported_train_type: tuple = ("train", "training")
292
+ supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
293
+
294
+ def __init__(
295
+ self,
296
+ workflow_type: str | None = None,
297
+ properties_path: PathLike | None = None,
298
+ config_file: str | Sequence[str] | None = None,
299
+ meta_file: str | Sequence[str] | None = None,
300
+ logging_file: str | None = None,
301
+ **override: Any,
302
+ ):
303
+ meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file
304
+ super().__init__(
305
+ workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file
306
+ )
307
+ self._props_vals: dict = {}
308
+ self._set_props_vals: dict = {}
309
+ self.parser = ConfigParser()
310
+ if config_file is not None:
311
+ self.parser.read_config(f=config_file)
312
+ if self.meta_file is not None:
313
+ self.parser.read_meta(f=self.meta_file)
314
+
315
+ # the rest key-values in the _args are to override config content
316
+ self.parser.update(pairs=override)
317
+ self._is_initialized: bool = False
318
+
319
+ def initialize(self, *args: Any, **kwargs: Any) -> Any:
320
+ """
321
+ Initialize the bundle workflow before running.
322
+ """
323
+ self._props_vals = {}
324
+ self._is_initialized = True
325
+
326
+ def _get_property(self, name: str, property: dict) -> Any:
327
+ """
328
+ With specified property name and information, get the expected property value.
329
+ If the property is already generated, return from the bucket directly.
330
+ If user explicitly set the property, return it directly.
331
+ Otherwise, generate the expected property as a class private property with prefix "_".
332
+
333
+ Args:
334
+ name: the name of target property.
335
+ property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
336
+ """
337
+ if not self._is_initialized:
338
+ raise RuntimeError("Please execute 'initialize' before getting any properties.")
339
+ value = None
340
+ if name in self._set_props_vals:
341
+ value = self._set_props_vals[name]
342
+ elif name in self._props_vals:
343
+ value = self._props_vals[name]
344
+ elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index]
345
+ id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)
346
+ value = self.parser[id]
347
+ else:
348
+ try:
349
+ value = getattr(self, f"get_{name}")()
350
+ except AttributeError as e:
351
+ if property[BundleProperty.REQUIRED]:
352
+ raise ValueError(
353
+ f"unsupported property '{name}' is required in the bundle properties,"
354
+ f"need to implement a method 'get_{name}' to provide the property."
355
+ ) from e
356
+ self._props_vals[name] = value
357
+ return value
358
+
359
+ def _set_property(self, name: str, property: dict, value: Any) -> Any:
360
+ """
361
+ With specified property name and information, set value for the expected property.
362
+ Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.
363
+
364
+ Args:
365
+ name: the name of target property.
366
+ property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
367
+ value: value to set for the property.
368
+
369
+ """
370
+ self._set_props_vals[name] = value
371
+ self._is_initialized = False
372
+
373
+
229
374
  class ConfigWorkflow(BundleWorkflow):
230
375
  """
231
376
  Specification for the config-based bundle workflow.
@@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow):
262
407
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
263
408
  other unsupported string will raise a ValueError.
264
409
  default to `None` for common workflow.
265
- properties_path: the path to the JSON file of properties.
410
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
411
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
412
+ properties will default to loading from "train". If `properties_path` is None, default properties
413
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
414
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
415
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
416
+ For workflow_type = None : only `MetaProperties` will be loaded.
266
417
  override: id-value pairs to override or add the corresponding config content.
267
418
  e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
268
419
 
@@ -324,7 +475,6 @@ class ConfigWorkflow(BundleWorkflow):
324
475
  self.parser.read_config(f=config_file)
325
476
  if self.meta_file is not None:
326
477
  self.parser.read_meta(f=self.meta_file)
327
-
328
478
  # the rest key-values in the _args are to override config content
329
479
  self.parser.update(pairs=override)
330
480
  self.init_id = init_id
@@ -394,8 +544,23 @@ class ConfigWorkflow(BundleWorkflow):
394
544
  ret.extend(wrong_props)
395
545
  return ret
396
546
 
397
- def _run_expr(self, id: str, **kwargs: dict) -> Any:
398
- return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
547
+ def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
548
+ """
549
+ Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
550
+ allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
551
+ """
552
+ ret = []
553
+ if id in self.parser:
554
+ # suppose all the expressions are in a list, run and reset the expressions
555
+ if isinstance(self.parser[id], list):
556
+ for i in range(len(self.parser[id])):
557
+ sub_id = f"{id}{ID_SEP_KEY}{i}"
558
+ ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
559
+ self.parser.ref_resolver.remove_resolved_content(sub_id)
560
+ else:
561
+ ret.append(self.parser.get_parsed_content(id, **kwargs))
562
+ self.parser.ref_resolver.remove_resolved_content(id)
563
+ return ret
399
564
 
400
565
  def _get_prop_id(self, name: str, property: dict) -> Any:
401
566
  prop_id = property[BundlePropertyConfig.ID]
@@ -11,7 +11,7 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from typing import Tuple, Union
14
+ from typing import Optional, Tuple, Union
15
15
 
16
16
  import torch
17
17
  import torch.nn as nn
@@ -154,10 +154,12 @@ class SABlock(nn.Module):
154
154
  )
155
155
  self.input_size = input_size
156
156
 
157
- def forward(self, x):
157
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
158
158
  """
159
159
  Args:
160
160
  x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161
+ attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162
+ B x (s_dim_1 * ... * s_dim_n). Defaults to None.
161
163
 
162
164
  Return:
163
165
  torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -176,7 +178,13 @@ class SABlock(nn.Module):
176
178
 
177
179
  if self.use_flash_attention:
178
180
  x = F.scaled_dot_product_attention(
179
- query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
181
+ query=q,
182
+ key=k,
183
+ value=v,
184
+ attn_mask=attn_mask,
185
+ scale=self.scale,
186
+ dropout_p=self.dropout_rate,
187
+ is_causal=self.causal,
180
188
  )
181
189
  else:
182
190
  att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
@@ -186,10 +194,16 @@ class SABlock(nn.Module):
186
194
  att_mat = self.rel_positional_embedding(x, att_mat, q)
187
195
 
188
196
  if self.causal:
197
+ if attn_mask is not None:
198
+ raise ValueError("Causal attention does not support attention masks.")
189
199
  att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
190
200
 
191
- att_mat = att_mat.softmax(dim=-1)
201
+ if attn_mask is not None:
202
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
203
+ attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
204
+ att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
192
205
 
206
+ att_mat = att_mat.softmax(dim=-1)
193
207
  if self.save_attn:
194
208
  # no gradients and new tensor;
195
209
  # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -90,8 +90,10 @@ class TransformerBlock(nn.Module):
90
90
  use_flash_attention=use_flash_attention,
91
91
  )
92
92
 
93
- def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
94
- x = x + self.attn(self.norm1(x))
93
+ def forward(
94
+ self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
95
+ ) -> torch.Tensor:
96
+ x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
95
97
  if self.with_cross_attention:
96
98
  x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
97
99
  x = x + self.mlp(self.norm2(x))
@@ -53,6 +53,7 @@ from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
53
53
  from .generator import Generator
54
54
  from .highresnet import HighResBlock, HighResNet
55
55
  from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
56
+ from .masked_autoencoder_vit import MaskedAutoEncoderViT
56
57
  from .mednext import (
57
58
  MedNeXt,
58
59
  MedNext,
@@ -0,0 +1,211 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Sequence
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
21
+ from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
22
+ from monai.networks.blocks.transformerblock import TransformerBlock
23
+ from monai.networks.layers import trunc_normal_
24
+ from monai.utils import ensure_tuple_rep
25
+ from monai.utils.module import look_up_option
26
+
27
+ SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
28
+
29
+ __all__ = ["MaskedAutoEncoderViT"]
30
+
31
+
32
+ class MaskedAutoEncoderViT(nn.Module):
33
+ """
34
+ Masked Autoencoder (ViT), based on: "Kaiming et al.,
35
+ Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
36
+ Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
37
+ the masked patches, resulting in improved training speed.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ in_channels: int,
43
+ img_size: Sequence[int] | int,
44
+ patch_size: Sequence[int] | int,
45
+ hidden_size: int = 768,
46
+ mlp_dim: int = 512,
47
+ num_layers: int = 12,
48
+ num_heads: int = 12,
49
+ masking_ratio: float = 0.75,
50
+ decoder_hidden_size: int = 384,
51
+ decoder_mlp_dim: int = 512,
52
+ decoder_num_layers: int = 4,
53
+ decoder_num_heads: int = 12,
54
+ proj_type: str = "conv",
55
+ pos_embed_type: str = "sincos",
56
+ decoder_pos_embed_type: str = "sincos",
57
+ dropout_rate: float = 0.0,
58
+ spatial_dims: int = 3,
59
+ qkv_bias: bool = False,
60
+ save_attn: bool = False,
61
+ ) -> None:
62
+ """
63
+ Args:
64
+ in_channels: dimension of input channels or the number of channels for input.
65
+ img_size: dimension of input image.
66
+ patch_size: dimension of patch size
67
+ hidden_size: dimension of hidden layer. Defaults to 768.
68
+ mlp_dim: dimension of feedforward layer. Defaults to 512.
69
+ num_layers: number of transformer blocks. Defaults to 12.
70
+ num_heads: number of attention heads. Defaults to 12.
71
+ masking_ratio: ratio of patches to be masked. Defaults to 0.75.
72
+ decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
73
+ decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
74
+ decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
75
+ decoder_num_heads: number of attention heads for decoder. Defaults to 12.
76
+ proj_type: position embedding layer type. Defaults to "conv".
77
+ pos_embed_type: position embedding layer type. Defaults to "sincos".
78
+ decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
79
+ dropout_rate: fraction of the input units to drop. Defaults to 0.0.
80
+ spatial_dims: number of spatial dimensions. Defaults to 3.
81
+ qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
82
+ save_attn: to make accessible the attention in self attention block. Defaults to False.
83
+ Examples::
84
+ # for single channel input with image size of (96,96,96), and sin-cos positional encoding
85
+ >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
86
+ pos_embed_type='sincos')
87
+ # for 3-channel with image size of (128,128,128) and a learnable positional encoding
88
+ >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
89
+ # for 3-channel with image size of (224,224) and a masking ratio of 0.25
90
+ >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
91
+ spatial_dims=2)
92
+ """
93
+
94
+ super().__init__()
95
+
96
+ if not (0 <= dropout_rate <= 1):
97
+ raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")
98
+
99
+ if hidden_size % num_heads != 0:
100
+ raise ValueError("hidden_size should be divisible by num_heads.")
101
+
102
+ if decoder_hidden_size % decoder_num_heads != 0:
103
+ raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")
104
+
105
+ self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
106
+ self.img_size = ensure_tuple_rep(img_size, spatial_dims)
107
+ self.spatial_dims = spatial_dims
108
+ for m, p in zip(self.img_size, self.patch_size):
109
+ if m % p != 0:
110
+ raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")
111
+
112
+ self.decoder_hidden_size = decoder_hidden_size
113
+
114
+ if masking_ratio <= 0 or masking_ratio >= 1:
115
+ raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")
116
+
117
+ self.masking_ratio = masking_ratio
118
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
119
+
120
+ self.patch_embedding = PatchEmbeddingBlock(
121
+ in_channels=in_channels,
122
+ img_size=img_size,
123
+ patch_size=patch_size,
124
+ hidden_size=hidden_size,
125
+ num_heads=num_heads,
126
+ proj_type=proj_type,
127
+ pos_embed_type=pos_embed_type,
128
+ dropout_rate=dropout_rate,
129
+ spatial_dims=self.spatial_dims,
130
+ )
131
+ blocks = [
132
+ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
133
+ for _ in range(num_layers)
134
+ ]
135
+ self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))
136
+
137
+ # decoder
138
+ self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)
139
+
140
+ self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
141
+
142
+ self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
143
+ self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))
144
+
145
+ decoder_blocks = [
146
+ TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
147
+ for _ in range(decoder_num_layers)
148
+ ]
149
+ self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
150
+ self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)
151
+
152
+ self._init_weights()
153
+
154
+ def _init_weights(self):
155
+ """
156
+ similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
157
+ classification tokens
158
+ """
159
+ if self.decoder_pos_embed_type == "none":
160
+ pass
161
+ elif self.decoder_pos_embed_type == "learnable":
162
+ trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
163
+ elif self.decoder_pos_embed_type == "sincos":
164
+ grid_size = []
165
+ for in_size, pa_size in zip(self.img_size, self.patch_size):
166
+ grid_size.append(in_size // pa_size)
167
+
168
+ self.decoder_pos_embedding = build_sincos_position_embedding(
169
+ grid_size, self.decoder_hidden_size, self.spatial_dims
170
+ )
171
+
172
+ else:
173
+ raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")
174
+
175
+ # initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
176
+ trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
177
+ trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)
178
+
179
+ def _masking(self, x, masking_ratio: float | None = None):
180
+ batch_size, num_tokens, _ = x.shape
181
+ percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
182
+ selected_indices = torch.multinomial(
183
+ torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
184
+ )
185
+ x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
186
+ mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
187
+ mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0
188
+
189
+ return x_masked, selected_indices, mask
190
+
191
+ def forward(self, x, masking_ratio: float | None = None):
192
+ x = self.patch_embedding(x)
193
+ x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)
194
+
195
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
196
+ x = torch.cat((cls_tokens, x), dim=1)
197
+
198
+ x = self.blocks(x)
199
+
200
+ # decoder
201
+ x = self.decoder_embed(x)
202
+
203
+ x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
204
+ x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
205
+ x_ = x_ + self.decoder_pos_embedding
206
+ x = torch.cat([x[:, :1, :], x_], dim=1)
207
+ x = self.decoder_blocks(x)
208
+ x = self.decoder_pred(x)
209
+
210
+ x = x[:, 1:, :]
211
+ return x, mask
@@ -13,7 +13,6 @@ from __future__ import annotations
13
13
 
14
14
  import itertools
15
15
  from collections.abc import Sequence
16
- from typing import Final
17
16
 
18
17
  import numpy as np
19
18
  import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
51
50
  <https://arxiv.org/abs/2201.01266>"
52
51
  """
53
52
 
54
- patch_size: Final[int] = 2
55
-
56
53
  @deprecated_arg(
57
54
  name="img_size",
58
55
  since="1.3",
@@ -65,18 +62,24 @@ class SwinUNETR(nn.Module):
65
62
  img_size: Sequence[int] | int,
66
63
  in_channels: int,
67
64
  out_channels: int,
65
+ patch_size: int = 2,
68
66
  depths: Sequence[int] = (2, 2, 2, 2),
69
67
  num_heads: Sequence[int] = (3, 6, 12, 24),
68
+ window_size: Sequence[int] | int = 7,
69
+ qkv_bias: bool = True,
70
+ mlp_ratio: float = 4.0,
70
71
  feature_size: int = 24,
71
72
  norm_name: tuple | str = "instance",
72
73
  drop_rate: float = 0.0,
73
74
  attn_drop_rate: float = 0.0,
74
75
  dropout_path_rate: float = 0.0,
75
76
  normalize: bool = True,
77
+ norm_layer: type[LayerNorm] = nn.LayerNorm,
78
+ patch_norm: bool = True,
76
79
  use_checkpoint: bool = False,
77
80
  spatial_dims: int = 3,
78
- downsample="merging",
79
- use_v2=False,
81
+ downsample: str | nn.Module = "merging",
82
+ use_v2: bool = False,
80
83
  ) -> None:
81
84
  """
82
85
  Args:
@@ -86,14 +89,20 @@ class SwinUNETR(nn.Module):
86
89
  It will be removed in an upcoming version.
87
90
  in_channels: dimension of input channels.
88
91
  out_channels: dimension of output channels.
92
+ patch_size: size of the patch token.
89
93
  feature_size: dimension of network feature size.
90
94
  depths: number of layers in each stage.
91
95
  num_heads: number of attention heads.
96
+ window_size: local window size.
97
+ qkv_bias: add a learnable bias to query, key, value.
98
+ mlp_ratio: ratio of mlp hidden dim to embedding dim.
92
99
  norm_name: feature normalization type and arguments.
93
100
  drop_rate: dropout rate.
94
101
  attn_drop_rate: attention dropout rate.
95
102
  dropout_path_rate: drop path rate.
96
103
  normalize: normalize output intermediate features in each stage.
104
+ norm_layer: normalization layer.
105
+ patch_norm: whether to apply normalization to the patch embedding.
97
106
  use_checkpoint: use gradient checkpointing for reduced memory usage.
98
107
  spatial_dims: number of spatial dims.
99
108
  downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ class SwinUNETR(nn.Module):
116
125
 
117
126
  super().__init__()
118
127
 
119
- img_size = ensure_tuple_rep(img_size, spatial_dims)
120
- patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
121
- window_size = ensure_tuple_rep(7, spatial_dims)
122
-
123
128
  if spatial_dims not in (2, 3):
124
129
  raise ValueError("spatial dimension should be 2 or 3.")
125
130
 
131
+ self.patch_size = patch_size
132
+
133
+ img_size = ensure_tuple_rep(img_size, spatial_dims)
134
+ patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
135
+ window_size = ensure_tuple_rep(window_size, spatial_dims)
136
+
126
137
  self._check_input_size(img_size)
127
138
 
128
139
  if not (0 <= drop_rate <= 1):
@@ -146,12 +157,13 @@ class SwinUNETR(nn.Module):
146
157
  patch_size=patch_sizes,
147
158
  depths=depths,
148
159
  num_heads=num_heads,
149
- mlp_ratio=4.0,
150
- qkv_bias=True,
160
+ mlp_ratio=mlp_ratio,
161
+ qkv_bias=qkv_bias,
151
162
  drop_rate=drop_rate,
152
163
  attn_drop_rate=attn_drop_rate,
153
164
  drop_path_rate=dropout_path_rate,
154
- norm_layer=nn.LayerNorm,
165
+ norm_layer=norm_layer,
166
+ patch_norm=patch_norm,
155
167
  use_checkpoint=use_checkpoint,
156
168
  spatial_dims=spatial_dims,
157
169
  downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
@@ -531,6 +531,8 @@ from .utility.array import (
531
531
  RandIdentity,
532
532
  RandImageFilter,
533
533
  RandLambda,
534
+ RandTorchIO,
535
+ RandTorchVision,
534
536
  RemoveRepeatedChannel,
535
537
  RepeatChannel,
536
538
  SimulateDelay,
@@ -540,6 +542,7 @@ from .utility.array import (
540
542
  ToDevice,
541
543
  ToNumpy,
542
544
  ToPIL,
545
+ TorchIO,
543
546
  TorchVision,
544
547
  ToTensor,
545
548
  Transpose,
@@ -620,6 +623,9 @@ from .utility.dictionary import (
620
623
  RandLambdad,
621
624
  RandLambdaD,
622
625
  RandLambdaDict,
626
+ RandTorchIOd,
627
+ RandTorchIOD,
628
+ RandTorchIODict,
623
629
  RandTorchVisiond,
624
630
  RandTorchVisionD,
625
631
  RandTorchVisionDict,
@@ -653,6 +659,9 @@ from .utility.dictionary import (
653
659
  ToPILd,
654
660
  ToPILD,
655
661
  ToPILDict,
662
+ TorchIOd,
663
+ TorchIOD,
664
+ TorchIODict,
656
665
  TorchVisiond,
657
666
  TorchVisionD,
658
667
  TorchVisionDict,
@@ -18,10 +18,10 @@ import logging
18
18
  import sys
19
19
  import time
20
20
  import warnings
21
- from collections.abc import Mapping, Sequence
21
+ from collections.abc import Hashable, Mapping, Sequence
22
22
  from copy import deepcopy
23
23
  from functools import partial
24
- from typing import Any, Callable
24
+ from typing import Any, Callable, Union
25
25
 
26
26
  import numpy as np
27
27
  import torch
@@ -99,11 +99,14 @@ __all__ = [
99
99
  "ConvertToMultiChannelBasedOnBratsClasses",
100
100
  "AddExtremePointsChannel",
101
101
  "TorchVision",
102
+ "TorchIO",
102
103
  "MapLabelValue",
103
104
  "IntensityStats",
104
105
  "ToDevice",
105
106
  "CuCIM",
106
107
  "RandCuCIM",
108
+ "RandTorchIO",
109
+ "RandTorchVision",
107
110
  "ToCupy",
108
111
  "ImageFilter",
109
112
  "RandImageFilter",
@@ -1136,12 +1139,44 @@ class AddExtremePointsChannel(Randomizable, Transform):
1136
1139
  return concatenate((img, points_image), axis=0)
1137
1140
 
1138
1141
 
1139
- class TorchVision:
1142
+ class TorchVision(Transform):
1140
1143
  """
1141
- This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
1142
- As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
1143
- data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
1144
+ This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
1145
+ Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
1146
+ """
1147
+
1148
+ backend = [TransformBackends.TORCH]
1149
+
1150
+ def __init__(self, name: str, *args, **kwargs) -> None:
1151
+ """
1152
+ Args:
1153
+ name: The transform name in TorchVision package.
1154
+ args: parameters for the TorchVision transform.
1155
+ kwargs: parameters for the TorchVision transform.
1156
+
1157
+ """
1158
+ super().__init__()
1159
+ self.name = name
1160
+ transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
1161
+ self.trans = transform(*args, **kwargs)
1162
+
1163
+ def __call__(self, img: NdarrayOrTensor):
1164
+ """
1165
+ Args:
1166
+ img: PyTorch Tensor data for the TorchVision transform.
1144
1167
 
1168
+ """
1169
+ img_t, *_ = convert_data_type(img, torch.Tensor)
1170
+
1171
+ out = self.trans(img_t)
1172
+ out, *_ = convert_to_dst_type(src=out, dst=img)
1173
+ return out
1174
+
1175
+
1176
+ class RandTorchVision(Transform, RandomizableTrait):
1177
+ """
1178
+ This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
1179
+ Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
1145
1180
  """
1146
1181
 
1147
1182
  backend = [TransformBackends.TORCH]
@@ -1172,6 +1207,68 @@ class TorchVision:
1172
1207
  return out
1173
1208
 
1174
1209
 
1210
+ class TorchIO(Transform):
1211
+ """
1212
+ This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
1213
+ See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1214
+ """
1215
+
1216
+ backend = [TransformBackends.TORCH]
1217
+
1218
+ def __init__(self, name: str, *args, **kwargs) -> None:
1219
+ """
1220
+ Args:
1221
+ name: The transform name in TorchIO package.
1222
+ args: parameters for the TorchIO transform.
1223
+ kwargs: parameters for the TorchIO transform.
1224
+ """
1225
+ super().__init__()
1226
+ self.name = name
1227
+ transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1228
+ self.trans = transform(*args, **kwargs)
1229
+
1230
+ def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1231
+ """
1232
+ Args:
1233
+ img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1234
+ or dict containing 4D tensors as values
1235
+
1236
+ """
1237
+ return self.trans(img)
1238
+
1239
+
1240
+ class RandTorchIO(Transform, RandomizableTrait):
1241
+ """
1242
+ This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
1243
+ See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1244
+ Use this wrapper for all TorchIO transform inheriting from RandomTransform:
1245
+ https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
1246
+ """
1247
+
1248
+ backend = [TransformBackends.TORCH]
1249
+
1250
+ def __init__(self, name: str, *args, **kwargs) -> None:
1251
+ """
1252
+ Args:
1253
+ name: The transform name in TorchIO package.
1254
+ args: parameters for the TorchIO transform.
1255
+ kwargs: parameters for the TorchIO transform.
1256
+ """
1257
+ super().__init__()
1258
+ self.name = name
1259
+ transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1260
+ self.trans = transform(*args, **kwargs)
1261
+
1262
+ def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1263
+ """
1264
+ Args:
1265
+ img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1266
+ or dict containing 4D tensors as values
1267
+
1268
+ """
1269
+ return self.trans(img)
1270
+
1271
+
1175
1272
  class MapLabelValue:
1176
1273
  """
1177
1274
  Utility to map label values to another set of values.
@@ -60,6 +60,7 @@ from monai.transforms.utility.array import (
60
60
  ToDevice,
61
61
  ToNumpy,
62
62
  ToPIL,
63
+ TorchIO,
63
64
  TorchVision,
64
65
  ToTensor,
65
66
  Transpose,
@@ -136,6 +137,9 @@ __all__ = [
136
137
  "RandLambdaD",
137
138
  "RandLambdaDict",
138
139
  "RandLambdad",
140
+ "RandTorchIOd",
141
+ "RandTorchIOD",
142
+ "RandTorchIODict",
139
143
  "RandTorchVisionD",
140
144
  "RandTorchVisionDict",
141
145
  "RandTorchVisiond",
@@ -172,6 +176,9 @@ __all__ = [
172
176
  "ToTensorD",
173
177
  "ToTensorDict",
174
178
  "ToTensord",
179
+ "TorchIOD",
180
+ "TorchIODict",
181
+ "TorchIOd",
175
182
  "TorchVisionD",
176
183
  "TorchVisionDict",
177
184
  "TorchVisiond",
@@ -1445,6 +1452,64 @@ class RandTorchVisiond(MapTransform, RandomizableTrait):
1445
1452
  return d
1446
1453
 
1447
1454
 
1455
+ class TorchIOd(MapTransform):
1456
+ """
1457
+ Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
1458
+ For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
1459
+ """
1460
+
1461
+ backend = TorchIO.backend
1462
+
1463
+ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1464
+ """
1465
+ Args:
1466
+ keys: keys of the corresponding items to be transformed.
1467
+ See also: :py:class:`monai.transforms.compose.MapTransform`
1468
+ name: The transform name in TorchIO package.
1469
+ allow_missing_keys: don't raise exception if key is missing.
1470
+ args: parameters for the TorchIO transform.
1471
+ kwargs: parameters for the TorchIO transform.
1472
+
1473
+ """
1474
+ super().__init__(keys, allow_missing_keys)
1475
+ self.name = name
1476
+ kwargs["include"] = self.keys
1477
+
1478
+ self.trans = TorchIO(name, *args, **kwargs)
1479
+
1480
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1481
+ return dict(self.trans(data))
1482
+
1483
+
1484
+ class RandTorchIOd(MapTransform, RandomizableTrait):
1485
+ """
1486
+ Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
1487
+ For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
1488
+ """
1489
+
1490
+ backend = TorchIO.backend
1491
+
1492
+ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1493
+ """
1494
+ Args:
1495
+ keys: keys of the corresponding items to be transformed.
1496
+ See also: :py:class:`monai.transforms.compose.MapTransform`
1497
+ name: The transform name in TorchIO package.
1498
+ allow_missing_keys: don't raise exception if key is missing.
1499
+ args: parameters for the TorchIO transform.
1500
+ kwargs: parameters for the TorchIO transform.
1501
+
1502
+ """
1503
+ super().__init__(keys, allow_missing_keys)
1504
+ self.name = name
1505
+ kwargs["include"] = self.keys
1506
+
1507
+ self.trans = TorchIO(name, *args, **kwargs)
1508
+
1509
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1510
+ return dict(self.trans(data))
1511
+
1512
+
1448
1513
  class MapLabelValued(MapTransform):
1449
1514
  """
1450
1515
  Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
@@ -1871,8 +1936,10 @@ ConvertToMultiChannelBasedOnBratsClassesD = ConvertToMultiChannelBasedOnBratsCla
1871
1936
  ConvertToMultiChannelBasedOnBratsClassesd
1872
1937
  )
1873
1938
  AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
1939
+ TorchIOD = TorchIODict = TorchIOd
1874
1940
  TorchVisionD = TorchVisionDict = TorchVisiond
1875
1941
  RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
1942
+ RandTorchIOD = RandTorchIODict = RandTorchIOd
1876
1943
  RandLambdaD = RandLambdaDict = RandLambdad
1877
1944
  MapLabelValueD = MapLabelValueDict = MapLabelValued
1878
1945
  IntensityStatsD = IntensityStatsDict = IntensityStatsd
monai/utils/module.py CHANGED
@@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
649
649
  current_ver_string: if None, the current system GPU CUDA compute capability will be used.
650
650
 
651
651
  Returns:
652
- True if the current system GPU CUDA compute capability is greater than the specified version.
652
+ True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
653
653
  """
654
654
  if current_ver_string is None:
655
655
  cuda_available = torch.cuda.is_available()
@@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s
667
667
 
668
668
  ver, has_ver = optional_import("packaging.version", name="parse")
669
669
  if has_ver:
670
- return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
670
+ return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
671
671
  parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
672
672
  while len(parts) < 2:
673
673
  parts += ["0"]
674
674
  c_major, c_minor = parts[:2]
675
675
  c_mn = int(c_major), int(c_minor)
676
676
  mn = int(major), int(minor)
677
- return c_mn >= mn
677
+ return c_mn > mn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: monai-weekly
3
- Version: 1.5.dev2447
3
+ Version: 1.5.dev2448
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -40,6 +40,7 @@ Requires-Dist: pillow; extra == "all"
40
40
  Requires-Dist: tensorboard; extra == "all"
41
41
  Requires-Dist: gdown>=4.7.3; extra == "all"
42
42
  Requires-Dist: pytorch-ignite==0.4.11; extra == "all"
43
+ Requires-Dist: torchio; extra == "all"
43
44
  Requires-Dist: torchvision; extra == "all"
44
45
  Requires-Dist: itk>=5.2; extra == "all"
45
46
  Requires-Dist: tqdm>=4.47.0; extra == "all"
@@ -87,6 +88,8 @@ Provides-Extra: gdown
87
88
  Requires-Dist: gdown>=4.7.3; extra == "gdown"
88
89
  Provides-Extra: ignite
89
90
  Requires-Dist: pytorch-ignite==0.4.11; extra == "ignite"
91
+ Provides-Extra: torchio
92
+ Requires-Dist: torchio; extra == "torchio"
90
93
  Provides-Extra: torchvision
91
94
  Requires-Dist: torchvision; extra == "torchvision"
92
95
  Provides-Extra: itk
@@ -1,5 +1,5 @@
1
- monai/__init__.py,sha256=AviKC_GbiSGP7NQn_MnvULzw3Rx0WHgMPuIFqdHRBvQ,4095
2
- monai/_version.py,sha256=pEg9cY4Jj39QKd_EAHXV9gmHAxWKvwrVyxvHR_B8kOo,503
1
+ monai/__init__.py,sha256=iAf8W1W9ATQdrTIOSFo-6azszcYoz1kkY4MnCgX4Y4U,4095
2
+ monai/_version.py,sha256=C3XqwkXQaNqY0gKbUFW8ztNtyWAc5vmVP0oukF2Q3H8,503
3
3
  monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
5
5
  monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
@@ -108,15 +108,15 @@ monai/auto3dseg/analyzer.py,sha256=7l8QT36lG68b8rK23CC2omz6PO1fxmDwOljxXMn5clQ,4
108
108
  monai/auto3dseg/operations.py,sha256=1sNDWnz5Zs2-scpb1wotxar7yGYQ-VPI-_b2KnZqW9g,5110
109
109
  monai/auto3dseg/seg_summarizer.py,sha256=T5Kwvc6eKet-vlzvBQgCLHbxHto-P5tiN_7uIk5uVfs,8717
110
110
  monai/auto3dseg/utils.py,sha256=zEicEO_--6-1kzT5HlmhAAd575gnl2AFmW8O3FnIznE,18674
111
- monai/bundle/__init__.py,sha256=xvYgiAzq9fiyMkCRo0vwn41ZSzj0udyvF0jmySnqBRI,1443
111
+ monai/bundle/__init__.py,sha256=7Jn72qUeIX4H2NyTE4gCg4I_1SAO2lSjg-59RZQZfaM,1461
112
112
  monai/bundle/__main__.py,sha256=RiAn6raPUvPMfXvd03irAhB3nkIAgG1lf8GE34PG4Js,952
113
113
  monai/bundle/config_item.py,sha256=rMjXSGkjJZdi04BwSHwCcIwzIb_TflmC3xDhC3SVJRs,16151
114
114
  monai/bundle/config_parser.py,sha256=cGyEn-cqNk0rEEZ1Qiv6UydmIDvtWZcMVljyfVm5i50,23025
115
115
  monai/bundle/properties.py,sha256=iN3K4FVmN9ny1Hw9p5j7_ULcCdSD8PmrR7qXxbNz49k,11582
116
- monai/bundle/reference_resolver.py,sha256=5YTzVEoQDJSv-PF79abwYggXCZcFxaOa3veFVElme-M,16463
116
+ monai/bundle/reference_resolver.py,sha256=GXCMK4iogxxE6VocsmAbUrcXosmC5arnjeG9zYhHKpg,16748
117
117
  monai/bundle/scripts.py,sha256=ZxdkNI1D1LpAJpufnJSexZt13EpihSnzFzdo5DbH3NU,89316
118
118
  monai/bundle/utils.py,sha256=t-22uFvLn7Yy-dr1v1U33peNOxgAmU4TJiGAbsBrUKs,10108
119
- monai/bundle/workflows.py,sha256=a9X_yqVz_NPRj0N2ByXRDGXBWEiijzYEKv2qH14C324,24682
119
+ monai/bundle/workflows.py,sha256=CuhmFq1AWsN3ATiYJCSakPOxrOdGutl6vkpo9sxe8gU,34369
120
120
  monai/config/__init__.py,sha256=CN28CfTdsp301gv8YXfVvkbztCfbAqrLKrJi_C8oP9s,1048
121
121
  monai/config/deviceconfig.py,sha256=f3Xa0OL9kNqdsbZ0PfUEvm6NZivAPh454_VCE8BmsWE,10582
122
122
  monai/config/type_definitions.py,sha256=a8_YmLkVOeldchAS6cM3KiG9n9YixkXHoyYo1XoskMI,3512
@@ -268,12 +268,12 @@ monai/networks/blocks/pos_embed_utils.py,sha256=vFEQqxZ6UAmjcy_icFDL9EwjRHYXuIbW
268
268
  monai/networks/blocks/regunet_block.py,sha256=1FLIwVBtk66II6xQ7Q4LMY8DP0rMmeftN7HuaEgnf3A,8825
269
269
  monai/networks/blocks/rel_pos_embedding.py,sha256=wuTJsk_NHSDX-3V0X9ctF99WIh2-SHLDbQxzrG7tz_4,2208
270
270
  monai/networks/blocks/segresnet_block.py,sha256=dREFa0CWuSWlSOm53fT7vZz6UC2J_7JAEaeHB9rYjAk,3339
271
- monai/networks/blocks/selfattention.py,sha256=sVVVYLm4ByOBbEbrKYW2kA3JRgB2kveZqDMOfEzHuOs,9141
271
+ monai/networks/blocks/selfattention.py,sha256=fZGtQwtSvU5aoQ4DWnUbR4DWUA-oEa6L6x3BkHkCUVI,9844
272
272
  monai/networks/blocks/spade_norm.py,sha256=Kq2ImmCQBaFURMnOTj08aphgGkF3ghDm19kXpPRq91c,3654
273
273
  monai/networks/blocks/spatialattention.py,sha256=HhoOnp0YfygOZne8jZjxQezRXIwQg1kfs-Cdo0ruxhw,3442
274
274
  monai/networks/blocks/squeeze_and_excitation.py,sha256=y2kXgoSFxywu-KCGYbI_d-NCCAEbuKAIY5gSqO_T7TI,12752
275
275
  monai/networks/blocks/text_embedding.py,sha256=HIlCTQCSyOEXnqo1l9TOC05duCoeWd9Kb4Oc0gvLZKw,3814
276
- monai/networks/blocks/transformerblock.py,sha256=UgJH4S94a5GaU2j-9HnmYkCT247vgxV76yO9d_6Tu1k,3880
276
+ monai/networks/blocks/transformerblock.py,sha256=dGqVoLoQuRjIO1mi5FpTNUZ0nrgvOVqksfQK6oZdhZc,3957
277
277
  monai/networks/blocks/unetr_block.py,sha256=d_rqE76OFfd3QRcHuor5Zei2pOrupoleBWu3eYUup0c,9049
278
278
  monai/networks/blocks/upsample.py,sha256=CeqqKx31gNw1CT3xz6UpU0fOjgW-7ZWxCRAOH4qAcxs,14024
279
279
  monai/networks/blocks/warp.py,sha256=XVFZKZR0kBhEtU5-xQsaqL06a-pAI7JJVupQCD2X4e8,7255
@@ -289,7 +289,7 @@ monai/networks/layers/spatial_transforms.py,sha256=fz2t7-ibijNLqTYpAn4ZgdXtzBSIy
289
289
  monai/networks/layers/utils.py,sha256=k_2xVO8BTEMMVJtemUyKBWw4_5xtqd6OOTOG8qld8To,4916
290
290
  monai/networks/layers/vector_quantizer.py,sha256=0PCcaH5_uaxFORHgEetQKazq74jgOVmvQJ3h4Ywat6Y,10058
291
291
  monai/networks/layers/weight_init.py,sha256=ehwI5F7jm_lmDkK4qVL7ocIzCEPx5UPgLaURcsfMNwk,2253
292
- monai/networks/nets/__init__.py,sha256=sEmOdnrwy-eCb6-HEPf9ySFMyEmF0GcdXzERLwM7szA,4152
292
+ monai/networks/nets/__init__.py,sha256=QS_r_mjmymo3YX6DnWftREug1zVRUV56b2xjj5rvWDU,4209
293
293
  monai/networks/nets/ahnet.py,sha256=RT-loCa5Z_3I2DWB8lmRkhxGXSsnMVBCEDpwo68-YB4,21570
294
294
  monai/networks/nets/attentionunet.py,sha256=lqsrzpy0sRuuFjAtKUUJ0hT3lGF9skpepWXLG0JBo-k,9427
295
295
  monai/networks/nets/autoencoder.py,sha256=QuLdDfDwhefIqA2n8XfmFyi5T8enP6O4PETdBKmFMKc,12586
@@ -310,6 +310,7 @@ monai/networks/nets/fullyconnectednet.py,sha256=j5uo68qnYSxgH_sEMRh7s3QGNKFaJAIx
310
310
  monai/networks/nets/generator.py,sha256=q20EAl9N7Q56t78JiZaUEkPhYWyD02oqO0yekJCd9x0,6581
311
311
  monai/networks/nets/highresnet.py,sha256=1Mx8lR5K4sRXGWjspDAHaKq0WrX9Q7qz8CcBCKZxIXk,8883
312
312
  monai/networks/nets/hovernet.py,sha256=gQDeDGqCwjJACTPmQLAx9nPRBO_D65F-scx15w3Ho_Q,28645
313
+ monai/networks/nets/masked_autoencoder_vit.py,sha256=U2DmyKOP-GqFfzbpyMwCoGfcBvMHYeua5G2ZpwqzKpw,9610
313
314
  monai/networks/nets/mednext.py,sha256=svsIk0dH7MdNI8Fr7eP2YM8j1IBJ2paF7m_2VWpLOZ4,13258
314
315
  monai/networks/nets/milmodel.py,sha256=aUDgYJG0kS3p4nBW_dF7b4cWwuC31w3KIzmUzXA08HE,9813
315
316
  monai/networks/nets/netadapter.py,sha256=JtcME9pcg8ud4jHKZKM9fE-8leP2PQXgUIfKBdB0wcA,6102
@@ -324,7 +325,7 @@ monai/networks/nets/senet.py,sha256=gulqPMYmSABbMbN39NElGzSU1TKGviJas7EPTBaZ60A,
324
325
  monai/networks/nets/spade_autoencoderkl.py,sha256=-b2Sbl4jPpwo3ukTgsTcON26cSTB35K9sy1S9DKlZz0,19566
325
326
  monai/networks/nets/spade_diffusion_model_unet.py,sha256=zYsXhkHNpHWWyal5ljAMxOICJ1loYQQMAOuzWzdLBCM,39007
326
327
  monai/networks/nets/spade_network.py,sha256=GguYucjIRyT_rZa9DrvUmv00FtqXHZtY1VfJM9Rygns,16479
327
- monai/networks/nets/swin_unetr.py,sha256=69GHMvtBTpJvWGvYsYYenSdWogw4y77My2Bm016mimA,44891
328
+ monai/networks/nets/swin_unetr.py,sha256=myQIg5jFUvDX5O_3KjZILkXk2a6KIVd08pYM8mSWytU,45522
328
329
  monai/networks/nets/torchvision_fc.py,sha256=3g5PD7C1MSkQ8xndhnVd0b3aN8zfshT8uiFS0OHyQaY,6309
329
330
  monai/networks/nets/transchex.py,sha256=uA_RfTDfPhwA1ecAPZ9EDnMyJKn2tUMLEWdyB_rU2v0,15726
330
331
  monai/networks/nets/transformer.py,sha256=-nzl20Z5xdtn7xChOd_cRbbPVoPIFGVfTQw3fIEGMuE,6395
@@ -347,7 +348,7 @@ monai/optimizers/lr_finder.py,sha256=tbVi6qd-LLI6pENM9cDUv-Hh1HqziO3Wb9aI6JoaPng
347
348
  monai/optimizers/lr_scheduler.py,sha256=YPY5MWgCTmExuIOBsVJrgfErkCT1ELBekcH0XeRP6Kk,4082
348
349
  monai/optimizers/novograd.py,sha256=dgjyM-WGqrEHsSKNdI3Lw1wJ2YNG3oKCYotfPsDBE80,5677
349
350
  monai/optimizers/utils.py,sha256=GVsJsZWO2aAP9IzwhXgca_9gUNHFClup6qG4ZFs42z4,4133
350
- monai/transforms/__init__.py,sha256=lyIf64v-I2soIjfK2RxOWS7_CIc-x6bRJHLI6UZ8yDs,16591
351
+ monai/transforms/__init__.py,sha256=-LmAa_W5fJxm5I_btvAONNebWe2exa7IWwcvYrNxzCc,16744
351
352
  monai/transforms/adaptors.py,sha256=LpYChldlOur-VFgu_nBIBze0J841-NWgf0UHvvHRNPU,8796
352
353
  monai/transforms/compose.py,sha256=zQa_hf8gIater3Bo_XW1IVYgX7aFa_Co6-BZPwoeaQw,37663
353
354
  monai/transforms/inverse.py,sha256=Wg8UnMJru41G3eHGipUemAWziHGU-qdd-Flfi3eOpeo,18746
@@ -394,8 +395,8 @@ monai/transforms/spatial/array.py,sha256=5EKivdPYCP4i4qYUlkK1RpYQFzaU_baYyzgubid
394
395
  monai/transforms/spatial/dictionary.py,sha256=t0SvEDSVNFUEw2fK66OVF20sqSzCNxil17HmvsMFBt8,133752
395
396
  monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1irG7feXqig,33886
396
397
  monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
397
- monai/transforms/utility/array.py,sha256=Ju5WvLDmujqh2yPbi9iX-qbStbWa5iqUMrbuOwy0x6w,78188
398
- monai/transforms/utility/dictionary.py,sha256=N6E230-g2zupG63oCsAXWgkdfZmF---TZbvk7p5FQU8,78079
398
+ monai/transforms/utility/array.py,sha256=OaEmhjDoWg9buQHXbgwgN713qqR05trZJ0F7qAm1re0,81658
399
+ monai/transforms/utility/dictionary.py,sha256=iOFdTSekvkAsBbbfHeffcRsOKRtNcnt3N1cVuUarZ1s,80549
399
400
  monai/utils/__init__.py,sha256=2_AIpb1wqGMkmgoZ3r43muFTEsnMTCkPu3LtckipYHg,3793
400
401
  monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
401
402
  monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
@@ -404,7 +405,7 @@ monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
404
405
  monai/utils/enums.py,sha256=orCV7SGDajYtl3DhTTjbLDbayr6WxkMSw_bZ6yeGGTY,19513
405
406
  monai/utils/jupyter_utils.py,sha256=kQqfLTLAre3TLzXTt091X_XeWy5K0QKAcTuYlJ8BOag,15650
406
407
  monai/utils/misc.py,sha256=R-sCS5u7SA8hX6e7x6WSc8FgLcNpqKFRRDMWxUd2wCo,31759
407
- monai/utils/module.py,sha256=2G9mgrUhytkIADHWPAH4xWKXgIhknBYzj_RCKZdYHJA,26123
408
+ monai/utils/module.py,sha256=RWDq_64WcfGtGGiNaywR70eyE9RN4bbaMXLLW7BEUj0,26135
408
409
  monai/utils/nvtx.py,sha256=i9JBxR1uhW1ZCgLPLlTx8b907QlXkFzJyTBLMlFjhtU,6876
409
410
  monai/utils/ordering.py,sha256=0nlA5b5QpVCHbtiCbTC-YsqjTmjm0bub0IeJhGFBOes,8270
410
411
  monai/utils/profiling.py,sha256=V2_cSHgrcmVF48_G3nUi2-O6fnXsS89nSlb8jj58YLo,15937
@@ -418,8 +419,8 @@ monai/visualize/img2tensorboard.py,sha256=NnMcyfIFqX-jD7TBO3Rn02zt5uug79d_7pIIaV
418
419
  monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
419
420
  monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
420
421
  monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
421
- monai_weekly-1.5.dev2447.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
422
- monai_weekly-1.5.dev2447.dist-info/METADATA,sha256=TUTxgV4URp2TYBBRw9DaqzDVnkijKHdAy9FXGbTIKc8,11187
423
- monai_weekly-1.5.dev2447.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
424
- monai_weekly-1.5.dev2447.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
425
- monai_weekly-1.5.dev2447.dist-info/RECORD,,
422
+ monai_weekly-1.5.dev2448.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
423
+ monai_weekly-1.5.dev2448.dist-info/METADATA,sha256=5LrqePqyfsUBBT1Y_VCaJlgKAP0ZCcNibWIHTp6YOOc,11293
424
+ monai_weekly-1.5.dev2448.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
425
+ monai_weekly-1.5.dev2448.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
426
+ monai_weekly-1.5.dev2448.dist-info/RECORD,,