monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/inferers/utils.py +1 -0
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/vista3d.py +44 -6
- monai/networks/utils.py +1 -1
- monai/transforms/__init__.py +12 -2
- monai/transforms/io/array.py +58 -2
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +52 -6
- monai/utils/enums.py +1 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +2 -1
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +21 -18
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-
|
11
|
+
"date": "2024-09-01T02:28:54+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "d311b1d7b12a95dd7de995b507ffbb5ed413bab6",
|
15
|
+
"version": "1.4.dev2435"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
@@ -0,0 +1,177 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import copy
|
15
|
+
from collections.abc import Sequence
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from monai.data.meta_tensor import MetaTensor
|
21
|
+
from monai.utils import optional_import
|
22
|
+
|
23
|
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
24
|
+
|
25
|
+
__all__ = ["point_based_window_inferer"]
|
26
|
+
|
27
|
+
|
28
|
+
def point_based_window_inferer(
|
29
|
+
inputs: torch.Tensor | MetaTensor,
|
30
|
+
roi_size: Sequence[int],
|
31
|
+
predictor: torch.nn.Module,
|
32
|
+
point_coords: torch.Tensor,
|
33
|
+
point_labels: torch.Tensor,
|
34
|
+
class_vector: torch.Tensor | None = None,
|
35
|
+
prompt_class: torch.Tensor | None = None,
|
36
|
+
prev_mask: torch.Tensor | MetaTensor | None = None,
|
37
|
+
point_start: int = 0,
|
38
|
+
center_only: bool = True,
|
39
|
+
margin: int = 5,
|
40
|
+
**kwargs: Any,
|
41
|
+
) -> torch.Tensor:
|
42
|
+
"""
|
43
|
+
Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
|
44
|
+
The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
|
45
|
+
patch inference and average output stitching, and finally returns the segmented mask.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
inputs: [1CHWD], input image to be processed.
|
49
|
+
roi_size: the spatial window size for inferences.
|
50
|
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
51
|
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
52
|
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
53
|
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
54
|
+
sw_batch_size: the batch size to run window slices.
|
55
|
+
predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
|
56
|
+
Add transpose=True in kwargs for vista3d.
|
57
|
+
point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
|
58
|
+
point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
|
59
|
+
2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
|
60
|
+
class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
|
61
|
+
prompt_class: [B]. The same as class_vector representing the point class and inform point head about
|
62
|
+
supported class or zeroshot, not used for automatic segmentation. If None, point head is default
|
63
|
+
to supported class segmentation.
|
64
|
+
prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
|
65
|
+
point_start: only use points starting from this number. All points before this number is used to generate
|
66
|
+
prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
|
67
|
+
center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
|
68
|
+
margin: if center_only is false, this value is the distance between point to the patch boundary.
|
69
|
+
Returns:
|
70
|
+
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
|
71
|
+
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
|
72
|
+
"""
|
73
|
+
if not point_coords.shape[0] == 1:
|
74
|
+
raise ValueError("Only supports single object point click.")
|
75
|
+
if not len(inputs.shape) == 5:
|
76
|
+
raise ValueError("Input image should be 5D.")
|
77
|
+
image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
|
78
|
+
point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
|
79
|
+
prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
|
80
|
+
stitched_output = None
|
81
|
+
for p in point_coords[0][point_start:]:
|
82
|
+
lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
|
83
|
+
ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
|
84
|
+
lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
|
85
|
+
for i in range(len(lx_)):
|
86
|
+
for j in range(len(ly_)):
|
87
|
+
for k in range(len(lz_)):
|
88
|
+
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
|
89
|
+
unravel_slice = [
|
90
|
+
slice(None),
|
91
|
+
slice(None),
|
92
|
+
slice(int(lx), int(rx)),
|
93
|
+
slice(int(ly), int(ry)),
|
94
|
+
slice(int(lz), int(rz)),
|
95
|
+
]
|
96
|
+
batch_image = image[unravel_slice]
|
97
|
+
output = predictor(
|
98
|
+
batch_image,
|
99
|
+
point_coords=point_coords,
|
100
|
+
point_labels=point_labels,
|
101
|
+
class_vector=class_vector,
|
102
|
+
prompt_class=prompt_class,
|
103
|
+
patch_coords=unravel_slice,
|
104
|
+
prev_mask=prev_mask,
|
105
|
+
**kwargs,
|
106
|
+
)
|
107
|
+
if stitched_output is None:
|
108
|
+
stitched_output = torch.zeros(
|
109
|
+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
|
110
|
+
)
|
111
|
+
stitched_mask = torch.zeros(
|
112
|
+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
|
113
|
+
)
|
114
|
+
stitched_output[unravel_slice] += output.to("cpu")
|
115
|
+
stitched_mask[unravel_slice] = 1
|
116
|
+
# if stitched_mask is 0, then NaN value
|
117
|
+
stitched_output = stitched_output / stitched_mask
|
118
|
+
# revert padding
|
119
|
+
stitched_output = stitched_output[
|
120
|
+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
|
121
|
+
]
|
122
|
+
stitched_mask = stitched_mask[
|
123
|
+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
|
124
|
+
]
|
125
|
+
if prev_mask is not None:
|
126
|
+
prev_mask = prev_mask[
|
127
|
+
:,
|
128
|
+
:,
|
129
|
+
pad[4] : image.shape[-3] - pad[5],
|
130
|
+
pad[2] : image.shape[-2] - pad[3],
|
131
|
+
pad[0] : image.shape[-1] - pad[1],
|
132
|
+
]
|
133
|
+
prev_mask = prev_mask.to("cpu") # type: ignore
|
134
|
+
# for un-calculated place, use previous mask
|
135
|
+
stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
|
136
|
+
if isinstance(inputs, torch.Tensor):
|
137
|
+
inputs = MetaTensor(inputs)
|
138
|
+
if not hasattr(stitched_output, "meta"):
|
139
|
+
stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
|
140
|
+
return stitched_output
|
141
|
+
|
142
|
+
|
143
|
+
def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
|
144
|
+
"""Helper function to get the window index."""
|
145
|
+
if p - roi // 2 < 0:
|
146
|
+
left, right = 0, roi
|
147
|
+
elif p + roi // 2 > s:
|
148
|
+
left, right = s - roi, s
|
149
|
+
else:
|
150
|
+
left, right = int(p) - roi // 2, int(p) + roi // 2
|
151
|
+
return left, right
|
152
|
+
|
153
|
+
|
154
|
+
def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
|
155
|
+
"""Get the window index."""
|
156
|
+
left, right = _get_window_idx_c(p, roi, s)
|
157
|
+
if center_only:
|
158
|
+
return [left], [right]
|
159
|
+
left_most = max(0, p - roi + margin)
|
160
|
+
right_most = min(s, p + roi - margin)
|
161
|
+
left_list = [left_most, right_most - roi, left]
|
162
|
+
right_list = [left_most + roi, right_most, right]
|
163
|
+
return left_list, right_list
|
164
|
+
|
165
|
+
|
166
|
+
def _pad_previous_mask(
|
167
|
+
inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
|
168
|
+
) -> tuple[torch.Tensor | MetaTensor, list[int]]:
|
169
|
+
"""Helper function to pad inputs."""
|
170
|
+
pad_size = []
|
171
|
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
172
|
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
173
|
+
half = diff // 2
|
174
|
+
pad_size.extend([half, diff - half])
|
175
|
+
if any(pad_size):
|
176
|
+
inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
|
177
|
+
return inputs, pad_size
|
@@ -0,0 +1,179 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import copy
|
15
|
+
import random
|
16
|
+
from collections.abc import Callable, Sequence
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
from torch import Tensor
|
22
|
+
|
23
|
+
ENABLE_SPECIAL = True
|
24
|
+
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
|
25
|
+
MERGE_LIST = {
|
26
|
+
1: [25, 26], # hepatic tumor and vessel merge into liver
|
27
|
+
4: [24], # pancreatic tumor merge into pancreas
|
28
|
+
132: [57], # overlap with trachea merge into airway
|
29
|
+
}
|
30
|
+
|
31
|
+
__all__ = ["sample_prompt_pairs"]
|
32
|
+
|
33
|
+
|
34
|
+
def _get_point_label(id: int) -> tuple[int, int]:
|
35
|
+
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
|
36
|
+
return 2, 3
|
37
|
+
else:
|
38
|
+
return 0, 1
|
39
|
+
|
40
|
+
|
41
|
+
def sample_prompt_pairs(
|
42
|
+
labels: Tensor,
|
43
|
+
label_set: Sequence[int],
|
44
|
+
max_prompt: int | None = None,
|
45
|
+
max_foreprompt: int | None = None,
|
46
|
+
max_backprompt: int = 1,
|
47
|
+
max_point: int = 20,
|
48
|
+
include_background: bool = False,
|
49
|
+
drop_label_prob: float = 0.2,
|
50
|
+
drop_point_prob: float = 0.2,
|
51
|
+
point_sampler: Callable | None = None,
|
52
|
+
**point_sampler_kwargs: Any,
|
53
|
+
) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
|
54
|
+
"""
|
55
|
+
Sample training pairs for VISTA3D training.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
labels: [1, 1, H, W, D], ground truth labels.
|
59
|
+
label_set: the label list for the specific dataset. Note if 0 is included in label_set,
|
60
|
+
it will be added into automatic branch training. Recommend removing 0 from label_set
|
61
|
+
for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
|
62
|
+
The reason is region with 0 in one partially labeled dataset may contain foregrounds in
|
63
|
+
another dataset.
|
64
|
+
max_prompt: int, max number of total prompt, including foreground and background.
|
65
|
+
max_foreprompt: int, max number of prompt from foreground.
|
66
|
+
max_backprompt: int, max number of prompt from background.
|
67
|
+
max_point: maximum number of points for each object.
|
68
|
+
include_background: if include 0 into training prompt. If included, background 0 is treated
|
69
|
+
the same as foreground and points will be sampled. Can be true only if user want to segment
|
70
|
+
background 0 with point clicks, otherwise always be false.
|
71
|
+
drop_label_prob: probability to drop label prompt.
|
72
|
+
drop_point_prob: probability to drop point prompt.
|
73
|
+
point_sampler: sampler to augment masks with supervoxel.
|
74
|
+
point_sampler_kwargs: arguments for point_sampler.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
tuple:
|
78
|
+
- label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
|
79
|
+
training automatic segmentation.
|
80
|
+
- point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
|
81
|
+
for each class. Note that background label prompts require matching points as well
|
82
|
+
(e.g., [0, 0, 0] is used).
|
83
|
+
- point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
|
84
|
+
labels for each point (negative or positive). -1 is used for padding the background
|
85
|
+
label prompt and will be ignored.
|
86
|
+
- prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
|
87
|
+
for label indexing during training. If label_prompt is None, prompt_class is used to
|
88
|
+
identify point classes.
|
89
|
+
|
90
|
+
"""
|
91
|
+
|
92
|
+
# class label number
|
93
|
+
if not labels.shape[0] == 1:
|
94
|
+
raise ValueError("only support batch size 1")
|
95
|
+
labels = labels[0, 0]
|
96
|
+
device = labels.device
|
97
|
+
unique_labels = labels.unique().cpu().numpy().tolist()
|
98
|
+
if include_background:
|
99
|
+
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
|
100
|
+
else:
|
101
|
+
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
|
102
|
+
background_labels = list(set(label_set) - set(unique_labels))
|
103
|
+
# during training, balance background and foreground prompts
|
104
|
+
if max_backprompt is not None:
|
105
|
+
if len(background_labels) > max_backprompt:
|
106
|
+
random.shuffle(background_labels)
|
107
|
+
background_labels = background_labels[:max_backprompt]
|
108
|
+
|
109
|
+
if max_foreprompt is not None:
|
110
|
+
if len(unique_labels) > max_foreprompt:
|
111
|
+
random.shuffle(unique_labels)
|
112
|
+
unique_labels = unique_labels[:max_foreprompt]
|
113
|
+
|
114
|
+
if max_prompt is not None:
|
115
|
+
if len(unique_labels) + len(background_labels) > max_prompt:
|
116
|
+
if len(unique_labels) > max_prompt:
|
117
|
+
unique_labels = random.sample(unique_labels, max_prompt)
|
118
|
+
background_labels = []
|
119
|
+
else:
|
120
|
+
background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
|
121
|
+
_point = []
|
122
|
+
_point_label = []
|
123
|
+
# if use regular sampling
|
124
|
+
if point_sampler is None:
|
125
|
+
num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
|
126
|
+
num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
|
127
|
+
for id in unique_labels:
|
128
|
+
neg_id, pos_id = _get_point_label(id)
|
129
|
+
plabels = labels == int(id)
|
130
|
+
nlabels = ~plabels
|
131
|
+
plabelpoints = torch.nonzero(plabels)
|
132
|
+
nlabelpoints = torch.nonzero(nlabels)
|
133
|
+
# final sampled positive points
|
134
|
+
num_pa = min(len(plabelpoints), num_p)
|
135
|
+
# final sampled negative points
|
136
|
+
num_na = min(len(nlabelpoints), num_n)
|
137
|
+
_point.append(
|
138
|
+
torch.stack(
|
139
|
+
random.choices(plabelpoints, k=num_pa)
|
140
|
+
+ random.choices(nlabelpoints, k=num_na)
|
141
|
+
+ [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
|
142
|
+
)
|
143
|
+
)
|
144
|
+
_point_label.append(
|
145
|
+
torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
|
146
|
+
device
|
147
|
+
)
|
148
|
+
)
|
149
|
+
for _ in background_labels:
|
150
|
+
# pad the background labels
|
151
|
+
_point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
|
152
|
+
_point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
|
153
|
+
else:
|
154
|
+
_point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
|
155
|
+
for _ in background_labels:
|
156
|
+
# pad the background labels
|
157
|
+
_point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
|
158
|
+
_point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
|
159
|
+
if len(unique_labels) == 0 and len(background_labels) == 0:
|
160
|
+
# if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
|
161
|
+
# be skipped. Handle this in trainer.
|
162
|
+
label_prompt, point, point_label, prompt_class = None, None, None, None
|
163
|
+
else:
|
164
|
+
label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
|
165
|
+
point = torch.stack(_point)
|
166
|
+
point_label = torch.stack(_point_label)
|
167
|
+
prompt_class = copy.deepcopy(label_prompt)
|
168
|
+
if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
|
169
|
+
label_prompt = None
|
170
|
+
# If label prompt is dropped, there is no need to pad with points with label -1.
|
171
|
+
pad = len(background_labels)
|
172
|
+
point = point[: len(point) - pad] # type: ignore
|
173
|
+
point_label = point_label[: len(point_label) - pad]
|
174
|
+
prompt_class = prompt_class[: len(prompt_class) - pad]
|
175
|
+
else:
|
176
|
+
if random.uniform(0, 1) < drop_point_prob:
|
177
|
+
point = None
|
178
|
+
point_label = None
|
179
|
+
return label_prompt, point, point_label, prompt_class
|
@@ -0,0 +1,224 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import warnings
|
15
|
+
from typing import Sequence
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import torch
|
19
|
+
|
20
|
+
from monai.config import DtypeLike, KeysCollection
|
21
|
+
from monai.transforms import MapLabelValue
|
22
|
+
from monai.transforms.transform import MapTransform
|
23
|
+
from monai.transforms.utils import keep_components_with_positive_points
|
24
|
+
from monai.utils import look_up_option
|
25
|
+
|
26
|
+
__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"]
|
27
|
+
|
28
|
+
|
29
|
+
def _get_name_to_index_mapping(labels_dict: dict | None) -> dict:
|
30
|
+
"""get the label name to index mapping"""
|
31
|
+
name_to_index_mapping = {}
|
32
|
+
if labels_dict is not None:
|
33
|
+
name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}
|
34
|
+
return name_to_index_mapping
|
35
|
+
|
36
|
+
|
37
|
+
def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:
|
38
|
+
"""convert the label name to index"""
|
39
|
+
if label_prompt is not None and isinstance(label_prompt, list):
|
40
|
+
converted_label_prompt = []
|
41
|
+
# for new class, add to the mapping
|
42
|
+
for l in label_prompt:
|
43
|
+
if isinstance(l, str) and not l.isdigit():
|
44
|
+
if l.lower() not in name_to_index_mapping:
|
45
|
+
name_to_index_mapping[l.lower()] = len(name_to_index_mapping)
|
46
|
+
for l in label_prompt:
|
47
|
+
if isinstance(l, (int, str)):
|
48
|
+
converted_label_prompt.append(
|
49
|
+
name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)
|
50
|
+
)
|
51
|
+
else:
|
52
|
+
converted_label_prompt.append(l)
|
53
|
+
return converted_label_prompt
|
54
|
+
return label_prompt
|
55
|
+
|
56
|
+
|
57
|
+
class VistaPreTransformd(MapTransform):
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
keys: KeysCollection,
|
61
|
+
allow_missing_keys: bool = False,
|
62
|
+
special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),
|
63
|
+
labels_dict: dict | None = None,
|
64
|
+
subclass: dict | None = None,
|
65
|
+
) -> None:
|
66
|
+
"""
|
67
|
+
Pre-transform for Vista3d.
|
68
|
+
|
69
|
+
It performs two functionalities:
|
70
|
+
|
71
|
+
1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),
|
72
|
+
convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).
|
73
|
+
|
74
|
+
2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].
|
75
|
+
e.g. "lung" label is converted to ["left lung", "right lung"].
|
76
|
+
|
77
|
+
The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,
|
78
|
+
where each element is an int value of length [B, N].
|
79
|
+
|
80
|
+
Args:
|
81
|
+
keys: keys of the corresponding items to be transformed.
|
82
|
+
special_index: the index that defines the special class.
|
83
|
+
subclass: a dictionary that maps a label prompt to its subclasses.
|
84
|
+
allow_missing_keys: don't raise exception if key is missing.
|
85
|
+
"""
|
86
|
+
super().__init__(keys, allow_missing_keys)
|
87
|
+
self.special_index = special_index
|
88
|
+
self.subclass = subclass
|
89
|
+
self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)
|
90
|
+
|
91
|
+
def __call__(self, data):
|
92
|
+
label_prompt = data.get("label_prompt", None)
|
93
|
+
point_labels = data.get("point_labels", None)
|
94
|
+
# convert the label name to index if needed
|
95
|
+
label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)
|
96
|
+
try:
|
97
|
+
# The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.
|
98
|
+
if self.subclass is not None and label_prompt is not None:
|
99
|
+
_label_prompt = []
|
100
|
+
subclass_keys = list(map(int, self.subclass.keys()))
|
101
|
+
for i in range(len(label_prompt)):
|
102
|
+
if label_prompt[i] in subclass_keys:
|
103
|
+
_label_prompt.extend(self.subclass[str(label_prompt[i])])
|
104
|
+
else:
|
105
|
+
_label_prompt.append(label_prompt[i])
|
106
|
+
data["label_prompt"] = _label_prompt
|
107
|
+
if label_prompt is not None and point_labels is not None:
|
108
|
+
if label_prompt[0] in self.special_index:
|
109
|
+
point_labels = np.array(point_labels)
|
110
|
+
point_labels[point_labels == 0] = 2
|
111
|
+
point_labels[point_labels == 1] = 3
|
112
|
+
point_labels = point_labels.tolist()
|
113
|
+
data["point_labels"] = point_labels
|
114
|
+
except Exception:
|
115
|
+
# There is specific requirements for `label_prompt` and `point_labels`.
|
116
|
+
# If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None.
|
117
|
+
# Those formatting errors should be captured later.
|
118
|
+
warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.")
|
119
|
+
|
120
|
+
return data
|
121
|
+
|
122
|
+
|
123
|
+
class VistaPostTransformd(MapTransform):
|
124
|
+
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
|
125
|
+
"""
|
126
|
+
Post-transform for Vista3d. It converts the model output logits into final segmentation masks.
|
127
|
+
If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],
|
128
|
+
else the indexes will be [0, label_prompt[0], label_prompt[1], ...].
|
129
|
+
If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove
|
130
|
+
regions that does not contain positive points.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
keys: keys of the corresponding items to be transformed.
|
134
|
+
dataset_transforms: a dictionary specifies the transform for corresponding dataset:
|
135
|
+
key: dataset name, value: list of data transforms.
|
136
|
+
dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
|
137
|
+
allow_missing_keys: don't raise exception if key is missing.
|
138
|
+
|
139
|
+
"""
|
140
|
+
super().__init__(keys, allow_missing_keys)
|
141
|
+
|
142
|
+
def __call__(self, data):
|
143
|
+
"""data["label_prompt"] should not contain 0"""
|
144
|
+
for keys in self.keys:
|
145
|
+
if keys in data:
|
146
|
+
pred = data[keys]
|
147
|
+
object_num = pred.shape[0]
|
148
|
+
device = pred.device
|
149
|
+
if data.get("label_prompt", None) is None and data.get("points", None) is not None:
|
150
|
+
pred = keep_components_with_positive_points(
|
151
|
+
pred.unsqueeze(0),
|
152
|
+
point_coords=data.get("points").to(device),
|
153
|
+
point_labels=data.get("point_labels").to(device),
|
154
|
+
)[0]
|
155
|
+
pred[pred < 0] = 0.0
|
156
|
+
# if it's multichannel, perform argmax
|
157
|
+
if object_num > 1:
|
158
|
+
# concate background channel. Make sure user did not provide 0 as prompt.
|
159
|
+
is_bk = torch.all(pred <= 0, dim=0, keepdim=True)
|
160
|
+
pred = pred.argmax(0).unsqueeze(0).float() + 1.0
|
161
|
+
pred[is_bk] = 0.0
|
162
|
+
else:
|
163
|
+
# AsDiscrete will remove NaN
|
164
|
+
# pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)
|
165
|
+
pred[pred > 0] = 1.0
|
166
|
+
if "label_prompt" in data and data["label_prompt"] is not None:
|
167
|
+
pred += 0.5 # inplace mapping to avoid cloning pred
|
168
|
+
label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device
|
169
|
+
for i in range(1, object_num + 1):
|
170
|
+
frac = i + 0.5
|
171
|
+
pred[pred == frac] = label_prompt[i - 1].to(pred.dtype)
|
172
|
+
pred[pred == 0.5] = 0.0
|
173
|
+
data[keys] = pred
|
174
|
+
return data
|
175
|
+
|
176
|
+
|
177
|
+
class Relabeld(MapTransform):
|
178
|
+
def __init__(
|
179
|
+
self,
|
180
|
+
keys: KeysCollection,
|
181
|
+
label_mappings: dict[str, list[tuple[int, int]]],
|
182
|
+
dtype: DtypeLike = np.int16,
|
183
|
+
dataset_key: str = "dataset_name",
|
184
|
+
allow_missing_keys: bool = False,
|
185
|
+
) -> None:
|
186
|
+
"""
|
187
|
+
Remap the voxel labels in the input data dictionary based on the specified mapping.
|
188
|
+
|
189
|
+
This list of local -> global label mappings will be applied to each input `data[keys]`.
|
190
|
+
if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
|
191
|
+
if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
keys: keys of the corresponding items to be transformed.
|
195
|
+
label_mappings: a dictionary specifies how local dataset class indices are mapped to the
|
196
|
+
global class indices. The dictionary keys are dataset names and the values are lists of
|
197
|
+
list of (local label, global label) pairs. This list of local -> global label mappings
|
198
|
+
will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
|
199
|
+
label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
|
200
|
+
no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.
|
201
|
+
dtype: convert the output data to dtype, default to float32.
|
202
|
+
dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
|
203
|
+
allow_missing_keys: don't raise exception if key is missing.
|
204
|
+
|
205
|
+
"""
|
206
|
+
super().__init__(keys, allow_missing_keys)
|
207
|
+
self.mappers = {}
|
208
|
+
self.dataset_key = dataset_key
|
209
|
+
for name, mapping in label_mappings.items():
|
210
|
+
self.mappers[name] = MapLabelValue(
|
211
|
+
orig_labels=[int(pair[0]) for pair in mapping],
|
212
|
+
target_labels=[int(pair[1]) for pair in mapping],
|
213
|
+
dtype=dtype,
|
214
|
+
)
|
215
|
+
|
216
|
+
def __call__(self, data):
|
217
|
+
d = dict(data)
|
218
|
+
dataset_name = d.get(self.dataset_key, "default")
|
219
|
+
_m = look_up_option(dataset_name, self.mappers, default=None)
|
220
|
+
if _m is None:
|
221
|
+
return d
|
222
|
+
for key in self.key_iterator(d):
|
223
|
+
d[key] = _m(d[key])
|
224
|
+
return d
|
monai/inferers/utils.py
CHANGED
@@ -300,6 +300,7 @@ def sliding_window_inference(
|
|
300
300
|
|
301
301
|
# remove padding if image_size smaller than roi_size
|
302
302
|
if any(pad_size):
|
303
|
+
kwargs.update({"pad_size": pad_size})
|
303
304
|
for ss, output_i in enumerate(output_image_list):
|
304
305
|
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
|
305
306
|
final_slicing: list[slice] = []
|
@@ -51,6 +51,8 @@ class BilateralFilter(torch.autograd.Function):
|
|
51
51
|
ctx.cs = color_sigma
|
52
52
|
ctx.fa = fast_approx
|
53
53
|
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
|
54
|
+
if torch.cuda.is_available():
|
55
|
+
torch.cuda.synchronize()
|
54
56
|
return output_data
|
55
57
|
|
56
58
|
@staticmethod
|
@@ -139,7 +141,8 @@ class TrainableBilateralFilterFunction(torch.autograd.Function):
|
|
139
141
|
do_dsig_y,
|
140
142
|
do_dsig_z,
|
141
143
|
)
|
142
|
-
|
144
|
+
if torch.cuda.is_available():
|
145
|
+
torch.cuda.synchronize()
|
143
146
|
return output_tensor
|
144
147
|
|
145
148
|
@staticmethod
|
@@ -301,7 +304,8 @@ class TrainableJointBilateralFilterFunction(torch.autograd.Function):
|
|
301
304
|
do_dsig_z,
|
302
305
|
guidance_img,
|
303
306
|
)
|
304
|
-
|
307
|
+
if torch.cuda.is_available():
|
308
|
+
torch.cuda.synchronize()
|
305
309
|
return output_tensor
|
306
310
|
|
307
311
|
@staticmethod
|
monai/networks/nets/vista3d.py
CHANGED
@@ -23,7 +23,7 @@ import monai
|
|
23
23
|
from monai.networks.blocks import MLPBlock, UnetrBasicBlock
|
24
24
|
from monai.networks.nets import SegResNetDS2
|
25
25
|
from monai.transforms.utils import convert_points_to_disc
|
26
|
-
from monai.transforms.utils import
|
26
|
+
from monai.transforms.utils import keep_merge_components_with_points as lcc
|
27
27
|
from monai.transforms.utils import sample_points_from_label
|
28
28
|
from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
|
29
29
|
|
@@ -78,6 +78,35 @@ class VISTA3D(nn.Module):
|
|
78
78
|
self.NINF_VALUE = -9999
|
79
79
|
self.PINF_VALUE = 9999
|
80
80
|
|
81
|
+
def update_slidingwindow_padding(
|
82
|
+
self,
|
83
|
+
pad_size: list | None,
|
84
|
+
labels: torch.Tensor | None,
|
85
|
+
prev_mask: torch.Tensor | None,
|
86
|
+
point_coords: torch.Tensor | None,
|
87
|
+
):
|
88
|
+
"""
|
89
|
+
Image has been padded by sliding window inferer.
|
90
|
+
The related padding need to be performed outside of slidingwindow inferer.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
pad_size: padding size passed from sliding window inferer.
|
94
|
+
labels: image label ground truth.
|
95
|
+
prev_mask: previous segmentation mask.
|
96
|
+
point_coords: point click coordinates.
|
97
|
+
"""
|
98
|
+
if pad_size is None:
|
99
|
+
return labels, prev_mask, point_coords
|
100
|
+
if labels is not None:
|
101
|
+
labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
|
102
|
+
if prev_mask is not None:
|
103
|
+
prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
|
104
|
+
if point_coords is not None:
|
105
|
+
point_coords = point_coords + torch.tensor(
|
106
|
+
[pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
|
107
|
+
)
|
108
|
+
return labels, prev_mask, point_coords
|
109
|
+
|
81
110
|
def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
|
82
111
|
"""Get number of foreground classes based on class and point prompt."""
|
83
112
|
if class_vector is None:
|
@@ -307,16 +336,17 @@ class VISTA3D(nn.Module):
|
|
307
336
|
def forward(
|
308
337
|
self,
|
309
338
|
input_images: torch.Tensor,
|
339
|
+
patch_coords: Sequence[slice] | None = None,
|
310
340
|
point_coords: torch.Tensor | None = None,
|
311
341
|
point_labels: torch.Tensor | None = None,
|
312
342
|
class_vector: torch.Tensor | None = None,
|
313
343
|
prompt_class: torch.Tensor | None = None,
|
314
|
-
patch_coords: Sequence[slice] | None = None,
|
315
344
|
labels: torch.Tensor | None = None,
|
316
345
|
label_set: Sequence[int] | None = None,
|
317
346
|
prev_mask: torch.Tensor | None = None,
|
318
347
|
radius: int | None = None,
|
319
348
|
val_point_sampler: Callable | None = None,
|
349
|
+
transpose: bool = False,
|
320
350
|
**kwargs,
|
321
351
|
):
|
322
352
|
"""
|
@@ -329,7 +359,7 @@ class VISTA3D(nn.Module):
|
|
329
359
|
point_coords: [B, N, 3]
|
330
360
|
point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
|
331
361
|
2/3 means negative/postive ponits for special supported class like tumor.
|
332
|
-
class_vector: [B, 1], the global class index
|
362
|
+
class_vector: [B, 1], the global class index.
|
333
363
|
prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
|
334
364
|
the points are for zero-shot or supported class. When class_vector and point_coords are both
|
335
365
|
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
|
@@ -346,8 +376,12 @@ class VISTA3D(nn.Module):
|
|
346
376
|
radius: single float value controling the gaussian blur when combining point and auto results.
|
347
377
|
The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
|
348
378
|
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
|
349
|
-
|
379
|
+
transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
|
380
|
+
sliding window inferer/point inferer.
|
350
381
|
"""
|
382
|
+
labels, prev_mask, point_coords = self.update_slidingwindow_padding(
|
383
|
+
kwargs.get("pad_size", None), labels, prev_mask, point_coords
|
384
|
+
)
|
351
385
|
image_size = input_images.shape[-3:]
|
352
386
|
device = input_images.device
|
353
387
|
if point_coords is None and class_vector is None:
|
@@ -387,7 +421,10 @@ class VISTA3D(nn.Module):
|
|
387
421
|
point_coords, point_labels = None, None
|
388
422
|
|
389
423
|
if point_coords is None and class_vector is None:
|
390
|
-
|
424
|
+
logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
|
425
|
+
if transpose:
|
426
|
+
logits = logits.transpose(1, 0)
|
427
|
+
return logits
|
391
428
|
|
392
429
|
if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
|
393
430
|
out, out_auto = self.image_embeddings, None
|
@@ -424,9 +461,10 @@ class VISTA3D(nn.Module):
|
|
424
461
|
point_labels, # type: ignore
|
425
462
|
mapping_index,
|
426
463
|
)
|
427
|
-
|
428
464
|
if kwargs.get("keep_cache", False) and class_vector is None:
|
429
465
|
self.image_embeddings = out.detach()
|
466
|
+
if transpose:
|
467
|
+
logits = logits.transpose(1, 0)
|
430
468
|
return logits
|
431
469
|
|
432
470
|
|
monai/networks/utils.py
CHANGED
@@ -851,7 +851,7 @@ def _onnx_trt_compile(
|
|
851
851
|
# wrap the serialized TensorRT engine back to a TorchScript module.
|
852
852
|
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
|
853
853
|
f.getvalue(),
|
854
|
-
device=
|
854
|
+
device=torch_tensorrt.Device(f"cuda:{device}"),
|
855
855
|
input_binding_names=input_names,
|
856
856
|
output_binding_names=output_names,
|
857
857
|
)
|
monai/transforms/__init__.py
CHANGED
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
|
|
238
238
|
)
|
239
239
|
from .inverse import InvertibleTransform, TraceableTransform
|
240
240
|
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
|
241
|
-
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
|
242
|
-
from .io.dictionary import
|
241
|
+
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
|
242
|
+
from .io.dictionary import (
|
243
|
+
LoadImaged,
|
244
|
+
LoadImageD,
|
245
|
+
LoadImageDict,
|
246
|
+
SaveImaged,
|
247
|
+
SaveImageD,
|
248
|
+
SaveImageDict,
|
249
|
+
WriteFileMappingd,
|
250
|
+
WriteFileMappingD,
|
251
|
+
WriteFileMappingDict,
|
252
|
+
)
|
243
253
|
from .lazy.array import ApplyPending
|
244
254
|
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
|
245
255
|
from .lazy.functional import apply_pending
|
monai/transforms/io/array.py
CHANGED
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import inspect
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import sys
|
20
21
|
import traceback
|
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
|
|
45
46
|
from monai.transforms.utility.array import EnsureChannelFirst
|
46
47
|
from monai.utils import GridSamplePadMode
|
47
48
|
from monai.utils import ImageMetaKey as Key
|
48
|
-
from monai.utils import
|
49
|
+
from monai.utils import (
|
50
|
+
MetaKeys,
|
51
|
+
OptionalImportError,
|
52
|
+
convert_to_dst_type,
|
53
|
+
ensure_tuple,
|
54
|
+
look_up_option,
|
55
|
+
optional_import,
|
56
|
+
)
|
49
57
|
|
50
58
|
nib, _ = optional_import("nibabel")
|
51
59
|
Image, _ = optional_import("PIL.Image")
|
52
60
|
nrrd, _ = optional_import("nrrd")
|
61
|
+
FileLock, has_filelock = optional_import("filelock", name="FileLock")
|
53
62
|
|
54
63
|
__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
|
55
64
|
|
@@ -505,7 +514,7 @@ class SaveImage(Transform):
|
|
505
514
|
else:
|
506
515
|
self._data_index += 1
|
507
516
|
if self.savepath_in_metadict and meta_data is not None:
|
508
|
-
meta_data[
|
517
|
+
meta_data[MetaKeys.SAVED_TO] = filename
|
509
518
|
return img
|
510
519
|
msg = "\n".join([f"{e}" for e in err])
|
511
520
|
raise RuntimeError(
|
@@ -514,3 +523,50 @@ class SaveImage(Transform):
|
|
514
523
|
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
|
515
524
|
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
|
516
525
|
)
|
526
|
+
|
527
|
+
|
528
|
+
class WriteFileMapping(Transform):
|
529
|
+
"""
|
530
|
+
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
|
531
|
+
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
|
535
|
+
"""
|
536
|
+
|
537
|
+
def __init__(self, mapping_file_path: Path | str = "mapping.json"):
|
538
|
+
self.mapping_file_path = Path(mapping_file_path)
|
539
|
+
|
540
|
+
def __call__(self, img: NdarrayOrTensor):
|
541
|
+
"""
|
542
|
+
Args:
|
543
|
+
img: The input image with metadata.
|
544
|
+
"""
|
545
|
+
if isinstance(img, MetaTensor):
|
546
|
+
meta_data = img.meta
|
547
|
+
|
548
|
+
if MetaKeys.SAVED_TO not in meta_data:
|
549
|
+
raise KeyError(
|
550
|
+
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
|
551
|
+
)
|
552
|
+
|
553
|
+
input_path = meta_data[Key.FILENAME_OR_OBJ]
|
554
|
+
output_path = meta_data[MetaKeys.SAVED_TO]
|
555
|
+
log_data = {"input": input_path, "output": output_path}
|
556
|
+
|
557
|
+
if has_filelock:
|
558
|
+
with FileLock(str(self.mapping_file_path) + ".lock"):
|
559
|
+
self._write_to_file(log_data)
|
560
|
+
else:
|
561
|
+
self._write_to_file(log_data)
|
562
|
+
return img
|
563
|
+
|
564
|
+
def _write_to_file(self, log_data):
|
565
|
+
try:
|
566
|
+
with self.mapping_file_path.open("r") as f:
|
567
|
+
existing_log_data = json.load(f)
|
568
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
569
|
+
existing_log_data = []
|
570
|
+
existing_log_data.append(log_data)
|
571
|
+
with self.mapping_file_path.open("w") as f:
|
572
|
+
json.dump(existing_log_data, f, indent=4)
|
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from collections.abc import Hashable, Mapping
|
20
21
|
from pathlib import Path
|
21
22
|
from typing import Callable
|
22
23
|
|
23
24
|
import numpy as np
|
24
25
|
|
25
26
|
import monai
|
26
|
-
from monai.config import DtypeLike, KeysCollection
|
27
|
+
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
|
27
28
|
from monai.data import image_writer
|
28
29
|
from monai.data.image_reader import ImageReader
|
29
|
-
from monai.transforms.io.array import LoadImage, SaveImage
|
30
|
+
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
|
30
31
|
from monai.transforms.transform import MapTransform, Transform
|
31
32
|
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
|
32
33
|
from monai.utils.enums import PostFix
|
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
|
|
320
321
|
return d
|
321
322
|
|
322
323
|
|
324
|
+
class WriteFileMappingd(MapTransform):
|
325
|
+
"""
|
326
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
keys: keys of the corresponding items to be transformed.
|
330
|
+
See also: :py:class:`monai.transforms.compose.MapTransform`
|
331
|
+
mapping_file_path: Path to the JSON file where the mappings will be saved.
|
332
|
+
Defaults to "mapping.json".
|
333
|
+
allow_missing_keys: don't raise exception if key is missing.
|
334
|
+
"""
|
335
|
+
|
336
|
+
def __init__(
|
337
|
+
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
|
338
|
+
) -> None:
|
339
|
+
super().__init__(keys, allow_missing_keys)
|
340
|
+
self.mapping = WriteFileMapping(mapping_file_path)
|
341
|
+
|
342
|
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
343
|
+
d = dict(data)
|
344
|
+
for key in self.key_iterator(d):
|
345
|
+
d[key] = self.mapping(d[key])
|
346
|
+
return d
|
347
|
+
|
348
|
+
|
323
349
|
LoadImageD = LoadImageDict = LoadImaged
|
324
350
|
SaveImageD = SaveImageDict = SaveImaged
|
351
|
+
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
|
@@ -1714,6 +1714,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
|
|
1714
1714
|
Probability the transform is applied to the data
|
1715
1715
|
allow_missing_keys:
|
1716
1716
|
Don't raise exception if key is missing.
|
1717
|
+
|
1718
|
+
Note:
|
1719
|
+
- This transform does not scale output image values automatically to match the range of the input.
|
1720
|
+
The output should be scaled by later transforms to match the input if this is desired.
|
1717
1721
|
"""
|
1718
1722
|
|
1719
1723
|
backend = ImageFilter.backend
|
monai/transforms/utils.py
CHANGED
@@ -107,7 +107,8 @@ __all__ = [
|
|
107
107
|
"generate_spatial_bounding_box",
|
108
108
|
"get_extreme_points",
|
109
109
|
"get_largest_connected_component_mask",
|
110
|
-
"
|
110
|
+
"keep_merge_components_with_points",
|
111
|
+
"keep_components_with_positive_points",
|
111
112
|
"convert_points_to_disc",
|
112
113
|
"remove_small_objects",
|
113
114
|
"img_bounds",
|
@@ -1178,7 +1179,7 @@ def get_largest_connected_component_mask(
|
|
1178
1179
|
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
|
1179
1180
|
|
1180
1181
|
|
1181
|
-
def
|
1182
|
+
def keep_merge_components_with_points(
|
1182
1183
|
img_pos: NdarrayTensor,
|
1183
1184
|
img_neg: NdarrayTensor,
|
1184
1185
|
point_coords: NdarrayTensor,
|
@@ -1188,8 +1189,8 @@ def get_largest_connected_component_mask_point(
|
|
1188
1189
|
margins: int = 3,
|
1189
1190
|
) -> NdarrayTensor:
|
1190
1191
|
"""
|
1191
|
-
|
1192
|
-
negative points separately. The function is used for
|
1192
|
+
Keep connected regions of img_pos and img_neg that include the positive points and
|
1193
|
+
negative points separately. The function is used for merging automatic results with interactive
|
1193
1194
|
results in VISTA3D.
|
1194
1195
|
|
1195
1196
|
Args:
|
@@ -1199,6 +1200,7 @@ def get_largest_connected_component_mask_point(
|
|
1199
1200
|
neg_val: negative point label values.
|
1200
1201
|
point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
|
1201
1202
|
point_labels: the label of each point, shape [B, N].
|
1203
|
+
margins: include points outside of the region but within the margin.
|
1202
1204
|
"""
|
1203
1205
|
|
1204
1206
|
cucim_skimage, has_cucim = optional_import("cucim.skimage")
|
@@ -1249,6 +1251,49 @@ def get_largest_connected_component_mask_point(
|
|
1249
1251
|
return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
|
1250
1252
|
|
1251
1253
|
|
1254
|
+
def keep_components_with_positive_points(
|
1255
|
+
img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor
|
1256
|
+
) -> torch.Tensor:
|
1257
|
+
"""
|
1258
|
+
Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove
|
1259
|
+
regions without positive points.
|
1260
|
+
Args:
|
1261
|
+
img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.
|
1262
|
+
point_coords: [B, N, 3]. Point click coordinates
|
1263
|
+
point_labels: [B, N]. Point click labels.
|
1264
|
+
"""
|
1265
|
+
if not has_measure:
|
1266
|
+
raise RuntimeError("skimage.measure required.")
|
1267
|
+
outs = torch.zeros_like(img)
|
1268
|
+
for c in range(len(point_coords)):
|
1269
|
+
if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):
|
1270
|
+
# skip if no positive points.
|
1271
|
+
continue
|
1272
|
+
coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()
|
1273
|
+
not_nan_mask = ~torch.isnan(img[0, c])
|
1274
|
+
img_ = torch.nan_to_num(img[0, c] > 0, 0)
|
1275
|
+
img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore
|
1276
|
+
label = measure.label
|
1277
|
+
features = label(img_, connectivity=3)
|
1278
|
+
pos_mask = torch.from_numpy(img_).to(img.device) > 0
|
1279
|
+
# if num features less than max desired, nothing to do.
|
1280
|
+
features = torch.from_numpy(features).to(img.device)
|
1281
|
+
# generate a map with all pos points
|
1282
|
+
idx = []
|
1283
|
+
for p in coords:
|
1284
|
+
idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())
|
1285
|
+
idx = list(set(idx))
|
1286
|
+
for i in idx:
|
1287
|
+
if i == 0:
|
1288
|
+
continue
|
1289
|
+
outs[0, c] += features == i
|
1290
|
+
outs = outs > 0
|
1291
|
+
# find negative mean value
|
1292
|
+
fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()
|
1293
|
+
img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in
|
1294
|
+
return img
|
1295
|
+
|
1296
|
+
|
1252
1297
|
def convert_points_to_disc(
|
1253
1298
|
image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
|
1254
1299
|
):
|
@@ -1269,7 +1314,7 @@ def convert_points_to_disc(
|
|
1269
1314
|
_array = [
|
1270
1315
|
torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
|
1271
1316
|
]
|
1272
|
-
coord_rows, coord_cols, coord_z = torch.meshgrid(_array[
|
1317
|
+
coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])
|
1273
1318
|
# [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
|
1274
1319
|
coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
|
1275
1320
|
coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
|
@@ -2467,6 +2512,7 @@ def distance_transform_edt(
|
|
2467
2512
|
block_params=block_params,
|
2468
2513
|
float64_distances=float64_distances,
|
2469
2514
|
)
|
2515
|
+
torch.cuda.synchronize()
|
2470
2516
|
else:
|
2471
2517
|
if not has_ndimage:
|
2472
2518
|
raise RuntimeError("scipy.ndimage required if cupy is not available")
|
@@ -2500,7 +2546,7 @@ def distance_transform_edt(
|
|
2500
2546
|
|
2501
2547
|
r_vals = []
|
2502
2548
|
if return_distances and distances_original is None:
|
2503
|
-
r_vals.append(distances)
|
2549
|
+
r_vals.append(distances_ if use_cp else distances)
|
2504
2550
|
if return_indices and indices_original is None:
|
2505
2551
|
r_vals.append(indices)
|
2506
2552
|
if not r_vals:
|
monai/utils/enums.py
CHANGED
@@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
|
|
543
543
|
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
|
544
544
|
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
|
545
545
|
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
|
546
|
+
SAVED_TO = "saved_to"
|
546
547
|
|
547
548
|
|
548
549
|
class ColorOrder(StrEnum):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.4.
|
3
|
+
Version: 1.4.dev2435
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -167,6 +167,7 @@ Requires-Dist: zarr; extra == "zarr"
|
|
167
167
|
[](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev)
|
168
168
|
[](https://docs.monai.io/en/latest/)
|
169
169
|
[](https://codecov.io/gh/Project-MONAI/MONAI)
|
170
|
+
[](https://piptrends.com/package/monai)
|
170
171
|
|
171
172
|
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
|
172
173
|
Its ambitions are:
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=FR4z28zBAHUoK_i-6_1WiPzJfVhqZk6Z4t06Tz305Tg,2722
|
2
|
+
monai/_version.py,sha256=Mbte5KgHGPqFVpE9bmQsl31M2shIDZrt8WLkgkm37f4,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -52,7 +52,6 @@ monai/apps/generation/maisi/networks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30
|
|
52
52
|
monai/apps/generation/maisi/networks/autoencoderkl_maisi.py,sha256=Jbj5w9_p_xOLWYgfta26H22zgcC01BR4dmRmDdi13EU,36695
|
53
53
|
monai/apps/generation/maisi/networks/controlnet_maisi.py,sha256=jaTbpvttLybOq6KzC64CQl92BhlOi39zD48Zkdb7zBE,7698
|
54
54
|
monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py,sha256=XFOiy6GngXC_OKM1dUiel_gp71yUFWgPErYdgrVLQAU,19072
|
55
|
-
monai/apps/generation/maisi/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
56
55
|
monai/apps/mmars/__init__.py,sha256=BolpgEi9jNBgrOQd3Kwp-9QQLeWQwQtlN_MJkK1eu5s,726
|
57
56
|
monai/apps/mmars/mmars.py,sha256=AYsx5FDmJ0dT0hAkWGYhM470aPIG23PYloHihDZfOKE,13115
|
58
57
|
monai/apps/mmars/model_desc.py,sha256=k7WSMRuyQN8xPax8aUmGKiTNZmcVatdqPYCgxDih-x4,9996
|
@@ -99,6 +98,10 @@ monai/apps/reconstruction/transforms/dictionary.py,sha256=3NGkie0WYZdsWWx1_h9Orr
|
|
99
98
|
monai/apps/tcia/__init__.py,sha256=2uu3nP1j3mDs2AeG-9zmXicD33eQs1g0VHCN8KysEbQ,824
|
100
99
|
monai/apps/tcia/label_desc.py,sha256=B8l9mVmRzLysLmEIIYVeenly_68okCt461qeLQSxCJ8,1582
|
101
100
|
monai/apps/tcia/utils.py,sha256=iyLXr5_51rolbRUZFN_Fwc6TIhAbeSl6XZ2m5RYpzTw,6303
|
101
|
+
monai/apps/vista3d/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
102
|
+
monai/apps/vista3d/inferer.py,sha256=m5jW456eXCgeq8PAnUoJS-SWt0gmLA-aEiLFuS7vbEk,8710
|
103
|
+
monai/apps/vista3d/sampler.py,sha256=1uZQIRCO9HY8Rs2FtZ1v0XtPQGZ9RyEjxUycMgIqx7A,8274
|
104
|
+
monai/apps/vista3d/transforms.py,sha256=SLsVVRJty5R8X2oeeyPUQCej83__3yKv8qvAUpGxr3s,10641
|
102
105
|
monai/auto3dseg/__init__.py,sha256=DbZC7wqx4zBNcguLQGu8bGmAiKnk9LvjtQDtwdwG19I,1164
|
103
106
|
monai/auto3dseg/algo_gen.py,sha256=_BscoAnUzQKRqz5jHvdsuCe3tTxq7PUQYPMLX0WuxCc,4286
|
104
107
|
monai/auto3dseg/analyzer.py,sha256=7l8QT36lG68b8rK23CC2omz6PO1fxmDwOljxXMn5clQ,41351
|
@@ -193,7 +196,7 @@ monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,11
|
|
193
196
|
monai/inferers/inferer.py,sha256=aZwCmM6WGj49SHi_jIkQeGDstMz45frvM1Lomoeqzm4,92669
|
194
197
|
monai/inferers/merger.py,sha256=Ch-qoGUVTTDWN9z_LXBRxElvyuZxOmuqAcecpg1xxAg,15566
|
195
198
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
196
|
-
monai/inferers/utils.py,sha256=
|
199
|
+
monai/inferers/utils.py,sha256=hKiudomhQL9mbcq1rVWRpRy55Fz1bCD5egv4J3QgLNQ,20432
|
197
200
|
monai/losses/__init__.py,sha256=igy7BjoQzM3McmJPD2tmeiW2ljSXfB2HBdc4YiDzYEg,1778
|
198
201
|
monai/losses/adversarial_loss.py,sha256=9w47lPYU3clj2w9UZ_ZcXCKnmlMfA74YkjFOCVfhF0E,7722
|
199
202
|
monai/losses/barlow_twins.py,sha256=prDdaY0vXAXMuVDmc9Tv6svRZzNwKA0LdsmRaUmusiI,3613
|
@@ -237,7 +240,7 @@ monai/metrics/surface_distance.py,sha256=bKDTm7ulhjfiphHLrDJoA3OKI3npwQy2Z5wY-Jk
|
|
237
240
|
monai/metrics/utils.py,sha256=jJiIFGGa-iwvz1otHAKqPKTNmfZqd2dI7_Hsfblgxqk,46914
|
238
241
|
monai/metrics/wrapper.py,sha256=c1zg-xcypQyZ840TEuhhLgr4sClYMWTxlv1OieJTtvE,11781
|
239
242
|
monai/networks/__init__.py,sha256=X-z-kmVt9kwoNPgfYITGycnvG_9HC3_RSRKD2YC35Ag,1020
|
240
|
-
monai/networks/utils.py,sha256=
|
243
|
+
monai/networks/utils.py,sha256=fk-jyV6U9UI5V23a7rS4lZ-Oge0x0Rfu2SoxEaCcB9w,50299
|
241
244
|
monai/networks/blocks/__init__.py,sha256=-LMGPMN-eHzwsjkb88H66kImpr4v2hYATZ2y-mRm_K0,2264
|
242
245
|
monai/networks/blocks/acti_norm.py,sha256=bVGXbTZ_ssRvmED5R7LOQ7jj4V6WbVFl8JMO-4iZ2Dk,4275
|
243
246
|
monai/networks/blocks/activation.py,sha256=S5k3zcP2PsHBkeIxgWgNg8ppW80tTResVP2j9ZsvTFw,5839
|
@@ -276,7 +279,7 @@ monai/networks/layers/conjugate_gradient.py,sha256=kCAwjtX_j5wrgR8x52WdGl4yCwZmc
|
|
276
279
|
monai/networks/layers/convutils.py,sha256=zwbYK4WJO1Tj2KASnOfxwYnb3p4pizXxdZRm6I1P3j4,8288
|
277
280
|
monai/networks/layers/drop_path.py,sha256=SZtRNa1bDwk1rXWbUe70YDaw6H_NKeplm_Wk5Ye1L4Y,1802
|
278
281
|
monai/networks/layers/factories.py,sha256=dMj-y3LRV5P_FmqMCZuf_A8P8l_fge3TVAXWzNhONuo,15795
|
279
|
-
monai/networks/layers/filtering.py,sha256=
|
282
|
+
monai/networks/layers/filtering.py,sha256=294TaEF_oF-IuL7NQzh64iwW28bRezbPGwp9KynP_ks,18215
|
280
283
|
monai/networks/layers/gmm.py,sha256=Aq-YCHgUalgOZQ0x5mwYKJe1G7aiCiJybdkPTiiT120,3325
|
281
284
|
monai/networks/layers/simplelayers.py,sha256=ciUdKrj_DpEdT3AKs70aPySh73UMsyhoOCTiR2qk8Js,28478
|
282
285
|
monai/networks/layers/spatial_transforms.py,sha256=fz2t7-ibijNLqTYpAn4ZgdXtzBSIyWlaF35mQtqWRY4,25581
|
@@ -324,7 +327,7 @@ monai/networks/nets/transformer.py,sha256=-nzl20Z5xdtn7xChOd_cRbbPVoPIFGVfTQw3fI
|
|
324
327
|
monai/networks/nets/unet.py,sha256=riKWB8iEEgO4CIiVTOo532726HWWBfuBcIHeoLvvN0w,13627
|
325
328
|
monai/networks/nets/unetr.py,sha256=wQC3mpn_jEcZb0RXef0ueTe4WGjmnZqQVKKdnemFjnc,8545
|
326
329
|
monai/networks/nets/varautoencoder.py,sha256=Pd9BdXW1iVjmAVCZIc2ElGtSDAWRBaLwEKxLDicyxZI,6282
|
327
|
-
monai/networks/nets/vista3d.py,sha256=
|
330
|
+
monai/networks/nets/vista3d.py,sha256=hFXgJmQ3nHxci9M5SiKaw0koWKyLUduT-gO22fo0cio,42955
|
328
331
|
monai/networks/nets/vit.py,sha256=SJ5MCJcVAQ2iTqkc1-AFF7oBgCkE7xcNr_ziGc8n_t8,6250
|
329
332
|
monai/networks/nets/vitautoenc.py,sha256=tTX-JHNl2H4y9e5Wk9rrtR6i_ebJHq90O61DnbBFhek,6033
|
330
333
|
monai/networks/nets/vnet.py,sha256=zaJi5kSiTLAuFHThSZfhJvHP6zKh3oBWsTWG-328O_g,10820
|
@@ -340,7 +343,7 @@ monai/optimizers/lr_finder.py,sha256=tbVi6qd-LLI6pENM9cDUv-Hh1HqziO3Wb9aI6JoaPng
|
|
340
343
|
monai/optimizers/lr_scheduler.py,sha256=YPY5MWgCTmExuIOBsVJrgfErkCT1ELBekcH0XeRP6Kk,4082
|
341
344
|
monai/optimizers/novograd.py,sha256=dgjyM-WGqrEHsSKNdI3Lw1wJ2YNG3oKCYotfPsDBE80,5677
|
342
345
|
monai/optimizers/utils.py,sha256=GVsJsZWO2aAP9IzwhXgca_9gUNHFClup6qG4ZFs42z4,4133
|
343
|
-
monai/transforms/__init__.py,sha256=
|
346
|
+
monai/transforms/__init__.py,sha256=CnywmbXXGdM56BhhkqNrcfdRia81M5i3cEiOFj6ftjg,16261
|
344
347
|
monai/transforms/adaptors.py,sha256=jqh7cVvIj4h7-UndP7CNuwxgIUXWY_5kiMzjGC5jFBs,8950
|
345
348
|
monai/transforms/compose.py,sha256=zQa_hf8gIater3Bo_XW1IVYgX7aFa_Co6-BZPwoeaQw,37663
|
346
349
|
monai/transforms/inverse.py,sha256=Wg8UnMJru41G3eHGipUemAWziHGU-qdd-Flfi3eOpeo,18746
|
@@ -348,7 +351,7 @@ monai/transforms/inverse_batch_transform.py,sha256=fMbukZq2P99BhqqMuWZFJ9uboZ5dN
|
|
348
351
|
monai/transforms/nvtx.py,sha256=1EKEXZIhTUFKoIrJmd_fevwrHwo731dVFUFJQFiOk3w,3386
|
349
352
|
monai/transforms/traits.py,sha256=F8kmhnekTyaAdo8wIFjO3-uqpVtmFym3mNxbYbyvkFI,3563
|
350
353
|
monai/transforms/transform.py,sha256=DqWyfuI-FDBxjqern33R6Ia1iAfHb3Kh56u-__tp1Kw,21614
|
351
|
-
monai/transforms/utils.py,sha256=
|
354
|
+
monai/transforms/utils.py,sha256=KX-ikiVZUFytzoRC18AbyoVimXtRp1zXzViZisSDWXU,105015
|
352
355
|
monai/transforms/utils_create_transform_ims.py,sha256=QEJVHsCZX7ZxsBArk6NjgCzSZuuokf8l1uFqiUZBBys,31155
|
353
356
|
monai/transforms/utils_morphological_ops.py,sha256=abaFYSvCfH4k7jk3R_YLtUxgwRYgsz6zj6sOEGM1K5w,6758
|
354
357
|
monai/transforms/utils_pytorch_numpy_unification.py,sha256=PvNO1QeBLTcpLhvuO25ctGr2nIM4B0sTRvnA5TpxJ4Q,18855
|
@@ -361,8 +364,8 @@ monai/transforms/intensity/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJ
|
|
361
364
|
monai/transforms/intensity/array.py,sha256=bhKIAMgJu-QMQA8df9QdyancMJMShOIOGHjE__4XdXo,121574
|
362
365
|
monai/transforms/intensity/dictionary.py,sha256=RXZeQG9dPvdvjoiWWlNkYec4NDWBxYXjfct4fywv1Ic,85059
|
363
366
|
monai/transforms/io/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
364
|
-
monai/transforms/io/array.py,sha256=
|
365
|
-
monai/transforms/io/dictionary.py,sha256=
|
367
|
+
monai/transforms/io/array.py,sha256=z4aOxK44IhztN-LzG2uROYDwg_u1C6gcpx9ZH-ZhoVA,27482
|
368
|
+
monai/transforms/io/dictionary.py,sha256=64M9KUsKyzwXopDcarXT7JKIv9rHP8Ae-fYRvI0yBuM,18716
|
366
369
|
monai/transforms/lazy/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
367
370
|
monai/transforms/lazy/array.py,sha256=2jNLmQ3_sMX7DdbfcT3Extpwe5FgOBbbz2RqlDlyNcw,1211
|
368
371
|
monai/transforms/lazy/dictionary.py,sha256=bgpZ5CPh5rjdf1T5eQVqxlLh0B57xTWHWaBUUxiQAu4,1571
|
@@ -388,14 +391,14 @@ monai/transforms/spatial/dictionary.py,sha256=mvP_skSEI1sMl9y-AS3PZqNHhTLK6iOVOf
|
|
388
391
|
monai/transforms/spatial/functional.py,sha256=4sLTp5ggCJrePg1TQjFhOxdVf1It4-PA6hiv7vMkrBI,31253
|
389
392
|
monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
390
393
|
monai/transforms/utility/array.py,sha256=Pcg0nJEAHR60jydZTyueTSss9kaOiM4v6UFF1Fnj0PY,70600
|
391
|
-
monai/transforms/utility/dictionary.py,sha256=
|
394
|
+
monai/transforms/utility/dictionary.py,sha256=GmJHQpxEpB3lhIuib9K1V4vuPTRoMV2xpwPjSOvMC78,73329
|
392
395
|
monai/utils/__init__.py,sha256=QbMAngvOTgxcwIUpo-LRRBF8PtgG3bzgqXLGVlcUGnc,3757
|
393
396
|
monai/utils/aliases.py,sha256=uBxkLudRfy3Rts9RZo4NDPGoq4e3Ymcaihk6lT92GFo,4096
|
394
397
|
monai/utils/component_store.py,sha256=VMF7CtPu5Wi_eX_qFtm9iWo5kvoWFuCUIxdRzk90zZo,4498
|
395
398
|
monai/utils/decorators.py,sha256=YRK5iEMdbc2INrWnBNDSMTaHge_0ezRf2b9yJGL-opg,3129
|
396
399
|
monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
|
397
400
|
monai/utils/dist.py,sha256=mVaKlBTQJdWAG910sh5pGLEbb_KhRAXV5cPz7amH88Y,8639
|
398
|
-
monai/utils/enums.py,sha256=
|
401
|
+
monai/utils/enums.py,sha256=f__RhrrG4cxxzmICHnmM9riiCvsmUIIk9fYN12Q33lE,19700
|
399
402
|
monai/utils/jupyter_utils.py,sha256=QqcKhJxzEf6YwM8Ik_HvfVDr7gNfrfzCXdzd2urEH8M,15651
|
400
403
|
monai/utils/misc.py,sha256=GJIDxr42juFjnzUTvLtYndcpBQ-EDz6EVXIc7anBoNo,31380
|
401
404
|
monai/utils/module.py,sha256=D9KWFrZ8sS2LrGaLzHnw9MMEbrPI9pHHfHc0XrTLob0,25105
|
@@ -412,8 +415,8 @@ monai/visualize/img2tensorboard.py,sha256=NnMcyfIFqX-jD7TBO3Rn02zt5uug79d_7pIIaV
|
|
412
415
|
monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
|
413
416
|
monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
|
414
417
|
monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
|
415
|
-
monai_weekly-1.4.
|
416
|
-
monai_weekly-1.4.
|
417
|
-
monai_weekly-1.4.
|
418
|
-
monai_weekly-1.4.
|
419
|
-
monai_weekly-1.4.
|
418
|
+
monai_weekly-1.4.dev2435.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
419
|
+
monai_weekly-1.4.dev2435.dist-info/METADATA,sha256=-VRQSgkdacB1TGTzFz7kHZnoowBGiceQuol8LIRj3Hc,11096
|
420
|
+
monai_weekly-1.4.dev2435.dist-info/WHEEL,sha256=UvcQYKBHoFqaQd6LKyqHw9fxEolWLQnlzP0h_LgJAfI,91
|
421
|
+
monai_weekly-1.4.dev2435.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
|
422
|
+
monai_weekly-1.4.dev2435.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|