monai-weekly 1.4.dev2435__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monai/networks/utils.py CHANGED
@@ -36,6 +36,8 @@ from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
36
36
  onnx, _ = optional_import("onnx")
37
37
  onnxreference, _ = optional_import("onnx.reference")
38
38
  onnxruntime, _ = optional_import("onnxruntime")
39
+ polygraphy, polygraphy_imported = optional_import("polygraphy")
40
+ torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
39
41
 
40
42
  __all__ = [
41
43
  "one_hot",
@@ -61,6 +63,7 @@ __all__ = [
61
63
  "look_up_named_module",
62
64
  "set_named_module",
63
65
  "has_nvfuser_instance_norm",
66
+ "get_profile_shapes",
64
67
  ]
65
68
 
66
69
  logger = get_logger(module_name=__name__)
@@ -68,6 +71,26 @@ logger = get_logger(module_name=__name__)
68
71
  _has_nvfuser = None
69
72
 
70
73
 
74
+ def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
75
+ """
76
+ Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
77
+ """
78
+
79
+ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
80
+ scale_shape = [*input_shape]
81
+ scale_shape[0] = scale_num
82
+ return scale_shape
83
+
84
+ # Use the dynamic batchsize range to generate the min, opt and max model input shape
85
+ if dynamic_batchsize:
86
+ min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
87
+ opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
88
+ max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
89
+ else:
90
+ min_input_shape = opt_input_shape = max_input_shape = input_shape
91
+ return min_input_shape, opt_input_shape, max_input_shape
92
+
93
+
71
94
  def has_nvfuser_instance_norm():
72
95
  """whether the current environment has InstanceNorm3dNVFuser
73
96
  https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
@@ -606,6 +629,9 @@ def convert_to_onnx(
606
629
  rtol: float = 1e-4,
607
630
  atol: float = 0.0,
608
631
  use_trace: bool = True,
632
+ do_constant_folding: bool = True,
633
+ constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
634
+ dynamo=False,
609
635
  **kwargs,
610
636
  ):
611
637
  """
@@ -632,7 +658,10 @@ def convert_to_onnx(
632
658
  rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
633
659
  atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
634
660
  use_trace: whether to use `torch.jit.trace` to export the torchscript model.
635
- kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
661
+ do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
662
+ constant_size_threshold: passed to polygrapy conatant forling, default = 16M
663
+ kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
664
+ else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
636
665
  https://pytorch.org/docs/master/generated/torch.jit.script.html.
637
666
 
638
667
  """
@@ -642,6 +671,7 @@ def convert_to_onnx(
642
671
  if use_trace:
643
672
  # let torch.onnx.export to trace the model.
644
673
  mode_to_export = model
674
+ torch_versioned_kwargs = kwargs
645
675
  else:
646
676
  if not pytorch_after(1, 10):
647
677
  if "example_outputs" not in kwargs:
@@ -654,32 +684,37 @@ def convert_to_onnx(
654
684
  del kwargs["example_outputs"]
655
685
  mode_to_export = torch.jit.script(model, **kwargs)
656
686
 
687
+ if torch.is_tensor(inputs) or isinstance(inputs, dict):
688
+ onnx_inputs = (inputs,)
689
+ else:
690
+ onnx_inputs = tuple(inputs)
691
+
657
692
  if filename is None:
658
693
  f = io.BytesIO()
659
- torch.onnx.export(
660
- mode_to_export,
661
- tuple(inputs),
662
- f=f,
663
- input_names=input_names,
664
- output_names=output_names,
665
- dynamic_axes=dynamic_axes,
666
- opset_version=opset_version,
667
- **torch_versioned_kwargs,
668
- )
694
+ else:
695
+ f = filename
696
+
697
+ torch.onnx.export(
698
+ mode_to_export,
699
+ onnx_inputs,
700
+ f=f,
701
+ input_names=input_names,
702
+ output_names=output_names,
703
+ dynamic_axes=dynamic_axes,
704
+ opset_version=opset_version,
705
+ do_constant_folding=do_constant_folding,
706
+ **torch_versioned_kwargs,
707
+ )
708
+ if filename is None:
669
709
  onnx_model = onnx.load_model_from_string(f.getvalue())
670
710
  else:
671
- torch.onnx.export(
672
- mode_to_export,
673
- tuple(inputs),
674
- f=filename,
675
- input_names=input_names,
676
- output_names=output_names,
677
- dynamic_axes=dynamic_axes,
678
- opset_version=opset_version,
679
- **torch_versioned_kwargs,
680
- )
681
711
  onnx_model = onnx.load(filename)
682
712
 
713
+ if do_constant_folding and polygraphy_imported:
714
+ from polygraphy.backend.onnx.loader import fold_constants
715
+
716
+ fold_constants(onnx_model, size_threshold=constant_size_threshold)
717
+
683
718
  if verify:
684
719
  if device is None:
685
720
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -814,7 +849,6 @@ def _onnx_trt_compile(
814
849
 
815
850
  """
816
851
  trt, _ = optional_import("tensorrt", "8.5.3")
817
- torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
818
852
 
819
853
  input_shapes = (min_shape, opt_shape, max_shape)
820
854
  # default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
@@ -916,8 +950,6 @@ def convert_to_trt(
916
950
  to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
917
951
  """
918
952
 
919
- torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0")
920
-
921
953
  if not torch.cuda.is_available():
922
954
  raise Exception("Cannot find any GPU devices.")
923
955
 
@@ -935,23 +967,9 @@ def convert_to_trt(
935
967
  convert_precision = torch.float32 if precision == "fp32" else torch.half
936
968
  inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
937
969
 
938
- def scale_batch_size(input_shape: Sequence[int], scale_num: int):
939
- scale_shape = [*input_shape]
940
- scale_shape[0] *= scale_num
941
- return scale_shape
942
-
943
- # Use the dynamic batchsize range to generate the min, opt and max model input shape
944
- if dynamic_batchsize:
945
- min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
946
- opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
947
- max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
948
- else:
949
- min_input_shape = opt_input_shape = max_input_shape = input_shape
950
-
951
970
  # convert the torch model to a TorchScript model on target device
952
971
  model = model.eval().to(target_device)
953
- ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
954
- ir_model.eval()
972
+ min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
955
973
 
956
974
  if use_onnx:
957
975
  # set the batch dim as dynamic
@@ -960,7 +978,6 @@ def convert_to_trt(
960
978
  ir_model = convert_to_onnx(
961
979
  model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes
962
980
  )
963
-
964
981
  # convert the model through the ONNX-TensorRT way
965
982
  trt_model = _onnx_trt_compile(
966
983
  ir_model,
@@ -973,6 +990,8 @@ def convert_to_trt(
973
990
  output_names=onnx_output_names,
974
991
  )
975
992
  else:
993
+ ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
994
+ ir_model.eval()
976
995
  # convert the model through the Torch-TensorRT way
977
996
  ir_model.to(target_device)
978
997
  with torch.no_grad():
@@ -1189,3 +1208,168 @@ class CastTempType(nn.Module):
1189
1208
  if dtype == self.initial_type:
1190
1209
  x = x.to(self.initial_type)
1191
1210
  return x
1211
+
1212
+
1213
+ def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
1214
+ """
1215
+ Utility function to cast a single tensor from from_dtype to to_dtype
1216
+ """
1217
+ return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
1218
+
1219
+
1220
+ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
1221
+ """
1222
+ Utility function to cast all tensors in a tuple from from_dtype to to_dtype
1223
+ """
1224
+ if isinstance(x, torch.Tensor):
1225
+ return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
1226
+ else:
1227
+ if isinstance(x, dict):
1228
+ new_dict = {}
1229
+ for k in x.keys():
1230
+ new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
1231
+ return new_dict
1232
+ elif isinstance(x, tuple):
1233
+ return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
1234
+
1235
+
1236
+ class CastToFloat(torch.nn.Module):
1237
+ """
1238
+ Class used to add autocast protection for ONNX export
1239
+ for forward methods with single return vaue
1240
+ """
1241
+
1242
+ def __init__(self, mod):
1243
+ super().__init__()
1244
+ self.mod = mod
1245
+
1246
+ def forward(self, x):
1247
+ dtype = x.dtype
1248
+ with torch.amp.autocast("cuda", enabled=False):
1249
+ ret = self.mod.forward(x.to(torch.float32)).to(dtype)
1250
+ return ret
1251
+
1252
+
1253
+ class CastToFloatAll(torch.nn.Module):
1254
+ """
1255
+ Class used to add autocast protection for ONNX export
1256
+ for forward methods with multiple return values
1257
+ """
1258
+
1259
+ def __init__(self, mod):
1260
+ super().__init__()
1261
+ self.mod = mod
1262
+
1263
+ def forward(self, *args):
1264
+ from_dtype = args[0].dtype
1265
+ with torch.amp.autocast("cuda", enabled=False):
1266
+ ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
1267
+ return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
1268
+
1269
+
1270
+ def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
1271
+ """
1272
+ Generic function generator to replace base_t module with dest_t wrapper.
1273
+ Args:
1274
+ base_t : module type to replace
1275
+ dest_t : destination module type
1276
+ Returns:
1277
+ swap function to replace base_t module with dest_t
1278
+ """
1279
+
1280
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
1281
+ out = dest_t(mod)
1282
+ return out
1283
+
1284
+ return expansion_fn
1285
+
1286
+
1287
+ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
1288
+ """
1289
+ Generic function generator to replace base_t module with dest_t.
1290
+ base_t and dest_t should have same atrributes. No weights are copied.
1291
+ Args:
1292
+ base_t : module type to replace
1293
+ dest_t : destination module type
1294
+ Returns:
1295
+ swap function to replace base_t module with dest_t
1296
+ """
1297
+
1298
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
1299
+ if not isinstance(mod, base_t):
1300
+ return None
1301
+ args = [getattr(mod, name, None) for name in mod.__constants__]
1302
+ out = dest_t(*args)
1303
+ return out
1304
+
1305
+ return expansion_fn
1306
+
1307
+
1308
+ def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:
1309
+ """
1310
+ This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
1311
+ for swapping nested modules through arbitrary levels if children
1312
+
1313
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1314
+
1315
+ """
1316
+ for path, new_mod in mapping.items():
1317
+ expanded_path = path.split(".")
1318
+ parent_mod = model
1319
+ for sub_path in expanded_path[:-1]:
1320
+ submod = parent_mod._modules[sub_path]
1321
+ if submod is None:
1322
+ break
1323
+ else:
1324
+ parent_mod = submod
1325
+ parent_mod._modules[expanded_path[-1]] = new_mod
1326
+
1327
+ return model
1328
+
1329
+
1330
+ def replace_modules_by_type(
1331
+ model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]
1332
+ ) -> nn.Module:
1333
+ """
1334
+ Top-level function to replace modules in model, specified by class name with a desired replacement.
1335
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1336
+ Args:
1337
+ model : top level module
1338
+ expansions : replacement dictionary: module class name -> replacement function generator
1339
+ Returns:
1340
+ model, possibly modified in-place
1341
+ """
1342
+ mapping: dict[str, nn.Module] = {}
1343
+ for name, m in model.named_modules():
1344
+ m_type = type(m).__name__
1345
+ if m_type in expansions:
1346
+ # print (f"Found {m_type} in expansions ...")
1347
+ swapped = expansions[m_type](m)
1348
+ if swapped:
1349
+ mapping[name] = swapped
1350
+
1351
+ print(f"Swapped {len(mapping)} modules")
1352
+ _swap_modules(model, mapping)
1353
+ return model
1354
+
1355
+
1356
+ def add_casts_around_norms(model: nn.Module) -> nn.Module:
1357
+ """
1358
+ Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
1359
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
1360
+ Args:
1361
+ model : top level module
1362
+ Returns:
1363
+ model, possibly modified in-place
1364
+ """
1365
+ print("Adding casts around norms...")
1366
+ cast_replacements = {
1367
+ "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
1368
+ "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
1369
+ "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat),
1370
+ "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
1371
+ "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
1372
+ "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
1373
+ }
1374
+ replace_modules_by_type(model, cast_replacements)
1375
+ return model
@@ -396,6 +396,8 @@ from .smooth_field.dictionary import (
396
396
  from .spatial.array import (
397
397
  Affine,
398
398
  AffineGrid,
399
+ ConvertBoxToPoints,
400
+ ConvertPointsToBoxes,
399
401
  Flip,
400
402
  GridDistortion,
401
403
  GridPatch,
@@ -427,6 +429,12 @@ from .spatial.dictionary import (
427
429
  Affined,
428
430
  AffineD,
429
431
  AffineDict,
432
+ ConvertBoxToPointsd,
433
+ ConvertBoxToPointsD,
434
+ ConvertBoxToPointsDict,
435
+ ConvertPointsToBoxesd,
436
+ ConvertPointsToBoxesD,
437
+ ConvertPointsToBoxesDict,
430
438
  Flipd,
431
439
  FlipD,
432
440
  FlipDict,
@@ -503,6 +511,7 @@ from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTr
503
511
  from .utility.array import (
504
512
  AddCoordinateChannels,
505
513
  AddExtremePointsChannel,
514
+ ApplyTransformToPoints,
506
515
  AsChannelLast,
507
516
  CastToType,
508
517
  ClassesToIndices,
@@ -542,6 +551,9 @@ from .utility.dictionary import (
542
551
  AddExtremePointsChanneld,
543
552
  AddExtremePointsChannelD,
544
553
  AddExtremePointsChannelDict,
554
+ ApplyTransformToPointsd,
555
+ ApplyTransformToPointsD,
556
+ ApplyTransformToPointsDict,
545
557
  AsChannelLastd,
546
558
  AsChannelLastD,
547
559
  AsChannelLastDict,
@@ -25,6 +25,7 @@ import torch
25
25
 
26
26
  from monai.config import USE_COMPILED, DtypeLike
27
27
  from monai.config.type_definitions import NdarrayOrTensor
28
+ from monai.data.box_utils import BoxMode, StandardMode
28
29
  from monai.data.meta_obj import get_track_meta, set_track_meta
29
30
  from monai.data.meta_tensor import MetaTensor
30
31
  from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
@@ -34,6 +35,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCro
34
35
  from monai.transforms.inverse import InvertibleTransform
35
36
  from monai.transforms.spatial.functional import (
36
37
  affine_func,
38
+ convert_box_to_points,
39
+ convert_points_to_box,
37
40
  flip,
38
41
  orientation,
39
42
  resize,
@@ -3544,3 +3547,44 @@ class RandSimulateLowResolution(RandomizableTransform):
3544
3547
 
3545
3548
  else:
3546
3549
  return img
3550
+
3551
+
3552
+ class ConvertBoxToPoints(Transform):
3553
+ """
3554
+ Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
3555
+ Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
3556
+ Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
3557
+ """
3558
+
3559
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3560
+
3561
+ def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
3562
+ """
3563
+ Args:
3564
+ mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
3565
+ """
3566
+ super().__init__()
3567
+ self.mode = StandardMode if mode is None else mode
3568
+
3569
+ def __call__(self, data: Any):
3570
+ data = convert_to_tensor(data, track_meta=get_track_meta())
3571
+ points = convert_box_to_points(data, mode=self.mode)
3572
+ return convert_to_dst_type(points, data)[0]
3573
+
3574
+
3575
+ class ConvertPointsToBoxes(Transform):
3576
+ """
3577
+ Converts points to an axis-aligned bounding box.
3578
+ Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
3579
+ (N, 4, 2) for the 4 corners of a 2D rectangle.
3580
+ """
3581
+
3582
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
3583
+
3584
+ def __init__(self) -> None:
3585
+ super().__init__()
3586
+
3587
+ def __call__(self, data: Any):
3588
+ data = convert_to_tensor(data, track_meta=get_track_meta())
3589
+ box = convert_points_to_box(data)
3590
+ return convert_to_dst_type(box, data)[0]
@@ -26,6 +26,7 @@ import torch
26
26
 
27
27
  from monai.config import DtypeLike, KeysCollection, SequenceStr
28
28
  from monai.config.type_definitions import NdarrayOrTensor
29
+ from monai.data.box_utils import BoxMode, StandardMode
29
30
  from monai.data.meta_obj import get_track_meta
30
31
  from monai.data.meta_tensor import MetaTensor
31
32
  from monai.networks.layers.simplelayers import GaussianFilter
@@ -33,6 +34,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop
33
34
  from monai.transforms.inverse import InvertibleTransform
34
35
  from monai.transforms.spatial.array import (
35
36
  Affine,
37
+ ConvertBoxToPoints,
38
+ ConvertPointsToBoxes,
36
39
  Flip,
37
40
  GridDistortion,
38
41
  GridPatch,
@@ -2585,6 +2588,7 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
2585
2588
  self, seed: int | None = None, state: np.random.RandomState | None = None
2586
2589
  ) -> RandSimulateLowResolutiond:
2587
2590
  super().set_random_state(seed, state)
2591
+ self.sim_lowres_tfm.set_random_state(seed, state)
2588
2592
  return self
2589
2593
 
2590
2594
  def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
@@ -2611,6 +2615,61 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
2611
2615
  return d
2612
2616
 
2613
2617
 
2618
+ class ConvertBoxToPointsd(MapTransform):
2619
+ """
2620
+ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
2621
+ """
2622
+
2623
+ backend = ConvertBoxToPoints.backend
2624
+
2625
+ def __init__(
2626
+ self,
2627
+ keys: KeysCollection,
2628
+ point_key="points",
2629
+ mode: str | BoxMode | type[BoxMode] | None = StandardMode,
2630
+ allow_missing_keys: bool = False,
2631
+ ):
2632
+ """
2633
+ Args:
2634
+ keys: keys of the corresponding items to be transformed.
2635
+ point_key: key to store the point data.
2636
+ mode: the mode of the input boxes. Defaults to StandardMode.
2637
+ allow_missing_keys: don't raise exception if key is missing.
2638
+ """
2639
+ super().__init__(keys, allow_missing_keys)
2640
+ self.point_key = point_key
2641
+ self.converter = ConvertBoxToPoints(mode=mode)
2642
+
2643
+ def __call__(self, data):
2644
+ d = dict(data)
2645
+ for key in self.key_iterator(d):
2646
+ data[self.point_key] = self.converter(d[key])
2647
+ return data
2648
+
2649
+
2650
+ class ConvertPointsToBoxesd(MapTransform):
2651
+ """
2652
+ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
2653
+ """
2654
+
2655
+ def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
2656
+ """
2657
+ Args:
2658
+ keys: keys of the corresponding items to be transformed.
2659
+ box_key: key to store the box data.
2660
+ allow_missing_keys: don't raise exception if key is missing.
2661
+ """
2662
+ super().__init__(keys, allow_missing_keys)
2663
+ self.box_key = box_key
2664
+ self.converter = ConvertPointsToBoxes()
2665
+
2666
+ def __call__(self, data):
2667
+ d = dict(data)
2668
+ for key in self.key_iterator(d):
2669
+ data[self.box_key] = self.converter(d[key])
2670
+ return data
2671
+
2672
+
2614
2673
  SpatialResampleD = SpatialResampleDict = SpatialResampled
2615
2674
  ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
2616
2675
  SpacingD = SpacingDict = Spacingd
@@ -2635,3 +2694,5 @@ GridSplitD = GridSplitDict = GridSplitd
2635
2694
  GridPatchD = GridPatchDict = GridPatchd
2636
2695
  RandGridPatchD = RandGridPatchDict = RandGridPatchd
2637
2696
  RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
2697
+ ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
2698
+ ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
@@ -24,6 +24,7 @@ import torch
24
24
  import monai
25
25
  from monai.config import USE_COMPILED
26
26
  from monai.config.type_definitions import NdarrayOrTensor
27
+ from monai.data.box_utils import get_boxmode
27
28
  from monai.data.meta_obj import get_track_meta
28
29
  from monai.data.meta_tensor import MetaTensor
29
30
  from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
@@ -32,7 +33,7 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop
32
33
  from monai.transforms.intensity.array import GaussianSmooth
33
34
  from monai.transforms.inverse import TraceableTransform
34
35
  from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
35
- from monai.transforms.utils_pytorch_numpy_unification import allclose
36
+ from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
36
37
  from monai.utils import (
37
38
  LazyAttr,
38
39
  TraceKeys,
@@ -610,3 +611,71 @@ def affine_func(
610
611
  out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
611
612
  out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
612
613
  return out if image_only else (out, affine)
614
+
615
+
616
+ def convert_box_to_points(bbox, mode):
617
+ """
618
+ Converts an axis-aligned bounding box to points.
619
+
620
+ Args:
621
+ mode: The mode specifying how to interpret the bounding box.
622
+ bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
623
+ for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
624
+
625
+ Returns:
626
+ sequence of points representing the corners of the bounding box.
627
+ """
628
+
629
+ mode = get_boxmode(mode)
630
+
631
+ points_list = []
632
+ for _num in range(bbox.shape[0]):
633
+ corners = mode.boxes_to_corners(bbox[_num : _num + 1])
634
+ if len(corners) == 4:
635
+ points_list.append(
636
+ concatenate(
637
+ [
638
+ concatenate([corners[0], corners[1]], axis=1),
639
+ concatenate([corners[2], corners[1]], axis=1),
640
+ concatenate([corners[2], corners[3]], axis=1),
641
+ concatenate([corners[0], corners[3]], axis=1),
642
+ ],
643
+ axis=0,
644
+ )
645
+ )
646
+ else:
647
+ points_list.append(
648
+ concatenate(
649
+ [
650
+ concatenate([corners[0], corners[1], corners[2]], axis=1),
651
+ concatenate([corners[3], corners[1], corners[2]], axis=1),
652
+ concatenate([corners[3], corners[4], corners[2]], axis=1),
653
+ concatenate([corners[0], corners[4], corners[2]], axis=1),
654
+ concatenate([corners[0], corners[1], corners[5]], axis=1),
655
+ concatenate([corners[3], corners[1], corners[5]], axis=1),
656
+ concatenate([corners[3], corners[4], corners[5]], axis=1),
657
+ concatenate([corners[0], corners[4], corners[5]], axis=1),
658
+ ],
659
+ axis=0,
660
+ )
661
+ )
662
+
663
+ return stack(points_list, dim=0)
664
+
665
+
666
+ def convert_points_to_box(points):
667
+ """
668
+ Converts points to an axis-aligned bounding box.
669
+
670
+ Args:
671
+ points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
672
+ a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
673
+ """
674
+ from monai.transforms.utils_pytorch_numpy_unification import max, min
675
+
676
+ mins = min(points, dim=1)
677
+ maxs = max(points, dim=1)
678
+ # Concatenate the min and max values to get the bounding boxes
679
+ bboxes = concatenate([mins, maxs], axis=1)
680
+
681
+ return bboxes