monai-weekly 1.5.dev2447__py3-none-any.whl → 1.5.dev2449__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -136,4 +136,4 @@ except BaseException:
136
136
 
137
137
  if MONAIEnvVars.debug():
138
138
  raise
139
- __commit_id__ = "b1e915c323a8065cfe9e92de3013476f2f67c1b2"
139
+ __commit_id__ = "e604d1841fe60c0ffb6978ae4116535ca8d8f34f"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-24T02:30:43+0000",
11
+ "date": "2024-12-08T02:32:52+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2d6751b01bf78fddabe03b2c53645c6bc9808ed8",
15
- "version": "1.5.dev2447"
14
+ "full-revisionid": "8cad248c8b374702245989507da1dd8430ef863f",
15
+ "version": "1.5.dev2449"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
monai/bundle/__init__.py CHANGED
@@ -43,4 +43,4 @@ from .utils import (
43
43
  MACRO_KEY,
44
44
  load_bundle_config,
45
45
  )
46
- from .workflows import BundleWorkflow, ConfigWorkflow
46
+ from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow
@@ -192,6 +192,16 @@ class ReferenceResolver:
192
192
  """
193
193
  return self._resolve_one_item(id=id, **kwargs)
194
194
 
195
+ def remove_resolved_content(self, id: str) -> Any | None:
196
+ """
197
+ Remove the resolved ``ConfigItem`` by id.
198
+
199
+ Args:
200
+ id: id name of the expected item.
201
+
202
+ """
203
+ return self.resolved_content.pop(id) if id in self.resolved_content else None
204
+
195
205
  @classmethod
196
206
  def normalize_id(cls, id: str | int) -> str:
197
207
  """
monai/bundle/workflows.py CHANGED
@@ -44,12 +44,18 @@ class BundleWorkflow(ABC):
44
44
  workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
45
45
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
46
46
  other unsupported string will raise a ValueError.
47
- default to `train` for train workflow.
47
+ default to `None` for only using meta properties.
48
48
  workflow: specifies the workflow type: "train" or "training" for a training workflow,
49
49
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
50
50
  other unsupported string will raise a ValueError.
51
51
  default to `None` for common workflow.
52
- properties_path: the path to the JSON file of properties.
52
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
53
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
54
+ properties will default to loading from "meta". If `properties_path` is None, default properties
55
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
56
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
57
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
58
+ For workflow_type = None : only `MetaProperties` will be loaded.
53
59
  meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
54
60
  logging_file: config file for `logging` module in the program. for more details:
55
61
  https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
@@ -97,29 +103,50 @@ class BundleWorkflow(ABC):
97
103
  meta_file = None
98
104
 
99
105
  workflow_type = workflow if workflow is not None else workflow_type
100
- if workflow_type is None and properties_path is None:
101
- self.properties = copy(MetaProperties)
102
- self.workflow_type = None
103
- self.meta_file = meta_file
104
- return
106
+ if workflow_type is not None:
107
+ if workflow_type.lower() in self.supported_train_type:
108
+ workflow_type = "train"
109
+ elif workflow_type.lower() in self.supported_infer_type:
110
+ workflow_type = "infer"
111
+ else:
112
+ raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
113
+
105
114
  if properties_path is not None:
106
115
  properties_path = Path(properties_path)
107
116
  if not properties_path.is_file():
108
117
  raise ValueError(f"Property file {properties_path} does not exist.")
109
118
  with open(properties_path) as json_file:
110
- self.properties = json.load(json_file)
111
- self.workflow_type = None
112
- self.meta_file = meta_file
113
- return
114
- if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
115
- self.properties = {**TrainProperties, **MetaProperties}
116
- self.workflow_type = "train"
117
- elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
118
- self.properties = {**InferProperties, **MetaProperties}
119
- self.workflow_type = "infer"
119
+ try:
120
+ properties = json.load(json_file)
121
+ self.properties: dict = {}
122
+ if workflow_type is not None and workflow_type in properties:
123
+ self.properties = properties[workflow_type]
124
+ if "meta" in properties:
125
+ self.properties.update(properties["meta"])
126
+ elif workflow_type is None:
127
+ if "meta" in properties:
128
+ self.properties = properties["meta"]
129
+ logger.info(
130
+ "No workflow type specified, default to load meta properties from property file."
131
+ )
132
+ else:
133
+ logger.warning("No 'meta' key found in properties while workflow_type is None.")
134
+ except KeyError as e:
135
+ raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e
136
+ except json.JSONDecodeError as e:
137
+ raise ValueError(f"Error decoding JSON from property file {properties_path}") from e
120
138
  else:
121
- raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
139
+ if workflow_type == "train":
140
+ self.properties = {**TrainProperties, **MetaProperties}
141
+ elif workflow_type == "infer":
142
+ self.properties = {**InferProperties, **MetaProperties}
143
+ elif workflow_type is None:
144
+ self.properties = copy(MetaProperties)
145
+ logger.info("No workflow type and property file specified, default to 'meta' properties.")
146
+ else:
147
+ raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
122
148
 
149
+ self.workflow_type = workflow_type
123
150
  self.meta_file = meta_file
124
151
 
125
152
  @abstractmethod
@@ -226,6 +253,124 @@ class BundleWorkflow(ABC):
226
253
  return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)]
227
254
 
228
255
 
256
+ class PythonicWorkflow(BundleWorkflow):
257
+ """
258
+ Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow.
259
+ It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc.
260
+ This also provides the interface to get / set public properties to interact with a bundle workflow through
261
+ defined `get_<property>` accessor methods or directly defining members of the object.
262
+ For how to set the properties, users can define the `_set_<property>` methods or directly set the members of the object.
263
+ The `initialize` method is called to set up the workflow before running. This method sets up internal state
264
+ and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized`
265
+ is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is
266
+ properly set up with the new property values.
267
+
268
+ Args:
269
+ workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
270
+ or "infer", "inference", "eval", "evaluation" for a inference workflow,
271
+ other unsupported string will raise a ValueError.
272
+ default to `None` for only using meta properties.
273
+ workflow: specifies the workflow type: "train" or "training" for a training workflow,
274
+ or "infer", "inference", "eval", "evaluation" for a inference workflow,
275
+ other unsupported string will raise a ValueError.
276
+ default to `None` for common workflow.
277
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
278
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
279
+ properties will default to loading from "meta". If `properties_path` is None, default properties
280
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
281
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
282
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
283
+ For workflow_type = None : only `MetaProperties` will be loaded.
284
+ config_file: path to the config file, typically used to store hyperparameters.
285
+ meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
286
+ logging_file: config file for `logging` module in the program. for more details:
287
+ https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
288
+
289
+ """
290
+
291
+ supported_train_type: tuple = ("train", "training")
292
+ supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
293
+
294
+ def __init__(
295
+ self,
296
+ workflow_type: str | None = None,
297
+ properties_path: PathLike | None = None,
298
+ config_file: str | Sequence[str] | None = None,
299
+ meta_file: str | Sequence[str] | None = None,
300
+ logging_file: str | None = None,
301
+ **override: Any,
302
+ ):
303
+ meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file
304
+ super().__init__(
305
+ workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file
306
+ )
307
+ self._props_vals: dict = {}
308
+ self._set_props_vals: dict = {}
309
+ self.parser = ConfigParser()
310
+ if config_file is not None:
311
+ self.parser.read_config(f=config_file)
312
+ if self.meta_file is not None:
313
+ self.parser.read_meta(f=self.meta_file)
314
+
315
+ # the rest key-values in the _args are to override config content
316
+ self.parser.update(pairs=override)
317
+ self._is_initialized: bool = False
318
+
319
+ def initialize(self, *args: Any, **kwargs: Any) -> Any:
320
+ """
321
+ Initialize the bundle workflow before running.
322
+ """
323
+ self._props_vals = {}
324
+ self._is_initialized = True
325
+
326
+ def _get_property(self, name: str, property: dict) -> Any:
327
+ """
328
+ With specified property name and information, get the expected property value.
329
+ If the property is already generated, return from the bucket directly.
330
+ If user explicitly set the property, return it directly.
331
+ Otherwise, generate the expected property as a class private property with prefix "_".
332
+
333
+ Args:
334
+ name: the name of target property.
335
+ property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
336
+ """
337
+ if not self._is_initialized:
338
+ raise RuntimeError("Please execute 'initialize' before getting any properties.")
339
+ value = None
340
+ if name in self._set_props_vals:
341
+ value = self._set_props_vals[name]
342
+ elif name in self._props_vals:
343
+ value = self._props_vals[name]
344
+ elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index]
345
+ id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None)
346
+ value = self.parser[id]
347
+ else:
348
+ try:
349
+ value = getattr(self, f"get_{name}")()
350
+ except AttributeError as e:
351
+ if property[BundleProperty.REQUIRED]:
352
+ raise ValueError(
353
+ f"unsupported property '{name}' is required in the bundle properties,"
354
+ f"need to implement a method 'get_{name}' to provide the property."
355
+ ) from e
356
+ self._props_vals[name] = value
357
+ return value
358
+
359
+ def _set_property(self, name: str, property: dict, value: Any) -> Any:
360
+ """
361
+ With specified property name and information, set value for the expected property.
362
+ Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized.
363
+
364
+ Args:
365
+ name: the name of target property.
366
+ property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
367
+ value: value to set for the property.
368
+
369
+ """
370
+ self._set_props_vals[name] = value
371
+ self._is_initialized = False
372
+
373
+
229
374
  class ConfigWorkflow(BundleWorkflow):
230
375
  """
231
376
  Specification for the config-based bundle workflow.
@@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow):
262
407
  or "infer", "inference", "eval", "evaluation" for a inference workflow,
263
408
  other unsupported string will raise a ValueError.
264
409
  default to `None` for common workflow.
265
- properties_path: the path to the JSON file of properties.
410
+ properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be
411
+ loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified,
412
+ properties will default to loading from "train". If `properties_path` is None, default properties
413
+ will be sourced from "monai/bundle/properties.py" based on the workflow_type:
414
+ For a training workflow, properties load from `TrainProperties` and `MetaProperties`.
415
+ For a inference workflow, properties load from `InferProperties` and `MetaProperties`.
416
+ For workflow_type = None : only `MetaProperties` will be loaded.
266
417
  override: id-value pairs to override or add the corresponding config content.
267
418
  e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
268
419
 
@@ -324,7 +475,6 @@ class ConfigWorkflow(BundleWorkflow):
324
475
  self.parser.read_config(f=config_file)
325
476
  if self.meta_file is not None:
326
477
  self.parser.read_meta(f=self.meta_file)
327
-
328
478
  # the rest key-values in the _args are to override config content
329
479
  self.parser.update(pairs=override)
330
480
  self.init_id = init_id
@@ -394,8 +544,23 @@ class ConfigWorkflow(BundleWorkflow):
394
544
  ret.extend(wrong_props)
395
545
  return ret
396
546
 
397
- def _run_expr(self, id: str, **kwargs: dict) -> Any:
398
- return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
547
+ def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
548
+ """
549
+ Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
550
+ allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
551
+ """
552
+ ret = []
553
+ if id in self.parser:
554
+ # suppose all the expressions are in a list, run and reset the expressions
555
+ if isinstance(self.parser[id], list):
556
+ for i in range(len(self.parser[id])):
557
+ sub_id = f"{id}{ID_SEP_KEY}{i}"
558
+ ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
559
+ self.parser.ref_resolver.remove_resolved_content(sub_id)
560
+ else:
561
+ ret.append(self.parser.get_parsed_content(id, **kwargs))
562
+ self.parser.ref_resolver.remove_resolved_content(id)
563
+ return ret
399
564
 
400
565
  def _get_prop_id(self, name: str, property: dict) -> Any:
401
566
  prop_id = property[BundlePropertyConfig.ID]
monai/losses/dice.py CHANGED
@@ -23,6 +23,7 @@ from torch.nn.modules.loss import _Loss
23
23
 
24
24
  from monai.losses.focal_loss import FocalLoss
25
25
  from monai.losses.spatial_mask import MaskedLoss
26
+ from monai.losses.utils import compute_tp_fp_fn
26
27
  from monai.networks import one_hot
27
28
  from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
28
29
 
@@ -39,8 +40,16 @@ class DiceLoss(_Loss):
39
40
  The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
40
41
  the inter-over-union calculation to smooth results respectively, these values should be small.
41
42
 
42
- The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
43
- Medical Image Segmentation, 3DV, 2016.
43
+ The original papers:
44
+
45
+ Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
46
+ Medical Image Segmentation. 3DV 2016.
47
+
48
+ Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
49
+ Soft Labels. NeurIPS 2023.
50
+
51
+ Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
52
+ Soft Labels. MICCAI 2023.
44
53
 
45
54
  """
46
55
 
@@ -58,6 +67,7 @@ class DiceLoss(_Loss):
58
67
  smooth_dr: float = 1e-5,
59
68
  batch: bool = False,
60
69
  weight: Sequence[float] | float | int | torch.Tensor | None = None,
70
+ soft_label: bool = False,
61
71
  ) -> None:
62
72
  """
63
73
  Args:
@@ -89,6 +99,8 @@ class DiceLoss(_Loss):
89
99
  of the sequence should be the same as the number of classes. If not ``include_background``,
90
100
  the number of classes should not include the background category class 0).
91
101
  The value/values should be no less than 0. Defaults to None.
102
+ soft_label: whether the target contains non-binary values (soft labels) or not.
103
+ If True a soft label formulation of the loss will be used.
92
104
 
93
105
  Raises:
94
106
  TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -114,6 +126,7 @@ class DiceLoss(_Loss):
114
126
  weight = torch.as_tensor(weight) if weight is not None else None
115
127
  self.register_buffer("class_weight", weight)
116
128
  self.class_weight: None | torch.Tensor
129
+ self.soft_label = soft_label
117
130
 
118
131
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
119
132
  """
@@ -174,21 +187,15 @@ class DiceLoss(_Loss):
174
187
  # reducing spatial dimensions and batch
175
188
  reduce_axis = [0] + reduce_axis
176
189
 
177
- intersection = torch.sum(target * input, dim=reduce_axis)
178
-
179
- if self.squared_pred:
180
- ground_o = torch.sum(target**2, dim=reduce_axis)
181
- pred_o = torch.sum(input**2, dim=reduce_axis)
182
- else:
183
- ground_o = torch.sum(target, dim=reduce_axis)
184
- pred_o = torch.sum(input, dim=reduce_axis)
185
-
186
- denominator = ground_o + pred_o
187
-
188
- if self.jaccard:
189
- denominator = 2.0 * (denominator - intersection)
190
+ ord = 2 if self.squared_pred else 1
191
+ tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
192
+ if not self.jaccard:
193
+ fp *= 0.5
194
+ fn *= 0.5
195
+ numerator = 2 * tp + self.smooth_nr
196
+ denominator = 2 * (tp + fp + fn) + self.smooth_dr
190
197
 
191
- f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
198
+ f: torch.Tensor = 1 - numerator / denominator
192
199
 
193
200
  num_of_classes = target.shape[1]
194
201
  if self.class_weight is not None and num_of_classes != 1:
@@ -272,6 +279,7 @@ class GeneralizedDiceLoss(_Loss):
272
279
  smooth_nr: float = 1e-5,
273
280
  smooth_dr: float = 1e-5,
274
281
  batch: bool = False,
282
+ soft_label: bool = False,
275
283
  ) -> None:
276
284
  """
277
285
  Args:
@@ -295,6 +303,8 @@ class GeneralizedDiceLoss(_Loss):
295
303
  batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
296
304
  Defaults to False, intersection over union is computed from each item in the batch.
297
305
  If True, the class-weighted intersection and union areas are first summed across the batches.
306
+ soft_label: whether the target contains non-binary values (soft labels) or not.
307
+ If True a soft label formulation of the loss will be used.
298
308
 
299
309
  Raises:
300
310
  TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -319,6 +329,7 @@ class GeneralizedDiceLoss(_Loss):
319
329
  self.smooth_nr = float(smooth_nr)
320
330
  self.smooth_dr = float(smooth_dr)
321
331
  self.batch = batch
332
+ self.soft_label = soft_label
322
333
 
323
334
  def w_func(self, grnd):
324
335
  if self.w_type == str(Weight.SIMPLE):
@@ -370,13 +381,13 @@ class GeneralizedDiceLoss(_Loss):
370
381
  reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
371
382
  if self.batch:
372
383
  reduce_axis = [0] + reduce_axis
373
- intersection = torch.sum(target * input, reduce_axis)
374
384
 
375
- ground_o = torch.sum(target, reduce_axis)
376
- pred_o = torch.sum(input, reduce_axis)
377
-
378
- denominator = ground_o + pred_o
385
+ tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
386
+ fp *= 0.5
387
+ fn *= 0.5
388
+ denominator = 2 * (tp + fp + fn)
379
389
 
390
+ ground_o = torch.sum(target, reduce_axis)
380
391
  w = self.w_func(ground_o.float())
381
392
  infs = torch.isinf(w)
382
393
  if self.batch:
@@ -388,7 +399,7 @@ class GeneralizedDiceLoss(_Loss):
388
399
  w = w + infs * max_values
389
400
 
390
401
  final_reduce_dim = 0 if self.batch else 1
391
- numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
402
+ numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
392
403
  denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
393
404
  f: torch.Tensor = 1.0 - (numer / denom)
394
405
 
monai/losses/tversky.py CHANGED
@@ -17,6 +17,7 @@ from collections.abc import Callable
17
17
  import torch
18
18
  from torch.nn.modules.loss import _Loss
19
19
 
20
+ from monai.losses.utils import compute_tp_fp_fn
20
21
  from monai.networks import one_hot
21
22
  from monai.utils import LossReduction
22
23
 
@@ -28,6 +29,9 @@ class TverskyLoss(_Loss):
28
29
  Sadegh et al. (2017) Tversky loss function for image segmentation
29
30
  using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)
30
31
 
32
+ Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
33
+ Soft Labels. MICCAI 2023.
34
+
31
35
  Adapted from:
32
36
  https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631
33
37
 
@@ -46,6 +50,7 @@ class TverskyLoss(_Loss):
46
50
  smooth_nr: float = 1e-5,
47
51
  smooth_dr: float = 1e-5,
48
52
  batch: bool = False,
53
+ soft_label: bool = False,
49
54
  ) -> None:
50
55
  """
51
56
  Args:
@@ -70,6 +75,8 @@ class TverskyLoss(_Loss):
70
75
  batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
71
76
  Defaults to False, a Dice loss value is computed independently from each item in the batch
72
77
  before any `reduction`.
78
+ soft_label: whether the target contains non-binary values (soft labels) or not.
79
+ If True a soft label formulation of the loss will be used.
73
80
 
74
81
  Raises:
75
82
  TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -93,6 +100,7 @@ class TverskyLoss(_Loss):
93
100
  self.smooth_nr = float(smooth_nr)
94
101
  self.smooth_dr = float(smooth_dr)
95
102
  self.batch = batch
103
+ self.soft_label = soft_label
96
104
 
97
105
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
98
106
  """
@@ -134,20 +142,15 @@ class TverskyLoss(_Loss):
134
142
  if target.shape != input.shape:
135
143
  raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
136
144
 
137
- p0 = input
138
- p1 = 1 - p0
139
- g0 = target
140
- g1 = 1 - g0
141
-
142
145
  # reducing only spatial dimensions (not batch nor channels)
143
146
  reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
144
147
  if self.batch:
145
148
  # reducing spatial dimensions and batch
146
149
  reduce_axis = [0] + reduce_axis
147
150
 
148
- tp = torch.sum(p0 * g0, reduce_axis)
149
- fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
150
- fn = self.beta * torch.sum(p1 * g0, reduce_axis)
151
+ tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
152
+ fp *= self.alpha
153
+ fn *= self.beta
151
154
  numerator = tp + self.smooth_nr
152
155
  denominator = tp + fp + fn + self.smooth_dr
153
156
 
monai/losses/utils.py ADDED
@@ -0,0 +1,68 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import torch
15
+ import torch.linalg as LA
16
+
17
+
18
+ def compute_tp_fp_fn(
19
+ input: torch.Tensor,
20
+ target: torch.Tensor,
21
+ reduce_axis: list[int],
22
+ ord: int,
23
+ soft_label: bool,
24
+ decoupled: bool = True,
25
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
26
+ """
27
+ Args:
28
+ input: the shape should be BNH[WD], where N is the number of classes.
29
+ target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
30
+ reduce_axis: the axis to be reduced.
31
+ ord: the order of the vector norm.
32
+ soft_label: whether the target contains non-binary values (soft labels) or not.
33
+ If True a soft label formulation of the loss will be used.
34
+ decoupled: whether the input and the target should be decoupled when computing fp and fn.
35
+ Only for the original implementation when soft_label is False.
36
+
37
+ Adapted from:
38
+ https://github.com/zifuwanggg/JDTLosses
39
+ """
40
+
41
+ # the original implementation that is erroneous with soft labels
42
+ if ord == 1 and not soft_label:
43
+ tp = torch.sum(input * target, dim=reduce_axis)
44
+ # the original implementation of Dice and Jaccard loss
45
+ if decoupled:
46
+ fp = torch.sum(input, dim=reduce_axis) - tp
47
+ fn = torch.sum(target, dim=reduce_axis) - tp
48
+ # the original implementation of Tversky loss
49
+ else:
50
+ fp = torch.sum(input * (1 - target), dim=reduce_axis)
51
+ fn = torch.sum((1 - input) * target, dim=reduce_axis)
52
+ # the new implementation that is correct with soft labels
53
+ # and it is identical to the original implementation with hard labels
54
+ else:
55
+ pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
56
+ ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
57
+ difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)
58
+
59
+ if ord > 1:
60
+ pred_o = torch.pow(pred_o, exponent=ord)
61
+ ground_o = torch.pow(ground_o, exponent=ord)
62
+ difference = torch.pow(difference, exponent=ord)
63
+
64
+ tp = (pred_o + ground_o - difference) / 2
65
+ fp = pred_o - tp
66
+ fn = ground_o - tp
67
+
68
+ return tp, fp, fn
@@ -56,7 +56,7 @@ def build_sincos_position_embedding(
56
56
  grid_h = torch.arange(h, dtype=torch.float32)
57
57
  grid_w = torch.arange(w, dtype=torch.float32)
58
58
 
59
- grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij")
59
+ grid_h, grid_w = torch.meshgrid(grid_h, grid_w)
60
60
 
61
61
  if embed_dim % 4 != 0:
62
62
  raise AssertionError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
@@ -75,7 +75,7 @@ def build_sincos_position_embedding(
75
75
  grid_w = torch.arange(w, dtype=torch.float32)
76
76
  grid_d = torch.arange(d, dtype=torch.float32)
77
77
 
78
- grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij")
78
+ grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d)
79
79
 
80
80
  if embed_dim % 6 != 0:
81
81
  raise AssertionError("Embed dimension must be divisible by 6 for 3D sin-cos position embedding")
@@ -11,7 +11,7 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from typing import Tuple, Union
14
+ from typing import Optional, Tuple, Union
15
15
 
16
16
  import torch
17
17
  import torch.nn as nn
@@ -154,10 +154,12 @@ class SABlock(nn.Module):
154
154
  )
155
155
  self.input_size = input_size
156
156
 
157
- def forward(self, x):
157
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
158
158
  """
159
159
  Args:
160
160
  x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161
+ attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162
+ B x (s_dim_1 * ... * s_dim_n). Defaults to None.
161
163
 
162
164
  Return:
163
165
  torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -176,7 +178,13 @@ class SABlock(nn.Module):
176
178
 
177
179
  if self.use_flash_attention:
178
180
  x = F.scaled_dot_product_attention(
179
- query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
181
+ query=q,
182
+ key=k,
183
+ value=v,
184
+ attn_mask=attn_mask,
185
+ scale=self.scale,
186
+ dropout_p=self.dropout_rate,
187
+ is_causal=self.causal,
180
188
  )
181
189
  else:
182
190
  att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
@@ -186,10 +194,16 @@ class SABlock(nn.Module):
186
194
  att_mat = self.rel_positional_embedding(x, att_mat, q)
187
195
 
188
196
  if self.causal:
197
+ if attn_mask is not None:
198
+ raise ValueError("Causal attention does not support attention masks.")
189
199
  att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
190
200
 
191
- att_mat = att_mat.softmax(dim=-1)
201
+ if attn_mask is not None:
202
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
203
+ attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
204
+ att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
192
205
 
206
+ att_mat = att_mat.softmax(dim=-1)
193
207
  if self.save_attn:
194
208
  # no gradients and new tensor;
195
209
  # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -90,8 +90,10 @@ class TransformerBlock(nn.Module):
90
90
  use_flash_attention=use_flash_attention,
91
91
  )
92
92
 
93
- def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
94
- x = x + self.attn(self.norm1(x))
93
+ def forward(
94
+ self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
95
+ ) -> torch.Tensor:
96
+ x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
95
97
  if self.with_cross_attention:
96
98
  x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
97
99
  x = x + self.mlp(self.norm2(x))