monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +2 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/inferers/utils.py +1 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +53 -11
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +225 -41
- monai/transforms/__init__.py +24 -2
- monai/transforms/io/array.py +58 -2
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +105 -3
- monai/transforms/utils.py +83 -10
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +4 -1
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +36 -31
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
monai/bundle/scripts.py
CHANGED
@@ -32,7 +32,7 @@ from monai._version import get_versions
|
|
32
32
|
from monai.apps.utils import _basename, download_url, extractall, get_logger
|
33
33
|
from monai.bundle.config_item import ConfigComponent
|
34
34
|
from monai.bundle.config_parser import ConfigParser
|
35
|
-
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
|
35
|
+
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
|
36
36
|
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
|
37
37
|
from monai.config import IgniteInfo, PathLike
|
38
38
|
from monai.data import load_net_with_metadata, save_net_with_metadata
|
@@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
|
|
105
105
|
if isinstance(v, dict) and isinstance(args_.get(k), dict):
|
106
106
|
args_[k] = update_kwargs(args_[k], ignore_none, **v)
|
107
107
|
else:
|
108
|
-
args_
|
108
|
+
merge_kv(args_, k, v)
|
109
109
|
return args_
|
110
110
|
|
111
111
|
|
monai/bundle/utils.py
CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import json
|
15
15
|
import os
|
16
|
+
import warnings
|
16
17
|
import zipfile
|
17
18
|
from typing import Any
|
18
19
|
|
@@ -21,12 +22,21 @@ from monai.utils import optional_import
|
|
21
22
|
|
22
23
|
yaml, _ = optional_import("yaml")
|
23
24
|
|
24
|
-
__all__ = [
|
25
|
+
__all__ = [
|
26
|
+
"ID_REF_KEY",
|
27
|
+
"ID_SEP_KEY",
|
28
|
+
"EXPR_KEY",
|
29
|
+
"MACRO_KEY",
|
30
|
+
"MERGE_KEY",
|
31
|
+
"DEFAULT_MLFLOW_SETTINGS",
|
32
|
+
"DEFAULT_EXP_MGMT_SETTINGS",
|
33
|
+
]
|
25
34
|
|
26
35
|
ID_REF_KEY = "@" # start of a reference to a ConfigItem
|
27
36
|
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
|
28
37
|
EXPR_KEY = "$" # start of a ConfigExpression
|
29
38
|
MACRO_KEY = "%" # start of a macro of a config
|
39
|
+
MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.
|
30
40
|
|
31
41
|
_conf_values = get_config_values()
|
32
42
|
|
@@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
|
|
233
243
|
parser.read_config(f=cdata)
|
234
244
|
|
235
245
|
return parser
|
246
|
+
|
247
|
+
|
248
|
+
def merge_kv(args: dict | Any, k: str, v: Any) -> None:
|
249
|
+
"""
|
250
|
+
Update the `args` dict-like object with the key/value pair `k` and `v`.
|
251
|
+
"""
|
252
|
+
if k.startswith(MERGE_KEY):
|
253
|
+
"""
|
254
|
+
Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
|
255
|
+
`dict` values will be merged, `list` values - concatenated.
|
256
|
+
"""
|
257
|
+
id = k[1:]
|
258
|
+
if id in args:
|
259
|
+
if isinstance(v, dict) and isinstance(args[id], dict):
|
260
|
+
args[id].update(v)
|
261
|
+
elif isinstance(v, list) and isinstance(args[id], list):
|
262
|
+
args[id].extend(v)
|
263
|
+
else:
|
264
|
+
raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
|
265
|
+
else:
|
266
|
+
warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
|
267
|
+
args[id] = v
|
268
|
+
else:
|
269
|
+
args[k] = v
|
monai/handlers/__init__.py
CHANGED
@@ -40,5 +40,6 @@ from .smartcache_handler import SmartCacheHandler
|
|
40
40
|
from .stats_handler import StatsHandler
|
41
41
|
from .surface_distance import SurfaceDistance
|
42
42
|
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
|
43
|
+
from .trt_handler import TrtHandler
|
43
44
|
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
|
44
45
|
from .validation_handler import ValidationHandler
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import TYPE_CHECKING
|
15
|
+
|
16
|
+
from monai.config import IgniteInfo
|
17
|
+
from monai.networks import trt_compile
|
18
|
+
from monai.utils import min_version, optional_import
|
19
|
+
|
20
|
+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from ignite.engine import Engine
|
23
|
+
else:
|
24
|
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
25
|
+
|
26
|
+
|
27
|
+
class TrtHandler:
|
28
|
+
"""
|
29
|
+
TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
|
30
|
+
Usage example::
|
31
|
+
handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
|
32
|
+
handler.attach(engine)
|
33
|
+
engine.run()
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, model, base_path, args=None, submodule=None):
|
37
|
+
"""
|
38
|
+
Args:
|
39
|
+
base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
|
40
|
+
args: passed to trt_compile(). See trt_compile() for details.
|
41
|
+
submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
|
42
|
+
"""
|
43
|
+
self.model = model
|
44
|
+
self.base_path = base_path
|
45
|
+
self.args = args
|
46
|
+
self.submodule = submodule
|
47
|
+
|
48
|
+
def attach(self, engine: Engine) -> None:
|
49
|
+
"""
|
50
|
+
Args:
|
51
|
+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
52
|
+
"""
|
53
|
+
self.logger = engine.logger
|
54
|
+
engine.add_event_handler(Events.STARTED, self)
|
55
|
+
|
56
|
+
def __call__(self, engine: Engine) -> None:
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
60
|
+
"""
|
61
|
+
trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
|
monai/inferers/utils.py
CHANGED
@@ -300,6 +300,7 @@ def sliding_window_inference(
|
|
300
300
|
|
301
301
|
# remove padding if image_size smaller than roi_size
|
302
302
|
if any(pad_size):
|
303
|
+
kwargs.update({"pad_size": pad_size})
|
303
304
|
for ss, output_i in enumerate(output_image_list):
|
304
305
|
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
|
305
306
|
final_slicing: list[slice] = []
|
@@ -14,34 +14,47 @@ from __future__ import annotations
|
|
14
14
|
import torch
|
15
15
|
|
16
16
|
from monai.metrics.utils import do_metric_reduction, ignore_background
|
17
|
-
from monai.utils import MetricReduction, Weight, look_up_option
|
17
|
+
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
|
18
18
|
|
19
19
|
from .metric import CumulativeIterationMetric
|
20
20
|
|
21
21
|
|
22
22
|
class GeneralizedDiceScore(CumulativeIterationMetric):
|
23
|
-
"""
|
23
|
+
"""
|
24
|
+
Compute the Generalized Dice Score metric between tensors.
|
24
25
|
|
26
|
+
This metric is the complement of the Generalized Dice Loss defined in:
|
25
27
|
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
|
26
|
-
|
28
|
+
loss function for highly unbalanced segmentations. DLMIA 2017.
|
27
29
|
|
28
|
-
The inputs `y_pred` and `y` are expected to be one-hot, binarized
|
29
|
-
or batch-first tensors, i.e., CHW[D] or BCHW[D].
|
30
|
+
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
|
30
31
|
|
31
32
|
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
32
33
|
|
33
34
|
Args:
|
34
|
-
include_background
|
35
|
+
include_background: Whether to include the background class (assumed to be in channel 0) in the
|
35
36
|
score computation. Defaults to True.
|
36
|
-
reduction
|
37
|
-
{``"none"``, ``"
|
38
|
-
|
37
|
+
reduction: Define mode of reduction to the metrics. Available reduction modes:
|
38
|
+
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
|
39
|
+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
|
40
|
+
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
|
39
41
|
ground truth volume into a weight factor. Defaults to ``"square"``.
|
40
42
|
|
41
43
|
Raises:
|
42
|
-
ValueError:
|
44
|
+
ValueError: When the `reduction` is not one of MetricReduction enum.
|
43
45
|
"""
|
44
46
|
|
47
|
+
@deprecated_arg_default(
|
48
|
+
"reduction",
|
49
|
+
old_default=MetricReduction.MEAN_BATCH,
|
50
|
+
new_default=MetricReduction.MEAN,
|
51
|
+
since="1.4.0",
|
52
|
+
replaced="1.5.0",
|
53
|
+
msg_suffix=(
|
54
|
+
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
|
55
|
+
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
|
56
|
+
),
|
57
|
+
)
|
45
58
|
def __init__(
|
46
59
|
self,
|
47
60
|
include_background: bool = True,
|
@@ -50,79 +63,90 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
|
|
50
63
|
) -> None:
|
51
64
|
super().__init__()
|
52
65
|
self.include_background = include_background
|
53
|
-
|
54
|
-
"none",
|
55
|
-
"mean_batch",
|
56
|
-
"sum_batch",
|
57
|
-
MetricReduction.NONE,
|
58
|
-
MetricReduction.MEAN_BATCH,
|
59
|
-
MetricReduction.SUM_BATCH,
|
60
|
-
]
|
61
|
-
self.reduction = reduction
|
62
|
-
if self.reduction not in reduction_options:
|
63
|
-
raise ValueError(f"reduction must be one of {reduction_options}")
|
66
|
+
self.reduction = look_up_option(reduction, MetricReduction)
|
64
67
|
self.weight_type = look_up_option(weight_type, Weight)
|
68
|
+
self.sum_over_classes = self.reduction in {
|
69
|
+
MetricReduction.SUM,
|
70
|
+
MetricReduction.MEAN,
|
71
|
+
MetricReduction.MEAN_CHANNEL,
|
72
|
+
MetricReduction.SUM_CHANNEL,
|
73
|
+
}
|
65
74
|
|
66
75
|
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
67
|
-
"""
|
76
|
+
"""
|
77
|
+
Computes the Generalized Dice Score and returns a tensor with its per image values.
|
68
78
|
|
69
79
|
Args:
|
70
|
-
y_pred (torch.Tensor):
|
80
|
+
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
|
71
81
|
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
|
72
|
-
y (torch.Tensor):
|
82
|
+
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
torch.Tensor: Generalized Dice Score averaged across batch and class
|
73
86
|
|
74
87
|
Raises:
|
75
|
-
ValueError:
|
88
|
+
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
|
76
89
|
"""
|
77
90
|
return compute_generalized_dice(
|
78
|
-
y_pred=y_pred,
|
91
|
+
y_pred=y_pred,
|
92
|
+
y=y,
|
93
|
+
include_background=self.include_background,
|
94
|
+
weight_type=self.weight_type,
|
95
|
+
sum_over_classes=self.sum_over_classes,
|
79
96
|
)
|
80
97
|
|
98
|
+
@deprecated_arg(
|
99
|
+
"reduction",
|
100
|
+
since="1.3.3",
|
101
|
+
removed="1.7.0",
|
102
|
+
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
|
103
|
+
)
|
81
104
|
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
|
82
105
|
"""
|
83
106
|
Execute reduction logic for the output of `compute_generalized_dice`.
|
84
107
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
108
|
+
Returns:
|
109
|
+
torch.Tensor: Aggregated metric value.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If the data to aggregate is not a PyTorch Tensor.
|
89
113
|
"""
|
90
114
|
data = self.get_buffer()
|
91
115
|
if not isinstance(data, torch.Tensor):
|
92
116
|
raise ValueError("The data to aggregate must be a PyTorch Tensor.")
|
93
117
|
|
94
|
-
# Validate reduction argument if specified
|
95
|
-
if reduction is not None:
|
96
|
-
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
|
97
|
-
if reduction not in reduction_options:
|
98
|
-
raise ValueError(f"reduction must be one of {reduction_options}")
|
99
|
-
|
100
118
|
# Do metric reduction and return
|
101
|
-
f, _ = do_metric_reduction(data,
|
119
|
+
f, _ = do_metric_reduction(data, self.reduction)
|
102
120
|
|
103
121
|
return f
|
104
122
|
|
105
123
|
|
106
124
|
def compute_generalized_dice(
|
107
|
-
y_pred: torch.Tensor,
|
125
|
+
y_pred: torch.Tensor,
|
126
|
+
y: torch.Tensor,
|
127
|
+
include_background: bool = True,
|
128
|
+
weight_type: Weight | str = Weight.SQUARE,
|
129
|
+
sum_over_classes: bool = False,
|
108
130
|
) -> torch.Tensor:
|
109
|
-
"""
|
131
|
+
"""
|
132
|
+
Computes the Generalized Dice Score and returns a tensor with its per image values.
|
110
133
|
|
111
134
|
Args:
|
112
|
-
y_pred (torch.Tensor):
|
135
|
+
y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
|
113
136
|
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
|
114
137
|
remaining are the spatial dimensions.
|
115
|
-
y (torch.Tensor):
|
116
|
-
include_background
|
138
|
+
y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
|
139
|
+
include_background: Whether to include score computation on the first channel of the
|
117
140
|
predicted output. Defaults to True.
|
118
141
|
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
|
119
142
|
transform ground truth volume into a weight factor. Defaults to ``"square"``.
|
143
|
+
sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
|
120
144
|
|
121
145
|
Returns:
|
122
|
-
torch.Tensor:
|
146
|
+
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
|
123
147
|
|
124
148
|
Raises:
|
125
|
-
ValueError:
|
149
|
+
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
|
126
150
|
or `y_pred` and `y` don't have the same shape.
|
127
151
|
"""
|
128
152
|
# Ensure tensors have at least 3 dimensions and have the same shape
|
@@ -158,16 +182,21 @@ def compute_generalized_dice(
|
|
158
182
|
b[infs] = 0
|
159
183
|
b[infs] = torch.max(b)
|
160
184
|
|
161
|
-
# Compute the weighted numerator and denominator, summing along the class axis
|
162
|
-
|
163
|
-
|
185
|
+
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
|
186
|
+
if sum_over_classes:
|
187
|
+
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
|
188
|
+
denom = (denominator * w).sum(dim=1, keepdim=True)
|
189
|
+
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
|
190
|
+
else:
|
191
|
+
numer = 2.0 * (intersection * w)
|
192
|
+
denom = denominator * w
|
193
|
+
y_pred_o = y_pred_o
|
164
194
|
|
165
195
|
# Compute the score
|
166
196
|
generalized_dice_score = numer / denom
|
167
197
|
|
168
198
|
# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
|
169
199
|
# Where denom == 0 but the prediction volume is not 0, score is 0
|
170
|
-
y_pred_o = y_pred_o.sum(dim=-1)
|
171
200
|
denom_zeros = denom == 0
|
172
201
|
generalized_dice_score[denom_zeros] = torch.where(
|
173
202
|
(y_pred_o == 0)[denom_zeros],
|
monai/networks/__init__.py
CHANGED
@@ -51,6 +51,8 @@ class BilateralFilter(torch.autograd.Function):
|
|
51
51
|
ctx.cs = color_sigma
|
52
52
|
ctx.fa = fast_approx
|
53
53
|
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
|
54
|
+
if torch.cuda.is_available():
|
55
|
+
torch.cuda.synchronize()
|
54
56
|
return output_data
|
55
57
|
|
56
58
|
@staticmethod
|
@@ -139,7 +141,8 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
|
|
139
141
|
do_dsig_y,
|
140
142
|
do_dsig_z,
|
141
143
|
)
|
142
|
-
|
144
|
+
if torch.cuda.is_available():
|
145
|
+
torch.cuda.synchronize()
|
143
146
|
return output_tensor
|
144
147
|
|
145
148
|
@staticmethod
|
@@ -301,7 +304,8 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
|
|
301
304
|
do_dsig_z,
|
302
305
|
guidance_img,
|
303
306
|
)
|
304
|
-
|
307
|
+
if torch.cuda.is_available():
|
308
|
+
torch.cuda.synchronize()
|
305
309
|
return output_tensor
|
306
310
|
|
307
311
|
@staticmethod
|
@@ -320,7 +320,7 @@ class SwinUNETR(nn.Module):
|
|
320
320
|
)
|
321
321
|
|
322
322
|
def forward(self, x_in):
|
323
|
-
if not torch.jit.is_scripting():
|
323
|
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
324
324
|
self._check_input_size(x_in.shape[2:])
|
325
325
|
hidden_states_out = self.swinViT(x_in, self.normalize)
|
326
326
|
enc0 = self.encoder1(x_in)
|
@@ -1046,14 +1046,14 @@ class SwinTransformer(nn.Module):
|
|
1046
1046
|
|
1047
1047
|
def proj_out(self, x, normalize=False):
|
1048
1048
|
if normalize:
|
1049
|
-
x_shape = x.
|
1049
|
+
x_shape = x.shape
|
1050
|
+
# Force trace() to generate a constant by casting to int
|
1051
|
+
ch = int(x_shape[1])
|
1050
1052
|
if len(x_shape) == 5:
|
1051
|
-
n, ch, d, h, w = x_shape
|
1052
1053
|
x = rearrange(x, "n c d h w -> n d h w c")
|
1053
1054
|
x = F.layer_norm(x, [ch])
|
1054
1055
|
x = rearrange(x, "n d h w c -> n c d h w")
|
1055
1056
|
elif len(x_shape) == 4:
|
1056
|
-
n, ch, h, w = x_shape
|
1057
1057
|
x = rearrange(x, "n c h w -> n h w c")
|
1058
1058
|
x = F.layer_norm(x, [ch])
|
1059
1059
|
x = rearrange(x, "n h w c -> n c h w")
|
monai/networks/nets/vista3d.py
CHANGED
@@ -23,7 +23,7 @@ import monai
|
|
23
23
|
from monai.networks.blocks import MLPBlock, UnetrBasicBlock
|
24
24
|
from monai.networks.nets import SegResNetDS2
|
25
25
|
from monai.transforms.utils import convert_points_to_disc
|
26
|
-
from monai.transforms.utils import
|
26
|
+
from monai.transforms.utils import keep_merge_components_with_points as lcc
|
27
27
|
from monai.transforms.utils import sample_points_from_label
|
28
28
|
from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
|
29
29
|
|
@@ -78,6 +78,35 @@ class VISTA3D(nn.Module):
|
|
78
78
|
self.NINF_VALUE = -9999
|
79
79
|
self.PINF_VALUE = 9999
|
80
80
|
|
81
|
+
def update_slidingwindow_padding(
|
82
|
+
self,
|
83
|
+
pad_size: list | None,
|
84
|
+
labels: torch.Tensor | None,
|
85
|
+
prev_mask: torch.Tensor | None,
|
86
|
+
point_coords: torch.Tensor | None,
|
87
|
+
):
|
88
|
+
"""
|
89
|
+
Image has been padded by sliding window inferer.
|
90
|
+
The related padding need to be performed outside of slidingwindow inferer.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
pad_size: padding size passed from sliding window inferer.
|
94
|
+
labels: image label ground truth.
|
95
|
+
prev_mask: previous segmentation mask.
|
96
|
+
point_coords: point click coordinates.
|
97
|
+
"""
|
98
|
+
if pad_size is None:
|
99
|
+
return labels, prev_mask, point_coords
|
100
|
+
if labels is not None:
|
101
|
+
labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
|
102
|
+
if prev_mask is not None:
|
103
|
+
prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
|
104
|
+
if point_coords is not None:
|
105
|
+
point_coords = point_coords + torch.tensor(
|
106
|
+
[pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
|
107
|
+
)
|
108
|
+
return labels, prev_mask, point_coords
|
109
|
+
|
81
110
|
def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
|
82
111
|
"""Get number of foreground classes based on class and point prompt."""
|
83
112
|
if class_vector is None:
|
@@ -307,16 +336,17 @@ class VISTA3D(nn.Module):
|
|
307
336
|
def forward(
|
308
337
|
self,
|
309
338
|
input_images: torch.Tensor,
|
339
|
+
patch_coords: list[Sequence[slice]] | None = None,
|
310
340
|
point_coords: torch.Tensor | None = None,
|
311
341
|
point_labels: torch.Tensor | None = None,
|
312
342
|
class_vector: torch.Tensor | None = None,
|
313
343
|
prompt_class: torch.Tensor | None = None,
|
314
|
-
patch_coords: Sequence[slice] | None = None,
|
315
344
|
labels: torch.Tensor | None = None,
|
316
345
|
label_set: Sequence[int] | None = None,
|
317
346
|
prev_mask: torch.Tensor | None = None,
|
318
347
|
radius: int | None = None,
|
319
348
|
val_point_sampler: Callable | None = None,
|
349
|
+
transpose: bool = False,
|
320
350
|
**kwargs,
|
321
351
|
):
|
322
352
|
"""
|
@@ -329,13 +359,17 @@ class VISTA3D(nn.Module):
|
|
329
359
|
point_coords: [B, N, 3]
|
330
360
|
point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
|
331
361
|
2/3 means negative/postive ponits for special supported class like tumor.
|
332
|
-
class_vector: [B, 1], the global class index
|
362
|
+
class_vector: [B, 1], the global class index.
|
333
363
|
prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
|
334
364
|
the points are for zero-shot or supported class. When class_vector and point_coords are both
|
335
365
|
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
|
336
366
|
will be considered novel class.
|
337
|
-
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window
|
338
|
-
This value is passed from sliding_window_inferer.
|
367
|
+
patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
|
368
|
+
inference. This value is passed from sliding_window_inferer.
|
369
|
+
This is an indicator for training phase or validation phase.
|
370
|
+
Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
|
371
|
+
coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
|
372
|
+
functions using patch_coords will by default use patch_coords[0].
|
339
373
|
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
|
340
374
|
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
|
341
375
|
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
|
@@ -346,8 +380,12 @@ class VISTA3D(nn.Module):
|
|
346
380
|
radius: single float value controling the gaussian blur when combining point and auto results.
|
347
381
|
The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
|
348
382
|
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
|
349
|
-
|
383
|
+
transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
|
384
|
+
sliding window inferer/point inferer.
|
350
385
|
"""
|
386
|
+
labels, prev_mask, point_coords = self.update_slidingwindow_padding(
|
387
|
+
kwargs.get("pad_size", None), labels, prev_mask, point_coords
|
388
|
+
)
|
351
389
|
image_size = input_images.shape[-3:]
|
352
390
|
device = input_images.device
|
353
391
|
if point_coords is None and class_vector is None:
|
@@ -361,14 +399,14 @@ class VISTA3D(nn.Module):
|
|
361
399
|
if val_point_sampler is None:
|
362
400
|
# TODO: think about how to refactor this part.
|
363
401
|
val_point_sampler = self.sample_points_patch_val
|
364
|
-
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
|
402
|
+
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
|
365
403
|
if prompt_class[0].item() == 0: # type: ignore
|
366
404
|
point_labels[0] = -1 # type: ignore
|
367
405
|
labels, prev_mask = None, None
|
368
406
|
elif point_coords is not None:
|
369
407
|
# If not performing patch-based point only validation, use user provided click points for inference.
|
370
408
|
# the point clicks is in original image space, convert it to current patch-coordinate space.
|
371
|
-
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
|
409
|
+
point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore
|
372
410
|
|
373
411
|
if point_coords is not None and point_labels is not None:
|
374
412
|
# remove points that used for padding purposes (point_label = -1)
|
@@ -387,7 +425,10 @@ class VISTA3D(nn.Module):
|
|
387
425
|
point_coords, point_labels = None, None
|
388
426
|
|
389
427
|
if point_coords is None and class_vector is None:
|
390
|
-
|
428
|
+
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
|
429
|
+
if transpose:
|
430
|
+
logits = logits.transpose(1, 0)
|
431
|
+
return logits
|
391
432
|
|
392
433
|
if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
|
393
434
|
out, out_auto = self.image_embeddings, None
|
@@ -418,15 +459,16 @@ class VISTA3D(nn.Module):
|
|
418
459
|
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
|
419
460
|
if prev_mask is not None and patch_coords is not None:
|
420
461
|
logits = self.connected_components_combine(
|
421
|
-
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
|
462
|
+
prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
|
422
463
|
logits[mapping_index],
|
423
464
|
point_coords, # type: ignore
|
424
465
|
point_labels, # type: ignore
|
425
466
|
mapping_index,
|
426
467
|
)
|
427
|
-
|
428
468
|
if kwargs.get("keep_cache", False) and class_vector is None:
|
429
469
|
self.image_embeddings = out.detach()
|
470
|
+
if transpose:
|
471
|
+
logits = logits.transpose(1, 0)
|
430
472
|
return logits
|
431
473
|
|
432
474
|
|