monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. monai/__init__.py +44 -2
  2. monai/_version.py +3 -3
  3. monai/apps/vista3d/inferer.py +177 -0
  4. monai/apps/vista3d/sampler.py +179 -0
  5. monai/apps/vista3d/transforms.py +224 -0
  6. monai/bundle/config_parser.py +5 -3
  7. monai/bundle/scripts.py +2 -2
  8. monai/bundle/utils.py +35 -1
  9. monai/handlers/__init__.py +1 -0
  10. monai/handlers/trt_handler.py +61 -0
  11. monai/inferers/utils.py +1 -0
  12. monai/metrics/generalized_dice.py +77 -48
  13. monai/networks/__init__.py +2 -0
  14. monai/networks/layers/filtering.py +6 -2
  15. monai/networks/nets/swin_unetr.py +4 -4
  16. monai/networks/nets/vista3d.py +53 -11
  17. monai/networks/trt_compiler.py +569 -0
  18. monai/networks/utils.py +225 -41
  19. monai/transforms/__init__.py +24 -2
  20. monai/transforms/io/array.py +58 -2
  21. monai/transforms/io/dictionary.py +29 -2
  22. monai/transforms/spatial/array.py +44 -0
  23. monai/transforms/spatial/dictionary.py +61 -0
  24. monai/transforms/spatial/functional.py +70 -1
  25. monai/transforms/utility/array.py +153 -4
  26. monai/transforms/utility/dictionary.py +105 -3
  27. monai/transforms/utils.py +83 -10
  28. monai/utils/__init__.py +1 -0
  29. monai/utils/enums.py +1 -0
  30. monai/utils/type_conversion.py +8 -0
  31. {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +4 -1
  32. {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +36 -31
  33. {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
  34. /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
  35. {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
  36. {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
monai/networks/utils.py CHANGED
@@ -36,6 +36,8 @@ from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
36
36
  onnx, _ = optional_import("onnx")
37
37
  onnxreference, _ = optional_import("onnx.reference")
38
38
  onnxruntime, _ = optional_import("onnxruntime")
39
+ polygraphy, polygraphy_imported = optional_import("polygraphy")
40
+ torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
39
41
 
40
42
  __all__ = [
41
43
  "one_hot",
@@ -61,6 +63,7 @@ __all__ = [
61
63
  "look_up_named_module",
62
64
  "set_named_module",
63
65
  "has_nvfuser_instance_norm",
66
+ "get_profile_shapes",
64
67
  ]
65
68
 
66
69
  logger = get_logger(module_name=__name__)
@@ -68,6 +71,26 @@ logger = get_logger(module_name=__name__)
68
71
  _has_nvfuser = None
69
72
 
70
73
 
74
+ def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
75
+ """
76
+ Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
77
+ """
78
+
79
+ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
80
+ scale_shape = [*input_shape]
81
+ scale_shape[0] = scale_num
82
+ return scale_shape
83
+
84
+ # Use the dynamic batchsize range to generate the min, opt and max model input shape
85
+ if dynamic_batchsize:
86
+ min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
87
+ opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
88
+ max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
89
+ else:
90
+ min_input_shape = opt_input_shape = max_input_shape = input_shape
91
+ return min_input_shape, opt_input_shape, max_input_shape
92
+
93
+
71
94
  def has_nvfuser_instance_norm():
72
95
  """whether the current environment has InstanceNorm3dNVFuser
73
96
  https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
@@ -606,6 +629,9 @@ def convert_to_onnx(
606
629
  rtol: float = 1e-4,
607
630
  atol: float = 0.0,
608
631
  use_trace: bool = True,
632
+ do_constant_folding: bool = True,
633
+ constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
634
+ dynamo=False,
609
635
  **kwargs,
610
636
  ):
611
637
  """
@@ -632,7 +658,10 @@ def convert_to_onnx(
632
658
  rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
633
659
  atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
634
660
  use_trace: whether to use `torch.jit.trace` to export the torchscript model.
635
- kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
661
+ do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
662
+ constant_size_threshold: passed to polygrapy conatant forling, default = 16M
663
+ kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
664
+ else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
636
665
  https://pytorch.org/docs/master/generated/torch.jit.script.html.
637
666
 
638
667
  """
@@ -642,6 +671,7 @@ def convert_to_onnx(
642
671
  if use_trace:
643
672
  # let torch.onnx.export to trace the model.
644
673
  mode_to_export = model
674
+ torch_versioned_kwargs = kwargs
645
675
  else:
646
676
  if not pytorch_after(1, 10):
647
677
  if "example_outputs" not in kwargs:
@@ -654,32 +684,37 @@ def convert_to_onnx(
654
684
  del kwargs["example_outputs"]
655
685
  mode_to_export = torch.jit.script(model, **kwargs)
656
686
 
687
+ if torch.is_tensor(inputs) or isinstance(inputs, dict):
688
+ onnx_inputs = (inputs,)
689
+ else:
690
+ onnx_inputs = tuple(inputs)
691
+
657
692
  if filename is None:
658
693
  f = io.BytesIO()
659
- torch.onnx.export(
660
- mode_to_export,
661
- tuple(inputs),
662
- f=f,
663
- input_names=input_names,
664
- output_names=output_names,
665
- dynamic_axes=dynamic_axes,
666
- opset_version=opset_version,
667
- **torch_versioned_kwargs,
668
- )
694
+ else:
695
+ f = filename
696
+
697
+ torch.onnx.export(
698
+ mode_to_export,
699
+ onnx_inputs,
700
+ f=f,
701
+ input_names=input_names,
702
+ output_names=output_names,
703
+ dynamic_axes=dynamic_axes,
704
+ opset_version=opset_version,
705
+ do_constant_folding=do_constant_folding,
706
+ **torch_versioned_kwargs,
707
+ )
708
+ if filename is None:
669
709
  onnx_model = onnx.load_model_from_string(f.getvalue())
670
710
  else:
671
- torch.onnx.export(
672
- mode_to_export,
673
- tuple(inputs),
674
- f=filename,
675
- input_names=input_names,
676
- output_names=output_names,
677
- dynamic_axes=dynamic_axes,
678
- opset_version=opset_version,
679
- **torch_versioned_kwargs,
680
- )
681
711
  onnx_model = onnx.load(filename)
682
712
 
713
+ if do_constant_folding and polygraphy_imported:
714
+ from polygraphy.backend.onnx.loader import fold_constants
715
+
716
+ fold_constants(onnx_model, size_threshold=constant_size_threshold)
717
+
683
718
  if verify:
684
719
  if device is None:
685
720
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -814,7 +849,6 @@ def _onnx_trt_compile(
814
849
 
815
850
  """
816
851
  trt, _ = optional_import("tensorrt", "8.5.3")
817
- torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
818
852
 
819
853
  input_shapes = (min_shape, opt_shape, max_shape)
820
854
  # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
@@ -851,7 +885,7 @@ def _onnx_trt_compile(
851
885
  # wrap the serialized TensorRT engine back to a TorchScript module.
852
886
  trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
853
887
  f.getvalue(),
854
- device=torch.device(f"cuda:{device}"),
888
+ device=torch_tensorrt.Device(f"cuda:{device}"),
855
889
  input_binding_names=input_names,
856
890
  output_binding_names=output_names,
857
891
  )
@@ -916,8 +950,6 @@ def convert_to_trt(
916
950
  to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
917
951
  """
918
952
 
919
- torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0")
920
-
921
953
  if not torch.cuda.is_available():
922
954
  raise Exception("Cannot find any GPU devices.")
923
955
 
@@ -935,23 +967,9 @@ def convert_to_trt(
935
967
  convert_precision = torch.float32 if precision == "fp32" else torch.half
936
968
  inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
937
969
 
938
- def scale_batch_size(input_shape: Sequence[int], scale_num: int):
939
- scale_shape = [*input_shape]
940
- scale_shape[0] *= scale_num
941
- return scale_shape
942
-
943
- # Use the dynamic batchsize range to generate the min, opt and max model input shape
944
- if dynamic_batchsize:
945
- min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
946
- opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
947
- max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
948
- else:
949
- min_input_shape = opt_input_shape = max_input_shape = input_shape
950
-
951
970
  # convert the torch model to a TorchScript model on target device
952
971
  model = model.eval().to(target_device)
953
- ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
954
- ir_model.eval()
972
+ min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
955
973
 
956
974
  if use_onnx:
957
975
  # set the batch dim as dynamic
@@ -960,7 +978,6 @@ def convert_to_trt(
960
978
  ir_model = convert_to_onnx(
961
979
  model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes
962
980
  )
963
-
964
981
  # convert the model through the ONNX-TensorRT way
965
982
  trt_model = _onnx_trt_compile(
966
983
  ir_model,
@@ -973,6 +990,8 @@ def convert_to_trt(
973
990
  output_names=onnx_output_names,
974
991
  )
975
992
  else:
993
+ ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
994
+ ir_model.eval()
976
995
  # convert the model through the Torch-TensorRT way
977
996
  ir_model.to(target_device)
978
997
  with torch.no_grad():
@@ -1189,3 +1208,168 @@ class CastTempType(nn.Module):
1189
1208
  if dtype == self.initial_type:
1190
1209
  x = x.to(self.initial_type)
1191
1210
  return x
1211
+
1212
+
1213
+ def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
1214
+ """
1215
+ Utility function to cast a single tensor from from_dtype to to_dtype
1216
+ """
1217
+ return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
1218
+
1219
+
1220
+ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
1221
+ """
1222
+ Utility function to cast all tensors in a tuple from from_dtype to to_dtype
1223
+ """
1224
+ if isinstance(x, torch.Tensor):
1225
+ return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
1226
+ else:
1227
+ if isinstance(x, dict):
1228
+ new_dict = {}
1229
+ for k in x.keys():
1230
+ new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
1231
+ return new_dict
1232
+ elif isinstance(x, tuple):
1233
+ return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
1234
+
1235
+
1236
+ class CastToFloat(torch.nn.Module):
1237
+ """
1238
+ Class used to add autocast protection for ONNX export
1239
+ for forward methods with single return vaue
1240
+ """
1241
+
1242
+ def __init__(self, mod):
1243
+ super().__init__()
1244
+ self.mod = mod
1245
+
1246
+ def forward(self, x):
1247
+ dtype = x.dtype
1248
+ with torch.amp.autocast("cuda", enabled=False):
1249
+ ret = self.mod.forward(x.to(torch.float32)).to(dtype)
1250
+ return ret
1251
+
1252
+
1253
+ class CastToFloatAll(torch.nn.Module):
1254
+ """
1255
+ Class used to add autocast protection for ONNX export
1256
+ for forward methods with multiple return values
1257
+ """
1258
+
1259
+ def __init__(self, mod):
1260
+ super().__init__()
1261
+ self.mod = mod
1262
+
1263
+ def forward(self, *args):
1264
+ from_dtype = args[0].dtype
1265
+ with torch.amp.autocast("cuda", enabled=False):
1266
+ ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
1267
+ return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
1268
+
1269
+
1270
+ def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
1271
+ """
1272
+ Generic function generator to replace base_t module with dest_t wrapper.
1273
+ Args:
1274
+ base_t : module type to replace
1275
+ dest_t : destination module type
1276
+ Returns:
1277
+ swap function to replace base_t module with dest_t
1278
+ """
1279
+
1280
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
1281
+ out = dest_t(mod)
1282
+ return out
1283
+
1284
+ return expansion_fn
1285
+
1286
+
1287
+ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
1288
+ """
1289
+ Generic function generator to replace base_t module with dest_t.
1290
+ base_t and dest_t should have same atrributes. No weights are copied.
1291
+ Args:
1292
+ base_t : module type to replace
1293
+ dest_t : destination module type
1294
+ Returns:
1295
+ swap function to replace base_t module with dest_t
1296
+ """
1297
+
1298
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
1299
+ if not isinstance(mod, base_t):
1300
+ return None
1301
+ args = [getattr(mod, name, None) for name in mod.__constants__]
1302
+ out = dest_t(*args)
1303
+ return out
1304
+
1305
+ return expansion_fn
1306
+
1307
+
1308
+ def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:
1309
+ """
1310
+ This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
1311
+ for swapping nested modules through arbitrary levels if children
1312
+
1313
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1314
+
1315
+ """
1316
+ for path, new_mod in mapping.items():
1317
+ expanded_path = path.split(".")
1318
+ parent_mod = model
1319
+ for sub_path in expanded_path[:-1]:
1320
+ submod = parent_mod._modules[sub_path]
1321
+ if submod is None:
1322
+ break
1323
+ else:
1324
+ parent_mod = submod
1325
+ parent_mod._modules[expanded_path[-1]] = new_mod
1326
+
1327
+ return model
1328
+
1329
+
1330
+ def replace_modules_by_type(
1331
+ model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]
1332
+ ) -> nn.Module:
1333
+ """
1334
+ Top-level function to replace modules in model, specified by class name with a desired replacement.
1335
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1336
+ Args:
1337
+ model : top level module
1338
+ expansions : replacement dictionary: module class name -> replacement function generator
1339
+ Returns:
1340
+ model, possibly modified in-place
1341
+ """
1342
+ mapping: dict[str, nn.Module] = {}
1343
+ for name, m in model.named_modules():
1344
+ m_type = type(m).__name__
1345
+ if m_type in expansions:
1346
+ # print (f"Found {m_type} in expansions ...")
1347
+ swapped = expansions[m_type](m)
1348
+ if swapped:
1349
+ mapping[name] = swapped
1350
+
1351
+ print(f"Swapped {len(mapping)} modules")
1352
+ _swap_modules(model, mapping)
1353
+ return model
1354
+
1355
+
1356
+ def add_casts_around_norms(model: nn.Module) -> nn.Module:
1357
+ """
1358
+ Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
1359
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1360
+ Args:
1361
+ model : top level module
1362
+ Returns:
1363
+ model, possibly modified in-place
1364
+ """
1365
+ print("Adding casts around norms...")
1366
+ cast_replacements = {
1367
+ "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
1368
+ "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
1369
+ "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat),
1370
+ "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
1371
+ "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
1372
+ "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
1373
+ }
1374
+ replace_modules_by_type(model, cast_replacements)
1375
+ return model
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
238
238
  )
239
239
  from .inverse import InvertibleTransform, TraceableTransform
240
240
  from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
241
- from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
242
- from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
241
+ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
242
+ from .io.dictionary import (
243
+ LoadImaged,
244
+ LoadImageD,
245
+ LoadImageDict,
246
+ SaveImaged,
247
+ SaveImageD,
248
+ SaveImageDict,
249
+ WriteFileMappingd,
250
+ WriteFileMappingD,
251
+ WriteFileMappingDict,
252
+ )
243
253
  from .lazy.array import ApplyPending
244
254
  from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
245
255
  from .lazy.functional import apply_pending
@@ -386,6 +396,8 @@ from .smooth_field.dictionary import (
386
396
  from .spatial.array import (
387
397
  Affine,
388
398
  AffineGrid,
399
+ ConvertBoxToPoints,
400
+ ConvertPointsToBoxes,
389
401
  Flip,
390
402
  GridDistortion,
391
403
  GridPatch,
@@ -417,6 +429,12 @@ from .spatial.dictionary import (
417
429
  Affined,
418
430
  AffineD,
419
431
  AffineDict,
432
+ ConvertBoxToPointsd,
433
+ ConvertBoxToPointsD,
434
+ ConvertBoxToPointsDict,
435
+ ConvertPointsToBoxesd,
436
+ ConvertPointsToBoxesD,
437
+ ConvertPointsToBoxesDict,
420
438
  Flipd,
421
439
  FlipD,
422
440
  FlipDict,
@@ -493,6 +511,7 @@ from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTr
493
511
  from .utility.array import (
494
512
  AddCoordinateChannels,
495
513
  AddExtremePointsChannel,
514
+ ApplyTransformToPoints,
496
515
  AsChannelLast,
497
516
  CastToType,
498
517
  ClassesToIndices,
@@ -532,6 +551,9 @@ from .utility.dictionary import (
532
551
  AddExtremePointsChanneld,
533
552
  AddExtremePointsChannelD,
534
553
  AddExtremePointsChannelDict,
554
+ ApplyTransformToPointsd,
555
+ ApplyTransformToPointsD,
556
+ ApplyTransformToPointsDict,
535
557
  AsChannelLastd,
536
558
  AsChannelLastD,
537
559
  AsChannelLastDict,
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
15
15
  from __future__ import annotations
16
16
 
17
17
  import inspect
18
+ import json
18
19
  import logging
19
20
  import sys
20
21
  import traceback
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
45
46
  from monai.transforms.utility.array import EnsureChannelFirst
46
47
  from monai.utils import GridSamplePadMode
47
48
  from monai.utils import ImageMetaKey as Key
48
- from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
49
+ from monai.utils import (
50
+ MetaKeys,
51
+ OptionalImportError,
52
+ convert_to_dst_type,
53
+ ensure_tuple,
54
+ look_up_option,
55
+ optional_import,
56
+ )
49
57
 
50
58
  nib, _ = optional_import("nibabel")
51
59
  Image, _ = optional_import("PIL.Image")
52
60
  nrrd, _ = optional_import("nrrd")
61
+ FileLock, has_filelock = optional_import("filelock", name="FileLock")
53
62
 
54
63
  __all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
55
64
 
@@ -505,7 +514,7 @@ class SaveImage(Transform):
505
514
  else:
506
515
  self._data_index += 1
507
516
  if self.savepath_in_metadict and meta_data is not None:
508
- meta_data["saved_to"] = filename
517
+ meta_data[MetaKeys.SAVED_TO] = filename
509
518
  return img
510
519
  msg = "\n".join([f"{e}" for e in err])
511
520
  raise RuntimeError(
@@ -514,3 +523,50 @@ class SaveImage(Transform):
514
523
  " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
515
524
  f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
516
525
  )
526
+
527
+
528
+ class WriteFileMapping(Transform):
529
+ """
530
+ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
531
+ This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
532
+
533
+ Args:
534
+ mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
535
+ """
536
+
537
+ def __init__(self, mapping_file_path: Path | str = "mapping.json"):
538
+ self.mapping_file_path = Path(mapping_file_path)
539
+
540
+ def __call__(self, img: NdarrayOrTensor):
541
+ """
542
+ Args:
543
+ img: The input image with metadata.
544
+ """
545
+ if isinstance(img, MetaTensor):
546
+ meta_data = img.meta
547
+
548
+ if MetaKeys.SAVED_TO not in meta_data:
549
+ raise KeyError(
550
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
551
+ )
552
+
553
+ input_path = meta_data[Key.FILENAME_OR_OBJ]
554
+ output_path = meta_data[MetaKeys.SAVED_TO]
555
+ log_data = {"input": input_path, "output": output_path}
556
+
557
+ if has_filelock:
558
+ with FileLock(str(self.mapping_file_path) + ".lock"):
559
+ self._write_to_file(log_data)
560
+ else:
561
+ self._write_to_file(log_data)
562
+ return img
563
+
564
+ def _write_to_file(self, log_data):
565
+ try:
566
+ with self.mapping_file_path.open("r") as f:
567
+ existing_log_data = json.load(f)
568
+ except (FileNotFoundError, json.JSONDecodeError):
569
+ existing_log_data = []
570
+ existing_log_data.append(log_data)
571
+ with self.mapping_file_path.open("w") as f:
572
+ json.dump(existing_log_data, f, indent=4)
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from collections.abc import Hashable, Mapping
20
21
  from pathlib import Path
21
22
  from typing import Callable
22
23
 
23
24
  import numpy as np
24
25
 
25
26
  import monai
26
- from monai.config import DtypeLike, KeysCollection
27
+ from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
27
28
  from monai.data import image_writer
28
29
  from monai.data.image_reader import ImageReader
29
- from monai.transforms.io.array import LoadImage, SaveImage
30
+ from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
30
31
  from monai.transforms.transform import MapTransform, Transform
31
32
  from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
32
33
  from monai.utils.enums import PostFix
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
320
321
  return d
321
322
 
322
323
 
324
+ class WriteFileMappingd(MapTransform):
325
+ """
326
+ Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
327
+
328
+ Args:
329
+ keys: keys of the corresponding items to be transformed.
330
+ See also: :py:class:`monai.transforms.compose.MapTransform`
331
+ mapping_file_path: Path to the JSON file where the mappings will be saved.
332
+ Defaults to "mapping.json".
333
+ allow_missing_keys: don't raise exception if key is missing.
334
+ """
335
+
336
+ def __init__(
337
+ self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
338
+ ) -> None:
339
+ super().__init__(keys, allow_missing_keys)
340
+ self.mapping = WriteFileMapping(mapping_file_path)
341
+
342
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
343
+ d = dict(data)
344
+ for key in self.key_iterator(d):
345
+ d[key] = self.mapping(d[key])
346
+ return d
347
+
348
+
323
349
  LoadImageD = LoadImageDict = LoadImaged
324
350
  SaveImageD = SaveImageDict = SaveImaged
351
+ WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
@@ -25,6 +25,7 @@ import torch
25
25
 
26
26
  from monai.config import USE_COMPILED, DtypeLike
27
27
  from monai.config.type_definitions import NdarrayOrTensor
28
+ from monai.data.box_utils import BoxMode, StandardMode
28
29
  from monai.data.meta_obj import get_track_meta, set_track_meta
29
30
  from monai.data.meta_tensor import MetaTensor
30
31
  from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
@@ -34,6 +35,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCro
34
35
  from monai.transforms.inverse import InvertibleTransform
35
36
  from monai.transforms.spatial.functional import (
36
37
  affine_func,
38
+ convert_box_to_points,
39
+ convert_points_to_box,
37
40
  flip,
38
41
  orientation,
39
42
  resize,
@@ -3544,3 +3547,44 @@ class RandSimulateLowResolution(RandomizableTransform):
3544
3547
 
3545
3548
  else:
3546
3549
  return img
3550
+
3551
+
3552
+ class ConvertBoxToPoints(Transform):
3553
+ """
3554
+ Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
3555
+ Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
3556
+ Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
3557
+ """
3558
+
3559
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3560
+
3561
+ def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
3562
+ """
3563
+ Args:
3564
+ mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
3565
+ """
3566
+ super().__init__()
3567
+ self.mode = StandardMode if mode is None else mode
3568
+
3569
+ def __call__(self, data: Any):
3570
+ data = convert_to_tensor(data, track_meta=get_track_meta())
3571
+ points = convert_box_to_points(data, mode=self.mode)
3572
+ return convert_to_dst_type(points, data)[0]
3573
+
3574
+
3575
+ class ConvertPointsToBoxes(Transform):
3576
+ """
3577
+ Converts points to an axis-aligned bounding box.
3578
+ Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
3579
+ (N, 4, 2) for the 4 corners of a 2D rectangle.
3580
+ """
3581
+
3582
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3583
+
3584
+ def __init__(self) -> None:
3585
+ super().__init__()
3586
+
3587
+ def __call__(self, data: Any):
3588
+ data = convert_to_tensor(data, track_meta=get_track_meta())
3589
+ box = convert_points_to_box(data)
3590
+ return convert_to_dst_type(box, data)[0]