monai-weekly 1.4.dev2435__py3-none-any.whl → 1.4.dev2437__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +1 -1
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +3 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +10 -6
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +224 -40
- monai/transforms/__init__.py +12 -0
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +101 -3
- monai/transforms/utils.py +31 -4
- monai/utils/__init__.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2437.dist-info}/METADATA +3 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2437.dist-info}/RECORD +28 -26
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2437.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2437.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2437.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
@@ -13,9 +13,51 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import os
|
15
15
|
import sys
|
16
|
-
|
16
|
+
import logging
|
17
|
+
import warnings
|
17
18
|
from ._version import get_versions
|
18
19
|
|
20
|
+
|
21
|
+
old_showwarning = warnings.showwarning
|
22
|
+
|
23
|
+
|
24
|
+
def custom_warning_handler(message, category, filename, lineno, file=None, line=None):
|
25
|
+
ignore_files = ["ignite/handlers/checkpoint", "modelopt/torch/quantization/tensor_quant"]
|
26
|
+
if any(ignore in filename for ignore in ignore_files):
|
27
|
+
return
|
28
|
+
old_showwarning(message, category, filename, lineno, file, line)
|
29
|
+
|
30
|
+
|
31
|
+
class DeprecatedTypesWarningFilter(logging.Filter):
|
32
|
+
def filter(self, record):
|
33
|
+
message_bodies_to_ignore = [
|
34
|
+
"np.bool8",
|
35
|
+
"np.object0",
|
36
|
+
"np.int0",
|
37
|
+
"np.uint0",
|
38
|
+
"np.void0",
|
39
|
+
"np.str0",
|
40
|
+
"np.bytes0",
|
41
|
+
"@validator",
|
42
|
+
"@root_validator",
|
43
|
+
"class-based `config`",
|
44
|
+
"pkg_resources",
|
45
|
+
"Implicitly cleaning up",
|
46
|
+
]
|
47
|
+
for message in message_bodies_to_ignore:
|
48
|
+
if message in record.getMessage():
|
49
|
+
return False
|
50
|
+
return True
|
51
|
+
|
52
|
+
|
53
|
+
# workaround for https://github.com/Project-MONAI/MONAI/issues/8060
|
54
|
+
# TODO: remove this workaround after upstream fixed the warning
|
55
|
+
# Set the custom warning handler to filter warning
|
56
|
+
warnings.showwarning = custom_warning_handler
|
57
|
+
# Get the logger for warnings and add the filter to the logger
|
58
|
+
logging.getLogger("py.warnings").addFilter(DeprecatedTypesWarningFilter())
|
59
|
+
|
60
|
+
|
19
61
|
PY_REQUIRED_MAJOR = 3
|
20
62
|
PY_REQUIRED_MINOR = 9
|
21
63
|
|
@@ -93,4 +135,4 @@ except BaseException:
|
|
93
135
|
|
94
136
|
if MONAIEnvVars.debug():
|
95
137
|
raise
|
96
|
-
__commit_id__ = "
|
138
|
+
__commit_id__ = "64eee8cb9cfad9ef5bd3eaf597fef0fbe85144b4"
|
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-09-
|
11
|
+
"date": "2024-09-15T02:27:58+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "1d72a6bbc4db84d507147fb422f9f54a939640b5",
|
15
|
+
"version": "1.4.dev2437"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/apps/vista3d/inferer.py
CHANGED
monai/bundle/config_parser.py
CHANGED
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
|
|
20
20
|
|
21
21
|
from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
|
22
22
|
from monai.bundle.reference_resolver import ReferenceResolver
|
23
|
-
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
|
23
|
+
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
|
24
24
|
from monai.config import PathLike
|
25
25
|
from monai.utils import ensure_tuple, look_up_option, optional_import
|
26
26
|
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
|
@@ -423,8 +423,10 @@ class ConfigParser:
|
|
423
423
|
if isinstance(files, str) and not Path(files).is_file() and "," in files:
|
424
424
|
files = files.split(",")
|
425
425
|
for i in ensure_tuple(files):
|
426
|
-
|
427
|
-
|
426
|
+
config_dict = cls.load_config_file(i, **kwargs)
|
427
|
+
for k, v in config_dict.items():
|
428
|
+
merge_kv(parser, k, v)
|
429
|
+
|
428
430
|
return parser.get() # type: ignore
|
429
431
|
|
430
432
|
@classmethod
|
monai/bundle/scripts.py
CHANGED
@@ -32,7 +32,7 @@ from monai._version import get_versions
|
|
32
32
|
from monai.apps.utils import _basename, download_url, extractall, get_logger
|
33
33
|
from monai.bundle.config_item import ConfigComponent
|
34
34
|
from monai.bundle.config_parser import ConfigParser
|
35
|
-
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
|
35
|
+
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
|
36
36
|
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
|
37
37
|
from monai.config import IgniteInfo, PathLike
|
38
38
|
from monai.data import load_net_with_metadata, save_net_with_metadata
|
@@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
|
|
105
105
|
if isinstance(v, dict) and isinstance(args_.get(k), dict):
|
106
106
|
args_[k] = update_kwargs(args_[k], ignore_none, **v)
|
107
107
|
else:
|
108
|
-
args_
|
108
|
+
merge_kv(args_, k, v)
|
109
109
|
return args_
|
110
110
|
|
111
111
|
|
@@ -255,6 +255,7 @@ def _download_from_ngc_private(
|
|
255
255
|
else:
|
256
256
|
raise ValueError("NGC API requires requests package. Please install it.")
|
257
257
|
|
258
|
+
os.makedirs(download_path, exist_ok=True)
|
258
259
|
zip_path = download_path / f"{filename}_v{version}.zip"
|
259
260
|
with open(zip_path, "wb") as f:
|
260
261
|
f.write(response.content)
|
monai/bundle/utils.py
CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import json
|
15
15
|
import os
|
16
|
+
import warnings
|
16
17
|
import zipfile
|
17
18
|
from typing import Any
|
18
19
|
|
@@ -21,12 +22,21 @@ from monai.utils import optional_import
|
|
21
22
|
|
22
23
|
yaml, _ = optional_import("yaml")
|
23
24
|
|
24
|
-
__all__ = [
|
25
|
+
__all__ = [
|
26
|
+
"ID_REF_KEY",
|
27
|
+
"ID_SEP_KEY",
|
28
|
+
"EXPR_KEY",
|
29
|
+
"MACRO_KEY",
|
30
|
+
"MERGE_KEY",
|
31
|
+
"DEFAULT_MLFLOW_SETTINGS",
|
32
|
+
"DEFAULT_EXP_MGMT_SETTINGS",
|
33
|
+
]
|
25
34
|
|
26
35
|
ID_REF_KEY = "@" # start of a reference to a ConfigItem
|
27
36
|
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
|
28
37
|
EXPR_KEY = "$" # start of a ConfigExpression
|
29
38
|
MACRO_KEY = "%" # start of a macro of a config
|
39
|
+
MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.
|
30
40
|
|
31
41
|
_conf_values = get_config_values()
|
32
42
|
|
@@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
|
|
233
243
|
parser.read_config(f=cdata)
|
234
244
|
|
235
245
|
return parser
|
246
|
+
|
247
|
+
|
248
|
+
def merge_kv(args: dict | Any, k: str, v: Any) -> None:
|
249
|
+
"""
|
250
|
+
Update the `args` dict-like object with the key/value pair `k` and `v`.
|
251
|
+
"""
|
252
|
+
if k.startswith(MERGE_KEY):
|
253
|
+
"""
|
254
|
+
Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
|
255
|
+
`dict` values will be merged, `list` values - concatenated.
|
256
|
+
"""
|
257
|
+
id = k[1:]
|
258
|
+
if id in args:
|
259
|
+
if isinstance(v, dict) and isinstance(args[id], dict):
|
260
|
+
args[id].update(v)
|
261
|
+
elif isinstance(v, list) and isinstance(args[id], list):
|
262
|
+
args[id].extend(v)
|
263
|
+
else:
|
264
|
+
raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
|
265
|
+
else:
|
266
|
+
warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
|
267
|
+
args[id] = v
|
268
|
+
else:
|
269
|
+
args[k] = v
|
monai/handlers/__init__.py
CHANGED
@@ -40,5 +40,6 @@ from .smartcache_handler import SmartCacheHandler
|
|
40
40
|
from .stats_handler import StatsHandler
|
41
41
|
from .surface_distance import SurfaceDistance
|
42
42
|
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
|
43
|
+
from .trt_handler import TrtHandler
|
43
44
|
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
|
44
45
|
from .validation_handler import ValidationHandler
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from typing import TYPE_CHECKING
|
15
|
+
|
16
|
+
from monai.config import IgniteInfo
|
17
|
+
from monai.networks import trt_compile
|
18
|
+
from monai.utils import min_version, optional_import
|
19
|
+
|
20
|
+
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from ignite.engine import Engine
|
23
|
+
else:
|
24
|
+
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
|
25
|
+
|
26
|
+
|
27
|
+
class TrtHandler:
|
28
|
+
"""
|
29
|
+
TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
|
30
|
+
Usage example::
|
31
|
+
handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
|
32
|
+
handler.attach(engine)
|
33
|
+
engine.run()
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, model, base_path, args=None, submodule=None):
|
37
|
+
"""
|
38
|
+
Args:
|
39
|
+
base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
|
40
|
+
args: passed to trt_compile(). See trt_compile() for details.
|
41
|
+
submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
|
42
|
+
"""
|
43
|
+
self.model = model
|
44
|
+
self.base_path = base_path
|
45
|
+
self.args = args
|
46
|
+
self.submodule = submodule
|
47
|
+
|
48
|
+
def attach(self, engine: Engine) -> None:
|
49
|
+
"""
|
50
|
+
Args:
|
51
|
+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
52
|
+
"""
|
53
|
+
self.logger = engine.logger
|
54
|
+
engine.add_event_handler(Events.STARTED, self)
|
55
|
+
|
56
|
+
def __call__(self, engine: Engine) -> None:
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
engine: Ignite Engine, it can be a trainer, validator or evaluator.
|
60
|
+
"""
|
61
|
+
trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
|
@@ -14,34 +14,47 @@ from __future__ import annotations
|
|
14
14
|
import torch
|
15
15
|
|
16
16
|
from monai.metrics.utils import do_metric_reduction, ignore_background
|
17
|
-
from monai.utils import MetricReduction, Weight, look_up_option
|
17
|
+
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
|
18
18
|
|
19
19
|
from .metric import CumulativeIterationMetric
|
20
20
|
|
21
21
|
|
22
22
|
class GeneralizedDiceScore(CumulativeIterationMetric):
|
23
|
-
"""
|
23
|
+
"""
|
24
|
+
Compute the Generalized Dice Score metric between tensors.
|
24
25
|
|
26
|
+
This metric is the complement of the Generalized Dice Loss defined in:
|
25
27
|
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
|
26
|
-
|
28
|
+
loss function for highly unbalanced segmentations. DLMIA 2017.
|
27
29
|
|
28
|
-
The inputs `y_pred` and `y` are expected to be one-hot, binarized
|
29
|
-
or batch-first tensors, i.e., CHW[D] or BCHW[D].
|
30
|
+
The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
|
30
31
|
|
31
32
|
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
32
33
|
|
33
34
|
Args:
|
34
|
-
include_background
|
35
|
+
include_background: Whether to include the background class (assumed to be in channel 0) in the
|
35
36
|
score computation. Defaults to True.
|
36
|
-
reduction
|
37
|
-
{``"none"``, ``"
|
38
|
-
|
37
|
+
reduction: Define mode of reduction to the metrics. Available reduction modes:
|
38
|
+
{``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
|
39
|
+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
|
40
|
+
weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
|
39
41
|
ground truth volume into a weight factor. Defaults to ``"square"``.
|
40
42
|
|
41
43
|
Raises:
|
42
|
-
ValueError:
|
44
|
+
ValueError: When the `reduction` is not one of MetricReduction enum.
|
43
45
|
"""
|
44
46
|
|
47
|
+
@deprecated_arg_default(
|
48
|
+
"reduction",
|
49
|
+
old_default=MetricReduction.MEAN_BATCH,
|
50
|
+
new_default=MetricReduction.MEAN,
|
51
|
+
since="1.4.0",
|
52
|
+
replaced="1.5.0",
|
53
|
+
msg_suffix=(
|
54
|
+
"Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
|
55
|
+
"If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
|
56
|
+
),
|
57
|
+
)
|
45
58
|
def __init__(
|
46
59
|
self,
|
47
60
|
include_background: bool = True,
|
@@ -50,79 +63,90 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
|
|
50
63
|
) -> None:
|
51
64
|
super().__init__()
|
52
65
|
self.include_background = include_background
|
53
|
-
|
54
|
-
"none",
|
55
|
-
"mean_batch",
|
56
|
-
"sum_batch",
|
57
|
-
MetricReduction.NONE,
|
58
|
-
MetricReduction.MEAN_BATCH,
|
59
|
-
MetricReduction.SUM_BATCH,
|
60
|
-
]
|
61
|
-
self.reduction = reduction
|
62
|
-
if self.reduction not in reduction_options:
|
63
|
-
raise ValueError(f"reduction must be one of {reduction_options}")
|
66
|
+
self.reduction = look_up_option(reduction, MetricReduction)
|
64
67
|
self.weight_type = look_up_option(weight_type, Weight)
|
68
|
+
self.sum_over_classes = self.reduction in {
|
69
|
+
MetricReduction.SUM,
|
70
|
+
MetricReduction.MEAN,
|
71
|
+
MetricReduction.MEAN_CHANNEL,
|
72
|
+
MetricReduction.SUM_CHANNEL,
|
73
|
+
}
|
65
74
|
|
66
75
|
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
67
|
-
"""
|
76
|
+
"""
|
77
|
+
Computes the Generalized Dice Score and returns a tensor with its per image values.
|
68
78
|
|
69
79
|
Args:
|
70
|
-
y_pred (torch.Tensor):
|
80
|
+
y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
|
71
81
|
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
|
72
|
-
y (torch.Tensor):
|
82
|
+
y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
torch.Tensor: Generalized Dice Score averaged across batch and class
|
73
86
|
|
74
87
|
Raises:
|
75
|
-
ValueError:
|
88
|
+
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
|
76
89
|
"""
|
77
90
|
return compute_generalized_dice(
|
78
|
-
y_pred=y_pred,
|
91
|
+
y_pred=y_pred,
|
92
|
+
y=y,
|
93
|
+
include_background=self.include_background,
|
94
|
+
weight_type=self.weight_type,
|
95
|
+
sum_over_classes=self.sum_over_classes,
|
79
96
|
)
|
80
97
|
|
98
|
+
@deprecated_arg(
|
99
|
+
"reduction",
|
100
|
+
since="1.3.3",
|
101
|
+
removed="1.7.0",
|
102
|
+
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
|
103
|
+
)
|
81
104
|
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
|
82
105
|
"""
|
83
106
|
Execute reduction logic for the output of `compute_generalized_dice`.
|
84
107
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
108
|
+
Returns:
|
109
|
+
torch.Tensor: Aggregated metric value.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
ValueError: If the data to aggregate is not a PyTorch Tensor.
|
89
113
|
"""
|
90
114
|
data = self.get_buffer()
|
91
115
|
if not isinstance(data, torch.Tensor):
|
92
116
|
raise ValueError("The data to aggregate must be a PyTorch Tensor.")
|
93
117
|
|
94
|
-
# Validate reduction argument if specified
|
95
|
-
if reduction is not None:
|
96
|
-
reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
|
97
|
-
if reduction not in reduction_options:
|
98
|
-
raise ValueError(f"reduction must be one of {reduction_options}")
|
99
|
-
|
100
118
|
# Do metric reduction and return
|
101
|
-
f, _ = do_metric_reduction(data,
|
119
|
+
f, _ = do_metric_reduction(data, self.reduction)
|
102
120
|
|
103
121
|
return f
|
104
122
|
|
105
123
|
|
106
124
|
def compute_generalized_dice(
|
107
|
-
y_pred: torch.Tensor,
|
125
|
+
y_pred: torch.Tensor,
|
126
|
+
y: torch.Tensor,
|
127
|
+
include_background: bool = True,
|
128
|
+
weight_type: Weight | str = Weight.SQUARE,
|
129
|
+
sum_over_classes: bool = False,
|
108
130
|
) -> torch.Tensor:
|
109
|
-
"""
|
131
|
+
"""
|
132
|
+
Computes the Generalized Dice Score and returns a tensor with its per image values.
|
110
133
|
|
111
134
|
Args:
|
112
|
-
y_pred (torch.Tensor):
|
135
|
+
y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
|
113
136
|
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
|
114
137
|
remaining are the spatial dimensions.
|
115
|
-
y (torch.Tensor):
|
116
|
-
include_background
|
138
|
+
y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
|
139
|
+
include_background: Whether to include score computation on the first channel of the
|
117
140
|
predicted output. Defaults to True.
|
118
141
|
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
|
119
142
|
transform ground truth volume into a weight factor. Defaults to ``"square"``.
|
143
|
+
sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
|
120
144
|
|
121
145
|
Returns:
|
122
|
-
torch.Tensor:
|
146
|
+
torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
|
123
147
|
|
124
148
|
Raises:
|
125
|
-
ValueError:
|
149
|
+
ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
|
126
150
|
or `y_pred` and `y` don't have the same shape.
|
127
151
|
"""
|
128
152
|
# Ensure tensors have at least 3 dimensions and have the same shape
|
@@ -158,16 +182,21 @@ def compute_generalized_dice(
|
|
158
182
|
b[infs] = 0
|
159
183
|
b[infs] = torch.max(b)
|
160
184
|
|
161
|
-
# Compute the weighted numerator and denominator, summing along the class axis
|
162
|
-
|
163
|
-
|
185
|
+
# Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
|
186
|
+
if sum_over_classes:
|
187
|
+
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
|
188
|
+
denom = (denominator * w).sum(dim=1, keepdim=True)
|
189
|
+
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
|
190
|
+
else:
|
191
|
+
numer = 2.0 * (intersection * w)
|
192
|
+
denom = denominator * w
|
193
|
+
y_pred_o = y_pred_o
|
164
194
|
|
165
195
|
# Compute the score
|
166
196
|
generalized_dice_score = numer / denom
|
167
197
|
|
168
198
|
# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
|
169
199
|
# Where denom == 0 but the prediction volume is not 0, score is 0
|
170
|
-
y_pred_o = y_pred_o.sum(dim=-1)
|
171
200
|
denom_zeros = denom == 0
|
172
201
|
generalized_dice_score[denom_zeros] = torch.where(
|
173
202
|
(y_pred_o == 0)[denom_zeros],
|
monai/networks/__init__.py
CHANGED
@@ -320,7 +320,7 @@ class SwinUNETR(nn.Module):
|
|
320
320
|
)
|
321
321
|
|
322
322
|
def forward(self, x_in):
|
323
|
-
if not torch.jit.is_scripting():
|
323
|
+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
324
324
|
self._check_input_size(x_in.shape[2:])
|
325
325
|
hidden_states_out = self.swinViT(x_in, self.normalize)
|
326
326
|
enc0 = self.encoder1(x_in)
|
@@ -1046,14 +1046,14 @@ class SwinTransformer(nn.Module):
|
|
1046
1046
|
|
1047
1047
|
def proj_out(self, x, normalize=False):
|
1048
1048
|
if normalize:
|
1049
|
-
x_shape = x.
|
1049
|
+
x_shape = x.shape
|
1050
|
+
# Force trace() to generate a constant by casting to int
|
1051
|
+
ch = int(x_shape[1])
|
1050
1052
|
if len(x_shape) == 5:
|
1051
|
-
n, ch, d, h, w = x_shape
|
1052
1053
|
x = rearrange(x, "n c d h w -> n d h w c")
|
1053
1054
|
x = F.layer_norm(x, [ch])
|
1054
1055
|
x = rearrange(x, "n d h w c -> n c d h w")
|
1055
1056
|
elif len(x_shape) == 4:
|
1056
|
-
n, ch, h, w = x_shape
|
1057
1057
|
x = rearrange(x, "n c h w -> n h w c")
|
1058
1058
|
x = F.layer_norm(x, [ch])
|
1059
1059
|
x = rearrange(x, "n h w c -> n c h w")
|
monai/networks/nets/vista3d.py
CHANGED
@@ -336,7 +336,7 @@ class VISTA3D(nn.Module):
|
|
336
336
|
def forward(
|
337
337
|
self,
|
338
338
|
input_images: torch.Tensor,
|
339
|
-
patch_coords: Sequence[slice] | None = None,
|
339
|
+
patch_coords: list[Sequence[slice]] | None = None,
|
340
340
|
point_coords: torch.Tensor | None = None,
|
341
341
|
point_labels: torch.Tensor | None = None,
|
342
342
|
class_vector: torch.Tensor | None = None,
|
@@ -364,8 +364,12 @@ class VISTA3D(nn.Module):
|
|
364
364
|
the points are for zero-shot or supported class. When class_vector and point_coords are both
|
365
365
|
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
|
366
366
|
will be considered novel class.
|
367
|
-
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window
|
368
|
-
This value is passed from sliding_window_inferer.
|
367
|
+
patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
|
368
|
+
inference. This value is passed from sliding_window_inferer.
|
369
|
+
This is an indicator for training phase or validation phase.
|
370
|
+
Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
|
371
|
+
coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
|
372
|
+
functions using patch_coords will by default use patch_coords[0].
|
369
373
|
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
|
370
374
|
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
|
371
375
|
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
|
@@ -395,14 +399,14 @@ class VISTA3D(nn.Module):
|
|
395
399
|
if val_point_sampler is None:
|
396
400
|
# TODO: think about how to refactor this part.
|
397
401
|
val_point_sampler = self.sample_points_patch_val
|
398
|
-
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
|
402
|
+
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
|
399
403
|
if prompt_class[0].item() == 0: # type: ignore
|
400
404
|
point_labels[0] = -1 # type: ignore
|
401
405
|
labels, prev_mask = None, None
|
402
406
|
elif point_coords is not None:
|
403
407
|
# If not performing patch-based point only validation, use user provided click points for inference.
|
404
408
|
# the point clicks is in original image space, convert it to current patch-coordinate space.
|
405
|
-
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
|
409
|
+
point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore
|
406
410
|
|
407
411
|
if point_coords is not None and point_labels is not None:
|
408
412
|
# remove points that used for padding purposes (point_label = -1)
|
@@ -455,7 +459,7 @@ class VISTA3D(nn.Module):
|
|
455
459
|
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
|
456
460
|
if prev_mask is not None and patch_coords is not None:
|
457
461
|
logits = self.connected_components_combine(
|
458
|
-
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
|
462
|
+
prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
|
459
463
|
logits[mapping_index],
|
460
464
|
point_coords, # type: ignore
|
461
465
|
point_labels, # type: ignore
|