monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
- monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
- monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/scripts.py +29 -17
- monai/data/utils.py +1 -1
- monai/data/wsi_datasets.py +3 -3
- monai/inferers/utils.py +1 -0
- monai/losses/__init__.py +1 -0
- monai/losses/dice.py +10 -1
- monai/losses/nacl_loss.py +139 -0
- monai/networks/blocks/crossattention.py +48 -26
- monai/networks/blocks/mlp.py +16 -4
- monai/networks/blocks/selfattention.py +75 -23
- monai/networks/blocks/spatialattention.py +16 -1
- monai/networks/blocks/transformerblock.py +17 -2
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/__init__.py +2 -1
- monai/networks/nets/autoencoderkl.py +55 -22
- monai/networks/nets/cell_sam_wrapper.py +92 -0
- monai/networks/nets/controlnet.py +24 -22
- monai/networks/nets/diffusion_model_unet.py +159 -19
- monai/networks/nets/segresnet_ds.py +127 -1
- monai/networks/nets/spade_autoencoderkl.py +22 -0
- monai/networks/nets/spade_diffusion_model_unet.py +39 -2
- monai/networks/nets/transformer.py +17 -17
- monai/networks/nets/vista3d.py +946 -0
- monai/networks/utils.py +4 -4
- monai/transforms/__init__.py +13 -2
- monai/transforms/io/array.py +59 -3
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/functional.py +1 -1
- monai/transforms/transform.py +2 -2
- monai/transforms/utility/dictionary.py +4 -0
- monai/transforms/utils.py +230 -1
- monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
- monai/transforms/utils_pytorch_numpy_unification.py +2 -2
- monai/utils/enums.py +1 -0
- monai/utils/module.py +7 -6
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
monai/networks/utils.py
CHANGED
@@ -822,7 +822,7 @@ def _onnx_trt_compile(
|
|
822
822
|
output_names = [] if not output_names else output_names
|
823
823
|
|
824
824
|
# set up the TensorRT builder
|
825
|
-
|
825
|
+
torch.cuda.set_device(device)
|
826
826
|
logger = trt.Logger(trt.Logger.WARNING)
|
827
827
|
builder = trt.Builder(logger)
|
828
828
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
@@ -851,7 +851,7 @@ def _onnx_trt_compile(
|
|
851
851
|
# wrap the serialized TensorRT engine back to a TorchScript module.
|
852
852
|
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
|
853
853
|
f.getvalue(),
|
854
|
-
device=
|
854
|
+
device=torch_tensorrt.Device(f"cuda:{device}"),
|
855
855
|
input_binding_names=input_names,
|
856
856
|
output_binding_names=output_names,
|
857
857
|
)
|
@@ -931,7 +931,7 @@ def convert_to_trt(
|
|
931
931
|
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")
|
932
932
|
|
933
933
|
device = device if device else 0
|
934
|
-
target_device = torch.device(f"cuda:{device}")
|
934
|
+
target_device = torch.device(f"cuda:{device}")
|
935
935
|
convert_precision = torch.float32 if precision == "fp32" else torch.half
|
936
936
|
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
|
937
937
|
|
@@ -986,7 +986,7 @@ def convert_to_trt(
|
|
986
986
|
ir_model,
|
987
987
|
inputs=input_placeholder,
|
988
988
|
enabled_precisions=convert_precision,
|
989
|
-
device=
|
989
|
+
device=torch_tensorrt.Device(f"cuda:{device}"),
|
990
990
|
ir="torchscript",
|
991
991
|
**kwargs,
|
992
992
|
)
|
monai/transforms/__init__.py
CHANGED
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
|
|
238
238
|
)
|
239
239
|
from .inverse import InvertibleTransform, TraceableTransform
|
240
240
|
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
|
241
|
-
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
|
242
|
-
from .io.dictionary import
|
241
|
+
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
|
242
|
+
from .io.dictionary import (
|
243
|
+
LoadImaged,
|
244
|
+
LoadImageD,
|
245
|
+
LoadImageDict,
|
246
|
+
SaveImaged,
|
247
|
+
SaveImageD,
|
248
|
+
SaveImageDict,
|
249
|
+
WriteFileMappingd,
|
250
|
+
WriteFileMappingD,
|
251
|
+
WriteFileMappingDict,
|
252
|
+
)
|
243
253
|
from .lazy.array import ApplyPending
|
244
254
|
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
|
245
255
|
from .lazy.functional import apply_pending
|
@@ -688,6 +698,7 @@ from .utils import (
|
|
688
698
|
weighted_patch_samples,
|
689
699
|
zero_margins,
|
690
700
|
)
|
701
|
+
from .utils_morphological_ops import dilate, erode
|
691
702
|
from .utils_pytorch_numpy_unification import (
|
692
703
|
allclose,
|
693
704
|
any_np_pt,
|
monai/transforms/io/array.py
CHANGED
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import inspect
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import sys
|
20
21
|
import traceback
|
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
|
|
45
46
|
from monai.transforms.utility.array import EnsureChannelFirst
|
46
47
|
from monai.utils import GridSamplePadMode
|
47
48
|
from monai.utils import ImageMetaKey as Key
|
48
|
-
from monai.utils import
|
49
|
+
from monai.utils import (
|
50
|
+
MetaKeys,
|
51
|
+
OptionalImportError,
|
52
|
+
convert_to_dst_type,
|
53
|
+
ensure_tuple,
|
54
|
+
look_up_option,
|
55
|
+
optional_import,
|
56
|
+
)
|
49
57
|
|
50
58
|
nib, _ = optional_import("nibabel")
|
51
59
|
Image, _ = optional_import("PIL.Image")
|
52
60
|
nrrd, _ = optional_import("nrrd")
|
61
|
+
FileLock, has_filelock = optional_import("filelock", name="FileLock")
|
53
62
|
|
54
63
|
__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
|
55
64
|
|
@@ -86,7 +95,7 @@ def switch_endianness(data, new="<"):
|
|
86
95
|
if new not in ("<", ">"):
|
87
96
|
raise NotImplementedError(f"Not implemented option new={new}.")
|
88
97
|
if current_ != new:
|
89
|
-
data = data.byteswap().newbyteorder(new)
|
98
|
+
data = data.byteswap().view(data.dtype.newbyteorder(new))
|
90
99
|
elif isinstance(data, tuple):
|
91
100
|
data = tuple(switch_endianness(x, new) for x in data)
|
92
101
|
elif isinstance(data, list):
|
@@ -505,7 +514,7 @@ class SaveImage(Transform):
|
|
505
514
|
else:
|
506
515
|
self._data_index += 1
|
507
516
|
if self.savepath_in_metadict and meta_data is not None:
|
508
|
-
meta_data[
|
517
|
+
meta_data[MetaKeys.SAVED_TO] = filename
|
509
518
|
return img
|
510
519
|
msg = "\n".join([f"{e}" for e in err])
|
511
520
|
raise RuntimeError(
|
@@ -514,3 +523,50 @@ class SaveImage(Transform):
|
|
514
523
|
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
|
515
524
|
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
|
516
525
|
)
|
526
|
+
|
527
|
+
|
528
|
+
class WriteFileMapping(Transform):
|
529
|
+
"""
|
530
|
+
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
|
531
|
+
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
|
535
|
+
"""
|
536
|
+
|
537
|
+
def __init__(self, mapping_file_path: Path | str = "mapping.json"):
|
538
|
+
self.mapping_file_path = Path(mapping_file_path)
|
539
|
+
|
540
|
+
def __call__(self, img: NdarrayOrTensor):
|
541
|
+
"""
|
542
|
+
Args:
|
543
|
+
img: The input image with metadata.
|
544
|
+
"""
|
545
|
+
if isinstance(img, MetaTensor):
|
546
|
+
meta_data = img.meta
|
547
|
+
|
548
|
+
if MetaKeys.SAVED_TO not in meta_data:
|
549
|
+
raise KeyError(
|
550
|
+
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
|
551
|
+
)
|
552
|
+
|
553
|
+
input_path = meta_data[Key.FILENAME_OR_OBJ]
|
554
|
+
output_path = meta_data[MetaKeys.SAVED_TO]
|
555
|
+
log_data = {"input": input_path, "output": output_path}
|
556
|
+
|
557
|
+
if has_filelock:
|
558
|
+
with FileLock(str(self.mapping_file_path) + ".lock"):
|
559
|
+
self._write_to_file(log_data)
|
560
|
+
else:
|
561
|
+
self._write_to_file(log_data)
|
562
|
+
return img
|
563
|
+
|
564
|
+
def _write_to_file(self, log_data):
|
565
|
+
try:
|
566
|
+
with self.mapping_file_path.open("r") as f:
|
567
|
+
existing_log_data = json.load(f)
|
568
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
569
|
+
existing_log_data = []
|
570
|
+
existing_log_data.append(log_data)
|
571
|
+
with self.mapping_file_path.open("w") as f:
|
572
|
+
json.dump(existing_log_data, f, indent=4)
|
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from collections.abc import Hashable, Mapping
|
20
21
|
from pathlib import Path
|
21
22
|
from typing import Callable
|
22
23
|
|
23
24
|
import numpy as np
|
24
25
|
|
25
26
|
import monai
|
26
|
-
from monai.config import DtypeLike, KeysCollection
|
27
|
+
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
|
27
28
|
from monai.data import image_writer
|
28
29
|
from monai.data.image_reader import ImageReader
|
29
|
-
from monai.transforms.io.array import LoadImage, SaveImage
|
30
|
+
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
|
30
31
|
from monai.transforms.transform import MapTransform, Transform
|
31
32
|
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
|
32
33
|
from monai.utils.enums import PostFix
|
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
|
|
320
321
|
return d
|
321
322
|
|
322
323
|
|
324
|
+
class WriteFileMappingd(MapTransform):
|
325
|
+
"""
|
326
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
keys: keys of the corresponding items to be transformed.
|
330
|
+
See also: :py:class:`monai.transforms.compose.MapTransform`
|
331
|
+
mapping_file_path: Path to the JSON file where the mappings will be saved.
|
332
|
+
Defaults to "mapping.json".
|
333
|
+
allow_missing_keys: don't raise exception if key is missing.
|
334
|
+
"""
|
335
|
+
|
336
|
+
def __init__(
|
337
|
+
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
|
338
|
+
) -> None:
|
339
|
+
super().__init__(keys, allow_missing_keys)
|
340
|
+
self.mapping = WriteFileMapping(mapping_file_path)
|
341
|
+
|
342
|
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
343
|
+
d = dict(data)
|
344
|
+
for key in self.key_iterator(d):
|
345
|
+
d[key] = self.mapping(d[key])
|
346
|
+
return d
|
347
|
+
|
348
|
+
|
323
349
|
LoadImageD = LoadImageDict = LoadImaged
|
324
350
|
SaveImageD = SaveImageDict = SaveImaged
|
351
|
+
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
|
@@ -373,7 +373,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l
|
|
373
373
|
if output_shape is None:
|
374
374
|
corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1))
|
375
375
|
corners = transform[:-1, :-1] @ corners # type: ignore
|
376
|
-
output_shape = np.asarray(
|
376
|
+
output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int)
|
377
377
|
else:
|
378
378
|
output_shape = np.asarray(output_shape, dtype=int)
|
379
379
|
shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist())
|
monai/transforms/transform.py
CHANGED
@@ -203,8 +203,8 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
|
|
203
203
|
|
204
204
|
"""
|
205
205
|
if seed is not None:
|
206
|
-
_seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed
|
207
|
-
_seed = _seed % MAX_SEED
|
206
|
+
_seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
|
207
|
+
_seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
|
208
208
|
self.R = np.random.RandomState(_seed)
|
209
209
|
return self
|
210
210
|
|
@@ -1714,6 +1714,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
|
|
1714
1714
|
Probability the transform is applied to the data
|
1715
1715
|
allow_missing_keys:
|
1716
1716
|
Don't raise exception if key is missing.
|
1717
|
+
|
1718
|
+
Note:
|
1719
|
+
- This transform does not scale output image values automatically to match the range of the input.
|
1720
|
+
The output should be scaled by later transforms to match the input if this is desired.
|
1717
1721
|
"""
|
1718
1722
|
|
1719
1723
|
backend = ImageFilter.backend
|
monai/transforms/utils.py
CHANGED
@@ -22,6 +22,7 @@ from typing import Any
|
|
22
22
|
|
23
23
|
import numpy as np
|
24
24
|
import torch
|
25
|
+
from torch import Tensor
|
25
26
|
|
26
27
|
import monai
|
27
28
|
from monai.config import DtypeLike, IndexSelection
|
@@ -30,6 +31,7 @@ from monai.networks.layers import GaussianFilter
|
|
30
31
|
from monai.networks.utils import meshgrid_ij
|
31
32
|
from monai.transforms.compose import Compose
|
32
33
|
from monai.transforms.transform import MapTransform, Transform, apply_transform
|
34
|
+
from monai.transforms.utils_morphological_ops import erode
|
33
35
|
from monai.transforms.utils_pytorch_numpy_unification import (
|
34
36
|
any_np_pt,
|
35
37
|
ascontiguousarray,
|
@@ -65,6 +67,8 @@ from monai.utils import (
|
|
65
67
|
min_version,
|
66
68
|
optional_import,
|
67
69
|
pytorch_after,
|
70
|
+
unsqueeze_left,
|
71
|
+
unsqueeze_right,
|
68
72
|
)
|
69
73
|
from monai.utils.enums import TransformBackends
|
70
74
|
from monai.utils.type_conversion import (
|
@@ -103,6 +107,9 @@ __all__ = [
|
|
103
107
|
"generate_spatial_bounding_box",
|
104
108
|
"get_extreme_points",
|
105
109
|
"get_largest_connected_component_mask",
|
110
|
+
"keep_merge_components_with_points",
|
111
|
+
"keep_components_with_positive_points",
|
112
|
+
"convert_points_to_disc",
|
106
113
|
"remove_small_objects",
|
107
114
|
"img_bounds",
|
108
115
|
"in_bounds",
|
@@ -1172,6 +1179,227 @@ def get_largest_connected_component_mask(
|
|
1172
1179
|
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
|
1173
1180
|
|
1174
1181
|
|
1182
|
+
def keep_merge_components_with_points(
|
1183
|
+
img_pos: NdarrayTensor,
|
1184
|
+
img_neg: NdarrayTensor,
|
1185
|
+
point_coords: NdarrayTensor,
|
1186
|
+
point_labels: NdarrayTensor,
|
1187
|
+
pos_val: Sequence[int] = (1, 3),
|
1188
|
+
neg_val: Sequence[int] = (0, 2),
|
1189
|
+
margins: int = 3,
|
1190
|
+
) -> NdarrayTensor:
|
1191
|
+
"""
|
1192
|
+
Keep connected regions of img_pos and img_neg that include the positive points and
|
1193
|
+
negative points separately. The function is used for merging automatic results with interactive
|
1194
|
+
results in VISTA3D.
|
1195
|
+
|
1196
|
+
Args:
|
1197
|
+
img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.
|
1198
|
+
img_neg: same format as img_pos but corresponds to negative points.
|
1199
|
+
pos_val: positive point label values.
|
1200
|
+
neg_val: negative point label values.
|
1201
|
+
point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
|
1202
|
+
point_labels: the label of each point, shape [B, N].
|
1203
|
+
margins: include points outside of the region but within the margin.
|
1204
|
+
"""
|
1205
|
+
|
1206
|
+
cucim_skimage, has_cucim = optional_import("cucim.skimage")
|
1207
|
+
|
1208
|
+
use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu")
|
1209
|
+
if use_cp:
|
1210
|
+
img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore
|
1211
|
+
img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore
|
1212
|
+
label = cucim_skimage.measure.label
|
1213
|
+
lib = cp
|
1214
|
+
else:
|
1215
|
+
if not has_measure:
|
1216
|
+
raise RuntimeError("skimage.measure required.")
|
1217
|
+
img_pos_, *_ = convert_data_type(img_pos, np.ndarray)
|
1218
|
+
img_neg_, *_ = convert_data_type(img_neg, np.ndarray)
|
1219
|
+
# for skimage.measure.label, the input must be bool type
|
1220
|
+
if img_pos_.dtype != bool or img_neg_.dtype != bool:
|
1221
|
+
raise ValueError("img_pos and img_neg must be bool type.")
|
1222
|
+
label = measure.label
|
1223
|
+
lib = np
|
1224
|
+
|
1225
|
+
features_pos, _ = label(img_pos_, connectivity=3, return_num=True)
|
1226
|
+
features_neg, _ = label(img_neg_, connectivity=3, return_num=True)
|
1227
|
+
|
1228
|
+
outs = np.zeros_like(img_pos_)
|
1229
|
+
for bs in range(point_coords.shape[0]):
|
1230
|
+
for i, p in enumerate(point_coords[bs]):
|
1231
|
+
if point_labels[bs, i] in pos_val:
|
1232
|
+
features = features_pos
|
1233
|
+
elif point_labels[bs, i] in neg_val:
|
1234
|
+
features = features_neg
|
1235
|
+
else:
|
1236
|
+
# if -1 padding point, skip
|
1237
|
+
continue
|
1238
|
+
for margin in range(margins):
|
1239
|
+
if isinstance(p, np.ndarray):
|
1240
|
+
x, y, z = np.round(p).astype(int).tolist()
|
1241
|
+
else:
|
1242
|
+
x, y, z = p.float().round().int().tolist()
|
1243
|
+
l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])
|
1244
|
+
t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])
|
1245
|
+
f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])
|
1246
|
+
if (features[bs, 0, l:r, t:d, f:b] > 0).any():
|
1247
|
+
index = features[bs, 0, l:r, t:d, f:b].max()
|
1248
|
+
outs[[bs]] += lib.isin(features[[bs]], index)
|
1249
|
+
break
|
1250
|
+
outs[outs > 1] = 1
|
1251
|
+
return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
|
1252
|
+
|
1253
|
+
|
1254
|
+
def keep_components_with_positive_points(
|
1255
|
+
img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor
|
1256
|
+
) -> torch.Tensor:
|
1257
|
+
"""
|
1258
|
+
Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove
|
1259
|
+
regions without positive points.
|
1260
|
+
Args:
|
1261
|
+
img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.
|
1262
|
+
point_coords: [B, N, 3]. Point click coordinates
|
1263
|
+
point_labels: [B, N]. Point click labels.
|
1264
|
+
"""
|
1265
|
+
if not has_measure:
|
1266
|
+
raise RuntimeError("skimage.measure required.")
|
1267
|
+
outs = torch.zeros_like(img)
|
1268
|
+
for c in range(len(point_coords)):
|
1269
|
+
if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):
|
1270
|
+
# skip if no positive points.
|
1271
|
+
continue
|
1272
|
+
coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()
|
1273
|
+
not_nan_mask = ~torch.isnan(img[0, c])
|
1274
|
+
img_ = torch.nan_to_num(img[0, c] > 0, 0)
|
1275
|
+
img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore
|
1276
|
+
label = measure.label
|
1277
|
+
features = label(img_, connectivity=3)
|
1278
|
+
pos_mask = torch.from_numpy(img_).to(img.device) > 0
|
1279
|
+
# if num features less than max desired, nothing to do.
|
1280
|
+
features = torch.from_numpy(features).to(img.device)
|
1281
|
+
# generate a map with all pos points
|
1282
|
+
idx = []
|
1283
|
+
for p in coords:
|
1284
|
+
idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())
|
1285
|
+
idx = list(set(idx))
|
1286
|
+
for i in idx:
|
1287
|
+
if i == 0:
|
1288
|
+
continue
|
1289
|
+
outs[0, c] += features == i
|
1290
|
+
outs = outs > 0
|
1291
|
+
# find negative mean value
|
1292
|
+
fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()
|
1293
|
+
img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in
|
1294
|
+
return img
|
1295
|
+
|
1296
|
+
|
1297
|
+
def convert_points_to_disc(
|
1298
|
+
image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
|
1299
|
+
):
|
1300
|
+
"""
|
1301
|
+
Convert a 3D point coordinates into image mask. The returned mask has the same spatial
|
1302
|
+
size as `image_size` while the batch dimension is the same as 'point' batch dimension.
|
1303
|
+
The point is converted to a mask ball with radius defined by `radius`. The output
|
1304
|
+
contains two channels each for negative (first channel) and positive points.
|
1305
|
+
|
1306
|
+
Args:
|
1307
|
+
image_size: The output size of the converted mask. It should be a 3D tuple.
|
1308
|
+
point: [B, N, 3], 3D point coordinates.
|
1309
|
+
point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.
|
1310
|
+
radius: disc ball radius size.
|
1311
|
+
disc: If true, use regular disc, other use gaussian.
|
1312
|
+
"""
|
1313
|
+
masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)
|
1314
|
+
_array = [
|
1315
|
+
torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
|
1316
|
+
]
|
1317
|
+
coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])
|
1318
|
+
# [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
|
1319
|
+
coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
|
1320
|
+
coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
|
1321
|
+
for b, n in np.ndindex(*point.shape[:2]):
|
1322
|
+
point_bn = unsqueeze_right(point[b, n], 4)
|
1323
|
+
if point_label[b, n] > -1:
|
1324
|
+
channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1
|
1325
|
+
pow_diff = torch.pow(coords[b, channel] - point_bn, 2)
|
1326
|
+
if disc:
|
1327
|
+
masks[b, channel] += pow_diff.sum(0) < radius**2
|
1328
|
+
else:
|
1329
|
+
masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))
|
1330
|
+
return masks
|
1331
|
+
|
1332
|
+
|
1333
|
+
def sample_points_from_label(
|
1334
|
+
labels: Tensor,
|
1335
|
+
label_set: Sequence[int],
|
1336
|
+
max_ppoint: int = 1,
|
1337
|
+
max_npoint: int = 0,
|
1338
|
+
device: torch.device | str | None = "cpu",
|
1339
|
+
use_center: bool = False,
|
1340
|
+
):
|
1341
|
+
"""Sample points from labels.
|
1342
|
+
|
1343
|
+
Args:
|
1344
|
+
labels: [1, 1, H, W, D]
|
1345
|
+
label_set: local index, must match values in labels.
|
1346
|
+
max_ppoint: maximum positive point samples.
|
1347
|
+
max_npoint: maximum negative point samples.
|
1348
|
+
device: returned tensor device.
|
1349
|
+
use_center: whether to sample points from center.
|
1350
|
+
|
1351
|
+
Returns:
|
1352
|
+
point: point coordinates of [B, N, 3]. B equals to the length of label_set.
|
1353
|
+
point_label: [B, N], always 0 for negative, 1 for positive.
|
1354
|
+
"""
|
1355
|
+
if not labels.shape[0] == 1:
|
1356
|
+
raise ValueError("labels must have batch size 1.")
|
1357
|
+
|
1358
|
+
if device is None:
|
1359
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1360
|
+
|
1361
|
+
labels = labels[0, 0]
|
1362
|
+
unique_labels = labels.unique().cpu().numpy().tolist()
|
1363
|
+
_point = []
|
1364
|
+
_point_label = []
|
1365
|
+
for id in label_set:
|
1366
|
+
if id in unique_labels:
|
1367
|
+
plabels = labels == int(id)
|
1368
|
+
nlabels = ~plabels
|
1369
|
+
_plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])
|
1370
|
+
plabelpoints = torch.nonzero(_plabels).to(device)
|
1371
|
+
if len(plabelpoints) == 0:
|
1372
|
+
plabelpoints = torch.nonzero(plabels).to(device)
|
1373
|
+
nlabelpoints = torch.nonzero(nlabels).to(device)
|
1374
|
+
num_p = min(len(plabelpoints), max_ppoint)
|
1375
|
+
num_n = min(len(nlabelpoints), max_npoint)
|
1376
|
+
pad = max_ppoint + max_npoint - num_p - num_n
|
1377
|
+
if use_center:
|
1378
|
+
pmean = plabelpoints.float().mean(0)
|
1379
|
+
pdis = ((plabelpoints - pmean) ** 2).sum(-1)
|
1380
|
+
_, sorted_indices_tensor = torch.sort(pdis)
|
1381
|
+
sorted_indices = sorted_indices_tensor.cpu().tolist()
|
1382
|
+
else:
|
1383
|
+
sorted_indices = list(range(len(plabelpoints)))
|
1384
|
+
random.shuffle(sorted_indices)
|
1385
|
+
_point.append(
|
1386
|
+
torch.stack(
|
1387
|
+
[plabelpoints[sorted_indices[i]] for i in range(num_p)]
|
1388
|
+
+ random.choices(nlabelpoints, k=num_n)
|
1389
|
+
+ [torch.tensor([0, 0, 0], device=device)] * pad
|
1390
|
+
)
|
1391
|
+
)
|
1392
|
+
_point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))
|
1393
|
+
else:
|
1394
|
+
# pad the background labels
|
1395
|
+
_point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))
|
1396
|
+
_point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)
|
1397
|
+
point = torch.stack(_point)
|
1398
|
+
point_label = torch.stack(_point_label)
|
1399
|
+
|
1400
|
+
return point, point_label
|
1401
|
+
|
1402
|
+
|
1175
1403
|
def remove_small_objects(
|
1176
1404
|
img: NdarrayTensor,
|
1177
1405
|
min_size: int = 64,
|
@@ -2284,6 +2512,7 @@ def distance_transform_edt(
|
|
2284
2512
|
block_params=block_params,
|
2285
2513
|
float64_distances=float64_distances,
|
2286
2514
|
)
|
2515
|
+
torch.cuda.synchronize()
|
2287
2516
|
else:
|
2288
2517
|
if not has_ndimage:
|
2289
2518
|
raise RuntimeError("scipy.ndimage required if cupy is not available")
|
@@ -2317,7 +2546,7 @@ def distance_transform_edt(
|
|
2317
2546
|
|
2318
2547
|
r_vals = []
|
2319
2548
|
if return_distances and distances_original is None:
|
2320
|
-
r_vals.append(distances)
|
2549
|
+
r_vals.append(distances_ if use_cp else distances)
|
2321
2550
|
if return_indices and indices_original is None:
|
2322
2551
|
r_vals.append(indices)
|
2323
2552
|
if not r_vals:
|
monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py}
RENAMED
@@ -20,6 +20,8 @@ from torch import Tensor
|
|
20
20
|
from monai.config import NdarrayOrTensor
|
21
21
|
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep
|
22
22
|
|
23
|
+
__all__ = ["erode", "dilate"]
|
24
|
+
|
23
25
|
|
24
26
|
def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
|
25
27
|
"""
|
@@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
|
|
480
480
|
else:
|
481
481
|
ret = torch.max(x, int(dim), **kwargs) # type: ignore
|
482
482
|
|
483
|
-
return ret
|
483
|
+
return ret[0] if isinstance(ret, tuple) else ret
|
484
484
|
|
485
485
|
|
486
486
|
def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:
|
@@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
|
|
546
546
|
else:
|
547
547
|
ret = torch.min(x, int(dim), **kwargs) # type: ignore
|
548
548
|
|
549
|
-
return ret
|
549
|
+
return ret[0] if isinstance(ret, tuple) else ret
|
550
550
|
|
551
551
|
|
552
552
|
def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:
|
monai/utils/enums.py
CHANGED
@@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
|
|
543
543
|
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
|
544
544
|
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
|
545
545
|
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
|
546
|
+
SAVED_TO = "saved_to"
|
546
547
|
|
547
548
|
|
548
549
|
class ColorOrder(StrEnum):
|
monai/utils/module.py
CHANGED
@@ -13,7 +13,6 @@ from __future__ import annotations
|
|
13
13
|
|
14
14
|
import enum
|
15
15
|
import functools
|
16
|
-
import importlib.util
|
17
16
|
import os
|
18
17
|
import pdb
|
19
18
|
import re
|
@@ -209,10 +208,11 @@ def load_submodules(
|
|
209
208
|
):
|
210
209
|
if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None:
|
211
210
|
try:
|
211
|
+
mod = import_module(name)
|
212
212
|
mod_spec = importer.find_spec(name) # type: ignore
|
213
213
|
if mod_spec and mod_spec.loader:
|
214
|
-
|
215
|
-
|
214
|
+
loader = mod_spec.loader
|
215
|
+
loader.exec_module(mod)
|
216
216
|
submodules.append(mod)
|
217
217
|
except OptionalImportError:
|
218
218
|
pass # could not import the optional deps., they are ignored
|
@@ -564,7 +564,7 @@ def version_leq(lhs: str, rhs: str) -> bool:
|
|
564
564
|
"""
|
565
565
|
|
566
566
|
lhs, rhs = str(lhs), str(rhs)
|
567
|
-
pkging, has_ver = optional_import("
|
567
|
+
pkging, has_ver = optional_import("packaging.Version")
|
568
568
|
if has_ver:
|
569
569
|
try:
|
570
570
|
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
|
@@ -591,7 +591,8 @@ def version_geq(lhs: str, rhs: str) -> bool:
|
|
591
591
|
|
592
592
|
"""
|
593
593
|
lhs, rhs = str(lhs), str(rhs)
|
594
|
-
pkging, has_ver = optional_import("
|
594
|
+
pkging, has_ver = optional_import("packaging.Version")
|
595
|
+
|
595
596
|
if has_ver:
|
596
597
|
try:
|
597
598
|
return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
|
@@ -629,7 +630,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
|
|
629
630
|
if current_ver_string is None:
|
630
631
|
_env_var = os.environ.get("PYTORCH_VER", "")
|
631
632
|
current_ver_string = _env_var if _env_var else torch.__version__
|
632
|
-
ver, has_ver = optional_import("
|
633
|
+
ver, has_ver = optional_import("packaging.version", name="parse")
|
633
634
|
if has_ver:
|
634
635
|
return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore
|
635
636
|
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)
|