monai-weekly 1.4.dev2431__py3-none-any.whl → 1.4.dev2435__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +43 -25
  4. monai/apps/generation/maisi/networks/controlnet_maisi.py +15 -18
  5. monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +18 -18
  6. monai/apps/vista3d/inferer.py +177 -0
  7. monai/apps/vista3d/sampler.py +179 -0
  8. monai/apps/vista3d/transforms.py +224 -0
  9. monai/bundle/scripts.py +29 -17
  10. monai/data/utils.py +1 -1
  11. monai/data/wsi_datasets.py +3 -3
  12. monai/inferers/utils.py +1 -0
  13. monai/losses/__init__.py +1 -0
  14. monai/losses/dice.py +10 -1
  15. monai/losses/nacl_loss.py +139 -0
  16. monai/networks/blocks/crossattention.py +48 -26
  17. monai/networks/blocks/mlp.py +16 -4
  18. monai/networks/blocks/selfattention.py +75 -23
  19. monai/networks/blocks/spatialattention.py +16 -1
  20. monai/networks/blocks/transformerblock.py +17 -2
  21. monai/networks/layers/filtering.py +6 -2
  22. monai/networks/nets/__init__.py +2 -1
  23. monai/networks/nets/autoencoderkl.py +55 -22
  24. monai/networks/nets/cell_sam_wrapper.py +92 -0
  25. monai/networks/nets/controlnet.py +24 -22
  26. monai/networks/nets/diffusion_model_unet.py +159 -19
  27. monai/networks/nets/segresnet_ds.py +127 -1
  28. monai/networks/nets/spade_autoencoderkl.py +22 -0
  29. monai/networks/nets/spade_diffusion_model_unet.py +39 -2
  30. monai/networks/nets/transformer.py +17 -17
  31. monai/networks/nets/vista3d.py +946 -0
  32. monai/networks/utils.py +4 -4
  33. monai/transforms/__init__.py +13 -2
  34. monai/transforms/io/array.py +59 -3
  35. monai/transforms/io/dictionary.py +29 -2
  36. monai/transforms/spatial/functional.py +1 -1
  37. monai/transforms/transform.py +2 -2
  38. monai/transforms/utility/dictionary.py +4 -0
  39. monai/transforms/utils.py +230 -1
  40. monai/{apps/generation/maisi/utils/morphological_ops.py → transforms/utils_morphological_ops.py} +2 -0
  41. monai/transforms/utils_pytorch_numpy_unification.py +2 -2
  42. monai/utils/enums.py +1 -0
  43. monai/utils/module.py +7 -6
  44. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/METADATA +84 -81
  45. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/RECORD +49 -43
  46. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/WHEEL +1 -1
  47. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  48. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/LICENSE +0 -0
  49. {monai_weekly-1.4.dev2431.dist-info → monai_weekly-1.4.dev2435.dist-info}/top_level.txt +0 -0
monai/networks/utils.py CHANGED
@@ -822,7 +822,7 @@ def _onnx_trt_compile(
822
822
  output_names = [] if not output_names else output_names
823
823
 
824
824
  # set up the TensorRT builder
825
- torch_tensorrt.set_device(device)
825
+ torch.cuda.set_device(device)
826
826
  logger = trt.Logger(trt.Logger.WARNING)
827
827
  builder = trt.Builder(logger)
828
828
  network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
@@ -851,7 +851,7 @@ def _onnx_trt_compile(
851
851
  # wrap the serialized TensorRT engine back to a TorchScript module.
852
852
  trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
853
853
  f.getvalue(),
854
- device=torch.device(f"cuda:{device}"),
854
+ device=torch_tensorrt.Device(f"cuda:{device}"),
855
855
  input_binding_names=input_names,
856
856
  output_binding_names=output_names,
857
857
  )
@@ -931,7 +931,7 @@ def convert_to_trt(
931
931
  warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")
932
932
 
933
933
  device = device if device else 0
934
- target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
934
+ target_device = torch.device(f"cuda:{device}")
935
935
  convert_precision = torch.float32 if precision == "fp32" else torch.half
936
936
  inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
937
937
 
@@ -986,7 +986,7 @@ def convert_to_trt(
986
986
  ir_model,
987
987
  inputs=input_placeholder,
988
988
  enabled_precisions=convert_precision,
989
- device=target_device,
989
+ device=torch_tensorrt.Device(f"cuda:{device}"),
990
990
  ir="torchscript",
991
991
  **kwargs,
992
992
  )
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
238
238
  )
239
239
  from .inverse import InvertibleTransform, TraceableTransform
240
240
  from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
241
- from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
242
- from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
241
+ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
242
+ from .io.dictionary import (
243
+ LoadImaged,
244
+ LoadImageD,
245
+ LoadImageDict,
246
+ SaveImaged,
247
+ SaveImageD,
248
+ SaveImageDict,
249
+ WriteFileMappingd,
250
+ WriteFileMappingD,
251
+ WriteFileMappingDict,
252
+ )
243
253
  from .lazy.array import ApplyPending
244
254
  from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
245
255
  from .lazy.functional import apply_pending
@@ -688,6 +698,7 @@ from .utils import (
688
698
  weighted_patch_samples,
689
699
  zero_margins,
690
700
  )
701
+ from .utils_morphological_ops import dilate, erode
691
702
  from .utils_pytorch_numpy_unification import (
692
703
  allclose,
693
704
  any_np_pt,
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
15
15
  from __future__ import annotations
16
16
 
17
17
  import inspect
18
+ import json
18
19
  import logging
19
20
  import sys
20
21
  import traceback
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
45
46
  from monai.transforms.utility.array import EnsureChannelFirst
46
47
  from monai.utils import GridSamplePadMode
47
48
  from monai.utils import ImageMetaKey as Key
48
- from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
49
+ from monai.utils import (
50
+ MetaKeys,
51
+ OptionalImportError,
52
+ convert_to_dst_type,
53
+ ensure_tuple,
54
+ look_up_option,
55
+ optional_import,
56
+ )
49
57
 
50
58
  nib, _ = optional_import("nibabel")
51
59
  Image, _ = optional_import("PIL.Image")
52
60
  nrrd, _ = optional_import("nrrd")
61
+ FileLock, has_filelock = optional_import("filelock", name="FileLock")
53
62
 
54
63
  __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
55
64
 
@@ -86,7 +95,7 @@ def switch_endianness(data, new="<"):
86
95
  if new not in ("<", ">"):
87
96
  raise NotImplementedError(f"Not implemented option new={new}.")
88
97
  if current_ != new:
89
- data = data.byteswap().newbyteorder(new)
98
+ data = data.byteswap().view(data.dtype.newbyteorder(new))
90
99
  elif isinstance(data, tuple):
91
100
  data = tuple(switch_endianness(x, new) for x in data)
92
101
  elif isinstance(data, list):
@@ -505,7 +514,7 @@ class SaveImage(Transform):
505
514
  else:
506
515
  self._data_index += 1
507
516
  if self.savepath_in_metadict and meta_data is not None:
508
- meta_data["saved_to"] = filename
517
+ meta_data[MetaKeys.SAVED_TO] = filename
509
518
  return img
510
519
  msg = "\n".join([f"{e}" for e in err])
511
520
  raise RuntimeError(
@@ -514,3 +523,50 @@ class SaveImage(Transform):
514
523
  " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
515
524
  f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
516
525
  )
526
+
527
+
528
+ class WriteFileMapping(Transform):
529
+ """
530
+ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
531
+ This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
532
+
533
+ Args:
534
+ mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
535
+ """
536
+
537
+ def __init__(self, mapping_file_path: Path | str = "mapping.json"):
538
+ self.mapping_file_path = Path(mapping_file_path)
539
+
540
+ def __call__(self, img: NdarrayOrTensor):
541
+ """
542
+ Args:
543
+ img: The input image with metadata.
544
+ """
545
+ if isinstance(img, MetaTensor):
546
+ meta_data = img.meta
547
+
548
+ if MetaKeys.SAVED_TO not in meta_data:
549
+ raise KeyError(
550
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
551
+ )
552
+
553
+ input_path = meta_data[Key.FILENAME_OR_OBJ]
554
+ output_path = meta_data[MetaKeys.SAVED_TO]
555
+ log_data = {"input": input_path, "output": output_path}
556
+
557
+ if has_filelock:
558
+ with FileLock(str(self.mapping_file_path) + ".lock"):
559
+ self._write_to_file(log_data)
560
+ else:
561
+ self._write_to_file(log_data)
562
+ return img
563
+
564
+ def _write_to_file(self, log_data):
565
+ try:
566
+ with self.mapping_file_path.open("r") as f:
567
+ existing_log_data = json.load(f)
568
+ except (FileNotFoundError, json.JSONDecodeError):
569
+ existing_log_data = []
570
+ existing_log_data.append(log_data)
571
+ with self.mapping_file_path.open("w") as f:
572
+ json.dump(existing_log_data, f, indent=4)
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from collections.abc import Hashable, Mapping
20
21
  from pathlib import Path
21
22
  from typing import Callable
22
23
 
23
24
  import numpy as np
24
25
 
25
26
  import monai
26
- from monai.config import DtypeLike, KeysCollection
27
+ from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
27
28
  from monai.data import image_writer
28
29
  from monai.data.image_reader import ImageReader
29
- from monai.transforms.io.array import LoadImage, SaveImage
30
+ from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
30
31
  from monai.transforms.transform import MapTransform, Transform
31
32
  from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
32
33
  from monai.utils.enums import PostFix
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
320
321
  return d
321
322
 
322
323
 
324
+ class WriteFileMappingd(MapTransform):
325
+ """
326
+ Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
327
+
328
+ Args:
329
+ keys: keys of the corresponding items to be transformed.
330
+ See also: :py:class:`monai.transforms.compose.MapTransform`
331
+ mapping_file_path: Path to the JSON file where the mappings will be saved.
332
+ Defaults to "mapping.json".
333
+ allow_missing_keys: don't raise exception if key is missing.
334
+ """
335
+
336
+ def __init__(
337
+ self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
338
+ ) -> None:
339
+ super().__init__(keys, allow_missing_keys)
340
+ self.mapping = WriteFileMapping(mapping_file_path)
341
+
342
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
343
+ d = dict(data)
344
+ for key in self.key_iterator(d):
345
+ d[key] = self.mapping(d[key])
346
+ return d
347
+
348
+
323
349
  LoadImageD = LoadImageDict = LoadImaged
324
350
  SaveImageD = SaveImageDict = SaveImaged
351
+ WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
@@ -373,7 +373,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l
373
373
  if output_shape is None:
374
374
  corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1))
375
375
  corners = transform[:-1, :-1] @ corners # type: ignore
376
- output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int)
376
+ output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int)
377
377
  else:
378
378
  output_shape = np.asarray(output_shape, dtype=int)
379
379
  shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist())
@@ -203,8 +203,8 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
203
203
 
204
204
  """
205
205
  if seed is not None:
206
- _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed
207
- _seed = _seed % MAX_SEED
206
+ _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
207
+ _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
208
208
  self.R = np.random.RandomState(_seed)
209
209
  return self
210
210
 
@@ -1714,6 +1714,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
1714
1714
  Probability the transform is applied to the data
1715
1715
  allow_missing_keys:
1716
1716
  Don't raise exception if key is missing.
1717
+
1718
+ Note:
1719
+ - This transform does not scale output image values automatically to match the range of the input.
1720
+ The output should be scaled by later transforms to match the input if this is desired.
1717
1721
  """
1718
1722
 
1719
1723
  backend = ImageFilter.backend
monai/transforms/utils.py CHANGED
@@ -22,6 +22,7 @@ from typing import Any
22
22
 
23
23
  import numpy as np
24
24
  import torch
25
+ from torch import Tensor
25
26
 
26
27
  import monai
27
28
  from monai.config import DtypeLike, IndexSelection
@@ -30,6 +31,7 @@ from monai.networks.layers import GaussianFilter
30
31
  from monai.networks.utils import meshgrid_ij
31
32
  from monai.transforms.compose import Compose
32
33
  from monai.transforms.transform import MapTransform, Transform, apply_transform
34
+ from monai.transforms.utils_morphological_ops import erode
33
35
  from monai.transforms.utils_pytorch_numpy_unification import (
34
36
  any_np_pt,
35
37
  ascontiguousarray,
@@ -65,6 +67,8 @@ from monai.utils import (
65
67
  min_version,
66
68
  optional_import,
67
69
  pytorch_after,
70
+ unsqueeze_left,
71
+ unsqueeze_right,
68
72
  )
69
73
  from monai.utils.enums import TransformBackends
70
74
  from monai.utils.type_conversion import (
@@ -103,6 +107,9 @@ __all__ = [
103
107
  "generate_spatial_bounding_box",
104
108
  "get_extreme_points",
105
109
  "get_largest_connected_component_mask",
110
+ "keep_merge_components_with_points",
111
+ "keep_components_with_positive_points",
112
+ "convert_points_to_disc",
106
113
  "remove_small_objects",
107
114
  "img_bounds",
108
115
  "in_bounds",
@@ -1172,6 +1179,227 @@ def get_largest_connected_component_mask(
1172
1179
  return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
1173
1180
 
1174
1181
 
1182
+ def keep_merge_components_with_points(
1183
+ img_pos: NdarrayTensor,
1184
+ img_neg: NdarrayTensor,
1185
+ point_coords: NdarrayTensor,
1186
+ point_labels: NdarrayTensor,
1187
+ pos_val: Sequence[int] = (1, 3),
1188
+ neg_val: Sequence[int] = (0, 2),
1189
+ margins: int = 3,
1190
+ ) -> NdarrayTensor:
1191
+ """
1192
+ Keep connected regions of img_pos and img_neg that include the positive points and
1193
+ negative points separately. The function is used for merging automatic results with interactive
1194
+ results in VISTA3D.
1195
+
1196
+ Args:
1197
+ img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.
1198
+ img_neg: same format as img_pos but corresponds to negative points.
1199
+ pos_val: positive point label values.
1200
+ neg_val: negative point label values.
1201
+ point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
1202
+ point_labels: the label of each point, shape [B, N].
1203
+ margins: include points outside of the region but within the margin.
1204
+ """
1205
+
1206
+ cucim_skimage, has_cucim = optional_import("cucim.skimage")
1207
+
1208
+ use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu")
1209
+ if use_cp:
1210
+ img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore
1211
+ img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore
1212
+ label = cucim_skimage.measure.label
1213
+ lib = cp
1214
+ else:
1215
+ if not has_measure:
1216
+ raise RuntimeError("skimage.measure required.")
1217
+ img_pos_, *_ = convert_data_type(img_pos, np.ndarray)
1218
+ img_neg_, *_ = convert_data_type(img_neg, np.ndarray)
1219
+ # for skimage.measure.label, the input must be bool type
1220
+ if img_pos_.dtype != bool or img_neg_.dtype != bool:
1221
+ raise ValueError("img_pos and img_neg must be bool type.")
1222
+ label = measure.label
1223
+ lib = np
1224
+
1225
+ features_pos, _ = label(img_pos_, connectivity=3, return_num=True)
1226
+ features_neg, _ = label(img_neg_, connectivity=3, return_num=True)
1227
+
1228
+ outs = np.zeros_like(img_pos_)
1229
+ for bs in range(point_coords.shape[0]):
1230
+ for i, p in enumerate(point_coords[bs]):
1231
+ if point_labels[bs, i] in pos_val:
1232
+ features = features_pos
1233
+ elif point_labels[bs, i] in neg_val:
1234
+ features = features_neg
1235
+ else:
1236
+ # if -1 padding point, skip
1237
+ continue
1238
+ for margin in range(margins):
1239
+ if isinstance(p, np.ndarray):
1240
+ x, y, z = np.round(p).astype(int).tolist()
1241
+ else:
1242
+ x, y, z = p.float().round().int().tolist()
1243
+ l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])
1244
+ t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])
1245
+ f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])
1246
+ if (features[bs, 0, l:r, t:d, f:b] > 0).any():
1247
+ index = features[bs, 0, l:r, t:d, f:b].max()
1248
+ outs[[bs]] += lib.isin(features[[bs]], index)
1249
+ break
1250
+ outs[outs > 1] = 1
1251
+ return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
1252
+
1253
+
1254
+ def keep_components_with_positive_points(
1255
+ img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor
1256
+ ) -> torch.Tensor:
1257
+ """
1258
+ Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove
1259
+ regions without positive points.
1260
+ Args:
1261
+ img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.
1262
+ point_coords: [B, N, 3]. Point click coordinates
1263
+ point_labels: [B, N]. Point click labels.
1264
+ """
1265
+ if not has_measure:
1266
+ raise RuntimeError("skimage.measure required.")
1267
+ outs = torch.zeros_like(img)
1268
+ for c in range(len(point_coords)):
1269
+ if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):
1270
+ # skip if no positive points.
1271
+ continue
1272
+ coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()
1273
+ not_nan_mask = ~torch.isnan(img[0, c])
1274
+ img_ = torch.nan_to_num(img[0, c] > 0, 0)
1275
+ img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore
1276
+ label = measure.label
1277
+ features = label(img_, connectivity=3)
1278
+ pos_mask = torch.from_numpy(img_).to(img.device) > 0
1279
+ # if num features less than max desired, nothing to do.
1280
+ features = torch.from_numpy(features).to(img.device)
1281
+ # generate a map with all pos points
1282
+ idx = []
1283
+ for p in coords:
1284
+ idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())
1285
+ idx = list(set(idx))
1286
+ for i in idx:
1287
+ if i == 0:
1288
+ continue
1289
+ outs[0, c] += features == i
1290
+ outs = outs > 0
1291
+ # find negative mean value
1292
+ fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()
1293
+ img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in
1294
+ return img
1295
+
1296
+
1297
+ def convert_points_to_disc(
1298
+ image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
1299
+ ):
1300
+ """
1301
+ Convert a 3D point coordinates into image mask. The returned mask has the same spatial
1302
+ size as `image_size` while the batch dimension is the same as 'point' batch dimension.
1303
+ The point is converted to a mask ball with radius defined by `radius`. The output
1304
+ contains two channels each for negative (first channel) and positive points.
1305
+
1306
+ Args:
1307
+ image_size: The output size of the converted mask. It should be a 3D tuple.
1308
+ point: [B, N, 3], 3D point coordinates.
1309
+ point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.
1310
+ radius: disc ball radius size.
1311
+ disc: If true, use regular disc, other use gaussian.
1312
+ """
1313
+ masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)
1314
+ _array = [
1315
+ torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
1316
+ ]
1317
+ coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])
1318
+ # [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
1319
+ coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
1320
+ coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
1321
+ for b, n in np.ndindex(*point.shape[:2]):
1322
+ point_bn = unsqueeze_right(point[b, n], 4)
1323
+ if point_label[b, n] > -1:
1324
+ channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1
1325
+ pow_diff = torch.pow(coords[b, channel] - point_bn, 2)
1326
+ if disc:
1327
+ masks[b, channel] += pow_diff.sum(0) < radius**2
1328
+ else:
1329
+ masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))
1330
+ return masks
1331
+
1332
+
1333
+ def sample_points_from_label(
1334
+ labels: Tensor,
1335
+ label_set: Sequence[int],
1336
+ max_ppoint: int = 1,
1337
+ max_npoint: int = 0,
1338
+ device: torch.device | str | None = "cpu",
1339
+ use_center: bool = False,
1340
+ ):
1341
+ """Sample points from labels.
1342
+
1343
+ Args:
1344
+ labels: [1, 1, H, W, D]
1345
+ label_set: local index, must match values in labels.
1346
+ max_ppoint: maximum positive point samples.
1347
+ max_npoint: maximum negative point samples.
1348
+ device: returned tensor device.
1349
+ use_center: whether to sample points from center.
1350
+
1351
+ Returns:
1352
+ point: point coordinates of [B, N, 3]. B equals to the length of label_set.
1353
+ point_label: [B, N], always 0 for negative, 1 for positive.
1354
+ """
1355
+ if not labels.shape[0] == 1:
1356
+ raise ValueError("labels must have batch size 1.")
1357
+
1358
+ if device is None:
1359
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1360
+
1361
+ labels = labels[0, 0]
1362
+ unique_labels = labels.unique().cpu().numpy().tolist()
1363
+ _point = []
1364
+ _point_label = []
1365
+ for id in label_set:
1366
+ if id in unique_labels:
1367
+ plabels = labels == int(id)
1368
+ nlabels = ~plabels
1369
+ _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])
1370
+ plabelpoints = torch.nonzero(_plabels).to(device)
1371
+ if len(plabelpoints) == 0:
1372
+ plabelpoints = torch.nonzero(plabels).to(device)
1373
+ nlabelpoints = torch.nonzero(nlabels).to(device)
1374
+ num_p = min(len(plabelpoints), max_ppoint)
1375
+ num_n = min(len(nlabelpoints), max_npoint)
1376
+ pad = max_ppoint + max_npoint - num_p - num_n
1377
+ if use_center:
1378
+ pmean = plabelpoints.float().mean(0)
1379
+ pdis = ((plabelpoints - pmean) ** 2).sum(-1)
1380
+ _, sorted_indices_tensor = torch.sort(pdis)
1381
+ sorted_indices = sorted_indices_tensor.cpu().tolist()
1382
+ else:
1383
+ sorted_indices = list(range(len(plabelpoints)))
1384
+ random.shuffle(sorted_indices)
1385
+ _point.append(
1386
+ torch.stack(
1387
+ [plabelpoints[sorted_indices[i]] for i in range(num_p)]
1388
+ + random.choices(nlabelpoints, k=num_n)
1389
+ + [torch.tensor([0, 0, 0], device=device)] * pad
1390
+ )
1391
+ )
1392
+ _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))
1393
+ else:
1394
+ # pad the background labels
1395
+ _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))
1396
+ _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)
1397
+ point = torch.stack(_point)
1398
+ point_label = torch.stack(_point_label)
1399
+
1400
+ return point, point_label
1401
+
1402
+
1175
1403
  def remove_small_objects(
1176
1404
  img: NdarrayTensor,
1177
1405
  min_size: int = 64,
@@ -2284,6 +2512,7 @@ def distance_transform_edt(
2284
2512
  block_params=block_params,
2285
2513
  float64_distances=float64_distances,
2286
2514
  )
2515
+ torch.cuda.synchronize()
2287
2516
  else:
2288
2517
  if not has_ndimage:
2289
2518
  raise RuntimeError("scipy.ndimage required if cupy is not available")
@@ -2317,7 +2546,7 @@ def distance_transform_edt(
2317
2546
 
2318
2547
  r_vals = []
2319
2548
  if return_distances and distances_original is None:
2320
- r_vals.append(distances)
2549
+ r_vals.append(distances_ if use_cp else distances)
2321
2550
  if return_indices and indices_original is None:
2322
2551
  r_vals.append(indices)
2323
2552
  if not r_vals:
@@ -20,6 +20,8 @@ from torch import Tensor
20
20
  from monai.config import NdarrayOrTensor
21
21
  from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep
22
22
 
23
+ __all__ = ["erode", "dilate"]
24
+
23
25
 
24
26
  def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
25
27
  """
@@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
480
480
  else:
481
481
  ret = torch.max(x, int(dim), **kwargs) # type: ignore
482
482
 
483
- return ret
483
+ return ret[0] if isinstance(ret, tuple) else ret
484
484
 
485
485
 
486
486
  def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:
@@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
546
546
  else:
547
547
  ret = torch.min(x, int(dim), **kwargs) # type: ignore
548
548
 
549
- return ret
549
+ return ret[0] if isinstance(ret, tuple) else ret
550
550
 
551
551
 
552
552
  def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:
monai/utils/enums.py CHANGED
@@ -543,6 +543,7 @@ class MetaKeys(StrEnum):
543
543
  SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
544
544
  SPACE = "space" # possible values of space type are defined in `SpaceKeys`
545
545
  ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
546
+ SAVED_TO = "saved_to"
546
547
 
547
548
 
548
549
  class ColorOrder(StrEnum):
monai/utils/module.py CHANGED
@@ -13,7 +13,6 @@ from __future__ import annotations
13
13
 
14
14
  import enum
15
15
  import functools
16
- import importlib.util
17
16
  import os
18
17
  import pdb
19
18
  import re
@@ -209,10 +208,11 @@ def load_submodules(
209
208
  ):
210
209
  if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None:
211
210
  try:
211
+ mod = import_module(name)
212
212
  mod_spec = importer.find_spec(name) # type: ignore
213
213
  if mod_spec and mod_spec.loader:
214
- mod = importlib.util.module_from_spec(mod_spec)
215
- mod_spec.loader.exec_module(mod)
214
+ loader = mod_spec.loader
215
+ loader.exec_module(mod)
216
216
  submodules.append(mod)
217
217
  except OptionalImportError:
218
218
  pass # could not import the optional deps., they are ignored
@@ -564,7 +564,7 @@ def version_leq(lhs: str, rhs: str) -> bool:
564
564
  """
565
565
 
566
566
  lhs, rhs = str(lhs), str(rhs)
567
- pkging, has_ver = optional_import("pkg_resources", name="packaging")
567
+ pkging, has_ver = optional_import("packaging.Version")
568
568
  if has_ver:
569
569
  try:
570
570
  return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
@@ -591,7 +591,8 @@ def version_geq(lhs: str, rhs: str) -> bool:
591
591
 
592
592
  """
593
593
  lhs, rhs = str(lhs), str(rhs)
594
- pkging, has_ver = optional_import("pkg_resources", name="packaging")
594
+ pkging, has_ver = optional_import("packaging.Version")
595
+
595
596
  if has_ver:
596
597
  try:
597
598
  return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
@@ -629,7 +630,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
629
630
  if current_ver_string is None:
630
631
  _env_var = os.environ.get("PYTORCH_VER", "")
631
632
  current_ver_string = _env_var if _env_var else torch.__version__
632
- ver, has_ver = optional_import("pkg_resources", name="parse_version")
633
+ ver, has_ver = optional_import("packaging.version", name="parse")
633
634
  if has_ver:
634
635
  return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore
635
636
  parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)