monai-weekly 1.4.dev2435__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -13,9 +13,51 @@ from __future__ import annotations
13
13
 
14
14
  import os
15
15
  import sys
16
-
16
+ import logging
17
+ import warnings
17
18
  from ._version import get_versions
18
19
 
20
+
21
+ old_showwarning = warnings.showwarning
22
+
23
+
24
+ def custom_warning_handler(message, category, filename, lineno, file=None, line=None):
25
+ ignore_files = ["ignite/handlers/checkpoint", "modelopt/torch/quantization/tensor_quant"]
26
+ if any(ignore in filename for ignore in ignore_files):
27
+ return
28
+ old_showwarning(message, category, filename, lineno, file, line)
29
+
30
+
31
+ class DeprecatedTypesWarningFilter(logging.Filter):
32
+ def filter(self, record):
33
+ message_bodies_to_ignore = [
34
+ "np.bool8",
35
+ "np.object0",
36
+ "np.int0",
37
+ "np.uint0",
38
+ "np.void0",
39
+ "np.str0",
40
+ "np.bytes0",
41
+ "@validator",
42
+ "@root_validator",
43
+ "class-based `config`",
44
+ "pkg_resources",
45
+ "Implicitly cleaning up",
46
+ ]
47
+ for message in message_bodies_to_ignore:
48
+ if message in record.getMessage():
49
+ return False
50
+ return True
51
+
52
+
53
+ # workaround for https://github.com/Project-MONAI/MONAI/issues/8060
54
+ # TODO: remove this workaround after upstream fixed the warning
55
+ # Set the custom warning handler to filter warning
56
+ warnings.showwarning = custom_warning_handler
57
+ # Get the logger for warnings and add the filter to the logger
58
+ logging.getLogger("py.warnings").addFilter(DeprecatedTypesWarningFilter())
59
+
60
+
19
61
  PY_REQUIRED_MAJOR = 3
20
62
  PY_REQUIRED_MINOR = 9
21
63
 
@@ -93,4 +135,4 @@ except BaseException:
93
135
 
94
136
  if MONAIEnvVars.debug():
95
137
  raise
96
- __commit_id__ = "fa1ef8be157d5eb96de17aa78642384f68d99397"
138
+ __commit_id__ = "d02ba11d8069870d71316a616f047c499627c71c"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-09-01T02:28:54+0000",
11
+ "date": "2024-09-08T02:25:56+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "d311b1d7b12a95dd7de995b507ffbb5ed413bab6",
15
- "version": "1.4.dev2435"
14
+ "full-revisionid": "0d9ab7da5ba0cbc2df3de3f7397c58ac1fe80598",
15
+ "version": "1.4.dev2436"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -100,7 +100,7 @@ def point_based_window_inferer(
100
100
  point_labels=point_labels,
101
101
  class_vector=class_vector,
102
102
  prompt_class=prompt_class,
103
- patch_coords=unravel_slice,
103
+ patch_coords=[unravel_slice],
104
104
  prev_mask=prev_mask,
105
105
  **kwargs,
106
106
  )
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any
20
20
 
21
21
  from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
22
22
  from monai.bundle.reference_resolver import ReferenceResolver
23
- from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
23
+ from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
24
24
  from monai.config import PathLike
25
25
  from monai.utils import ensure_tuple, look_up_option, optional_import
26
26
  from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
@@ -423,8 +423,10 @@ class ConfigParser:
423
423
  if isinstance(files, str) and not Path(files).is_file() and "," in files:
424
424
  files = files.split(",")
425
425
  for i in ensure_tuple(files):
426
- for k, v in (cls.load_config_file(i, **kwargs)).items():
427
- parser[k] = v
426
+ config_dict = cls.load_config_file(i, **kwargs)
427
+ for k, v in config_dict.items():
428
+ merge_kv(parser, k, v)
429
+
428
430
  return parser.get() # type: ignore
429
431
 
430
432
  @classmethod
monai/bundle/scripts.py CHANGED
@@ -32,7 +32,7 @@ from monai._version import get_versions
32
32
  from monai.apps.utils import _basename, download_url, extractall, get_logger
33
33
  from monai.bundle.config_item import ConfigComponent
34
34
  from monai.bundle.config_parser import ConfigParser
35
- from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
35
+ from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
36
36
  from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
37
37
  from monai.config import IgniteInfo, PathLike
38
38
  from monai.data import load_net_with_metadata, save_net_with_metadata
@@ -105,7 +105,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
105
105
  if isinstance(v, dict) and isinstance(args_.get(k), dict):
106
106
  args_[k] = update_kwargs(args_[k], ignore_none, **v)
107
107
  else:
108
- args_[k] = v
108
+ merge_kv(args_, k, v)
109
109
  return args_
110
110
 
111
111
 
monai/bundle/utils.py CHANGED
@@ -13,6 +13,7 @@ from __future__ import annotations
13
13
 
14
14
  import json
15
15
  import os
16
+ import warnings
16
17
  import zipfile
17
18
  from typing import Any
18
19
 
@@ -21,12 +22,21 @@ from monai.utils import optional_import
21
22
 
22
23
  yaml, _ = optional_import("yaml")
23
24
 
24
- __all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"]
25
+ __all__ = [
26
+ "ID_REF_KEY",
27
+ "ID_SEP_KEY",
28
+ "EXPR_KEY",
29
+ "MACRO_KEY",
30
+ "MERGE_KEY",
31
+ "DEFAULT_MLFLOW_SETTINGS",
32
+ "DEFAULT_EXP_MGMT_SETTINGS",
33
+ ]
25
34
 
26
35
  ID_REF_KEY = "@" # start of a reference to a ConfigItem
27
36
  ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
28
37
  EXPR_KEY = "$" # start of a ConfigExpression
29
38
  MACRO_KEY = "%" # start of a macro of a config
39
+ MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.
30
40
 
31
41
  _conf_values = get_config_values()
32
42
 
@@ -233,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
233
243
  parser.read_config(f=cdata)
234
244
 
235
245
  return parser
246
+
247
+
248
+ def merge_kv(args: dict | Any, k: str, v: Any) -> None:
249
+ """
250
+ Update the `args` dict-like object with the key/value pair `k` and `v`.
251
+ """
252
+ if k.startswith(MERGE_KEY):
253
+ """
254
+ Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
255
+ `dict` values will be merged, `list` values - concatenated.
256
+ """
257
+ id = k[1:]
258
+ if id in args:
259
+ if isinstance(v, dict) and isinstance(args[id], dict):
260
+ args[id].update(v)
261
+ elif isinstance(v, list) and isinstance(args[id], list):
262
+ args[id].extend(v)
263
+ else:
264
+ raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
265
+ else:
266
+ warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
267
+ args[id] = v
268
+ else:
269
+ args[k] = v
@@ -40,5 +40,6 @@ from .smartcache_handler import SmartCacheHandler
40
40
  from .stats_handler import StatsHandler
41
41
  from .surface_distance import SurfaceDistance
42
42
  from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
43
+ from .trt_handler import TrtHandler
43
44
  from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
44
45
  from .validation_handler import ValidationHandler
@@ -0,0 +1,61 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING
15
+
16
+ from monai.config import IgniteInfo
17
+ from monai.networks import trt_compile
18
+ from monai.utils import min_version, optional_import
19
+
20
+ Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
21
+ if TYPE_CHECKING:
22
+ from ignite.engine import Engine
23
+ else:
24
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
25
+
26
+
27
+ class TrtHandler:
28
+ """
29
+ TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
30
+ Usage example::
31
+ handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
32
+ handler.attach(engine)
33
+ engine.run()
34
+ """
35
+
36
+ def __init__(self, model, base_path, args=None, submodule=None):
37
+ """
38
+ Args:
39
+ base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
40
+ args: passed to trt_compile(). See trt_compile() for details.
41
+ submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
42
+ """
43
+ self.model = model
44
+ self.base_path = base_path
45
+ self.args = args
46
+ self.submodule = submodule
47
+
48
+ def attach(self, engine: Engine) -> None:
49
+ """
50
+ Args:
51
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
52
+ """
53
+ self.logger = engine.logger
54
+ engine.add_event_handler(Events.STARTED, self)
55
+
56
+ def __call__(self, engine: Engine) -> None:
57
+ """
58
+ Args:
59
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
60
+ """
61
+ trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
@@ -14,34 +14,47 @@ from __future__ import annotations
14
14
  import torch
15
15
 
16
16
  from monai.metrics.utils import do_metric_reduction, ignore_background
17
- from monai.utils import MetricReduction, Weight, look_up_option
17
+ from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
18
18
 
19
19
  from .metric import CumulativeIterationMetric
20
20
 
21
21
 
22
22
  class GeneralizedDiceScore(CumulativeIterationMetric):
23
- """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
23
+ """
24
+ Compute the Generalized Dice Score metric between tensors.
24
25
 
26
+ This metric is the complement of the Generalized Dice Loss defined in:
25
27
  Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
26
- loss function for highly unbalanced segmentations. DLMIA 2017.
28
+ loss function for highly unbalanced segmentations. DLMIA 2017.
27
29
 
28
- The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
29
- or batch-first tensors, i.e., CHW[D] or BCHW[D].
30
+ The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
30
31
 
31
32
  Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
32
33
 
33
34
  Args:
34
- include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
35
+ include_background: Whether to include the background class (assumed to be in channel 0) in the
35
36
  score computation. Defaults to True.
36
- reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
37
- {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
38
- weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
37
+ reduction: Define mode of reduction to the metrics. Available reduction modes:
38
+ {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
39
+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
40
+ weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
39
41
  ground truth volume into a weight factor. Defaults to ``"square"``.
40
42
 
41
43
  Raises:
42
- ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
44
+ ValueError: When the `reduction` is not one of MetricReduction enum.
43
45
  """
44
46
 
47
+ @deprecated_arg_default(
48
+ "reduction",
49
+ old_default=MetricReduction.MEAN_BATCH,
50
+ new_default=MetricReduction.MEAN,
51
+ since="1.4.0",
52
+ replaced="1.5.0",
53
+ msg_suffix=(
54
+ "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
55
+ "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
56
+ ),
57
+ )
45
58
  def __init__(
46
59
  self,
47
60
  include_background: bool = True,
@@ -50,79 +63,90 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
50
63
  ) -> None:
51
64
  super().__init__()
52
65
  self.include_background = include_background
53
- reduction_options = [
54
- "none",
55
- "mean_batch",
56
- "sum_batch",
57
- MetricReduction.NONE,
58
- MetricReduction.MEAN_BATCH,
59
- MetricReduction.SUM_BATCH,
60
- ]
61
- self.reduction = reduction
62
- if self.reduction not in reduction_options:
63
- raise ValueError(f"reduction must be one of {reduction_options}")
66
+ self.reduction = look_up_option(reduction, MetricReduction)
64
67
  self.weight_type = look_up_option(weight_type, Weight)
68
+ self.sum_over_classes = self.reduction in {
69
+ MetricReduction.SUM,
70
+ MetricReduction.MEAN,
71
+ MetricReduction.MEAN_CHANNEL,
72
+ MetricReduction.SUM_CHANNEL,
73
+ }
65
74
 
66
75
  def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
67
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
76
+ """
77
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
68
78
 
69
79
  Args:
70
- y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
80
+ y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
71
81
  where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
72
- y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
82
+ y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
83
+
84
+ Returns:
85
+ torch.Tensor: Generalized Dice Score averaged across batch and class
73
86
 
74
87
  Raises:
75
- ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
88
+ ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
76
89
  """
77
90
  return compute_generalized_dice(
78
- y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
91
+ y_pred=y_pred,
92
+ y=y,
93
+ include_background=self.include_background,
94
+ weight_type=self.weight_type,
95
+ sum_over_classes=self.sum_over_classes,
79
96
  )
80
97
 
98
+ @deprecated_arg(
99
+ "reduction",
100
+ since="1.3.3",
101
+ removed="1.7.0",
102
+ msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
103
+ )
81
104
  def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
82
105
  """
83
106
  Execute reduction logic for the output of `compute_generalized_dice`.
84
107
 
85
- Args:
86
- reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
87
- Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
88
- Defaults to ``"mean"``. If "none", will not do reduction.
108
+ Returns:
109
+ torch.Tensor: Aggregated metric value.
110
+
111
+ Raises:
112
+ ValueError: If the data to aggregate is not a PyTorch Tensor.
89
113
  """
90
114
  data = self.get_buffer()
91
115
  if not isinstance(data, torch.Tensor):
92
116
  raise ValueError("The data to aggregate must be a PyTorch Tensor.")
93
117
 
94
- # Validate reduction argument if specified
95
- if reduction is not None:
96
- reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
97
- if reduction not in reduction_options:
98
- raise ValueError(f"reduction must be one of {reduction_options}")
99
-
100
118
  # Do metric reduction and return
101
- f, _ = do_metric_reduction(data, reduction or self.reduction)
119
+ f, _ = do_metric_reduction(data, self.reduction)
102
120
 
103
121
  return f
104
122
 
105
123
 
106
124
  def compute_generalized_dice(
107
- y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
125
+ y_pred: torch.Tensor,
126
+ y: torch.Tensor,
127
+ include_background: bool = True,
128
+ weight_type: Weight | str = Weight.SQUARE,
129
+ sum_over_classes: bool = False,
108
130
  ) -> torch.Tensor:
109
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
131
+ """
132
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
110
133
 
111
134
  Args:
112
- y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
135
+ y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
113
136
  and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
114
137
  remaining are the spatial dimensions.
115
- y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
116
- include_background (bool, optional): whether to include score computation on the first channel of the
138
+ y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
139
+ include_background: Whether to include score computation on the first channel of the
117
140
  predicted output. Defaults to True.
118
141
  weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
119
142
  transform ground truth volume into a weight factor. Defaults to ``"square"``.
143
+ sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
120
144
 
121
145
  Returns:
122
- torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
146
+ torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
123
147
 
124
148
  Raises:
125
- ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
149
+ ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
126
150
  or `y_pred` and `y` don't have the same shape.
127
151
  """
128
152
  # Ensure tensors have at least 3 dimensions and have the same shape
@@ -158,16 +182,21 @@ def compute_generalized_dice(
158
182
  b[infs] = 0
159
183
  b[infs] = torch.max(b)
160
184
 
161
- # Compute the weighted numerator and denominator, summing along the class axis
162
- numer = 2.0 * (intersection * w).sum(dim=1)
163
- denom = (denominator * w).sum(dim=1)
185
+ # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
186
+ if sum_over_classes:
187
+ numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
188
+ denom = (denominator * w).sum(dim=1, keepdim=True)
189
+ y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
190
+ else:
191
+ numer = 2.0 * (intersection * w)
192
+ denom = denominator * w
193
+ y_pred_o = y_pred_o
164
194
 
165
195
  # Compute the score
166
196
  generalized_dice_score = numer / denom
167
197
 
168
198
  # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
169
199
  # Where denom == 0 but the prediction volume is not 0, score is 0
170
- y_pred_o = y_pred_o.sum(dim=-1)
171
200
  denom_zeros = denom == 0
172
201
  generalized_dice_score[denom_zeros] = torch.where(
173
202
  (y_pred_o == 0)[denom_zeros],
@@ -11,7 +11,9 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from .trt_compiler import trt_compile
14
15
  from .utils import (
16
+ add_casts_around_norms,
15
17
  convert_to_onnx,
16
18
  convert_to_torchscript,
17
19
  convert_to_trt,
@@ -320,7 +320,7 @@ class SwinUNETR(nn.Module):
320
320
  )
321
321
 
322
322
  def forward(self, x_in):
323
- if not torch.jit.is_scripting():
323
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
324
324
  self._check_input_size(x_in.shape[2:])
325
325
  hidden_states_out = self.swinViT(x_in, self.normalize)
326
326
  enc0 = self.encoder1(x_in)
@@ -1046,14 +1046,14 @@ class SwinTransformer(nn.Module):
1046
1046
 
1047
1047
  def proj_out(self, x, normalize=False):
1048
1048
  if normalize:
1049
- x_shape = x.size()
1049
+ x_shape = x.shape
1050
+ # Force trace() to generate a constant by casting to int
1051
+ ch = int(x_shape[1])
1050
1052
  if len(x_shape) == 5:
1051
- n, ch, d, h, w = x_shape
1052
1053
  x = rearrange(x, "n c d h w -> n d h w c")
1053
1054
  x = F.layer_norm(x, [ch])
1054
1055
  x = rearrange(x, "n d h w c -> n c d h w")
1055
1056
  elif len(x_shape) == 4:
1056
- n, ch, h, w = x_shape
1057
1057
  x = rearrange(x, "n c h w -> n h w c")
1058
1058
  x = F.layer_norm(x, [ch])
1059
1059
  x = rearrange(x, "n h w c -> n c h w")
@@ -336,7 +336,7 @@ class VISTA3D(nn.Module):
336
336
  def forward(
337
337
  self,
338
338
  input_images: torch.Tensor,
339
- patch_coords: Sequence[slice] | None = None,
339
+ patch_coords: list[Sequence[slice]] | None = None,
340
340
  point_coords: torch.Tensor | None = None,
341
341
  point_labels: torch.Tensor | None = None,
342
342
  class_vector: torch.Tensor | None = None,
@@ -364,8 +364,12 @@ class VISTA3D(nn.Module):
364
364
  the points are for zero-shot or supported class. When class_vector and point_coords are both
365
365
  provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
366
366
  will be considered novel class.
367
- patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
368
- This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
367
+ patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
368
+ inference. This value is passed from sliding_window_inferer.
369
+ This is an indicator for training phase or validation phase.
370
+ Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
371
+ coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
372
+ functions using patch_coords will by default use patch_coords[0].
369
373
  labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
370
374
  label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
371
375
  this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
@@ -395,14 +399,14 @@ class VISTA3D(nn.Module):
395
399
  if val_point_sampler is None:
396
400
  # TODO: think about how to refactor this part.
397
401
  val_point_sampler = self.sample_points_patch_val
398
- point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
402
+ point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
399
403
  if prompt_class[0].item() == 0: # type: ignore
400
404
  point_labels[0] = -1 # type: ignore
401
405
  labels, prev_mask = None, None
402
406
  elif point_coords is not None:
403
407
  # If not performing patch-based point only validation, use user provided click points for inference.
404
408
  # the point clicks is in original image space, convert it to current patch-coordinate space.
405
- point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
409
+ point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore
406
410
 
407
411
  if point_coords is not None and point_labels is not None:
408
412
  # remove points that used for padding purposes (point_label = -1)
@@ -455,7 +459,7 @@ class VISTA3D(nn.Module):
455
459
  logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
456
460
  if prev_mask is not None and patch_coords is not None:
457
461
  logits = self.connected_components_combine(
458
- prev_mask[patch_coords].transpose(1, 0).to(logits.device),
462
+ prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
459
463
  logits[mapping_index],
460
464
  point_coords, # type: ignore
461
465
  point_labels, # type: ignore