monai-weekly 1.5.dev2447__py3-none-any.whl → 1.5.dev2449__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/bundle/__init__.py +1 -1
- monai/bundle/reference_resolver.py +10 -0
- monai/bundle/workflows.py +187 -22
- monai/losses/dice.py +33 -22
- monai/losses/tversky.py +11 -8
- monai/losses/utils.py +68 -0
- monai/networks/blocks/pos_embed_utils.py +2 -2
- monai/networks/blocks/selfattention.py +18 -4
- monai/networks/blocks/transformerblock.py +4 -2
- monai/networks/nets/__init__.py +1 -0
- monai/networks/nets/masked_autoencoder_vit.py +211 -0
- monai/networks/nets/swin_unetr.py +24 -12
- monai/transforms/__init__.py +9 -0
- monai/transforms/utility/array.py +108 -12
- monai/transforms/utility/dictionary.py +67 -0
- monai/utils/module.py +3 -3
- {monai_weekly-1.5.dev2447.dist-info → monai_weekly-1.5.dev2449.dist-info}/METADATA +4 -1
- {monai_weekly-1.5.dev2447.dist-info → monai_weekly-1.5.dev2449.dist-info}/RECORD +23 -21
- {monai_weekly-1.5.dev2447.dist-info → monai_weekly-1.5.dev2449.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2447.dist-info → monai_weekly-1.5.dev2449.dist-info}/WHEEL +0 -0
- {monai_weekly-1.5.dev2447.dist-info → monai_weekly-1.5.dev2449.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-
|
11
|
+
"date": "2024-12-08T02:32:52+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.5.
|
14
|
+
"full-revisionid": "8cad248c8b374702245989507da1dd8430ef863f",
|
15
|
+
"version": "1.5.dev2449"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/bundle/__init__.py
CHANGED
@@ -192,6 +192,16 @@ class ReferenceResolver:
|
|
192
192
|
"""
|
193
193
|
return self._resolve_one_item(id=id, **kwargs)
|
194
194
|
|
195
|
+
def remove_resolved_content(self, id: str) -> Any | None:
|
196
|
+
"""
|
197
|
+
Remove the resolved ``ConfigItem`` by id.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
id: id name of the expected item.
|
201
|
+
|
202
|
+
"""
|
203
|
+
return self.resolved_content.pop(id) if id in self.resolved_content else None
|
204
|
+
|
195
205
|
@classmethod
|
196
206
|
def normalize_id(cls, id: str | int) -> str:
|
197
207
|
"""
|
monai/bundle/workflows.py
CHANGED
@@ -44,12 +44,18 @@ class BundleWorkflow(ABC):
|
|
44
44
|
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
|
45
45
|
or "infer", "inference", "eval", "evaluation" for a inference workflow,
|
46
46
|
other unsupported string will raise a ValueError.
|
47
|
-
default to `
|
47
|
+
default to `None` for only using meta properties.
|
48
48
|
workflow: specifies the workflow type: "train" or "training" for a training workflow,
|
49
49
|
or "infer", "inference", "eval", "evaluation" for a inference workflow,
|
50
50
|
other unsupported string will raise a ValueError.
|
51
51
|
default to `None` for common workflow.
|
52
|
-
properties_path: the path to the JSON file of properties.
|
52
|
+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
|
53
|
+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
|
54
|
+
properties will default to loading from "meta". If `properties_path` is None, default properties
|
55
|
+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
|
56
|
+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
|
57
|
+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
|
58
|
+
For workflow_type = None : only `MetaProperties` will be loaded.
|
53
59
|
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
|
54
60
|
logging_file: config file for `logging` module in the program. for more details:
|
55
61
|
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
|
@@ -97,29 +103,50 @@ class BundleWorkflow(ABC):
|
|
97
103
|
meta_file = None
|
98
104
|
|
99
105
|
workflow_type = workflow if workflow is not None else workflow_type
|
100
|
-
if workflow_type is
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
106
|
+
if workflow_type is not None:
|
107
|
+
if workflow_type.lower() in self.supported_train_type:
|
108
|
+
workflow_type = "train"
|
109
|
+
elif workflow_type.lower() in self.supported_infer_type:
|
110
|
+
workflow_type = "infer"
|
111
|
+
else:
|
112
|
+
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
|
113
|
+
|
105
114
|
if properties_path is not None:
|
106
115
|
properties_path = Path(properties_path)
|
107
116
|
if not properties_path.is_file():
|
108
117
|
raise ValueError(f"Property file {properties_path} does not exist.")
|
109
118
|
with open(properties_path) as json_file:
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
119
|
+
try:
|
120
|
+
properties = json.load(json_file)
|
121
|
+
self.properties: dict = {}
|
122
|
+
if workflow_type is not None and workflow_type in properties:
|
123
|
+
self.properties = properties[workflow_type]
|
124
|
+
if "meta" in properties:
|
125
|
+
self.properties.update(properties["meta"])
|
126
|
+
elif workflow_type is None:
|
127
|
+
if "meta" in properties:
|
128
|
+
self.properties = properties["meta"]
|
129
|
+
logger.info(
|
130
|
+
"No workflow type specified, default to load meta properties from property file."
|
131
|
+
)
|
132
|
+
else:
|
133
|
+
logger.warning("No 'meta' key found in properties while workflow_type is None.")
|
134
|
+
except KeyError as e:
|
135
|
+
raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e
|
136
|
+
except json.JSONDecodeError as e:
|
137
|
+
raise ValueError(f"Error decoding JSON from property file {properties_path}") from e
|
120
138
|
else:
|
121
|
-
|
139
|
+
if workflow_type == "train":
|
140
|
+
self.properties = {**TrainProperties, **MetaProperties}
|
141
|
+
elif workflow_type == "infer":
|
142
|
+
self.properties = {**InferProperties, **MetaProperties}
|
143
|
+
elif workflow_type is None:
|
144
|
+
self.properties = copy(MetaProperties)
|
145
|
+
logger.info("No workflow type and property file specified, default to 'meta' properties.")
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
|
122
148
|
|
149
|
+
self.workflow_type = workflow_type
|
123
150
|
self.meta_file = meta_file
|
124
151
|
|
125
152
|
@abstractmethod
|
@@ -226,6 +253,124 @@ class BundleWorkflow(ABC):
|
|
226
253
|
return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]
|
227
254
|
|
228
255
|
|
256
|
+
class PythonicWorkflow(BundleWorkflow):
|
257
|
+
"""
|
258
|
+
Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.
|
259
|
+
It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.
|
260
|
+
This also provides the interface to get / set public properties to interact with a bundle workflow through
|
261
|
+
defined `get_<property>` accessor methods or directly defining members of the object.
|
262
|
+
For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.
|
263
|
+
The `initialize` method is called to set up the workflow before running. This method sets up internal state
|
264
|
+
and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`
|
265
|
+
is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is
|
266
|
+
properly set up with the new property values.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
|
270
|
+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
|
271
|
+
other unsupported string will raise a ValueError.
|
272
|
+
default to `None` for only using meta properties.
|
273
|
+
workflow: specifies the workflow type: "train" or "training" for a training workflow,
|
274
|
+
or "infer", "inference", "eval", "evaluation" for a inference workflow,
|
275
|
+
other unsupported string will raise a ValueError.
|
276
|
+
default to `None` for common workflow.
|
277
|
+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
|
278
|
+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
|
279
|
+
properties will default to loading from "meta". If `properties_path` is None, default properties
|
280
|
+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
|
281
|
+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
|
282
|
+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
|
283
|
+
For workflow_type = None : only `MetaProperties` will be loaded.
|
284
|
+
config_file: path to the config file, typically used to store hyperparameters.
|
285
|
+
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
|
286
|
+
logging_file: config file for `logging` module in the program. for more details:
|
287
|
+
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
|
288
|
+
|
289
|
+
"""
|
290
|
+
|
291
|
+
supported_train_type: tuple = ("train", "training")
|
292
|
+
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
|
293
|
+
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
workflow_type: str | None = None,
|
297
|
+
properties_path: PathLike | None = None,
|
298
|
+
config_file: str | Sequence[str] | None = None,
|
299
|
+
meta_file: str | Sequence[str] | None = None,
|
300
|
+
logging_file: str | None = None,
|
301
|
+
**override: Any,
|
302
|
+
):
|
303
|
+
meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file
|
304
|
+
super().__init__(
|
305
|
+
workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file
|
306
|
+
)
|
307
|
+
self._props_vals: dict = {}
|
308
|
+
self._set_props_vals: dict = {}
|
309
|
+
self.parser = ConfigParser()
|
310
|
+
if config_file is not None:
|
311
|
+
self.parser.read_config(f=config_file)
|
312
|
+
if self.meta_file is not None:
|
313
|
+
self.parser.read_meta(f=self.meta_file)
|
314
|
+
|
315
|
+
# the rest key-values in the _args are to override config content
|
316
|
+
self.parser.update(pairs=override)
|
317
|
+
self._is_initialized: bool = False
|
318
|
+
|
319
|
+
def initialize(self, *args: Any, **kwargs: Any) -> Any:
|
320
|
+
"""
|
321
|
+
Initialize the bundle workflow before running.
|
322
|
+
"""
|
323
|
+
self._props_vals = {}
|
324
|
+
self._is_initialized = True
|
325
|
+
|
326
|
+
def _get_property(self, name: str, property: dict) -> Any:
|
327
|
+
"""
|
328
|
+
With specified property name and information, get the expected property value.
|
329
|
+
If the property is already generated, return from the bucket directly.
|
330
|
+
If user explicitly set the property, return it directly.
|
331
|
+
Otherwise, generate the expected property as a class private property with prefix "_".
|
332
|
+
|
333
|
+
Args:
|
334
|
+
name: the name of target property.
|
335
|
+
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
|
336
|
+
"""
|
337
|
+
if not self._is_initialized:
|
338
|
+
raise RuntimeError("Please execute 'initialize' before getting any properties.")
|
339
|
+
value = None
|
340
|
+
if name in self._set_props_vals:
|
341
|
+
value = self._set_props_vals[name]
|
342
|
+
elif name in self._props_vals:
|
343
|
+
value = self._props_vals[name]
|
344
|
+
elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index]
|
345
|
+
id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)
|
346
|
+
value = self.parser[id]
|
347
|
+
else:
|
348
|
+
try:
|
349
|
+
value = getattr(self, f"get_{name}")()
|
350
|
+
except AttributeError as e:
|
351
|
+
if property[BundleProperty.REQUIRED]:
|
352
|
+
raise ValueError(
|
353
|
+
f"unsupported property '{name}' is required in the bundle properties,"
|
354
|
+
f"need to implement a method 'get_{name}' to provide the property."
|
355
|
+
) from e
|
356
|
+
self._props_vals[name] = value
|
357
|
+
return value
|
358
|
+
|
359
|
+
def _set_property(self, name: str, property: dict, value: Any) -> Any:
|
360
|
+
"""
|
361
|
+
With specified property name and information, set value for the expected property.
|
362
|
+
Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.
|
363
|
+
|
364
|
+
Args:
|
365
|
+
name: the name of target property.
|
366
|
+
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
|
367
|
+
value: value to set for the property.
|
368
|
+
|
369
|
+
"""
|
370
|
+
self._set_props_vals[name] = value
|
371
|
+
self._is_initialized = False
|
372
|
+
|
373
|
+
|
229
374
|
class ConfigWorkflow(BundleWorkflow):
|
230
375
|
"""
|
231
376
|
Specification for the config-based bundle workflow.
|
@@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow):
|
|
262
407
|
or "infer", "inference", "eval", "evaluation" for a inference workflow,
|
263
408
|
other unsupported string will raise a ValueError.
|
264
409
|
default to `None` for common workflow.
|
265
|
-
properties_path: the path to the JSON file of properties.
|
410
|
+
properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
|
411
|
+
loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
|
412
|
+
properties will default to loading from "train". If `properties_path` is None, default properties
|
413
|
+
will be sourced from "monai/bundle/properties.py" based on the workflow_type:
|
414
|
+
For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
|
415
|
+
For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
|
416
|
+
For workflow_type = None : only `MetaProperties` will be loaded.
|
266
417
|
override: id-value pairs to override or add the corresponding config content.
|
267
418
|
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
|
268
419
|
|
@@ -324,7 +475,6 @@ class ConfigWorkflow(BundleWorkflow):
|
|
324
475
|
self.parser.read_config(f=config_file)
|
325
476
|
if self.meta_file is not None:
|
326
477
|
self.parser.read_meta(f=self.meta_file)
|
327
|
-
|
328
478
|
# the rest key-values in the _args are to override config content
|
329
479
|
self.parser.update(pairs=override)
|
330
480
|
self.init_id = init_id
|
@@ -394,8 +544,23 @@ class ConfigWorkflow(BundleWorkflow):
|
|
394
544
|
ret.extend(wrong_props)
|
395
545
|
return ret
|
396
546
|
|
397
|
-
def _run_expr(self, id: str, **kwargs: dict) -> Any:
|
398
|
-
|
547
|
+
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
|
548
|
+
"""
|
549
|
+
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
|
550
|
+
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
|
551
|
+
"""
|
552
|
+
ret = []
|
553
|
+
if id in self.parser:
|
554
|
+
# suppose all the expressions are in a list, run and reset the expressions
|
555
|
+
if isinstance(self.parser[id], list):
|
556
|
+
for i in range(len(self.parser[id])):
|
557
|
+
sub_id = f"{id}{ID_SEP_KEY}{i}"
|
558
|
+
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
|
559
|
+
self.parser.ref_resolver.remove_resolved_content(sub_id)
|
560
|
+
else:
|
561
|
+
ret.append(self.parser.get_parsed_content(id, **kwargs))
|
562
|
+
self.parser.ref_resolver.remove_resolved_content(id)
|
563
|
+
return ret
|
399
564
|
|
400
565
|
def _get_prop_id(self, name: str, property: dict) -> Any:
|
401
566
|
prop_id = property[BundlePropertyConfig.ID]
|
monai/losses/dice.py
CHANGED
@@ -23,6 +23,7 @@ from torch.nn.modules.loss import _Loss
|
|
23
23
|
|
24
24
|
from monai.losses.focal_loss import FocalLoss
|
25
25
|
from monai.losses.spatial_mask import MaskedLoss
|
26
|
+
from monai.losses.utils import compute_tp_fp_fn
|
26
27
|
from monai.networks import one_hot
|
27
28
|
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
|
28
29
|
|
@@ -39,8 +40,16 @@ class DiceLoss(_Loss):
|
|
39
40
|
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
|
40
41
|
the inter-over-union calculation to smooth results respectively, these values should be small.
|
41
42
|
|
42
|
-
The original
|
43
|
-
|
43
|
+
The original papers:
|
44
|
+
|
45
|
+
Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
|
46
|
+
Medical Image Segmentation. 3DV 2016.
|
47
|
+
|
48
|
+
Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
|
49
|
+
Soft Labels. NeurIPS 2023.
|
50
|
+
|
51
|
+
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
|
52
|
+
Soft Labels. MICCAI 2023.
|
44
53
|
|
45
54
|
"""
|
46
55
|
|
@@ -58,6 +67,7 @@ class DiceLoss(_Loss):
|
|
58
67
|
smooth_dr: float = 1e-5,
|
59
68
|
batch: bool = False,
|
60
69
|
weight: Sequence[float] | float | int | torch.Tensor | None = None,
|
70
|
+
soft_label: bool = False,
|
61
71
|
) -> None:
|
62
72
|
"""
|
63
73
|
Args:
|
@@ -89,6 +99,8 @@ class DiceLoss(_Loss):
|
|
89
99
|
of the sequence should be the same as the number of classes. If not ``include_background``,
|
90
100
|
the number of classes should not include the background category class 0).
|
91
101
|
The value/values should be no less than 0. Defaults to None.
|
102
|
+
soft_label: whether the target contains non-binary values (soft labels) or not.
|
103
|
+
If True a soft label formulation of the loss will be used.
|
92
104
|
|
93
105
|
Raises:
|
94
106
|
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
|
@@ -114,6 +126,7 @@ class DiceLoss(_Loss):
|
|
114
126
|
weight = torch.as_tensor(weight) if weight is not None else None
|
115
127
|
self.register_buffer("class_weight", weight)
|
116
128
|
self.class_weight: None | torch.Tensor
|
129
|
+
self.soft_label = soft_label
|
117
130
|
|
118
131
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
119
132
|
"""
|
@@ -174,21 +187,15 @@ class DiceLoss(_Loss):
|
|
174
187
|
# reducing spatial dimensions and batch
|
175
188
|
reduce_axis = [0] + reduce_axis
|
176
189
|
|
177
|
-
|
178
|
-
|
179
|
-
if self.
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
pred_o = torch.sum(input, dim=reduce_axis)
|
185
|
-
|
186
|
-
denominator = ground_o + pred_o
|
187
|
-
|
188
|
-
if self.jaccard:
|
189
|
-
denominator = 2.0 * (denominator - intersection)
|
190
|
+
ord = 2 if self.squared_pred else 1
|
191
|
+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
|
192
|
+
if not self.jaccard:
|
193
|
+
fp *= 0.5
|
194
|
+
fn *= 0.5
|
195
|
+
numerator = 2 * tp + self.smooth_nr
|
196
|
+
denominator = 2 * (tp + fp + fn) + self.smooth_dr
|
190
197
|
|
191
|
-
f: torch.Tensor = 1
|
198
|
+
f: torch.Tensor = 1 - numerator / denominator
|
192
199
|
|
193
200
|
num_of_classes = target.shape[1]
|
194
201
|
if self.class_weight is not None and num_of_classes != 1:
|
@@ -272,6 +279,7 @@ class GeneralizedDiceLoss(_Loss):
|
|
272
279
|
smooth_nr: float = 1e-5,
|
273
280
|
smooth_dr: float = 1e-5,
|
274
281
|
batch: bool = False,
|
282
|
+
soft_label: bool = False,
|
275
283
|
) -> None:
|
276
284
|
"""
|
277
285
|
Args:
|
@@ -295,6 +303,8 @@ class GeneralizedDiceLoss(_Loss):
|
|
295
303
|
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
|
296
304
|
Defaults to False, intersection over union is computed from each item in the batch.
|
297
305
|
If True, the class-weighted intersection and union areas are first summed across the batches.
|
306
|
+
soft_label: whether the target contains non-binary values (soft labels) or not.
|
307
|
+
If True a soft label formulation of the loss will be used.
|
298
308
|
|
299
309
|
Raises:
|
300
310
|
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
|
@@ -319,6 +329,7 @@ class GeneralizedDiceLoss(_Loss):
|
|
319
329
|
self.smooth_nr = float(smooth_nr)
|
320
330
|
self.smooth_dr = float(smooth_dr)
|
321
331
|
self.batch = batch
|
332
|
+
self.soft_label = soft_label
|
322
333
|
|
323
334
|
def w_func(self, grnd):
|
324
335
|
if self.w_type == str(Weight.SIMPLE):
|
@@ -370,13 +381,13 @@ class GeneralizedDiceLoss(_Loss):
|
|
370
381
|
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
|
371
382
|
if self.batch:
|
372
383
|
reduce_axis = [0] + reduce_axis
|
373
|
-
intersection = torch.sum(target * input, reduce_axis)
|
374
384
|
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
denominator =
|
385
|
+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
|
386
|
+
fp *= 0.5
|
387
|
+
fn *= 0.5
|
388
|
+
denominator = 2 * (tp + fp + fn)
|
379
389
|
|
390
|
+
ground_o = torch.sum(target, reduce_axis)
|
380
391
|
w = self.w_func(ground_o.float())
|
381
392
|
infs = torch.isinf(w)
|
382
393
|
if self.batch:
|
@@ -388,7 +399,7 @@ class GeneralizedDiceLoss(_Loss):
|
|
388
399
|
w = w + infs * max_values
|
389
400
|
|
390
401
|
final_reduce_dim = 0 if self.batch else 1
|
391
|
-
numer = 2.0 * (
|
402
|
+
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
|
392
403
|
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
|
393
404
|
f: torch.Tensor = 1.0 - (numer / denom)
|
394
405
|
|
monai/losses/tversky.py
CHANGED
@@ -17,6 +17,7 @@ from collections.abc import Callable
|
|
17
17
|
import torch
|
18
18
|
from torch.nn.modules.loss import _Loss
|
19
19
|
|
20
|
+
from monai.losses.utils import compute_tp_fp_fn
|
20
21
|
from monai.networks import one_hot
|
21
22
|
from monai.utils import LossReduction
|
22
23
|
|
@@ -28,6 +29,9 @@ class TverskyLoss(_Loss):
|
|
28
29
|
Sadegh et al. (2017) Tversky loss function for image segmentation
|
29
30
|
using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)
|
30
31
|
|
32
|
+
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
|
33
|
+
Soft Labels. MICCAI 2023.
|
34
|
+
|
31
35
|
Adapted from:
|
32
36
|
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631
|
33
37
|
|
@@ -46,6 +50,7 @@ class TverskyLoss(_Loss):
|
|
46
50
|
smooth_nr: float = 1e-5,
|
47
51
|
smooth_dr: float = 1e-5,
|
48
52
|
batch: bool = False,
|
53
|
+
soft_label: bool = False,
|
49
54
|
) -> None:
|
50
55
|
"""
|
51
56
|
Args:
|
@@ -70,6 +75,8 @@ class TverskyLoss(_Loss):
|
|
70
75
|
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
|
71
76
|
Defaults to False, a Dice loss value is computed independently from each item in the batch
|
72
77
|
before any `reduction`.
|
78
|
+
soft_label: whether the target contains non-binary values (soft labels) or not.
|
79
|
+
If True a soft label formulation of the loss will be used.
|
73
80
|
|
74
81
|
Raises:
|
75
82
|
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
|
@@ -93,6 +100,7 @@ class TverskyLoss(_Loss):
|
|
93
100
|
self.smooth_nr = float(smooth_nr)
|
94
101
|
self.smooth_dr = float(smooth_dr)
|
95
102
|
self.batch = batch
|
103
|
+
self.soft_label = soft_label
|
96
104
|
|
97
105
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
98
106
|
"""
|
@@ -134,20 +142,15 @@ class TverskyLoss(_Loss):
|
|
134
142
|
if target.shape != input.shape:
|
135
143
|
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
|
136
144
|
|
137
|
-
p0 = input
|
138
|
-
p1 = 1 - p0
|
139
|
-
g0 = target
|
140
|
-
g1 = 1 - g0
|
141
|
-
|
142
145
|
# reducing only spatial dimensions (not batch nor channels)
|
143
146
|
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
|
144
147
|
if self.batch:
|
145
148
|
# reducing spatial dimensions and batch
|
146
149
|
reduce_axis = [0] + reduce_axis
|
147
150
|
|
148
|
-
tp =
|
149
|
-
fp
|
150
|
-
fn
|
151
|
+
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
|
152
|
+
fp *= self.alpha
|
153
|
+
fn *= self.beta
|
151
154
|
numerator = tp + self.smooth_nr
|
152
155
|
denominator = tp + fp + fn + self.smooth_dr
|
153
156
|
|
monai/losses/utils.py
ADDED
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.linalg as LA
|
16
|
+
|
17
|
+
|
18
|
+
def compute_tp_fp_fn(
|
19
|
+
input: torch.Tensor,
|
20
|
+
target: torch.Tensor,
|
21
|
+
reduce_axis: list[int],
|
22
|
+
ord: int,
|
23
|
+
soft_label: bool,
|
24
|
+
decoupled: bool = True,
|
25
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
26
|
+
"""
|
27
|
+
Args:
|
28
|
+
input: the shape should be BNH[WD], where N is the number of classes.
|
29
|
+
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
|
30
|
+
reduce_axis: the axis to be reduced.
|
31
|
+
ord: the order of the vector norm.
|
32
|
+
soft_label: whether the target contains non-binary values (soft labels) or not.
|
33
|
+
If True a soft label formulation of the loss will be used.
|
34
|
+
decoupled: whether the input and the target should be decoupled when computing fp and fn.
|
35
|
+
Only for the original implementation when soft_label is False.
|
36
|
+
|
37
|
+
Adapted from:
|
38
|
+
https://github.com/zifuwanggg/JDTLosses
|
39
|
+
"""
|
40
|
+
|
41
|
+
# the original implementation that is erroneous with soft labels
|
42
|
+
if ord == 1 and not soft_label:
|
43
|
+
tp = torch.sum(input * target, dim=reduce_axis)
|
44
|
+
# the original implementation of Dice and Jaccard loss
|
45
|
+
if decoupled:
|
46
|
+
fp = torch.sum(input, dim=reduce_axis) - tp
|
47
|
+
fn = torch.sum(target, dim=reduce_axis) - tp
|
48
|
+
# the original implementation of Tversky loss
|
49
|
+
else:
|
50
|
+
fp = torch.sum(input * (1 - target), dim=reduce_axis)
|
51
|
+
fn = torch.sum((1 - input) * target, dim=reduce_axis)
|
52
|
+
# the new implementation that is correct with soft labels
|
53
|
+
# and it is identical to the original implementation with hard labels
|
54
|
+
else:
|
55
|
+
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
|
56
|
+
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
|
57
|
+
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)
|
58
|
+
|
59
|
+
if ord > 1:
|
60
|
+
pred_o = torch.pow(pred_o, exponent=ord)
|
61
|
+
ground_o = torch.pow(ground_o, exponent=ord)
|
62
|
+
difference = torch.pow(difference, exponent=ord)
|
63
|
+
|
64
|
+
tp = (pred_o + ground_o - difference) / 2
|
65
|
+
fp = pred_o - tp
|
66
|
+
fn = ground_o - tp
|
67
|
+
|
68
|
+
return tp, fp, fn
|
@@ -56,7 +56,7 @@ def build_sincos_position_embedding(
|
|
56
56
|
grid_h = torch.arange(h, dtype=torch.float32)
|
57
57
|
grid_w = torch.arange(w, dtype=torch.float32)
|
58
58
|
|
59
|
-
grid_h, grid_w = torch.meshgrid(grid_h, grid_w
|
59
|
+
grid_h, grid_w = torch.meshgrid(grid_h, grid_w)
|
60
60
|
|
61
61
|
if embed_dim % 4 != 0:
|
62
62
|
raise AssertionError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
|
@@ -75,7 +75,7 @@ def build_sincos_position_embedding(
|
|
75
75
|
grid_w = torch.arange(w, dtype=torch.float32)
|
76
76
|
grid_d = torch.arange(d, dtype=torch.float32)
|
77
77
|
|
78
|
-
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d
|
78
|
+
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
|
79
79
|
|
80
80
|
if embed_dim % 6 != 0:
|
81
81
|
raise AssertionError("Embed dimension must be divisible by 6 for 3D sin-cos position embedding")
|
@@ -11,7 +11,7 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
-
from typing import Tuple, Union
|
14
|
+
from typing import Optional, Tuple, Union
|
15
15
|
|
16
16
|
import torch
|
17
17
|
import torch.nn as nn
|
@@ -154,10 +154,12 @@ class SABlock(nn.Module):
|
|
154
154
|
)
|
155
155
|
self.input_size = input_size
|
156
156
|
|
157
|
-
def forward(self, x):
|
157
|
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
158
158
|
"""
|
159
159
|
Args:
|
160
160
|
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
|
161
|
+
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
|
162
|
+
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
|
161
163
|
|
162
164
|
Return:
|
163
165
|
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
|
@@ -176,7 +178,13 @@ class SABlock(nn.Module):
|
|
176
178
|
|
177
179
|
if self.use_flash_attention:
|
178
180
|
x = F.scaled_dot_product_attention(
|
179
|
-
query=q,
|
181
|
+
query=q,
|
182
|
+
key=k,
|
183
|
+
value=v,
|
184
|
+
attn_mask=attn_mask,
|
185
|
+
scale=self.scale,
|
186
|
+
dropout_p=self.dropout_rate,
|
187
|
+
is_causal=self.causal,
|
180
188
|
)
|
181
189
|
else:
|
182
190
|
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
|
@@ -186,10 +194,16 @@ class SABlock(nn.Module):
|
|
186
194
|
att_mat = self.rel_positional_embedding(x, att_mat, q)
|
187
195
|
|
188
196
|
if self.causal:
|
197
|
+
if attn_mask is not None:
|
198
|
+
raise ValueError("Causal attention does not support attention masks.")
|
189
199
|
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
|
190
200
|
|
191
|
-
|
201
|
+
if attn_mask is not None:
|
202
|
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
|
203
|
+
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
|
204
|
+
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
|
192
205
|
|
206
|
+
att_mat = att_mat.softmax(dim=-1)
|
193
207
|
if self.save_attn:
|
194
208
|
# no gradients and new tensor;
|
195
209
|
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
|
@@ -90,8 +90,10 @@ class TransformerBlock(nn.Module):
|
|
90
90
|
use_flash_attention=use_flash_attention,
|
91
91
|
)
|
92
92
|
|
93
|
-
def forward(
|
94
|
-
x =
|
93
|
+
def forward(
|
94
|
+
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
|
95
|
+
) -> torch.Tensor:
|
96
|
+
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
|
95
97
|
if self.with_cross_attention:
|
96
98
|
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
|
97
99
|
x = x + self.mlp(self.norm2(x))
|