monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/__init__.py CHANGED
@@ -93,4 +93,4 @@ except BaseException:
93
93
 
94
94
  if MONAIEnvVars.debug():
95
95
  raise
96
- __commit_id__ = "a5fbe716378948630783deef8ee435e7e3bdc918"
96
+ __commit_id__ = "fa1ef8be157d5eb96de17aa78642384f68d99397"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-25T02:21:56+0000",
11
+ "date": "2024-09-01T02:28:54+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "dc611d231ba670004b1da1b011fe140375fb91af",
15
- "version": "1.4.dev2434"
14
+ "full-revisionid": "d311b1d7b12a95dd7de995b507ffbb5ed413bab6",
15
+ "version": "1.4.dev2435"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -0,0 +1,177 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ from collections.abc import Sequence
16
+ from typing import Any
17
+
18
+ import torch
19
+
20
+ from monai.data.meta_tensor import MetaTensor
21
+ from monai.utils import optional_import
22
+
23
+ tqdm, _ = optional_import("tqdm", name="tqdm")
24
+
25
+ __all__ = ["point_based_window_inferer"]
26
+
27
+
28
+ def point_based_window_inferer(
29
+ inputs: torch.Tensor | MetaTensor,
30
+ roi_size: Sequence[int],
31
+ predictor: torch.nn.Module,
32
+ point_coords: torch.Tensor,
33
+ point_labels: torch.Tensor,
34
+ class_vector: torch.Tensor | None = None,
35
+ prompt_class: torch.Tensor | None = None,
36
+ prev_mask: torch.Tensor | MetaTensor | None = None,
37
+ point_start: int = 0,
38
+ center_only: bool = True,
39
+ margin: int = 5,
40
+ **kwargs: Any,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
44
+ The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
45
+ patch inference and average output stitching, and finally returns the segmented mask.
46
+
47
+ Args:
48
+ inputs: [1CHWD], input image to be processed.
49
+ roi_size: the spatial window size for inferences.
50
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
51
+ if the components of the `roi_size` are non-positive values, the transform will use the
52
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
53
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
54
+ sw_batch_size: the batch size to run window slices.
55
+ predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
56
+ Add transpose=True in kwargs for vista3d.
57
+ point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
58
+ point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
59
+ 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
60
+ class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
61
+ prompt_class: [B]. The same as class_vector representing the point class and inform point head about
62
+ supported class or zeroshot, not used for automatic segmentation. If None, point head is default
63
+ to supported class segmentation.
64
+ prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
65
+ point_start: only use points starting from this number. All points before this number is used to generate
66
+ prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
67
+ center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
68
+ margin: if center_only is false, this value is the distance between point to the patch boundary.
69
+ Returns:
70
+ stitched_output: [1, B, H, W, D]. The value is before sigmoid.
71
+ Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
72
+ """
73
+ if not point_coords.shape[0] == 1:
74
+ raise ValueError("Only supports single object point click.")
75
+ if not len(inputs.shape) == 5:
76
+ raise ValueError("Input image should be 5D.")
77
+ image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
78
+ point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
79
+ prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
80
+ stitched_output = None
81
+ for p in point_coords[0][point_start:]:
82
+ lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
83
+ ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
84
+ lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
85
+ for i in range(len(lx_)):
86
+ for j in range(len(ly_)):
87
+ for k in range(len(lz_)):
88
+ lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
89
+ unravel_slice = [
90
+ slice(None),
91
+ slice(None),
92
+ slice(int(lx), int(rx)),
93
+ slice(int(ly), int(ry)),
94
+ slice(int(lz), int(rz)),
95
+ ]
96
+ batch_image = image[unravel_slice]
97
+ output = predictor(
98
+ batch_image,
99
+ point_coords=point_coords,
100
+ point_labels=point_labels,
101
+ class_vector=class_vector,
102
+ prompt_class=prompt_class,
103
+ patch_coords=unravel_slice,
104
+ prev_mask=prev_mask,
105
+ **kwargs,
106
+ )
107
+ if stitched_output is None:
108
+ stitched_output = torch.zeros(
109
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
110
+ )
111
+ stitched_mask = torch.zeros(
112
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
113
+ )
114
+ stitched_output[unravel_slice] += output.to("cpu")
115
+ stitched_mask[unravel_slice] = 1
116
+ # if stitched_mask is 0, then NaN value
117
+ stitched_output = stitched_output / stitched_mask
118
+ # revert padding
119
+ stitched_output = stitched_output[
120
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
121
+ ]
122
+ stitched_mask = stitched_mask[
123
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
124
+ ]
125
+ if prev_mask is not None:
126
+ prev_mask = prev_mask[
127
+ :,
128
+ :,
129
+ pad[4] : image.shape[-3] - pad[5],
130
+ pad[2] : image.shape[-2] - pad[3],
131
+ pad[0] : image.shape[-1] - pad[1],
132
+ ]
133
+ prev_mask = prev_mask.to("cpu") # type: ignore
134
+ # for un-calculated place, use previous mask
135
+ stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
136
+ if isinstance(inputs, torch.Tensor):
137
+ inputs = MetaTensor(inputs)
138
+ if not hasattr(stitched_output, "meta"):
139
+ stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
140
+ return stitched_output
141
+
142
+
143
+ def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
144
+ """Helper function to get the window index."""
145
+ if p - roi // 2 < 0:
146
+ left, right = 0, roi
147
+ elif p + roi // 2 > s:
148
+ left, right = s - roi, s
149
+ else:
150
+ left, right = int(p) - roi // 2, int(p) + roi // 2
151
+ return left, right
152
+
153
+
154
+ def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
155
+ """Get the window index."""
156
+ left, right = _get_window_idx_c(p, roi, s)
157
+ if center_only:
158
+ return [left], [right]
159
+ left_most = max(0, p - roi + margin)
160
+ right_most = min(s, p + roi - margin)
161
+ left_list = [left_most, right_most - roi, left]
162
+ right_list = [left_most + roi, right_most, right]
163
+ return left_list, right_list
164
+
165
+
166
+ def _pad_previous_mask(
167
+ inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
168
+ ) -> tuple[torch.Tensor | MetaTensor, list[int]]:
169
+ """Helper function to pad inputs."""
170
+ pad_size = []
171
+ for k in range(len(inputs.shape) - 1, 1, -1):
172
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
173
+ half = diff // 2
174
+ pad_size.extend([half, diff - half])
175
+ if any(pad_size):
176
+ inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
177
+ return inputs, pad_size
@@ -0,0 +1,179 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import random
16
+ from collections.abc import Callable, Sequence
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch import Tensor
22
+
23
+ ENABLE_SPECIAL = True
24
+ SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
25
+ MERGE_LIST = {
26
+ 1: [25, 26], # hepatic tumor and vessel merge into liver
27
+ 4: [24], # pancreatic tumor merge into pancreas
28
+ 132: [57], # overlap with trachea merge into airway
29
+ }
30
+
31
+ __all__ = ["sample_prompt_pairs"]
32
+
33
+
34
+ def _get_point_label(id: int) -> tuple[int, int]:
35
+ if id in SPECIAL_INDEX and ENABLE_SPECIAL:
36
+ return 2, 3
37
+ else:
38
+ return 0, 1
39
+
40
+
41
+ def sample_prompt_pairs(
42
+ labels: Tensor,
43
+ label_set: Sequence[int],
44
+ max_prompt: int | None = None,
45
+ max_foreprompt: int | None = None,
46
+ max_backprompt: int = 1,
47
+ max_point: int = 20,
48
+ include_background: bool = False,
49
+ drop_label_prob: float = 0.2,
50
+ drop_point_prob: float = 0.2,
51
+ point_sampler: Callable | None = None,
52
+ **point_sampler_kwargs: Any,
53
+ ) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
54
+ """
55
+ Sample training pairs for VISTA3D training.
56
+
57
+ Args:
58
+ labels: [1, 1, H, W, D], ground truth labels.
59
+ label_set: the label list for the specific dataset. Note if 0 is included in label_set,
60
+ it will be added into automatic branch training. Recommend removing 0 from label_set
61
+ for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
62
+ The reason is region with 0 in one partially labeled dataset may contain foregrounds in
63
+ another dataset.
64
+ max_prompt: int, max number of total prompt, including foreground and background.
65
+ max_foreprompt: int, max number of prompt from foreground.
66
+ max_backprompt: int, max number of prompt from background.
67
+ max_point: maximum number of points for each object.
68
+ include_background: if include 0 into training prompt. If included, background 0 is treated
69
+ the same as foreground and points will be sampled. Can be true only if user want to segment
70
+ background 0 with point clicks, otherwise always be false.
71
+ drop_label_prob: probability to drop label prompt.
72
+ drop_point_prob: probability to drop point prompt.
73
+ point_sampler: sampler to augment masks with supervoxel.
74
+ point_sampler_kwargs: arguments for point_sampler.
75
+
76
+ Returns:
77
+ tuple:
78
+ - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
79
+ training automatic segmentation.
80
+ - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
81
+ for each class. Note that background label prompts require matching points as well
82
+ (e.g., [0, 0, 0] is used).
83
+ - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
84
+ labels for each point (negative or positive). -1 is used for padding the background
85
+ label prompt and will be ignored.
86
+ - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
87
+ for label indexing during training. If label_prompt is None, prompt_class is used to
88
+ identify point classes.
89
+
90
+ """
91
+
92
+ # class label number
93
+ if not labels.shape[0] == 1:
94
+ raise ValueError("only support batch size 1")
95
+ labels = labels[0, 0]
96
+ device = labels.device
97
+ unique_labels = labels.unique().cpu().numpy().tolist()
98
+ if include_background:
99
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
100
+ else:
101
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
102
+ background_labels = list(set(label_set) - set(unique_labels))
103
+ # during training, balance background and foreground prompts
104
+ if max_backprompt is not None:
105
+ if len(background_labels) > max_backprompt:
106
+ random.shuffle(background_labels)
107
+ background_labels = background_labels[:max_backprompt]
108
+
109
+ if max_foreprompt is not None:
110
+ if len(unique_labels) > max_foreprompt:
111
+ random.shuffle(unique_labels)
112
+ unique_labels = unique_labels[:max_foreprompt]
113
+
114
+ if max_prompt is not None:
115
+ if len(unique_labels) + len(background_labels) > max_prompt:
116
+ if len(unique_labels) > max_prompt:
117
+ unique_labels = random.sample(unique_labels, max_prompt)
118
+ background_labels = []
119
+ else:
120
+ background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
121
+ _point = []
122
+ _point_label = []
123
+ # if use regular sampling
124
+ if point_sampler is None:
125
+ num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
126
+ num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
127
+ for id in unique_labels:
128
+ neg_id, pos_id = _get_point_label(id)
129
+ plabels = labels == int(id)
130
+ nlabels = ~plabels
131
+ plabelpoints = torch.nonzero(plabels)
132
+ nlabelpoints = torch.nonzero(nlabels)
133
+ # final sampled positive points
134
+ num_pa = min(len(plabelpoints), num_p)
135
+ # final sampled negative points
136
+ num_na = min(len(nlabelpoints), num_n)
137
+ _point.append(
138
+ torch.stack(
139
+ random.choices(plabelpoints, k=num_pa)
140
+ + random.choices(nlabelpoints, k=num_na)
141
+ + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
142
+ )
143
+ )
144
+ _point_label.append(
145
+ torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
146
+ device
147
+ )
148
+ )
149
+ for _ in background_labels:
150
+ # pad the background labels
151
+ _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
152
+ _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
153
+ else:
154
+ _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
155
+ for _ in background_labels:
156
+ # pad the background labels
157
+ _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
158
+ _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
159
+ if len(unique_labels) == 0 and len(background_labels) == 0:
160
+ # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
161
+ # be skipped. Handle this in trainer.
162
+ label_prompt, point, point_label, prompt_class = None, None, None, None
163
+ else:
164
+ label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
165
+ point = torch.stack(_point)
166
+ point_label = torch.stack(_point_label)
167
+ prompt_class = copy.deepcopy(label_prompt)
168
+ if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
169
+ label_prompt = None
170
+ # If label prompt is dropped, there is no need to pad with points with label -1.
171
+ pad = len(background_labels)
172
+ point = point[: len(point) - pad] # type: ignore
173
+ point_label = point_label[: len(point_label) - pad]
174
+ prompt_class = prompt_class[: len(prompt_class) - pad]
175
+ else:
176
+ if random.uniform(0, 1) < drop_point_prob:
177
+ point = None
178
+ point_label = None
179
+ return label_prompt, point, point_label, prompt_class
@@ -0,0 +1,224 @@
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ import warnings
15
+ from typing import Sequence
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from monai.config import DtypeLike, KeysCollection
21
+ from monai.transforms import MapLabelValue
22
+ from monai.transforms.transform import MapTransform
23
+ from monai.transforms.utils import keep_components_with_positive_points
24
+ from monai.utils import look_up_option
25
+
26
+ __all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"]
27
+
28
+
29
+ def _get_name_to_index_mapping(labels_dict: dict | None) -> dict:
30
+ """get the label name to index mapping"""
31
+ name_to_index_mapping = {}
32
+ if labels_dict is not None:
33
+ name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}
34
+ return name_to_index_mapping
35
+
36
+
37
+ def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:
38
+ """convert the label name to index"""
39
+ if label_prompt is not None and isinstance(label_prompt, list):
40
+ converted_label_prompt = []
41
+ # for new class, add to the mapping
42
+ for l in label_prompt:
43
+ if isinstance(l, str) and not l.isdigit():
44
+ if l.lower() not in name_to_index_mapping:
45
+ name_to_index_mapping[l.lower()] = len(name_to_index_mapping)
46
+ for l in label_prompt:
47
+ if isinstance(l, (int, str)):
48
+ converted_label_prompt.append(
49
+ name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)
50
+ )
51
+ else:
52
+ converted_label_prompt.append(l)
53
+ return converted_label_prompt
54
+ return label_prompt
55
+
56
+
57
+ class VistaPreTransformd(MapTransform):
58
+ def __init__(
59
+ self,
60
+ keys: KeysCollection,
61
+ allow_missing_keys: bool = False,
62
+ special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),
63
+ labels_dict: dict | None = None,
64
+ subclass: dict | None = None,
65
+ ) -> None:
66
+ """
67
+ Pre-transform for Vista3d.
68
+
69
+ It performs two functionalities:
70
+
71
+ 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),
72
+ convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).
73
+
74
+ 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].
75
+ e.g. "lung" label is converted to ["left lung", "right lung"].
76
+
77
+ The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,
78
+ where each element is an int value of length [B, N].
79
+
80
+ Args:
81
+ keys: keys of the corresponding items to be transformed.
82
+ special_index: the index that defines the special class.
83
+ subclass: a dictionary that maps a label prompt to its subclasses.
84
+ allow_missing_keys: don't raise exception if key is missing.
85
+ """
86
+ super().__init__(keys, allow_missing_keys)
87
+ self.special_index = special_index
88
+ self.subclass = subclass
89
+ self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)
90
+
91
+ def __call__(self, data):
92
+ label_prompt = data.get("label_prompt", None)
93
+ point_labels = data.get("point_labels", None)
94
+ # convert the label name to index if needed
95
+ label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)
96
+ try:
97
+ # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.
98
+ if self.subclass is not None and label_prompt is not None:
99
+ _label_prompt = []
100
+ subclass_keys = list(map(int, self.subclass.keys()))
101
+ for i in range(len(label_prompt)):
102
+ if label_prompt[i] in subclass_keys:
103
+ _label_prompt.extend(self.subclass[str(label_prompt[i])])
104
+ else:
105
+ _label_prompt.append(label_prompt[i])
106
+ data["label_prompt"] = _label_prompt
107
+ if label_prompt is not None and point_labels is not None:
108
+ if label_prompt[0] in self.special_index:
109
+ point_labels = np.array(point_labels)
110
+ point_labels[point_labels == 0] = 2
111
+ point_labels[point_labels == 1] = 3
112
+ point_labels = point_labels.tolist()
113
+ data["point_labels"] = point_labels
114
+ except Exception:
115
+ # There is specific requirements for `label_prompt` and `point_labels`.
116
+ # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None.
117
+ # Those formatting errors should be captured later.
118
+ warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.")
119
+
120
+ return data
121
+
122
+
123
+ class VistaPostTransformd(MapTransform):
124
+ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
125
+ """
126
+ Post-transform for Vista3d. It converts the model output logits into final segmentation masks.
127
+ If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],
128
+ else the indexes will be [0, label_prompt[0], label_prompt[1], ...].
129
+ If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove
130
+ regions that does not contain positive points.
131
+
132
+ Args:
133
+ keys: keys of the corresponding items to be transformed.
134
+ dataset_transforms: a dictionary specifies the transform for corresponding dataset:
135
+ key: dataset name, value: list of data transforms.
136
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
137
+ allow_missing_keys: don't raise exception if key is missing.
138
+
139
+ """
140
+ super().__init__(keys, allow_missing_keys)
141
+
142
+ def __call__(self, data):
143
+ """data["label_prompt"] should not contain 0"""
144
+ for keys in self.keys:
145
+ if keys in data:
146
+ pred = data[keys]
147
+ object_num = pred.shape[0]
148
+ device = pred.device
149
+ if data.get("label_prompt", None) is None and data.get("points", None) is not None:
150
+ pred = keep_components_with_positive_points(
151
+ pred.unsqueeze(0),
152
+ point_coords=data.get("points").to(device),
153
+ point_labels=data.get("point_labels").to(device),
154
+ )[0]
155
+ pred[pred < 0] = 0.0
156
+ # if it's multichannel, perform argmax
157
+ if object_num > 1:
158
+ # concate background channel. Make sure user did not provide 0 as prompt.
159
+ is_bk = torch.all(pred <= 0, dim=0, keepdim=True)
160
+ pred = pred.argmax(0).unsqueeze(0).float() + 1.0
161
+ pred[is_bk] = 0.0
162
+ else:
163
+ # AsDiscrete will remove NaN
164
+ # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)
165
+ pred[pred > 0] = 1.0
166
+ if "label_prompt" in data and data["label_prompt"] is not None:
167
+ pred += 0.5 # inplace mapping to avoid cloning pred
168
+ label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device
169
+ for i in range(1, object_num + 1):
170
+ frac = i + 0.5
171
+ pred[pred == frac] = label_prompt[i - 1].to(pred.dtype)
172
+ pred[pred == 0.5] = 0.0
173
+ data[keys] = pred
174
+ return data
175
+
176
+
177
+ class Relabeld(MapTransform):
178
+ def __init__(
179
+ self,
180
+ keys: KeysCollection,
181
+ label_mappings: dict[str, list[tuple[int, int]]],
182
+ dtype: DtypeLike = np.int16,
183
+ dataset_key: str = "dataset_name",
184
+ allow_missing_keys: bool = False,
185
+ ) -> None:
186
+ """
187
+ Remap the voxel labels in the input data dictionary based on the specified mapping.
188
+
189
+ This list of local -> global label mappings will be applied to each input `data[keys]`.
190
+ if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
191
+ if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.
192
+
193
+ Args:
194
+ keys: keys of the corresponding items to be transformed.
195
+ label_mappings: a dictionary specifies how local dataset class indices are mapped to the
196
+ global class indices. The dictionary keys are dataset names and the values are lists of
197
+ list of (local label, global label) pairs. This list of local -> global label mappings
198
+ will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
199
+ label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
200
+ no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.
201
+ dtype: convert the output data to dtype, default to float32.
202
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
203
+ allow_missing_keys: don't raise exception if key is missing.
204
+
205
+ """
206
+ super().__init__(keys, allow_missing_keys)
207
+ self.mappers = {}
208
+ self.dataset_key = dataset_key
209
+ for name, mapping in label_mappings.items():
210
+ self.mappers[name] = MapLabelValue(
211
+ orig_labels=[int(pair[0]) for pair in mapping],
212
+ target_labels=[int(pair[1]) for pair in mapping],
213
+ dtype=dtype,
214
+ )
215
+
216
+ def __call__(self, data):
217
+ d = dict(data)
218
+ dataset_name = d.get(self.dataset_key, "default")
219
+ _m = look_up_option(dataset_name, self.mappers, default=None)
220
+ if _m is None:
221
+ return d
222
+ for key in self.key_iterator(d):
223
+ d[key] = _m(d[key])
224
+ return d
monai/inferers/utils.py CHANGED
@@ -300,6 +300,7 @@ def sliding_window_inference(
300
300
 
301
301
  # remove padding if image_size smaller than roi_size
302
302
  if any(pad_size):
303
+ kwargs.update({"pad_size": pad_size})
303
304
  for ss, output_i in enumerate(output_image_list):
304
305
  zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
305
306
  final_slicing: list[slice] = []
@@ -51,6 +51,8 @@ class BilateralFilter(torch.autograd.Function):
51
51
  ctx.cs = color_sigma
52
52
  ctx.fa = fast_approx
53
53
  output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
54
+ if torch.cuda.is_available():
55
+ torch.cuda.synchronize()
54
56
  return output_data
55
57
 
56
58
  @staticmethod
@@ -139,7 +141,8 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
139
141
  do_dsig_y,
140
142
  do_dsig_z,
141
143
  )
142
-
144
+ if torch.cuda.is_available():
145
+ torch.cuda.synchronize()
143
146
  return output_tensor
144
147
 
145
148
  @staticmethod
@@ -301,7 +304,8 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
301
304
  do_dsig_z,
302
305
  guidance_img,
303
306
  )
304
-
307
+ if torch.cuda.is_available():
308
+ torch.cuda.synchronize()
305
309
  return output_tensor
306
310
 
307
311
  @staticmethod
@@ -23,7 +23,7 @@ import monai
23
23
  from monai.networks.blocks import MLPBlock, UnetrBasicBlock
24
24
  from monai.networks.nets import SegResNetDS2
25
25
  from monai.transforms.utils import convert_points_to_disc
26
- from monai.transforms.utils import get_largest_connected_component_mask_point as lcc
26
+ from monai.transforms.utils import keep_merge_components_with_points as lcc
27
27
  from monai.transforms.utils import sample_points_from_label
28
28
  from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
29
29
 
@@ -78,6 +78,35 @@ class VISTA3D(nn.Module):
78
78
  self.NINF_VALUE = -9999
79
79
  self.PINF_VALUE = 9999
80
80
 
81
+ def update_slidingwindow_padding(
82
+ self,
83
+ pad_size: list | None,
84
+ labels: torch.Tensor | None,
85
+ prev_mask: torch.Tensor | None,
86
+ point_coords: torch.Tensor | None,
87
+ ):
88
+ """
89
+ Image has been padded by sliding window inferer.
90
+ The related padding need to be performed outside of slidingwindow inferer.
91
+
92
+ Args:
93
+ pad_size: padding size passed from sliding window inferer.
94
+ labels: image label ground truth.
95
+ prev_mask: previous segmentation mask.
96
+ point_coords: point click coordinates.
97
+ """
98
+ if pad_size is None:
99
+ return labels, prev_mask, point_coords
100
+ if labels is not None:
101
+ labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
102
+ if prev_mask is not None:
103
+ prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
104
+ if point_coords is not None:
105
+ point_coords = point_coords + torch.tensor(
106
+ [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
107
+ )
108
+ return labels, prev_mask, point_coords
109
+
81
110
  def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
82
111
  """Get number of foreground classes based on class and point prompt."""
83
112
  if class_vector is None:
@@ -307,16 +336,17 @@ class VISTA3D(nn.Module):
307
336
  def forward(
308
337
  self,
309
338
  input_images: torch.Tensor,
339
+ patch_coords: Sequence[slice] | None = None,
310
340
  point_coords: torch.Tensor | None = None,
311
341
  point_labels: torch.Tensor | None = None,
312
342
  class_vector: torch.Tensor | None = None,
313
343
  prompt_class: torch.Tensor | None = None,
314
- patch_coords: Sequence[slice] | None = None,
315
344
  labels: torch.Tensor | None = None,
316
345
  label_set: Sequence[int] | None = None,
317
346
  prev_mask: torch.Tensor | None = None,
318
347
  radius: int | None = None,
319
348
  val_point_sampler: Callable | None = None,
349
+ transpose: bool = False,
320
350
  **kwargs,
321
351
  ):
322
352
  """
@@ -329,7 +359,7 @@ class VISTA3D(nn.Module):
329
359
  point_coords: [B, N, 3]
330
360
  point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
331
361
  2/3 means negative/postive ponits for special supported class like tumor.
332
- class_vector: [B, 1], the global class index
362
+ class_vector: [B, 1], the global class index.
333
363
  prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
334
364
  the points are for zero-shot or supported class. When class_vector and point_coords are both
335
365
  provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
@@ -346,8 +376,12 @@ class VISTA3D(nn.Module):
346
376
  radius: single float value controling the gaussian blur when combining point and auto results.
347
377
  The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
348
378
  val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
349
-
379
+ transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
380
+ sliding window inferer/point inferer.
350
381
  """
382
+ labels, prev_mask, point_coords = self.update_slidingwindow_padding(
383
+ kwargs.get("pad_size", None), labels, prev_mask, point_coords
384
+ )
351
385
  image_size = input_images.shape[-3:]
352
386
  device = input_images.device
353
387
  if point_coords is None and class_vector is None:
@@ -387,7 +421,10 @@ class VISTA3D(nn.Module):
387
421
  point_coords, point_labels = None, None
388
422
 
389
423
  if point_coords is None and class_vector is None:
390
- return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
424
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
425
+ if transpose:
426
+ logits = logits.transpose(1, 0)
427
+ return logits
391
428
 
392
429
  if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
393
430
  out, out_auto = self.image_embeddings, None
@@ -424,9 +461,10 @@ class VISTA3D(nn.Module):
424
461
  point_labels, # type: ignore
425
462
  mapping_index,
426
463
  )
427
-
428
464
  if kwargs.get("keep_cache", False) and class_vector is None:
429
465
  self.image_embeddings = out.detach()
466
+ if transpose:
467
+ logits = logits.transpose(1, 0)
430
468
  return logits
431
469
 
432
470
 
monai/networks/utils.py CHANGED
@@ -851,7 +851,7 @@ def _onnx_trt_compile(
851
851
  # wrap the serialized TensorRT engine back to a TorchScript module.
852
852
  trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
853
853
  f.getvalue(),
854
- device=torch.device(f"cuda:{device}"),
854
+ device=torch_tensorrt.Device(f"cuda:{device}"),
855
855
  input_binding_names=input_names,
856
856
  output_binding_names=output_names,
857
857
  )
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
238
238
  )
239
239
  from .inverse import InvertibleTransform, TraceableTransform
240
240
  from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
241
- from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
242
- from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
241
+ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
242
+ from .io.dictionary import (
243
+ LoadImaged,
244
+ LoadImageD,
245
+ LoadImageDict,
246
+ SaveImaged,
247
+ SaveImageD,
248
+ SaveImageDict,
249
+ WriteFileMappingd,
250
+ WriteFileMappingD,
251
+ WriteFileMappingDict,
252
+ )
243
253
  from .lazy.array import ApplyPending
244
254
  from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
245
255
  from .lazy.functional import apply_pending
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
15
15
  from __future__ import annotations
16
16
 
17
17
  import inspect
18
+ import json
18
19
  import logging
19
20
  import sys
20
21
  import traceback
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
45
46
  from monai.transforms.utility.array import EnsureChannelFirst
46
47
  from monai.utils import GridSamplePadMode
47
48
  from monai.utils import ImageMetaKey as Key
48
- from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
49
+ from monai.utils import (
50
+ MetaKeys,
51
+ OptionalImportError,
52
+ convert_to_dst_type,
53
+ ensure_tuple,
54
+ look_up_option,
55
+ optional_import,
56
+ )
49
57
 
50
58
  nib, _ = optional_import("nibabel")
51
59
  Image, _ = optional_import("PIL.Image")
52
60
  nrrd, _ = optional_import("nrrd")
61
+ FileLock, has_filelock = optional_import("filelock", name="FileLock")
53
62
 
54
63
  __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
55
64
 
@@ -505,7 +514,7 @@ class SaveImage(Transform):
505
514
  else:
506
515
  self._data_index += 1
507
516
  if self.savepath_in_metadict and meta_data is not None:
508
- meta_data["saved_to"] = filename
517
+ meta_data[MetaKeys.SAVED_TO] = filename
509
518
  return img
510
519
  msg = "\n".join([f"{e}" for e in err])
511
520
  raise RuntimeError(
@@ -514,3 +523,50 @@ class SaveImage(Transform):
514
523
  " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
515
524
  f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
516
525
  )
526
+
527
+
528
+ class WriteFileMapping(Transform):
529
+ """
530
+ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
531
+ This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
532
+
533
+ Args:
534
+ mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
535
+ """
536
+
537
+ def __init__(self, mapping_file_path: Path | str = "mapping.json"):
538
+ self.mapping_file_path = Path(mapping_file_path)
539
+
540
+ def __call__(self, img: NdarrayOrTensor):
541
+ """
542
+ Args:
543
+ img: The input image with metadata.
544
+ """
545
+ if isinstance(img, MetaTensor):
546
+ meta_data = img.meta
547
+
548
+ if MetaKeys.SAVED_TO not in meta_data:
549
+ raise KeyError(
550
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
551
+ )
552
+
553
+ input_path = meta_data[Key.FILENAME_OR_OBJ]
554
+ output_path = meta_data[MetaKeys.SAVED_TO]
555
+ log_data = {"input": input_path, "output": output_path}
556
+
557
+ if has_filelock:
558
+ with FileLock(str(self.mapping_file_path) + ".lock"):
559
+ self._write_to_file(log_data)
560
+ else:
561
+ self._write_to_file(log_data)
562
+ return img
563
+
564
+ def _write_to_file(self, log_data):
565
+ try:
566
+ with self.mapping_file_path.open("r") as f:
567
+ existing_log_data = json.load(f)
568
+ except (FileNotFoundError, json.JSONDecodeError):
569
+ existing_log_data = []
570
+ existing_log_data.append(log_data)
571
+ with self.mapping_file_path.open("w") as f:
572
+ json.dump(existing_log_data, f, indent=4)
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from collections.abc import Hashable, Mapping
20
21
  from pathlib import Path
21
22
  from typing import Callable
22
23
 
23
24
  import numpy as np
24
25
 
25
26
  import monai
26
- from monai.config import DtypeLike, KeysCollection
27
+ from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
27
28
  from monai.data import image_writer
28
29
  from monai.data.image_reader import ImageReader
29
- from monai.transforms.io.array import LoadImage, SaveImage
30
+ from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
30
31
  from monai.transforms.transform import MapTransform, Transform
31
32
  from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
32
33
  from monai.utils.enums import PostFix
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
320
321
  return d
321
322
 
322
323
 
324
+ class WriteFileMappingd(MapTransform):
325
+ """
326
+ Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
327
+
328
+ Args:
329
+ keys: keys of the corresponding items to be transformed.
330
+ See also: :py:class:`monai.transforms.compose.MapTransform`
331
+ mapping_file_path: Path to the JSON file where the mappings will be saved.
332
+ Defaults to "mapping.json".
333
+ allow_missing_keys: don't raise exception if key is missing.
334
+ """
335
+
336
+ def __init__(
337
+ self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
338
+ ) -> None:
339
+ super().__init__(keys, allow_missing_keys)
340
+ self.mapping = WriteFileMapping(mapping_file_path)
341
+
342
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
343
+ d = dict(data)
344
+ for key in self.key_iterator(d):
345
+ d[key] = self.mapping(d[key])
346
+ return d
347
+
348
+
323
349
  LoadImageD = LoadImageDict = LoadImaged
324
350
  SaveImageD = SaveImageDict = SaveImaged
351
+ WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
@@ -1714,6 +1714,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
1714
1714
  Probability the transform is applied to the data
1715
1715
  allow_missing_keys:
1716
1716
  Don't raise exception if key is missing.
1717
+
1718
+ Note:
1719
+ - This transform does not scale output image values automatically to match the range of the input.
1720
+ The output should be scaled by later transforms to match the input if this is desired.
1717
1721
  """
1718
1722
 
1719
1723
  backend = ImageFilter.backend
monai/transforms/utils.py CHANGED
@@ -107,7 +107,8 @@ __all__ = [
107
107
  "generate_spatial_bounding_box",
108
108
  "get_extreme_points",
109
109
  "get_largest_connected_component_mask",
110
- "get_largest_connected_component_mask_point",
110
+ "keep_merge_components_with_points",
111
+ "keep_components_with_positive_points",
111
112
  "convert_points_to_disc",
112
113
  "remove_small_objects",
113
114
  "img_bounds",
@@ -1178,7 +1179,7 @@ def get_largest_connected_component_mask(
1178
1179
  return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
1179
1180
 
1180
1181
 
1181
- def get_largest_connected_component_mask_point(
1182
+ def keep_merge_components_with_points(
1182
1183
  img_pos: NdarrayTensor,
1183
1184
  img_neg: NdarrayTensor,
1184
1185
  point_coords: NdarrayTensor,
@@ -1188,8 +1189,8 @@ def get_largest_connected_component_mask_point(
1188
1189
  margins: int = 3,
1189
1190
  ) -> NdarrayTensor:
1190
1191
  """
1191
- Gets the connected component of img_pos and img_neg that include the positive points and
1192
- negative points separately. The function is used for combining automatic results with interactive
1192
+ Keep connected regions of img_pos and img_neg that include the positive points and
1193
+ negative points separately. The function is used for merging automatic results with interactive
1193
1194
  results in VISTA3D.
1194
1195
 
1195
1196
  Args:
@@ -1199,6 +1200,7 @@ def get_largest_connected_component_mask_point(
1199
1200
  neg_val: negative point label values.
1200
1201
  point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
1201
1202
  point_labels: the label of each point, shape [B, N].
1203
+ margins: include points outside of the region but within the margin.
1202
1204
  """
1203
1205
 
1204
1206
  cucim_skimage, has_cucim = optional_import("cucim.skimage")
@@ -1249,6 +1251,49 @@ def get_largest_connected_component_mask_point(
1249
1251
  return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
1250
1252
 
1251
1253
 
1254
+ def keep_components_with_positive_points(
1255
+ img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor
1256
+ ) -> torch.Tensor:
1257
+ """
1258
+ Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove
1259
+ regions without positive points.
1260
+ Args:
1261
+ img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.
1262
+ point_coords: [B, N, 3]. Point click coordinates
1263
+ point_labels: [B, N]. Point click labels.
1264
+ """
1265
+ if not has_measure:
1266
+ raise RuntimeError("skimage.measure required.")
1267
+ outs = torch.zeros_like(img)
1268
+ for c in range(len(point_coords)):
1269
+ if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):
1270
+ # skip if no positive points.
1271
+ continue
1272
+ coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()
1273
+ not_nan_mask = ~torch.isnan(img[0, c])
1274
+ img_ = torch.nan_to_num(img[0, c] > 0, 0)
1275
+ img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore
1276
+ label = measure.label
1277
+ features = label(img_, connectivity=3)
1278
+ pos_mask = torch.from_numpy(img_).to(img.device) > 0
1279
+ # if num features less than max desired, nothing to do.
1280
+ features = torch.from_numpy(features).to(img.device)
1281
+ # generate a map with all pos points
1282
+ idx = []
1283
+ for p in coords:
1284
+ idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())
1285
+ idx = list(set(idx))
1286
+ for i in idx:
1287
+ if i == 0:
1288
+ continue
1289
+ outs[0, c] += features == i
1290
+ outs = outs > 0
1291
+ # find negative mean value
1292
+ fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()
1293
+ img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in
1294
+ return img
1295
+
1296
+
1252
1297
  def convert_points_to_disc(
1253
1298
  image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
1254
1299
  ):
@@ -1269,7 +1314,7 @@ def convert_points_to_disc(
1269
1314
  _array = [
1270
1315
  torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
1271
1316
  ]
1272
- coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0])
1317
+ coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])
1273
1318
  # [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
1274
1319
  coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
1275
1320
  coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
@@ -2467,6 +2512,7 @@ def distance_transform_edt(
2467
2512
  block_params=block_params,
2468
2513
  float64_distances=float64_distances,
2469
2514
  )
2515
+ torch.cuda.synchronize()
2470
2516
  else:
2471
2517
  if not has_ndimage:
2472
2518
  raise RuntimeError("scipy.ndimage required if cupy is not available")
@@ -2500,7 +2546,7 @@ def distance_transform_edt(
2500
2546
 
2501
2547
  r_vals = []
2502
2548
  if return_distances and distances_original is None:
2503
- r_vals.append(distances)
2549
+ r_vals.append(distances_ if use_cp else distances)
2504
2550
  if return_indices and indices_original is None:
2505
2551
  r_vals.append(indices)
2506
2552
  if not r_vals:
monai/utils/enums.py CHANGED
@@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
543
543
  SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
544
544
  SPACE = "space" # possible values of space type are defined in `SpaceKeys`
545
545
  ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
546
+ SAVED_TO = "saved_to"
546
547
 
547
548
 
548
549
  class ColorOrder(StrEnum):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: monai-weekly
3
- Version: 1.4.dev2434
3
+ Version: 1.4.dev2435
4
4
  Summary: AI Toolkit for Healthcare Imaging
5
5
  Home-page: https://monai.io/
6
6
  Author: MONAI Consortium
@@ -167,6 +167,7 @@ Requires-Dist: zarr; extra == "zarr"
167
167
  [![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev)
168
168
  [![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/)
169
169
  [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI)
170
+ [![monai Downloads Last Month](https://assets.piptrends.com/get-last-month-downloads-badge/monai.svg 'monai Downloads Last Month by pip Trends')](https://piptrends.com/package/monai)
170
171
 
171
172
  MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
172
173
  Its ambitions are:
@@ -1,5 +1,5 @@
1
- monai/__init__.py,sha256=rMaIh5LiiohWuFjwscl3Y9xGRH2kESEME5WjZBl7o1g,2722
2
- monai/_version.py,sha256=OSpdYFEM7VmFqOrWc2W_5Ypzg2cajYgcoY5Pcbce52I,503
1
+ monai/__init__.py,sha256=FR4z28zBAHUoK_i-6_1WiPzJfVhqZk6Z4t06Tz305Tg,2722
2
+ monai/_version.py,sha256=Mbte5KgHGPqFVpE9bmQsl31M2shIDZrt8WLkgkm37f4,503
3
3
  monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
5
5
  monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
@@ -52,7 +52,6 @@ monai/apps/generation/maisi/networks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30
52
52
  monai/apps/generation/maisi/networks/autoencoderkl_maisi.py,sha256=Jbj5w9_p_xOLWYgfta26H22zgcC01BR4dmRmDdi13EU,36695
53
53
  monai/apps/generation/maisi/networks/controlnet_maisi.py,sha256=jaTbpvttLybOq6KzC64CQl92BhlOi39zD48Zkdb7zBE,7698
54
54
  monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py,sha256=XFOiy6GngXC_OKM1dUiel_gp71yUFWgPErYdgrVLQAU,19072
55
- monai/apps/generation/maisi/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
56
55
  monai/apps/mmars/__init__.py,sha256=BolpgEi9jNBgrOQd3Kwp-9QQLeWQwQtlN_MJkK1eu5s,726
57
56
  monai/apps/mmars/mmars.py,sha256=AYsx5FDmJ0dT0hAkWGYhM470aPIG23PYloHihDZfOKE,13115
58
57
  monai/apps/mmars/model_desc.py,sha256=k7WSMRuyQN8xPax8aUmGKiTNZmcVatdqPYCgxDih-x4,9996
@@ -99,6 +98,10 @@ monai/apps/reconstruction/transforms/dictionary.py,sha256=3NGkie0WYZdsWWx1_h9Orr
99
98
  monai/apps/tcia/__init__.py,sha256=2uu3nP1j3mDs2AeG-9zmXicD33eQs1g0VHCN8KysEbQ,824
100
99
  monai/apps/tcia/label_desc.py,sha256=B8l9mVmRzLysLmEIIYVeenly_68okCt461qeLQSxCJ8,1582
101
100
  monai/apps/tcia/utils.py,sha256=iyLXr5_51rolbRUZFN_Fwc6TIhAbeSl6XZ2m5RYpzTw,6303
101
+ monai/apps/vista3d/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
102
+ monai/apps/vista3d/inferer.py,sha256=m5jW456eXCgeq8PAnUoJS-SWt0gmLA-aEiLFuS7vbEk,8710
103
+ monai/apps/vista3d/sampler.py,sha256=1uZQIRCO9HY8Rs2FtZ1v0XtPQGZ9RyEjxUycMgIqx7A,8274
104
+ monai/apps/vista3d/transforms.py,sha256=SLsVVRJty5R8X2oeeyPUQCej83__3yKv8qvAUpGxr3s,10641
102
105
  monai/auto3dseg/__init__.py,sha256=DbZC7wqx4zBNcguLQGu8bGmAiKnk9LvjtQDtwdwG19I,1164
103
106
  monai/auto3dseg/algo_gen.py,sha256=_BscoAnUzQKRqz5jHvdsuCe3tTxq7PUQYPMLX0WuxCc,4286
104
107
  monai/auto3dseg/analyzer.py,sha256=7l8QT36lG68b8rK23CC2omz6PO1fxmDwOljxXMn5clQ,41351
@@ -193,7 +196,7 @@ monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,11
193
196
  monai/inferers/inferer.py,sha256=aZwCmM6WGj49SHi_jIkQeGDstMz45frvM1Lomoeqzm4,92669
194
197
  monai/inferers/merger.py,sha256=Ch-qoGUVTTDWN9z_LXBRxElvyuZxOmuqAcecpg1xxAg,15566
195
198
  monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
196
- monai/inferers/utils.py,sha256=dloXtQY_zI_h-_ppoJ2P-0ij9j2vCVEiq5VyL1k-Bs0,20386
199
+ monai/inferers/utils.py,sha256=hKiudomhQL9mbcq1rVWRpRy55Fz1bCD5egv4J3QgLNQ,20432
197
200
  monai/losses/__init__.py,sha256=igy7BjoQzM3McmJPD2tmeiW2ljSXfB2HBdc4YiDzYEg,1778
198
201
  monai/losses/adversarial_loss.py,sha256=9w47lPYU3clj2w9UZ_ZcXCKnmlMfA74YkjFOCVfhF0E,7722
199
202
  monai/losses/barlow_twins.py,sha256=prDdaY0vXAXMuVDmc9Tv6svRZzNwKA0LdsmRaUmusiI,3613
@@ -237,7 +240,7 @@ monai/metrics/surface_distance.py,sha256=bKDTm7ulhjfiphHLrDJoA3OKI3npwQy2Z5wY-Jk
237
240
  monai/metrics/utils.py,sha256=jJiIFGGa-iwvz1otHAKqPKTNmfZqd2dI7_Hsfblgxqk,46914
238
241
  monai/metrics/wrapper.py,sha256=c1zg-xcypQyZ840TEuhhLgr4sClYMWTxlv1OieJTtvE,11781
239
242
  monai/networks/__init__.py,sha256=X-z-kmVt9kwoNPgfYITGycnvG_9HC3_RSRKD2YC35Ag,1020
240
- monai/networks/utils.py,sha256=XQKXogddrhxGz06ZfPGqO8j4VcqRUvaUY-TVrL4vKuA,50290
243
+ monai/networks/utils.py,sha256=fk-jyV6U9UI5V23a7rS4lZ-Oge0x0Rfu2SoxEaCcB9w,50299
241
244
  monai/networks/blocks/__init__.py,sha256=-LMGPMN-eHzwsjkb88H66kImpr4v2hYATZ2y-mRm_K0,2264
242
245
  monai/networks/blocks/acti_norm.py,sha256=bVGXbTZ_ssRvmED5R7LOQ7jj4V6WbVFl8JMO-4iZ2Dk,4275
243
246
  monai/networks/blocks/activation.py,sha256=S5k3zcP2PsHBkeIxgWgNg8ppW80tTResVP2j9ZsvTFw,5839
@@ -276,7 +279,7 @@ monai/networks/layers/conjugate_gradient.py,sha256=kCAwjtX_j5wrgR8x52WdGl4yCwZmc
276
279
  monai/networks/layers/convutils.py,sha256=zwbYK4WJO1Tj2KASnOfxwYnb3p4pizXxdZRm6I1P3j4,8288
277
280
  monai/networks/layers/drop_path.py,sha256=SZtRNa1bDwk1rXWbUe70YDaw6H_NKeplm_Wk5Ye1L4Y,1802
278
281
  monai/networks/layers/factories.py,sha256=dMj-y3LRV5P_FmqMCZuf_A8P8l_fge3TVAXWzNhONuo,15795
279
- monai/networks/layers/filtering.py,sha256=7ru9Yt3yOM-ko-UqzYp-2tMpb8VHt5d767F-KkzrqYY,17992
282
+ monai/networks/layers/filtering.py,sha256=294TaEF_oF-IuL7NQzh64iwW28bRezbPGwp9KynP_ks,18215
280
283
  monai/networks/layers/gmm.py,sha256=Aq-YCHgUalgOZQ0x5mwYKJe1G7aiCiJybdkPTiiT120,3325
281
284
  monai/networks/layers/simplelayers.py,sha256=ciUdKrj_DpEdT3AKs70aPySh73UMsyhoOCTiR2qk8Js,28478
282
285
  monai/networks/layers/spatial_transforms.py,sha256=fz2t7-ibijNLqTYpAn4ZgdXtzBSIyWlaF35mQtqWRY4,25581
@@ -324,7 +327,7 @@ monai/networks/nets/transformer.py,sha256=-nzl20Z5xdtn7xChOd_cRbbPVoPIFGVfTQw3fI
324
327
  monai/networks/nets/unet.py,sha256=riKWB8iEEgO4CIiVTOo532726HWWBfuBcIHeoLvvN0w,13627
325
328
  monai/networks/nets/unetr.py,sha256=wQC3mpn_jEcZb0RXef0ueTe4WGjmnZqQVKKdnemFjnc,8545
326
329
  monai/networks/nets/varautoencoder.py,sha256=Pd9BdXW1iVjmAVCZIc2ElGtSDAWRBaLwEKxLDicyxZI,6282
327
- monai/networks/nets/vista3d.py,sha256=hL9w6bzZntMFYtkKBPSlOo0qcB5ZKE6wdAb6zPqLVQc,41271
330
+ monai/networks/nets/vista3d.py,sha256=hFXgJmQ3nHxci9M5SiKaw0koWKyLUduT-gO22fo0cio,42955
328
331
  monai/networks/nets/vit.py,sha256=SJ5MCJcVAQ2iTqkc1-AFF7oBgCkE7xcNr_ziGc8n_t8,6250
329
332
  monai/networks/nets/vitautoenc.py,sha256=tTX-JHNl2H4y9e5Wk9rrtR6i_ebJHq90O61DnbBFhek,6033
330
333
  monai/networks/nets/vnet.py,sha256=zaJi5kSiTLAuFHThSZfhJvHP6zKh3oBWsTWG-328O_g,10820
@@ -340,7 +343,7 @@ monai/optimizers/lr_finder.py,sha256=tbVi6qd-LLI6pENM9cDUv-Hh1HqziO3Wb9aI6JoaPng
340
343
  monai/optimizers/lr_scheduler.py,sha256=YPY5MWgCTmExuIOBsVJrgfErkCT1ELBekcH0XeRP6Kk,4082
341
344
  monai/optimizers/novograd.py,sha256=dgjyM-WGqrEHsSKNdI3Lw1wJ2YNG3oKCYotfPsDBE80,5677
342
345
  monai/optimizers/utils.py,sha256=GVsJsZWO2aAP9IzwhXgca_9gUNHFClup6qG4ZFs42z4,4133
343
- monai/transforms/__init__.py,sha256=uBhfs9wlZDWjJ_5OHrHQBeLlLy7sse3hsVCJBrNKuS4,16142
346
+ monai/transforms/__init__.py,sha256=CnywmbXXGdM56BhhkqNrcfdRia81M5i3cEiOFj6ftjg,16261
344
347
  monai/transforms/adaptors.py,sha256=jqh7cVvIj4h7-UndP7CNuwxgIUXWY_5kiMzjGC5jFBs,8950
345
348
  monai/transforms/compose.py,sha256=zQa_hf8gIater3Bo_XW1IVYgX7aFa_Co6-BZPwoeaQw,37663
346
349
  monai/transforms/inverse.py,sha256=Wg8UnMJru41G3eHGipUemAWziHGU-qdd-Flfi3eOpeo,18746
@@ -348,7 +351,7 @@ monai/transforms/inverse_batch_transform.py,sha256=fMbukZq2P99BhqqMuWZFJ9uboZ5dN
348
351
  monai/transforms/nvtx.py,sha256=1EKEXZIhTUFKoIrJmd_fevwrHwo731dVFUFJQFiOk3w,3386
349
352
  monai/transforms/traits.py,sha256=F8kmhnekTyaAdo8wIFjO3-uqpVtmFym3mNxbYbyvkFI,3563
350
353
  monai/transforms/transform.py,sha256=DqWyfuI-FDBxjqern33R6Ia1iAfHb3Kh56u-__tp1Kw,21614
351
- monai/transforms/utils.py,sha256=2D56fVCr4SyKUDAPsd2x0-QUM_unH3I2XepvEkUJE2o,102957
354
+ monai/transforms/utils.py,sha256=KX-ikiVZUFytzoRC18AbyoVimXtRp1zXzViZisSDWXU,105015
352
355
  monai/transforms/utils_create_transform_ims.py,sha256=QEJVHsCZX7ZxsBArk6NjgCzSZuuokf8l1uFqiUZBBys,31155
353
356
  monai/transforms/utils_morphological_ops.py,sha256=abaFYSvCfH4k7jk3R_YLtUxgwRYgsz6zj6sOEGM1K5w,6758
354
357
  monai/transforms/utils_pytorch_numpy_unification.py,sha256=PvNO1QeBLTcpLhvuO25ctGr2nIM4B0sTRvnA5TpxJ4Q,18855
@@ -361,8 +364,8 @@ monai/transforms/intensity/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJ
361
364
  monai/transforms/intensity/array.py,sha256=bhKIAMgJu-QMQA8df9QdyancMJMShOIOGHjE__4XdXo,121574
362
365
  monai/transforms/intensity/dictionary.py,sha256=RXZeQG9dPvdvjoiWWlNkYec4NDWBxYXjfct4fywv1Ic,85059
363
366
  monai/transforms/io/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
364
- monai/transforms/io/array.py,sha256=HUFnM3c6yxGkZXfXCYcNLFdFSIhCPSvxSvz4nzkHjrc,25665
365
- monai/transforms/io/dictionary.py,sha256=O1fMHYJUFIgSGE1x0sGXN9Tqn5uPc1cnenfVMbRly-g,17602
367
+ monai/transforms/io/array.py,sha256=z4aOxK44IhztN-LzG2uROYDwg_u1C6gcpx9ZH-ZhoVA,27482
368
+ monai/transforms/io/dictionary.py,sha256=64M9KUsKyzwXopDcarXT7JKIv9rHP8Ae-fYRvI0yBuM,18716
366
369
  monai/transforms/lazy/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
367
370
  monai/transforms/lazy/array.py,sha256=2jNLmQ3_sMX7DdbfcT3Extpwe5FgOBbbz2RqlDlyNcw,1211
368
371
  monai/transforms/lazy/dictionary.py,sha256=bgpZ5CPh5rjdf1T5eQVqxlLh0B57xTWHWaBUUxiQAu4,1571
@@ -388,14 +391,14 @@ monai/transforms/spatial/dictionary.py,sha256=mvP_skSEI1sMl9y-AS3PZqNHhTLK6iOVOf
388
391
  monai/transforms/spatial/functional.py,sha256=4sLTp5ggCJrePg1TQjFhOxdVf1It4-PA6hiv7vMkrBI,31253
389
392
  monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
390
393
  monai/transforms/utility/array.py,sha256=Pcg0nJEAHR60jydZTyueTSss9kaOiM4v6UFF1Fnj0PY,70600
391
- monai/transforms/utility/dictionary.py,sha256=hF90-R2wAMLjYZiGz8xjTVhz4z4hmmrNDXZ5DEC7zLs,73114
394
+ monai/transforms/utility/dictionary.py,sha256=GmJHQpxEpB3lhIuib9K1V4vuPTRoMV2xpwPjSOvMC78,73329
392
395
  monai/utils/__init__.py,sha256=QbMAngvOTgxcwIUpo-LRRBF8PtgG3bzgqXLGVlcUGnc,3757
393
396
  monai/utils/aliases.py,sha256=uBxkLudRfy3Rts9RZo4NDPGoq4e3Ymcaihk6lT92GFo,4096
394
397
  monai/utils/component_store.py,sha256=VMF7CtPu5Wi_eX_qFtm9iWo5kvoWFuCUIxdRzk90zZo,4498
395
398
  monai/utils/decorators.py,sha256=YRK5iEMdbc2INrWnBNDSMTaHge_0ezRf2b9yJGL-opg,3129
396
399
  monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
397
400
  monai/utils/dist.py,sha256=mVaKlBTQJdWAG910sh5pGLEbb_KhRAXV5cPz7amH88Y,8639
398
- monai/utils/enums.py,sha256=Gdo9WBrFODIYz5zt6c00hGz0bqjUQbhCWsfGSgKlnAU,19674
401
+ monai/utils/enums.py,sha256=f__RhrrG4cxxzmICHnmM9riiCvsmUIIk9fYN12Q33lE,19700
399
402
  monai/utils/jupyter_utils.py,sha256=QqcKhJxzEf6YwM8Ik_HvfVDr7gNfrfzCXdzd2urEH8M,15651
400
403
  monai/utils/misc.py,sha256=GJIDxr42juFjnzUTvLtYndcpBQ-EDz6EVXIc7anBoNo,31380
401
404
  monai/utils/module.py,sha256=D9KWFrZ8sS2LrGaLzHnw9MMEbrPI9pHHfHc0XrTLob0,25105
@@ -412,8 +415,8 @@ monai/visualize/img2tensorboard.py,sha256=NnMcyfIFqX-jD7TBO3Rn02zt5uug79d_7pIIaV
412
415
  monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
413
416
  monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
414
417
  monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
415
- monai_weekly-1.4.dev2434.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
416
- monai_weekly-1.4.dev2434.dist-info/METADATA,sha256=O9IZ2AW17KFlfzo3x51bjmKm3aazemQavbcr-q2SP4o,10913
417
- monai_weekly-1.4.dev2434.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
418
- monai_weekly-1.4.dev2434.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
419
- monai_weekly-1.4.dev2434.dist-info/RECORD,,
418
+ monai_weekly-1.4.dev2435.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
419
+ monai_weekly-1.4.dev2435.dist-info/METADATA,sha256=-VRQSgkdacB1TGTzFz7kHZnoowBGiceQuol8LIRj3Hc,11096
420
+ monai_weekly-1.4.dev2435.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
421
+ monai_weekly-1.4.dev2435.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
422
+ monai_weekly-1.4.dev2435.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (74.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5