monai-weekly 1.5.dev2507__py3-none-any.whl → 1.5.dev2508__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +10 -7
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2508.dist-info}/METADATA +20 -16
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2508.dist-info}/RECORD +23 -19
- tests/bundle/test_bundle_trt_export.py +2 -2
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/inferers/test_controlnet_inferers.py +80 -2
- tests/inferers/test_latent_diffusion_inferer.py +61 -1
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/networks/test_convert_to_onnx.py +1 -1
- tests/transforms/test_gibbs_noise.py +3 -5
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2508.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2508.dist-info}/WHEEL +0 -0
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2508.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2025-02-
|
11
|
+
"date": "2025-02-23T02:28:09+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "e55b5cbfbbba1800a968a9c06b2deaaa5c9bec54",
|
15
|
+
"version": "1.5.dev2508"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/handlers/__init__.py
CHANGED
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
from collections.abc import Callable
|
15
|
+
|
16
|
+
from monai.handlers.ignite_metric import IgniteMetricHandler
|
17
|
+
from monai.metrics import AveragePrecisionMetric
|
18
|
+
from monai.utils import Average
|
19
|
+
|
20
|
+
|
21
|
+
class AveragePrecision(IgniteMetricHandler):
|
22
|
+
"""
|
23
|
+
Computes Average Precision (AP).
|
24
|
+
accumulating predictions and the ground-truth during an epoch and applying `compute_average_precision`.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
28
|
+
Type of averaging performed if not binary classification. Defaults to ``"macro"``.
|
29
|
+
|
30
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
31
|
+
This does not take label imbalance into account.
|
32
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
33
|
+
weighted by support (the number of true instances for each label).
|
34
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
35
|
+
indicator matrix as a label.
|
36
|
+
- ``"none"``: the scores for each class are returned.
|
37
|
+
|
38
|
+
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
|
39
|
+
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
|
40
|
+
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
|
41
|
+
`engine.state` and `output_transform` inherit from the ignite concept:
|
42
|
+
https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
|
43
|
+
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
|
44
|
+
|
45
|
+
Note:
|
46
|
+
Average Precision expects y to be comprised of 0's and 1's.
|
47
|
+
y_pred must either be probability estimates or confidence values.
|
48
|
+
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(self, average: Average | str = Average.MACRO, output_transform: Callable = lambda x: x) -> None:
|
52
|
+
metric_fn = AveragePrecisionMetric(average=Average(average))
|
53
|
+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=False)
|
monai/inferers/inferer.py
CHANGED
@@ -1202,15 +1202,16 @@ class LatentDiffusionInferer(DiffusionInferer):
|
|
1202
1202
|
|
1203
1203
|
if self.autoencoder_latent_shape is not None:
|
1204
1204
|
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1205
|
+
if save_intermediates:
|
1206
|
+
latent_intermediates = [
|
1207
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
|
1208
|
+
for l in latent_intermediates
|
1209
|
+
]
|
1208
1210
|
|
1209
1211
|
decode = autoencoder_model.decode_stage_2_outputs
|
1210
1212
|
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1211
1213
|
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1212
1214
|
image = decode(latent / self.scale_factor)
|
1213
|
-
|
1214
1215
|
if save_intermediates:
|
1215
1216
|
intermediates = []
|
1216
1217
|
for latent_intermediate in latent_intermediates:
|
@@ -1727,9 +1728,11 @@ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
|
1727
1728
|
|
1728
1729
|
if self.autoencoder_latent_shape is not None:
|
1729
1730
|
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1731
|
+
if save_intermediates:
|
1732
|
+
latent_intermediates = [
|
1733
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
|
1734
|
+
for l in latent_intermediates
|
1735
|
+
]
|
1733
1736
|
|
1734
1737
|
decode = autoencoder_model.decode_stage_2_outputs
|
1735
1738
|
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
monai/metrics/__init__.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
|
15
|
+
from .average_precision import AveragePrecisionMetric, compute_average_precision
|
15
16
|
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
|
16
17
|
from .cumulative_average import CumulativeAverage
|
17
18
|
from .f_beta_score import FBetaScore
|
@@ -0,0 +1,187 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import warnings
|
15
|
+
from typing import TYPE_CHECKING, cast
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
import numpy.typing as npt
|
21
|
+
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from monai.utils import Average, look_up_option
|
25
|
+
|
26
|
+
from .metric import CumulativeIterationMetric
|
27
|
+
|
28
|
+
|
29
|
+
class AveragePrecisionMetric(CumulativeIterationMetric):
|
30
|
+
"""
|
31
|
+
Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
|
32
|
+
imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
|
33
|
+
It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
|
34
|
+
threshold, with the increase in recall from the previous threshold used as the weight:
|
35
|
+
|
36
|
+
.. math::
|
37
|
+
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
|
38
|
+
:label: ap
|
39
|
+
|
40
|
+
where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.
|
41
|
+
|
42
|
+
Referring to: `sklearn.metrics.average_precision_score
|
43
|
+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
|
44
|
+
|
45
|
+
The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.
|
46
|
+
|
47
|
+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
51
|
+
Type of averaging performed if not binary classification.
|
52
|
+
Defaults to ``"macro"``.
|
53
|
+
|
54
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
55
|
+
This does not take label imbalance into account.
|
56
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
57
|
+
weighted by support (the number of true instances for each label).
|
58
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
59
|
+
indicator matrix as a label.
|
60
|
+
- ``"none"``: the scores for each class are returned.
|
61
|
+
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(self, average: Average | str = Average.MACRO) -> None:
|
65
|
+
super().__init__()
|
66
|
+
self.average = average
|
67
|
+
|
68
|
+
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # type: ignore[override]
|
69
|
+
return y_pred, y
|
70
|
+
|
71
|
+
def aggregate(self, average: Average | str | None = None) -> np.ndarray | float | npt.ArrayLike:
|
72
|
+
"""
|
73
|
+
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
|
74
|
+
This function reads the buffers and computes the Average Precision.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
78
|
+
Type of averaging performed if not binary classification. Defaults to `self.average`.
|
79
|
+
|
80
|
+
"""
|
81
|
+
y_pred, y = self.get_buffer()
|
82
|
+
# compute final value and do metric reduction
|
83
|
+
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
|
84
|
+
raise ValueError("y_pred and y must be PyTorch Tensor.")
|
85
|
+
|
86
|
+
return compute_average_precision(y_pred=y_pred, y=y, average=average or self.average)
|
87
|
+
|
88
|
+
|
89
|
+
def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
|
90
|
+
if not (y.ndimension() == y_pred.ndimension() == 1 and len(y) == len(y_pred)):
|
91
|
+
raise AssertionError("y and y_pred must be 1 dimension data with same length.")
|
92
|
+
y_unique = y.unique()
|
93
|
+
if len(y_unique) == 1:
|
94
|
+
warnings.warn(f"y values can not be all {y_unique.item()}, skip AP computation and return `Nan`.")
|
95
|
+
return float("nan")
|
96
|
+
if not y_unique.equal(torch.tensor([0, 1], dtype=y.dtype, device=y.device)):
|
97
|
+
warnings.warn(f"y values must be 0 or 1, but in {y_unique.tolist()}, skip AP computation and return `Nan`.")
|
98
|
+
return float("nan")
|
99
|
+
|
100
|
+
n = len(y)
|
101
|
+
indices = y_pred.argsort(descending=True)
|
102
|
+
y = y[indices].cpu().numpy() # type: ignore[assignment]
|
103
|
+
y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
|
104
|
+
npos = ap = tmp_pos = 0.0
|
105
|
+
|
106
|
+
for i in range(n):
|
107
|
+
y_i = cast(float, y[i])
|
108
|
+
if i + 1 < n and y_pred[i] == y_pred[i + 1]:
|
109
|
+
tmp_pos += y_i
|
110
|
+
else:
|
111
|
+
tmp_pos += y_i
|
112
|
+
npos += tmp_pos
|
113
|
+
ap += tmp_pos * npos / (i + 1)
|
114
|
+
tmp_pos = 0
|
115
|
+
|
116
|
+
return ap / npos
|
117
|
+
|
118
|
+
|
119
|
+
def compute_average_precision(
|
120
|
+
y_pred: torch.Tensor, y: torch.Tensor, average: Average | str = Average.MACRO
|
121
|
+
) -> np.ndarray | float | npt.ArrayLike:
|
122
|
+
"""Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
|
123
|
+
imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
|
124
|
+
Referring to: `sklearn.metrics.average_precision_score
|
125
|
+
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
y_pred: input data to compute, typical classification model output.
|
129
|
+
the first dim must be batch, if multi-classes, it must be in One-Hot format.
|
130
|
+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
|
131
|
+
y: ground truth to compute AP metric, the first dim must be batch.
|
132
|
+
if multi-classes, it must be in One-Hot format.
|
133
|
+
for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
|
134
|
+
average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
|
135
|
+
Type of averaging performed if not binary classification.
|
136
|
+
Defaults to ``"macro"``.
|
137
|
+
|
138
|
+
- ``"macro"``: calculate metrics for each label, and find their unweighted mean.
|
139
|
+
This does not take label imbalance into account.
|
140
|
+
- ``"weighted"``: calculate metrics for each label, and find their average,
|
141
|
+
weighted by support (the number of true instances for each label).
|
142
|
+
- ``"micro"``: calculate metrics globally by considering each element of the label
|
143
|
+
indicator matrix as a label.
|
144
|
+
- ``"none"``: the scores for each class are returned.
|
145
|
+
|
146
|
+
Raises:
|
147
|
+
ValueError: When ``y_pred`` dimension is not one of [1, 2].
|
148
|
+
ValueError: When ``y`` dimension is not one of [1, 2].
|
149
|
+
ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].
|
150
|
+
|
151
|
+
Note:
|
152
|
+
Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.
|
153
|
+
|
154
|
+
"""
|
155
|
+
y_pred_ndim = y_pred.ndimension()
|
156
|
+
y_ndim = y.ndimension()
|
157
|
+
if y_pred_ndim not in (1, 2):
|
158
|
+
raise ValueError(
|
159
|
+
f"Predictions should be of shape (batch_size, num_classes) or (batch_size, ), got {y_pred.shape}."
|
160
|
+
)
|
161
|
+
if y_ndim not in (1, 2):
|
162
|
+
raise ValueError(f"Targets should be of shape (batch_size, num_classes) or (batch_size, ), got {y.shape}.")
|
163
|
+
if y_pred_ndim == 2 and y_pred.shape[1] == 1:
|
164
|
+
y_pred = y_pred.squeeze(dim=-1)
|
165
|
+
y_pred_ndim = 1
|
166
|
+
if y_ndim == 2 and y.shape[1] == 1:
|
167
|
+
y = y.squeeze(dim=-1)
|
168
|
+
|
169
|
+
if y_pred_ndim == 1:
|
170
|
+
return _calculate(y_pred, y)
|
171
|
+
|
172
|
+
if y.shape != y_pred.shape:
|
173
|
+
raise ValueError(f"data shapes of y_pred and y do not match, got {y_pred.shape} and {y.shape}.")
|
174
|
+
|
175
|
+
average = look_up_option(average, Average)
|
176
|
+
if average == Average.MICRO:
|
177
|
+
return _calculate(y_pred.flatten(), y.flatten())
|
178
|
+
y, y_pred = y.transpose(0, 1), y_pred.transpose(0, 1)
|
179
|
+
ap_values = [_calculate(y_pred_, y_) for y_pred_, y_ in zip(y_pred, y)]
|
180
|
+
if average == Average.NONE:
|
181
|
+
return ap_values
|
182
|
+
if average == Average.MACRO:
|
183
|
+
return np.mean(ap_values)
|
184
|
+
if average == Average.WEIGHTED:
|
185
|
+
weights = [sum(y_) for y_ in y]
|
186
|
+
return np.average(ap_values, weights=weights) # type: ignore[no-any-return]
|
187
|
+
raise ValueError(f'Unsupported average: {average}, available options are ["macro", "weighted", "micro", "none"].')
|
@@ -66,7 +66,6 @@ from monai.utils import (
|
|
66
66
|
optional_import,
|
67
67
|
)
|
68
68
|
from monai.utils.enums import TransformBackends
|
69
|
-
from monai.utils.misc import is_module_ver_at_least
|
70
69
|
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
|
71
70
|
|
72
71
|
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
|
@@ -939,19 +938,10 @@ class LabelToMask(Transform):
|
|
939
938
|
data = img[[*select_labels]]
|
940
939
|
else:
|
941
940
|
where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore
|
942
|
-
|
943
|
-
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
|
944
|
-
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
|
945
|
-
else:
|
946
|
-
data = where(
|
947
|
-
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
|
948
|
-
).reshape(img.shape)
|
941
|
+
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
|
949
942
|
|
950
943
|
if merge_channels or self.merge_channels:
|
951
|
-
|
952
|
-
return data.any(0)[None]
|
953
|
-
# pre pytorch 1.8.0 compatibility
|
954
|
-
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
|
944
|
+
return data.any(0)[None]
|
955
945
|
|
956
946
|
return data
|
957
947
|
|
@@ -18,7 +18,6 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
|
20
20
|
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
|
21
|
-
from monai.utils.misc import is_module_ver_at_least
|
22
21
|
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
|
23
22
|
|
24
23
|
__all__ = [
|
@@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
|
|
215
214
|
Element-wise floor division between two arrays/tensors.
|
216
215
|
"""
|
217
216
|
if isinstance(a, torch.Tensor):
|
218
|
-
if is_module_ver_at_least(torch, (1, 8, 0)):
|
219
|
-
return torch.div(a, b, rounding_mode="floor")
|
220
217
|
return torch.floor_divide(a, b)
|
221
|
-
|
218
|
+
else:
|
219
|
+
return np.floor_divide(a, b)
|
222
220
|
|
223
221
|
|
224
222
|
def unravel_index(idx, shape) -> NdarrayOrTensor:
|
monai/utils/enums.py
CHANGED
@@ -213,7 +213,8 @@ class GridSamplePadMode(StrEnum):
|
|
213
213
|
|
214
214
|
class Average(StrEnum):
|
215
215
|
"""
|
216
|
-
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc`
|
216
|
+
See also: :py:class:`monai.metrics.rocauc.compute_roc_auc` or
|
217
|
+
:py:class:`monai.metrics.average_precision.compute_average_precision`
|
217
218
|
"""
|
218
219
|
|
219
220
|
MACRO = "macro"
|
@@ -335,7 +336,7 @@ class CommonKeys(StrEnum):
|
|
335
336
|
`LABEL` is the training or evaluation label of segmentation or classification task.
|
336
337
|
`PRED` is the prediction data of model output.
|
337
338
|
`LOSS` is the loss value of current iteration.
|
338
|
-
`
|
339
|
+
`METADATA` is some useful information during training or evaluation, like loss value, etc.
|
339
340
|
|
340
341
|
"""
|
341
342
|
|
monai/utils/module.py
CHANGED
@@ -540,11 +540,11 @@ def version_leq(lhs: str, rhs: str) -> bool:
|
|
540
540
|
"""
|
541
541
|
|
542
542
|
lhs, rhs = str(lhs), str(rhs)
|
543
|
-
pkging, has_ver = optional_import("packaging.
|
543
|
+
pkging, has_ver = optional_import("packaging.version")
|
544
544
|
if has_ver:
|
545
545
|
try:
|
546
|
-
return cast(bool, pkging.
|
547
|
-
except pkging.
|
546
|
+
return cast(bool, pkging.Version(lhs) <= pkging.Version(rhs))
|
547
|
+
except pkging.InvalidVersion:
|
548
548
|
return True
|
549
549
|
|
550
550
|
lhs_, rhs_ = parse_version_strs(lhs, rhs)
|
@@ -567,12 +567,12 @@ def version_geq(lhs: str, rhs: str) -> bool:
|
|
567
567
|
|
568
568
|
"""
|
569
569
|
lhs, rhs = str(lhs), str(rhs)
|
570
|
-
pkging, has_ver = optional_import("packaging.
|
570
|
+
pkging, has_ver = optional_import("packaging.version")
|
571
571
|
|
572
572
|
if has_ver:
|
573
573
|
try:
|
574
|
-
return cast(bool, pkging.
|
575
|
-
except pkging.
|
574
|
+
return cast(bool, pkging.Version(lhs) >= pkging.Version(rhs))
|
575
|
+
except pkging.InvalidVersion:
|
576
576
|
return True
|
577
577
|
|
578
578
|
lhs_, rhs_ = parse_version_strs(lhs, rhs)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.5.
|
3
|
+
Version: 1.5.dev2508
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -176,12 +176,13 @@ Requires-Dist: pyamg>=5.0.0; extra == "pyamg"
|
|
176
176
|
|
177
177
|
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
|
178
178
|
Its ambitions are as follows:
|
179
|
+
|
179
180
|
- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
|
180
181
|
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
|
181
182
|
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
|
182
183
|
|
183
|
-
|
184
184
|
## Features
|
185
|
+
|
185
186
|
> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
|
186
187
|
|
187
188
|
- flexible pre-processing for multi-dimensional medical imaging data;
|
@@ -190,7 +191,6 @@ Its ambitions are as follows:
|
|
190
191
|
- customizable design for varying user expertise;
|
191
192
|
- multi-GPU multi-node data parallelism support.
|
192
193
|
|
193
|
-
|
194
194
|
## Installation
|
195
195
|
|
196
196
|
To install [the current release](https://pypi.org/project/monai/), you can simply run:
|
@@ -211,30 +211,34 @@ Technical documentation is available at [docs.monai.io](https://docs.monai.io).
|
|
211
211
|
|
212
212
|
## Citation
|
213
213
|
|
214
|
-
If you have used MONAI in your research, please cite us! The citation can be exported from: https://arxiv.org/abs/2211.02701
|
214
|
+
If you have used MONAI in your research, please cite us! The citation can be exported from: <https://arxiv.org/abs/2211.02701>.
|
215
215
|
|
216
216
|
## Model Zoo
|
217
|
+
|
217
218
|
[The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community.
|
218
219
|
Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI.
|
219
220
|
|
220
221
|
## Contributing
|
222
|
+
|
221
223
|
For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md).
|
222
224
|
|
223
225
|
## Community
|
226
|
+
|
224
227
|
Join the conversation on Twitter/X [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).
|
225
228
|
|
226
229
|
Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions).
|
227
230
|
|
228
231
|
## Links
|
229
|
-
|
230
|
-
-
|
231
|
-
- API documentation (
|
232
|
-
-
|
233
|
-
-
|
234
|
-
-
|
235
|
-
-
|
236
|
-
-
|
237
|
-
-
|
238
|
-
-
|
239
|
-
-
|
240
|
-
-
|
232
|
+
|
233
|
+
- Website: <https://monai.io/>
|
234
|
+
- API documentation (milestone): <https://docs.monai.io/>
|
235
|
+
- API documentation (latest dev): <https://docs.monai.io/en/latest/>
|
236
|
+
- Code: <https://github.com/Project-MONAI/MONAI>
|
237
|
+
- Project tracker: <https://github.com/Project-MONAI/MONAI/projects>
|
238
|
+
- Issue tracker: <https://github.com/Project-MONAI/MONAI/issues>
|
239
|
+
- Wiki: <https://github.com/Project-MONAI/MONAI/wiki>
|
240
|
+
- Test status: <https://github.com/Project-MONAI/MONAI/actions>
|
241
|
+
- PyPI package: <https://pypi.org/project/monai/>
|
242
|
+
- conda-forge: <https://anaconda.org/conda-forge/monai>
|
243
|
+
- Weekly previews: <https://pypi.org/project/monai-weekly/>
|
244
|
+
- Docker Hub: <https://hub.docker.com/r/projectmonai/monai>
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=jHqt9Fx6mJlpL9TD8eihfJTg6IGs40j8bCpjE3PFrVI,4095
|
2
|
+
monai/_version.py,sha256=sQZ38u2mKWN9p59gP2DeDhflJxmQX4ckQZtIE_MCnbg,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -160,7 +160,8 @@ monai/fl/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,57
|
|
160
160
|
monai/fl/utils/constants.py,sha256=OjMAE17niYqQh7nz45SC6CXvkMa4-XZsIuoHUHqP7W0,1784
|
161
161
|
monai/fl/utils/exchange_object.py,sha256=q41trOwBdog_g3k_Eh2EFnLufHJ1mj7nGyQ-ShuW5Mo,3527
|
162
162
|
monai/fl/utils/filters.py,sha256=InXplYes52JJqtsNbePAPPAYS8am_uRO7UkBHyYyJCo,1633
|
163
|
-
monai/handlers/__init__.py,sha256=
|
163
|
+
monai/handlers/__init__.py,sha256=m6SDdtXAZ4ONLCCYrSgONuPaJOz7lewOAzOvZ3J9r14,2442
|
164
|
+
monai/handlers/average_precision.py,sha256=FkIUP2mKqGvybnc_HxuuOdqPeq06wnZP_vwb8K-IhUg,2753
|
164
165
|
monai/handlers/checkpoint_loader.py,sha256=Y0qNBq5b-GJ-XOJNjuslegCpIGPZYOdNs3PxzNYCCm8,7432
|
165
166
|
monai/handlers/checkpoint_saver.py,sha256=z_w5HtNSeRM3QwHQIgQKqVodSYNy8dhL8KTBUzHuF0g,16047
|
166
167
|
monai/handlers/classification_saver.py,sha256=CNzdU9GrKj8KEC42jaBy2rEgpd3mqgz-YZg4dr61Jyg,7605
|
@@ -194,7 +195,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
|
|
194
195
|
monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
|
195
196
|
monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
|
196
197
|
monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
|
197
|
-
monai/inferers/inferer.py,sha256=
|
198
|
+
monai/inferers/inferer.py,sha256=UNZpsb97qpl9c7ylNV32_jk52nsX77BqYySOl0XxDQw,92802
|
198
199
|
monai/inferers/merger.py,sha256=dZm-FVyXPlFb59q4DG52mbtPm8Iy4cNFWv3un0Z8k0M,16262
|
199
200
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
200
201
|
monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
|
@@ -220,8 +221,9 @@ monai/losses/sure_loss.py,sha256=PDDNNeZm8SLPRCDUPbc8o4--ribHnY4nbo8y55nRo0w,817
|
|
220
221
|
monai/losses/tversky.py,sha256=uLuqCvsac8OabTJzKQEzAfAvlwrflYCh0s76rgbcVJ0,6955
|
221
222
|
monai/losses/unified_focal_loss.py,sha256=rCj8IpueYH_UMrOUXU0tjbXIN4Uix3bGnRZQtRvl7Sg,10224
|
222
223
|
monai/losses/utils.py,sha256=wrpKcEO0XhbFOHz_jJRqeAeIgpMiMxmepnRf31_DNRU,2786
|
223
|
-
monai/metrics/__init__.py,sha256=
|
224
|
+
monai/metrics/__init__.py,sha256=rIRTn5dsXPzGoRv7tZ2ipZ7IiHlNJ4TrZOG_aDDhw28,2255
|
224
225
|
monai/metrics/active_learning_metrics.py,sha256=uKID2O4mnY-9P2ZzyT4sqJd2NfgzjSpNKpAwulWCozU,8211
|
226
|
+
monai/metrics/average_precision.py,sha256=rQYfPAmE78np8E4UoDPk-DSVRtEVC2hAcj5w9Q6ZIqk,8454
|
225
227
|
monai/metrics/confusion_matrix.py,sha256=Spb20jYPnbgGZfPKDQI36ePznPf1xujxhboNnW8HxdQ,15064
|
226
228
|
monai/metrics/cumulative_average.py,sha256=8GGjHmiBboBikprg1380SsNn7RgzFIrHGWBYDBv6ebE,5636
|
227
229
|
monai/metrics/f_beta_score.py,sha256=urI0J_tvl0qQ5-l2fgWV_jChbgpzLmgpRq125B3yxpw,3984
|
@@ -360,7 +362,7 @@ monai/transforms/transform.py,sha256=0eC_Gw7T2jBb589-3EHLh-8gJD687k2OVmrnMxaKs3o
|
|
360
362
|
monai/transforms/utils.py,sha256=t4TMksfSzozyNqP-HJK-ZydvmImLFzxhks0yJnZTOYM,106430
|
361
363
|
monai/transforms/utils_create_transform_ims.py,sha256=QEJVHsCZX7ZxsBArk6NjgCzSZuuokf8l1uFqiUZBBys,31155
|
362
364
|
monai/transforms/utils_morphological_ops.py,sha256=tt0lRLLxmlnn9roUuPEBtqah6t7BH8ittxyDFuskkUI,6767
|
363
|
-
monai/transforms/utils_pytorch_numpy_unification.py,sha256=
|
365
|
+
monai/transforms/utils_pytorch_numpy_unification.py,sha256=pM6-x-TAGVcQohSYirfTqiy2SQnPixcKKHTmTqtBbg0,18706
|
364
366
|
monai/transforms/croppad/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
365
367
|
monai/transforms/croppad/array.py,sha256=WeSAs4JNtNafFaIMLPi3-9NuuyCiTm19cq2oEOonKWQ,74632
|
366
368
|
monai/transforms/croppad/batch.py,sha256=5ukcYk3VCDpk62AL5Q_jTqpXmSNTlw0UCUhDeAB4aV0,6138
|
@@ -396,17 +398,17 @@ monai/transforms/spatial/array.py,sha256=5EKivdPYCP4i4qYUlkK1RpYQFzaU_baYyzgubid
|
|
396
398
|
monai/transforms/spatial/dictionary.py,sha256=t0SvEDSVNFUEw2fK66OVF20sqSzCNxil17HmvsMFBt8,133752
|
397
399
|
monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1irG7feXqig,33886
|
398
400
|
monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
399
|
-
monai/transforms/utility/array.py,sha256=
|
401
|
+
monai/transforms/utility/array.py,sha256=Du3QA6m0io7mR51gUgaMwHBFNStdFmRxhaYmBCVy7BY,81215
|
400
402
|
monai/transforms/utility/dictionary.py,sha256=iOFdTSekvkAsBbbfHeffcRsOKRtNcnt3N1cVuUarZ1s,80549
|
401
403
|
monai/utils/__init__.py,sha256=2_AIpb1wqGMkmgoZ3r43muFTEsnMTCkPu3LtckipYHg,3793
|
402
404
|
monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
|
403
405
|
monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
|
404
406
|
monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
|
405
407
|
monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
|
406
|
-
monai/utils/enums.py,sha256=
|
408
|
+
monai/utils/enums.py,sha256=jXtLaNDxG3BRBgLG2t13_S_G4iVWYHZO_GztykAtmXg,19594
|
407
409
|
monai/utils/jupyter_utils.py,sha256=BYtj80LWQAYg5RWPj5g4j2AMCzLECvAcnZdXns0Ruw8,15651
|
408
410
|
monai/utils/misc.py,sha256=R-sCS5u7SA8hX6e7x6WSc8FgLcNpqKFRRDMWxUd2wCo,31759
|
409
|
-
monai/utils/module.py,sha256=
|
411
|
+
monai/utils/module.py,sha256=R37PpCNCcHQvjjZFbNjNyzWb3FURaKLxQucjhzQk0eU,26087
|
410
412
|
monai/utils/nvtx.py,sha256=i9JBxR1uhW1ZCgLPLlTx8b907QlXkFzJyTBLMlFjhtU,6876
|
411
413
|
monai/utils/ordering.py,sha256=0nlA5b5QpVCHbtiCbTC-YsqjTmjm0bub0IeJhGFBOes,8270
|
412
414
|
monai/utils/profiling.py,sha256=V2_cSHgrcmVF48_G3nUi2-O6fnXsS89nSlb8jj58YLo,15937
|
@@ -504,7 +506,7 @@ tests/bundle/test_bundle_ckpt_export.py,sha256=VnpigCoBAAc2lo0rWOpVMg0IYGB6vbHXL
|
|
504
506
|
tests/bundle/test_bundle_download.py,sha256=4wpnCXNYTwTHWNjuSZqnXpVzadxNRabmFaFM3LZ_TJU,20072
|
505
507
|
tests/bundle/test_bundle_get_data.py,sha256=lQh321mev_7fsLXRg0Tq5uEjuQILethDHRKzB6VV0o4,3667
|
506
508
|
tests/bundle/test_bundle_push_to_hf_hub.py,sha256=Zjl6xDwRKgkS6jvO5dzMBaTLEd4EXyMXp0_wzDNSY3g,1740
|
507
|
-
tests/bundle/test_bundle_trt_export.py,sha256=
|
509
|
+
tests/bundle/test_bundle_trt_export.py,sha256=png-2SGjBSt46LXSz-PLprOXwJ0WkC_3YLR3Ibk_WBc,6344
|
508
510
|
tests/bundle/test_bundle_utils.py,sha256=GTTS_5tEvV5qLad-aHeZXHDQLZcsDwi56Ldn5FnK2RE,4573
|
509
511
|
tests/bundle/test_bundle_verify_metadata.py,sha256=OmcERLA5ht91cUDK9yYKXhpk-96yZcj4EBwZBk7zW3w,2660
|
510
512
|
tests/bundle/test_bundle_verify_net.py,sha256=guCsyjb5op216AUUUQo97YY3p1-XcQEWINouxNX6F84,3383
|
@@ -529,6 +531,7 @@ tests/fl/monai_algo/test_fl_monai_algo_dist.py,sha256=Tq560TGvTmafEa5sDGax_chRlD
|
|
529
531
|
tests/fl/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
530
532
|
tests/fl/utils/test_fl_exchange_object.py,sha256=rddodowFMAdNT9wquI0NHg0CSm5Xvk_v9Si-eJqyiow,2571
|
531
533
|
tests/handlers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
534
|
+
tests/handlers/test_handler_average_precision.py,sha256=0cmgjzWxlfdZsUJB1NnSXfx3dmmDI6CbvIqggtc5rTY,2814
|
532
535
|
tests/handlers/test_handler_checkpoint_loader.py,sha256=1dA4WYp-L6KxtzZIqUs--lNM4O-Anw2-s29QSdIOReU,8443
|
533
536
|
tests/handlers/test_handler_checkpoint_saver.py,sha256=K3bxelElfETpQSXRovWZlxZZmkjY3hm_cJo8kjYCJ3I,6256
|
534
537
|
tests/handlers/test_handler_classification_saver.py,sha256=vesCfTcAPkDAR7oAB_8kyeQrXpkrPQmdME9YBwPV7EE,2355
|
@@ -567,9 +570,9 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
|
|
567
570
|
tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
|
568
571
|
tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
569
572
|
tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
|
570
|
-
tests/inferers/test_controlnet_inferers.py,sha256=
|
573
|
+
tests/inferers/test_controlnet_inferers.py,sha256=SGluRyDlgwUJ8nm3BEWgXN3eb81fUGOaRXbLglC_ejc,49676
|
571
574
|
tests/inferers/test_diffusion_inferer.py,sha256=1O2V_bEmifOZ4RvpbZgYUCooiJ97T73avaBuMJPpBs0,9992
|
572
|
-
tests/inferers/test_latent_diffusion_inferer.py,sha256=
|
575
|
+
tests/inferers/test_latent_diffusion_inferer.py,sha256=atJjmfVznUq8z9EjohFIMyA0Q1XT1Ly0Zepf_1xPz5I,32274
|
573
576
|
tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
|
574
577
|
tests/inferers/test_saliency_inferer.py,sha256=7miHRbA4yb_WGcxql6za9uXXoZlql_7y23f7IzsyIps,1949
|
575
578
|
tests/inferers/test_slice_inferer.py,sha256=kzaJjjTnf2rAiR75l8A_J-Kie4NaLj2bogi-aJ5L5mk,1897
|
@@ -645,6 +648,7 @@ tests/losses/image_dissimilarity/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4o
|
|
645
648
|
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py,sha256=9xEX5BCEQ1s004QgcwYaAFwKTmlZjuVG8cIbK7Giwts,5692
|
646
649
|
tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py,sha256=Gs3zHnGWNZ50liU_tya4Z_6tCRKIWCtG59imAxXdKPI,6070
|
647
650
|
tests/metrics/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
651
|
+
tests/metrics/test_compute_average_precision.py,sha256=o5gYko4Ow87Ix1n_z6_HmfuTKmkZM__fDZQpjKNJNrA,4743
|
648
652
|
tests/metrics/test_compute_confusion_matrix.py,sha256=dwiqMnp7T6KJLJ7qv6J5g_RDDrB6UiLAe-pgmVNSz7I,10669
|
649
653
|
tests/metrics/test_compute_f_beta.py,sha256=xbCipeICoAXWZLgDFeDAa1KjDQxDTMVArNbtUYiCG3c,3286
|
650
654
|
tests/metrics/test_compute_fid_metric.py,sha256=B9OZECl3CT1JKzG-2C_YaPFjgfvlFoS9vI1j8vBzWZg,1328
|
@@ -670,7 +674,7 @@ tests/metrics/test_surface_dice.py,sha256=CGCQt-ydMzaT2q1fFnzpKb6E-TPydym4vE_kdp
|
|
670
674
|
tests/metrics/test_surface_distance.py,sha256=gkW0dai3vHjXubLNBilqFnV5Up-abSMgQ53v0iCTVeE,6237
|
671
675
|
tests/networks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
672
676
|
tests/networks/test_bundle_onnx_export.py,sha256=_lEnAJhq7D2IOuVEdgBVsA8vySgs34FkfMrvNsCLfUg,2853
|
673
|
-
tests/networks/test_convert_to_onnx.py,sha256=
|
677
|
+
tests/networks/test_convert_to_onnx.py,sha256=h1Sjb0SZmiwwbx0_PrzeFDOE3-JRSp18qDS6G_PdD6g,3673
|
674
678
|
tests/networks/test_convert_to_torchscript.py,sha256=NhrJMCfQtC0sftrhDjL28omS7VKzg_niRK0KtY5Mr_A,1636
|
675
679
|
tests/networks/test_convert_to_trt.py,sha256=5TkuUvCPgW5mAvYUysRRrSjtSbDoDDAoJb2kJtuXOVk,2656
|
676
680
|
tests/networks/test_save_state.py,sha256=_glX4irpJVqk2jnOJaVqYxsOQNX3oCauxlEXe2ly8Cg,2354
|
@@ -881,7 +885,7 @@ tests/transforms/test_generate_label_classes_crop_centers.py,sha256=E5DtL2s1sio1
|
|
881
885
|
tests/transforms/test_generate_pos_neg_label_crop_centers.py,sha256=DdCbdYaTHL40crC5o440cpEt0xNLXzT-rVphaBH11HM,2516
|
882
886
|
tests/transforms/test_generate_spatial_bounding_box.py,sha256=JxHt4BHmtGYIqyzGhWgkCB5_oJU2ro_737upVxWBPvI,3510
|
883
887
|
tests/transforms/test_get_extreme_points.py,sha256=881LZMTms1tXRDtODIheZbKDXMVQ69ff78IvukoabGc,1700
|
884
|
-
tests/transforms/test_gibbs_noise.py,sha256=
|
888
|
+
tests/transforms/test_gibbs_noise.py,sha256=9TgOYhGz1P6-VJUXszuV9NgqhjF5FKCVcQuG_7o3jUI,2658
|
885
889
|
tests/transforms/test_gibbs_noised.py,sha256=o9ZQVAyuHATbV9JHkeTy_pDLz5Mqg5ctMQawMmP71RQ,3228
|
886
890
|
tests/transforms/test_grid_distortion.py,sha256=8dTQjWQ2_euNKN00xxZXqZk-cFSsKfpVpkNm-1-WytA,4472
|
887
891
|
tests/transforms/test_grid_distortiond.py,sha256=bSLhB_LGQKXo5VqP9RCyJDSyiZi2er2W2Qdw7qDep9s,3492
|
@@ -1174,8 +1178,8 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1174
1178
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1175
1179
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1176
1180
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1177
|
-
monai_weekly-1.5.
|
1178
|
-
monai_weekly-1.5.
|
1179
|
-
monai_weekly-1.5.
|
1180
|
-
monai_weekly-1.5.
|
1181
|
-
monai_weekly-1.5.
|
1181
|
+
monai_weekly-1.5.dev2508.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
1182
|
+
monai_weekly-1.5.dev2508.dist-info/METADATA,sha256=y-KfkVBP9_LhTnQo37SKpjDYJYsdujQuCCQiZpKdSv8,11909
|
1183
|
+
monai_weekly-1.5.dev2508.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
1184
|
+
monai_weekly-1.5.dev2508.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1185
|
+
monai_weekly-1.5.dev2508.dist-info/RECORD,,
|
@@ -70,7 +70,7 @@ class TestTRTExport(unittest.TestCase):
|
|
70
70
|
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
|
71
71
|
@unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!")
|
72
72
|
def test_trt_export(self, convert_precision, input_shape, dynamic_batch):
|
73
|
-
tests_dir = Path(__file__).resolve().
|
73
|
+
tests_dir = Path(__file__).resolve().parents[1]
|
74
74
|
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
|
75
75
|
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
|
76
76
|
with tempfile.TemporaryDirectory() as tempdir:
|
@@ -108,7 +108,7 @@ class TestTRTExport(unittest.TestCase):
|
|
108
108
|
has_onnx and has_torchtrt and has_tensorrt, "Onnx and TensorRT are required for onnx-trt conversion!"
|
109
109
|
)
|
110
110
|
def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch):
|
111
|
-
tests_dir = Path(__file__).resolve().
|
111
|
+
tests_dir = Path(__file__).resolve().parents[1]
|
112
112
|
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
|
113
113
|
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
|
114
114
|
with tempfile.TemporaryDirectory() as tempdir:
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
import torch.distributed as dist
|
19
|
+
|
20
|
+
from monai.handlers import AveragePrecision
|
21
|
+
from monai.transforms import Activations, AsDiscrete
|
22
|
+
from tests.test_utils import DistCall, DistTestCase
|
23
|
+
|
24
|
+
|
25
|
+
class TestHandlerAveragePrecision(unittest.TestCase):
|
26
|
+
|
27
|
+
def test_compute(self):
|
28
|
+
ap_metric = AveragePrecision()
|
29
|
+
act = Activations(softmax=True)
|
30
|
+
to_onehot = AsDiscrete(to_onehot=2)
|
31
|
+
|
32
|
+
y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]
|
33
|
+
y = [torch.Tensor([0]), torch.Tensor([1])]
|
34
|
+
y_pred = [act(p) for p in y_pred]
|
35
|
+
y = [to_onehot(y_) for y_ in y]
|
36
|
+
ap_metric.update([y_pred, y])
|
37
|
+
|
38
|
+
y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]
|
39
|
+
y = [torch.Tensor([0]), torch.Tensor([1])]
|
40
|
+
y_pred = [act(p) for p in y_pred]
|
41
|
+
y = [to_onehot(y_) for y_ in y]
|
42
|
+
|
43
|
+
ap_metric.update([y_pred, y])
|
44
|
+
|
45
|
+
ap = ap_metric.compute()
|
46
|
+
np.testing.assert_allclose(0.8333333, ap)
|
47
|
+
|
48
|
+
|
49
|
+
class DistributedAveragePrecision(DistTestCase):
|
50
|
+
|
51
|
+
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
|
52
|
+
def test_compute(self):
|
53
|
+
ap_metric = AveragePrecision()
|
54
|
+
act = Activations(softmax=True)
|
55
|
+
to_onehot = AsDiscrete(to_onehot=2)
|
56
|
+
|
57
|
+
device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
|
58
|
+
if dist.get_rank() == 0:
|
59
|
+
y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]
|
60
|
+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]
|
61
|
+
|
62
|
+
if dist.get_rank() == 1:
|
63
|
+
y_pred = [
|
64
|
+
torch.tensor([0.2, 0.1], device=device),
|
65
|
+
torch.tensor([0.1, 0.5], device=device),
|
66
|
+
torch.tensor([0.3, 0.4], device=device),
|
67
|
+
]
|
68
|
+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]
|
69
|
+
|
70
|
+
y_pred = [act(p) for p in y_pred]
|
71
|
+
y = [to_onehot(y_) for y_ in y]
|
72
|
+
ap_metric.update([y_pred, y])
|
73
|
+
|
74
|
+
result = ap_metric.compute()
|
75
|
+
np.testing.assert_allclose(0.7778, result, rtol=1e-4)
|
76
|
+
|
77
|
+
|
78
|
+
if __name__ == "__main__":
|
79
|
+
unittest.main()
|
@@ -722,7 +722,7 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
722
722
|
|
723
723
|
@parameterized.expand(LATENT_CNDM_TEST_CASES)
|
724
724
|
@skipUnless(has_einops, "Requires einops")
|
725
|
-
def
|
725
|
+
def test_pred_shape(
|
726
726
|
self,
|
727
727
|
ae_model_type,
|
728
728
|
autoencoder_params,
|
@@ -1165,7 +1165,7 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1165
1165
|
|
1166
1166
|
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
|
1167
1167
|
@skipUnless(has_einops, "Requires einops")
|
1168
|
-
def
|
1168
|
+
def test_shape_different_latents(
|
1169
1169
|
self,
|
1170
1170
|
ae_model_type,
|
1171
1171
|
autoencoder_params,
|
@@ -1242,6 +1242,84 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1242
1242
|
)
|
1243
1243
|
self.assertEqual(prediction.shape, latent_shape)
|
1244
1244
|
|
1245
|
+
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
|
1246
|
+
@skipUnless(has_einops, "Requires einops")
|
1247
|
+
def test_sample_shape_different_latents(
|
1248
|
+
self,
|
1249
|
+
ae_model_type,
|
1250
|
+
autoencoder_params,
|
1251
|
+
dm_model_type,
|
1252
|
+
stage_2_params,
|
1253
|
+
controlnet_params,
|
1254
|
+
input_shape,
|
1255
|
+
latent_shape,
|
1256
|
+
):
|
1257
|
+
stage_1 = None
|
1258
|
+
|
1259
|
+
if ae_model_type == "AutoencoderKL":
|
1260
|
+
stage_1 = AutoencoderKL(**autoencoder_params)
|
1261
|
+
if ae_model_type == "VQVAE":
|
1262
|
+
stage_1 = VQVAE(**autoencoder_params)
|
1263
|
+
if ae_model_type == "SPADEAutoencoderKL":
|
1264
|
+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1265
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
1266
|
+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1267
|
+
else:
|
1268
|
+
stage_2 = DiffusionModelUNet(**stage_2_params)
|
1269
|
+
controlnet = ControlNet(**controlnet_params)
|
1270
|
+
|
1271
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1272
|
+
stage_1.to(device)
|
1273
|
+
stage_2.to(device)
|
1274
|
+
controlnet.to(device)
|
1275
|
+
stage_1.eval()
|
1276
|
+
stage_2.eval()
|
1277
|
+
controlnet.eval()
|
1278
|
+
|
1279
|
+
noise = torch.randn(latent_shape).to(device)
|
1280
|
+
mask = torch.randn(input_shape).to(device)
|
1281
|
+
scheduler = DDPMScheduler(num_train_timesteps=10)
|
1282
|
+
# We infer the VAE shape
|
1283
|
+
if ae_model_type == "VQVAE":
|
1284
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
|
1285
|
+
else:
|
1286
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
|
1287
|
+
|
1288
|
+
inferer = ControlNetLatentDiffusionInferer(
|
1289
|
+
scheduler=scheduler,
|
1290
|
+
scale_factor=1.0,
|
1291
|
+
ldm_latent_shape=list(latent_shape[2:]),
|
1292
|
+
autoencoder_latent_shape=autoencoder_latent_shape,
|
1293
|
+
)
|
1294
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
1295
|
+
|
1296
|
+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
|
1297
|
+
input_shape_seg = list(input_shape)
|
1298
|
+
if "label_nc" in stage_2_params.keys():
|
1299
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
1300
|
+
else:
|
1301
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
1302
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
1303
|
+
prediction, _ = inferer.sample(
|
1304
|
+
autoencoder_model=stage_1,
|
1305
|
+
diffusion_model=stage_2,
|
1306
|
+
controlnet=controlnet,
|
1307
|
+
cn_cond=mask,
|
1308
|
+
input_noise=noise,
|
1309
|
+
seg=input_seg,
|
1310
|
+
save_intermediates=True,
|
1311
|
+
)
|
1312
|
+
else:
|
1313
|
+
prediction = inferer.sample(
|
1314
|
+
autoencoder_model=stage_1,
|
1315
|
+
diffusion_model=stage_2,
|
1316
|
+
input_noise=noise,
|
1317
|
+
controlnet=controlnet,
|
1318
|
+
cn_cond=mask,
|
1319
|
+
save_intermediates=False,
|
1320
|
+
)
|
1321
|
+
self.assertEqual(prediction.shape, input_shape)
|
1322
|
+
|
1245
1323
|
@skipUnless(has_einops, "Requires einops")
|
1246
1324
|
def test_incompatible_spade_setup(self):
|
1247
1325
|
stage_1 = SPADEAutoencoderKL(
|
@@ -714,7 +714,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
714
714
|
|
715
715
|
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
|
716
716
|
@skipUnless(has_einops, "Requires einops")
|
717
|
-
def
|
717
|
+
def test_shape_different_latents(
|
718
718
|
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
|
719
719
|
):
|
720
720
|
stage_1 = None
|
@@ -772,6 +772,66 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
772
772
|
)
|
773
773
|
self.assertEqual(prediction.shape, latent_shape)
|
774
774
|
|
775
|
+
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
|
776
|
+
@skipUnless(has_einops, "Requires einops")
|
777
|
+
def test_sample_shape_different_latents(
|
778
|
+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
|
779
|
+
):
|
780
|
+
stage_1 = None
|
781
|
+
|
782
|
+
if ae_model_type == "AutoencoderKL":
|
783
|
+
stage_1 = AutoencoderKL(**autoencoder_params)
|
784
|
+
if ae_model_type == "VQVAE":
|
785
|
+
stage_1 = VQVAE(**autoencoder_params)
|
786
|
+
if ae_model_type == "SPADEAutoencoderKL":
|
787
|
+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
788
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
789
|
+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
790
|
+
else:
|
791
|
+
stage_2 = DiffusionModelUNet(**stage_2_params)
|
792
|
+
|
793
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
794
|
+
stage_1.to(device)
|
795
|
+
stage_2.to(device)
|
796
|
+
stage_1.eval()
|
797
|
+
stage_2.eval()
|
798
|
+
|
799
|
+
noise = torch.randn(latent_shape).to(device)
|
800
|
+
scheduler = DDPMScheduler(num_train_timesteps=10)
|
801
|
+
# We infer the VAE shape
|
802
|
+
if ae_model_type == "VQVAE":
|
803
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
|
804
|
+
else:
|
805
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
|
806
|
+
|
807
|
+
inferer = LatentDiffusionInferer(
|
808
|
+
scheduler=scheduler,
|
809
|
+
scale_factor=1.0,
|
810
|
+
ldm_latent_shape=list(latent_shape[2:]),
|
811
|
+
autoencoder_latent_shape=autoencoder_latent_shape,
|
812
|
+
)
|
813
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
814
|
+
|
815
|
+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
|
816
|
+
input_shape_seg = list(input_shape)
|
817
|
+
if "label_nc" in stage_2_params.keys():
|
818
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
819
|
+
else:
|
820
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
821
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
822
|
+
prediction, _ = inferer.sample(
|
823
|
+
autoencoder_model=stage_1,
|
824
|
+
diffusion_model=stage_2,
|
825
|
+
input_noise=noise,
|
826
|
+
save_intermediates=True,
|
827
|
+
seg=input_seg,
|
828
|
+
)
|
829
|
+
else:
|
830
|
+
prediction = inferer.sample(
|
831
|
+
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
|
832
|
+
)
|
833
|
+
self.assertEqual(prediction.shape, input_shape)
|
834
|
+
|
775
835
|
@skipUnless(has_einops, "Requires einops")
|
776
836
|
def test_incompatible_spade_setup(self):
|
777
837
|
stage_1 = SPADEAutoencoderKL(
|
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
from parameterized import parameterized
|
19
|
+
|
20
|
+
from monai.data import decollate_batch
|
21
|
+
from monai.metrics import AveragePrecisionMetric, compute_average_precision
|
22
|
+
from monai.transforms import Activations, AsDiscrete, Compose, ToTensor
|
23
|
+
|
24
|
+
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
25
|
+
TEST_CASE_1 = [
|
26
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),
|
27
|
+
torch.tensor([[0], [0], [1], [1]], device=_device),
|
28
|
+
True,
|
29
|
+
2,
|
30
|
+
"macro",
|
31
|
+
0.41667,
|
32
|
+
]
|
33
|
+
|
34
|
+
TEST_CASE_2 = [
|
35
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),
|
36
|
+
torch.tensor([[1], [1], [0], [0]], device=_device),
|
37
|
+
True,
|
38
|
+
2,
|
39
|
+
"micro",
|
40
|
+
0.85417,
|
41
|
+
]
|
42
|
+
|
43
|
+
TEST_CASE_3 = [
|
44
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),
|
45
|
+
torch.tensor([[0], [1], [0], [1]], device=_device),
|
46
|
+
True,
|
47
|
+
2,
|
48
|
+
"macro",
|
49
|
+
0.83333,
|
50
|
+
]
|
51
|
+
|
52
|
+
TEST_CASE_4 = [
|
53
|
+
torch.tensor([[0.5], [0.5], [0.2], [8.3]]),
|
54
|
+
torch.tensor([[0], [1], [0], [1]]),
|
55
|
+
False,
|
56
|
+
None,
|
57
|
+
"macro",
|
58
|
+
0.83333,
|
59
|
+
]
|
60
|
+
|
61
|
+
TEST_CASE_5 = [torch.tensor([[0.5], [0.5], [0.2], [8.3]]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333]
|
62
|
+
|
63
|
+
TEST_CASE_6 = [torch.tensor([0.5, 0.5, 0.2, 8.3]), torch.tensor([0, 1, 0, 1]), False, None, "macro", 0.83333]
|
64
|
+
|
65
|
+
TEST_CASE_7 = [
|
66
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
|
67
|
+
torch.tensor([[0], [1], [0], [1]]),
|
68
|
+
True,
|
69
|
+
2,
|
70
|
+
"none",
|
71
|
+
[0.83333, 0.83333],
|
72
|
+
]
|
73
|
+
|
74
|
+
TEST_CASE_8 = [
|
75
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
|
76
|
+
torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
|
77
|
+
True,
|
78
|
+
None,
|
79
|
+
"weighted",
|
80
|
+
0.66667,
|
81
|
+
]
|
82
|
+
|
83
|
+
TEST_CASE_9 = [
|
84
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5], [0.1, 0.5]]),
|
85
|
+
torch.tensor([[1, 0], [0, 1], [0, 0], [1, 1], [0, 1]]),
|
86
|
+
True,
|
87
|
+
None,
|
88
|
+
"micro",
|
89
|
+
0.71111,
|
90
|
+
]
|
91
|
+
|
92
|
+
TEST_CASE_10 = [
|
93
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
|
94
|
+
torch.tensor([[0], [0], [0], [0]]),
|
95
|
+
True,
|
96
|
+
2,
|
97
|
+
"macro",
|
98
|
+
float("nan"),
|
99
|
+
]
|
100
|
+
|
101
|
+
TEST_CASE_11 = [
|
102
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
|
103
|
+
torch.tensor([[1], [1], [1], [1]]),
|
104
|
+
True,
|
105
|
+
2,
|
106
|
+
"macro",
|
107
|
+
float("nan"),
|
108
|
+
]
|
109
|
+
|
110
|
+
TEST_CASE_12 = [
|
111
|
+
torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
|
112
|
+
torch.tensor([[0, 0], [1, 1], [2, 2], [3, 3]]),
|
113
|
+
True,
|
114
|
+
None,
|
115
|
+
"macro",
|
116
|
+
float("nan"),
|
117
|
+
]
|
118
|
+
|
119
|
+
ALL_TESTS = [
|
120
|
+
TEST_CASE_1,
|
121
|
+
TEST_CASE_2,
|
122
|
+
TEST_CASE_3,
|
123
|
+
TEST_CASE_4,
|
124
|
+
TEST_CASE_5,
|
125
|
+
TEST_CASE_6,
|
126
|
+
TEST_CASE_7,
|
127
|
+
TEST_CASE_8,
|
128
|
+
TEST_CASE_9,
|
129
|
+
TEST_CASE_10,
|
130
|
+
TEST_CASE_11,
|
131
|
+
TEST_CASE_12,
|
132
|
+
]
|
133
|
+
|
134
|
+
|
135
|
+
class TestComputeAveragePrecision(unittest.TestCase):
|
136
|
+
|
137
|
+
@parameterized.expand(ALL_TESTS)
|
138
|
+
def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value):
|
139
|
+
y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])
|
140
|
+
y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])
|
141
|
+
y_pred = torch.stack([y_pred_trans(i) for i in decollate_batch(y_pred)], dim=0)
|
142
|
+
y = torch.stack([y_trans(i) for i in decollate_batch(y)], dim=0)
|
143
|
+
result = compute_average_precision(y_pred=y_pred, y=y, average=average)
|
144
|
+
np.testing.assert_allclose(expected_value, result, rtol=1e-5)
|
145
|
+
|
146
|
+
@parameterized.expand(ALL_TESTS)
|
147
|
+
def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value):
|
148
|
+
y_pred_trans = Compose([ToTensor(), Activations(softmax=softmax)])
|
149
|
+
y_trans = Compose([ToTensor(), AsDiscrete(to_onehot=to_onehot)])
|
150
|
+
y_pred = [y_pred_trans(i) for i in decollate_batch(y_pred)]
|
151
|
+
y = [y_trans(i) for i in decollate_batch(y)]
|
152
|
+
metric = AveragePrecisionMetric(average=average)
|
153
|
+
metric(y_pred=y_pred, y=y)
|
154
|
+
result = metric.aggregate()
|
155
|
+
np.testing.assert_allclose(expected_value, result, rtol=1e-5)
|
156
|
+
result = metric.aggregate(average=average) # test optional argument
|
157
|
+
metric.reset()
|
158
|
+
np.testing.assert_allclose(expected_value, result, rtol=1e-5)
|
159
|
+
|
160
|
+
|
161
|
+
if __name__ == "__main__":
|
162
|
+
unittest.main()
|
@@ -64,7 +64,7 @@ class TestConvertToOnnx(unittest.TestCase):
|
|
64
64
|
rtol=rtol,
|
65
65
|
atol=atol,
|
66
66
|
)
|
67
|
-
|
67
|
+
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
|
68
68
|
|
69
69
|
@parameterized.expand(TESTS_ORT)
|
70
70
|
@SkipIfBeforePyTorchVersion((1, 12))
|
@@ -21,14 +21,12 @@ from monai.data.synthetic import create_test_image_2d, create_test_image_3d
|
|
21
21
|
from monai.transforms import GibbsNoise
|
22
22
|
from monai.utils.misc import set_determinism
|
23
23
|
from monai.utils.module import optional_import
|
24
|
-
from tests.test_utils import TEST_NDARRAYS, assert_allclose
|
24
|
+
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product
|
25
25
|
|
26
26
|
_, has_torch_fft = optional_import("torch.fft", name="fftshift")
|
27
27
|
|
28
|
-
|
29
|
-
|
30
|
-
for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:
|
31
|
-
TEST_CASES.append((shape, input_type))
|
28
|
+
params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
|
29
|
+
TEST_CASES = list(dict_product(format="list", **params))
|
32
30
|
|
33
31
|
|
34
32
|
class TestGibbsNoise(unittest.TestCase):
|
File without changes
|
File without changes
|
File without changes
|