monai-weekly 1.4.dev2435__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +1 -1
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +2 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +10 -6
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +224 -40
- monai/transforms/__init__.py +12 -0
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +101 -3
- monai/transforms/utils.py +31 -4
- monai/utils/__init__.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +3 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +28 -26
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
monai/networks/utils.py
CHANGED
@@ -36,6 +36,8 @@ from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
|
|
36
36
|
onnx, _ = optional_import("onnx")
|
37
37
|
onnxreference, _ = optional_import("onnx.reference")
|
38
38
|
onnxruntime, _ = optional_import("onnxruntime")
|
39
|
+
polygraphy, polygraphy_imported = optional_import("polygraphy")
|
40
|
+
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
|
39
41
|
|
40
42
|
__all__ = [
|
41
43
|
"one_hot",
|
@@ -61,6 +63,7 @@ __all__ = [
|
|
61
63
|
"look_up_named_module",
|
62
64
|
"set_named_module",
|
63
65
|
"has_nvfuser_instance_norm",
|
66
|
+
"get_profile_shapes",
|
64
67
|
]
|
65
68
|
|
66
69
|
logger = get_logger(module_name=__name__)
|
@@ -68,6 +71,26 @@ logger = get_logger(module_name=__name__)
|
|
68
71
|
_has_nvfuser = None
|
69
72
|
|
70
73
|
|
74
|
+
def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
|
75
|
+
"""
|
76
|
+
Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
|
80
|
+
scale_shape = [*input_shape]
|
81
|
+
scale_shape[0] = scale_num
|
82
|
+
return scale_shape
|
83
|
+
|
84
|
+
# Use the dynamic batchsize range to generate the min, opt and max model input shape
|
85
|
+
if dynamic_batchsize:
|
86
|
+
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
|
87
|
+
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
|
88
|
+
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
|
89
|
+
else:
|
90
|
+
min_input_shape = opt_input_shape = max_input_shape = input_shape
|
91
|
+
return min_input_shape, opt_input_shape, max_input_shape
|
92
|
+
|
93
|
+
|
71
94
|
def has_nvfuser_instance_norm():
|
72
95
|
"""whether the current environment has InstanceNorm3dNVFuser
|
73
96
|
https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
|
@@ -606,6 +629,9 @@ def convert_to_onnx(
|
|
606
629
|
rtol: float = 1e-4,
|
607
630
|
atol: float = 0.0,
|
608
631
|
use_trace: bool = True,
|
632
|
+
do_constant_folding: bool = True,
|
633
|
+
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
|
634
|
+
dynamo=False,
|
609
635
|
**kwargs,
|
610
636
|
):
|
611
637
|
"""
|
@@ -632,7 +658,10 @@ def convert_to_onnx(
|
|
632
658
|
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
|
633
659
|
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
|
634
660
|
use_trace: whether to use `torch.jit.trace` to export the torchscript model.
|
635
|
-
|
661
|
+
do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
|
662
|
+
constant_size_threshold: passed to polygrapy conatant forling, default = 16M
|
663
|
+
kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
|
664
|
+
else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
|
636
665
|
https://pytorch.org/docs/master/generated/torch.jit.script.html.
|
637
666
|
|
638
667
|
"""
|
@@ -642,6 +671,7 @@ def convert_to_onnx(
|
|
642
671
|
if use_trace:
|
643
672
|
# let torch.onnx.export to trace the model.
|
644
673
|
mode_to_export = model
|
674
|
+
torch_versioned_kwargs = kwargs
|
645
675
|
else:
|
646
676
|
if not pytorch_after(1, 10):
|
647
677
|
if "example_outputs" not in kwargs:
|
@@ -654,32 +684,37 @@ def convert_to_onnx(
|
|
654
684
|
del kwargs["example_outputs"]
|
655
685
|
mode_to_export = torch.jit.script(model, **kwargs)
|
656
686
|
|
687
|
+
if torch.is_tensor(inputs) or isinstance(inputs, dict):
|
688
|
+
onnx_inputs = (inputs,)
|
689
|
+
else:
|
690
|
+
onnx_inputs = tuple(inputs)
|
691
|
+
|
657
692
|
if filename is None:
|
658
693
|
f = io.BytesIO()
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
694
|
+
else:
|
695
|
+
f = filename
|
696
|
+
|
697
|
+
torch.onnx.export(
|
698
|
+
mode_to_export,
|
699
|
+
onnx_inputs,
|
700
|
+
f=f,
|
701
|
+
input_names=input_names,
|
702
|
+
output_names=output_names,
|
703
|
+
dynamic_axes=dynamic_axes,
|
704
|
+
opset_version=opset_version,
|
705
|
+
do_constant_folding=do_constant_folding,
|
706
|
+
**torch_versioned_kwargs,
|
707
|
+
)
|
708
|
+
if filename is None:
|
669
709
|
onnx_model = onnx.load_model_from_string(f.getvalue())
|
670
710
|
else:
|
671
|
-
torch.onnx.export(
|
672
|
-
mode_to_export,
|
673
|
-
tuple(inputs),
|
674
|
-
f=filename,
|
675
|
-
input_names=input_names,
|
676
|
-
output_names=output_names,
|
677
|
-
dynamic_axes=dynamic_axes,
|
678
|
-
opset_version=opset_version,
|
679
|
-
**torch_versioned_kwargs,
|
680
|
-
)
|
681
711
|
onnx_model = onnx.load(filename)
|
682
712
|
|
713
|
+
if do_constant_folding and polygraphy_imported:
|
714
|
+
from polygraphy.backend.onnx.loader import fold_constants
|
715
|
+
|
716
|
+
fold_constants(onnx_model, size_threshold=constant_size_threshold)
|
717
|
+
|
683
718
|
if verify:
|
684
719
|
if device is None:
|
685
720
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -814,7 +849,6 @@ def _onnx_trt_compile(
|
|
814
849
|
|
815
850
|
"""
|
816
851
|
trt, _ = optional_import("tensorrt", "8.5.3")
|
817
|
-
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
|
818
852
|
|
819
853
|
input_shapes = (min_shape, opt_shape, max_shape)
|
820
854
|
# default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
|
@@ -916,8 +950,6 @@ def convert_to_trt(
|
|
916
950
|
to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
|
917
951
|
"""
|
918
952
|
|
919
|
-
torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0")
|
920
|
-
|
921
953
|
if not torch.cuda.is_available():
|
922
954
|
raise Exception("Cannot find any GPU devices.")
|
923
955
|
|
@@ -935,23 +967,9 @@ def convert_to_trt(
|
|
935
967
|
convert_precision = torch.float32 if precision == "fp32" else torch.half
|
936
968
|
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
|
937
969
|
|
938
|
-
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
|
939
|
-
scale_shape = [*input_shape]
|
940
|
-
scale_shape[0] *= scale_num
|
941
|
-
return scale_shape
|
942
|
-
|
943
|
-
# Use the dynamic batchsize range to generate the min, opt and max model input shape
|
944
|
-
if dynamic_batchsize:
|
945
|
-
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
|
946
|
-
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
|
947
|
-
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
|
948
|
-
else:
|
949
|
-
min_input_shape = opt_input_shape = max_input_shape = input_shape
|
950
|
-
|
951
970
|
# convert the torch model to a TorchScript model on target device
|
952
971
|
model = model.eval().to(target_device)
|
953
|
-
|
954
|
-
ir_model.eval()
|
972
|
+
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
|
955
973
|
|
956
974
|
if use_onnx:
|
957
975
|
# set the batch dim as dynamic
|
@@ -960,7 +978,6 @@ def convert_to_trt(
|
|
960
978
|
ir_model = convert_to_onnx(
|
961
979
|
model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes
|
962
980
|
)
|
963
|
-
|
964
981
|
# convert the model through the ONNX-TensorRT way
|
965
982
|
trt_model = _onnx_trt_compile(
|
966
983
|
ir_model,
|
@@ -973,6 +990,8 @@ def convert_to_trt(
|
|
973
990
|
output_names=onnx_output_names,
|
974
991
|
)
|
975
992
|
else:
|
993
|
+
ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
|
994
|
+
ir_model.eval()
|
976
995
|
# convert the model through the Torch-TensorRT way
|
977
996
|
ir_model.to(target_device)
|
978
997
|
with torch.no_grad():
|
@@ -1189,3 +1208,168 @@ class CastTempType(nn.Module):
|
|
1189
1208
|
if dtype == self.initial_type:
|
1190
1209
|
x = x.to(self.initial_type)
|
1191
1210
|
return x
|
1211
|
+
|
1212
|
+
|
1213
|
+
def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
|
1214
|
+
"""
|
1215
|
+
Utility function to cast a single tensor from from_dtype to to_dtype
|
1216
|
+
"""
|
1217
|
+
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
|
1218
|
+
|
1219
|
+
|
1220
|
+
def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
|
1221
|
+
"""
|
1222
|
+
Utility function to cast all tensors in a tuple from from_dtype to to_dtype
|
1223
|
+
"""
|
1224
|
+
if isinstance(x, torch.Tensor):
|
1225
|
+
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
|
1226
|
+
else:
|
1227
|
+
if isinstance(x, dict):
|
1228
|
+
new_dict = {}
|
1229
|
+
for k in x.keys():
|
1230
|
+
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
|
1231
|
+
return new_dict
|
1232
|
+
elif isinstance(x, tuple):
|
1233
|
+
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
|
1234
|
+
|
1235
|
+
|
1236
|
+
class CastToFloat(torch.nn.Module):
|
1237
|
+
"""
|
1238
|
+
Class used to add autocast protection for ONNX export
|
1239
|
+
for forward methods with single return vaue
|
1240
|
+
"""
|
1241
|
+
|
1242
|
+
def __init__(self, mod):
|
1243
|
+
super().__init__()
|
1244
|
+
self.mod = mod
|
1245
|
+
|
1246
|
+
def forward(self, x):
|
1247
|
+
dtype = x.dtype
|
1248
|
+
with torch.amp.autocast("cuda", enabled=False):
|
1249
|
+
ret = self.mod.forward(x.to(torch.float32)).to(dtype)
|
1250
|
+
return ret
|
1251
|
+
|
1252
|
+
|
1253
|
+
class CastToFloatAll(torch.nn.Module):
|
1254
|
+
"""
|
1255
|
+
Class used to add autocast protection for ONNX export
|
1256
|
+
for forward methods with multiple return values
|
1257
|
+
"""
|
1258
|
+
|
1259
|
+
def __init__(self, mod):
|
1260
|
+
super().__init__()
|
1261
|
+
self.mod = mod
|
1262
|
+
|
1263
|
+
def forward(self, *args):
|
1264
|
+
from_dtype = args[0].dtype
|
1265
|
+
with torch.amp.autocast("cuda", enabled=False):
|
1266
|
+
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
|
1267
|
+
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
|
1268
|
+
|
1269
|
+
|
1270
|
+
def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
|
1271
|
+
"""
|
1272
|
+
Generic function generator to replace base_t module with dest_t wrapper.
|
1273
|
+
Args:
|
1274
|
+
base_t : module type to replace
|
1275
|
+
dest_t : destination module type
|
1276
|
+
Returns:
|
1277
|
+
swap function to replace base_t module with dest_t
|
1278
|
+
"""
|
1279
|
+
|
1280
|
+
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1281
|
+
out = dest_t(mod)
|
1282
|
+
return out
|
1283
|
+
|
1284
|
+
return expansion_fn
|
1285
|
+
|
1286
|
+
|
1287
|
+
def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
|
1288
|
+
"""
|
1289
|
+
Generic function generator to replace base_t module with dest_t.
|
1290
|
+
base_t and dest_t should have same atrributes. No weights are copied.
|
1291
|
+
Args:
|
1292
|
+
base_t : module type to replace
|
1293
|
+
dest_t : destination module type
|
1294
|
+
Returns:
|
1295
|
+
swap function to replace base_t module with dest_t
|
1296
|
+
"""
|
1297
|
+
|
1298
|
+
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1299
|
+
if not isinstance(mod, base_t):
|
1300
|
+
return None
|
1301
|
+
args = [getattr(mod, name, None) for name in mod.__constants__]
|
1302
|
+
out = dest_t(*args)
|
1303
|
+
return out
|
1304
|
+
|
1305
|
+
return expansion_fn
|
1306
|
+
|
1307
|
+
|
1308
|
+
def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:
|
1309
|
+
"""
|
1310
|
+
This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
|
1311
|
+
for swapping nested modules through arbitrary levels if children
|
1312
|
+
|
1313
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1314
|
+
|
1315
|
+
"""
|
1316
|
+
for path, new_mod in mapping.items():
|
1317
|
+
expanded_path = path.split(".")
|
1318
|
+
parent_mod = model
|
1319
|
+
for sub_path in expanded_path[:-1]:
|
1320
|
+
submod = parent_mod._modules[sub_path]
|
1321
|
+
if submod is None:
|
1322
|
+
break
|
1323
|
+
else:
|
1324
|
+
parent_mod = submod
|
1325
|
+
parent_mod._modules[expanded_path[-1]] = new_mod
|
1326
|
+
|
1327
|
+
return model
|
1328
|
+
|
1329
|
+
|
1330
|
+
def replace_modules_by_type(
|
1331
|
+
model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]
|
1332
|
+
) -> nn.Module:
|
1333
|
+
"""
|
1334
|
+
Top-level function to replace modules in model, specified by class name with a desired replacement.
|
1335
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1336
|
+
Args:
|
1337
|
+
model : top level module
|
1338
|
+
expansions : replacement dictionary: module class name -> replacement function generator
|
1339
|
+
Returns:
|
1340
|
+
model, possibly modified in-place
|
1341
|
+
"""
|
1342
|
+
mapping: dict[str, nn.Module] = {}
|
1343
|
+
for name, m in model.named_modules():
|
1344
|
+
m_type = type(m).__name__
|
1345
|
+
if m_type in expansions:
|
1346
|
+
# print (f"Found {m_type} in expansions ...")
|
1347
|
+
swapped = expansions[m_type](m)
|
1348
|
+
if swapped:
|
1349
|
+
mapping[name] = swapped
|
1350
|
+
|
1351
|
+
print(f"Swapped {len(mapping)} modules")
|
1352
|
+
_swap_modules(model, mapping)
|
1353
|
+
return model
|
1354
|
+
|
1355
|
+
|
1356
|
+
def add_casts_around_norms(model: nn.Module) -> nn.Module:
|
1357
|
+
"""
|
1358
|
+
Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
|
1359
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1360
|
+
Args:
|
1361
|
+
model : top level module
|
1362
|
+
Returns:
|
1363
|
+
model, possibly modified in-place
|
1364
|
+
"""
|
1365
|
+
print("Adding casts around norms...")
|
1366
|
+
cast_replacements = {
|
1367
|
+
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
|
1368
|
+
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
|
1369
|
+
"BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat),
|
1370
|
+
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
|
1371
|
+
"InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
|
1372
|
+
"InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
|
1373
|
+
}
|
1374
|
+
replace_modules_by_type(model, cast_replacements)
|
1375
|
+
return model
|
monai/transforms/__init__.py
CHANGED
@@ -396,6 +396,8 @@ from .smooth_field.dictionary import (
|
|
396
396
|
from .spatial.array import (
|
397
397
|
Affine,
|
398
398
|
AffineGrid,
|
399
|
+
ConvertBoxToPoints,
|
400
|
+
ConvertPointsToBoxes,
|
399
401
|
Flip,
|
400
402
|
GridDistortion,
|
401
403
|
GridPatch,
|
@@ -427,6 +429,12 @@ from .spatial.dictionary import (
|
|
427
429
|
Affined,
|
428
430
|
AffineD,
|
429
431
|
AffineDict,
|
432
|
+
ConvertBoxToPointsd,
|
433
|
+
ConvertBoxToPointsD,
|
434
|
+
ConvertBoxToPointsDict,
|
435
|
+
ConvertPointsToBoxesd,
|
436
|
+
ConvertPointsToBoxesD,
|
437
|
+
ConvertPointsToBoxesDict,
|
430
438
|
Flipd,
|
431
439
|
FlipD,
|
432
440
|
FlipDict,
|
@@ -503,6 +511,7 @@ from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTr
|
|
503
511
|
from .utility.array import (
|
504
512
|
AddCoordinateChannels,
|
505
513
|
AddExtremePointsChannel,
|
514
|
+
ApplyTransformToPoints,
|
506
515
|
AsChannelLast,
|
507
516
|
CastToType,
|
508
517
|
ClassesToIndices,
|
@@ -542,6 +551,9 @@ from .utility.dictionary import (
|
|
542
551
|
AddExtremePointsChanneld,
|
543
552
|
AddExtremePointsChannelD,
|
544
553
|
AddExtremePointsChannelDict,
|
554
|
+
ApplyTransformToPointsd,
|
555
|
+
ApplyTransformToPointsD,
|
556
|
+
ApplyTransformToPointsDict,
|
545
557
|
AsChannelLastd,
|
546
558
|
AsChannelLastD,
|
547
559
|
AsChannelLastDict,
|
@@ -25,6 +25,7 @@ import torch
|
|
25
25
|
|
26
26
|
from monai.config import USE_COMPILED, DtypeLike
|
27
27
|
from monai.config.type_definitions import NdarrayOrTensor
|
28
|
+
from monai.data.box_utils import BoxMode, StandardMode
|
28
29
|
from monai.data.meta_obj import get_track_meta, set_track_meta
|
29
30
|
from monai.data.meta_tensor import MetaTensor
|
30
31
|
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
|
@@ -34,6 +35,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCro
|
|
34
35
|
from monai.transforms.inverse import InvertibleTransform
|
35
36
|
from monai.transforms.spatial.functional import (
|
36
37
|
affine_func,
|
38
|
+
convert_box_to_points,
|
39
|
+
convert_points_to_box,
|
37
40
|
flip,
|
38
41
|
orientation,
|
39
42
|
resize,
|
@@ -3544,3 +3547,44 @@ class RandSimulateLowResolution(RandomizableTransform):
|
|
3544
3547
|
|
3545
3548
|
else:
|
3546
3549
|
return img
|
3550
|
+
|
3551
|
+
|
3552
|
+
class ConvertBoxToPoints(Transform):
|
3553
|
+
"""
|
3554
|
+
Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
|
3555
|
+
Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
|
3556
|
+
Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
|
3557
|
+
"""
|
3558
|
+
|
3559
|
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
3560
|
+
|
3561
|
+
def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
|
3562
|
+
"""
|
3563
|
+
Args:
|
3564
|
+
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
|
3565
|
+
"""
|
3566
|
+
super().__init__()
|
3567
|
+
self.mode = StandardMode if mode is None else mode
|
3568
|
+
|
3569
|
+
def __call__(self, data: Any):
|
3570
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
3571
|
+
points = convert_box_to_points(data, mode=self.mode)
|
3572
|
+
return convert_to_dst_type(points, data)[0]
|
3573
|
+
|
3574
|
+
|
3575
|
+
class ConvertPointsToBoxes(Transform):
|
3576
|
+
"""
|
3577
|
+
Converts points to an axis-aligned bounding box.
|
3578
|
+
Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
|
3579
|
+
(N, 4, 2) for the 4 corners of a 2D rectangle.
|
3580
|
+
"""
|
3581
|
+
|
3582
|
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
3583
|
+
|
3584
|
+
def __init__(self) -> None:
|
3585
|
+
super().__init__()
|
3586
|
+
|
3587
|
+
def __call__(self, data: Any):
|
3588
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
3589
|
+
box = convert_points_to_box(data)
|
3590
|
+
return convert_to_dst_type(box, data)[0]
|
@@ -26,6 +26,7 @@ import torch
|
|
26
26
|
|
27
27
|
from monai.config import DtypeLike, KeysCollection, SequenceStr
|
28
28
|
from monai.config.type_definitions import NdarrayOrTensor
|
29
|
+
from monai.data.box_utils import BoxMode, StandardMode
|
29
30
|
from monai.data.meta_obj import get_track_meta
|
30
31
|
from monai.data.meta_tensor import MetaTensor
|
31
32
|
from monai.networks.layers.simplelayers import GaussianFilter
|
@@ -33,6 +34,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop
|
|
33
34
|
from monai.transforms.inverse import InvertibleTransform
|
34
35
|
from monai.transforms.spatial.array import (
|
35
36
|
Affine,
|
37
|
+
ConvertBoxToPoints,
|
38
|
+
ConvertPointsToBoxes,
|
36
39
|
Flip,
|
37
40
|
GridDistortion,
|
38
41
|
GridPatch,
|
@@ -2585,6 +2588,7 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
|
|
2585
2588
|
self, seed: int | None = None, state: np.random.RandomState | None = None
|
2586
2589
|
) -> RandSimulateLowResolutiond:
|
2587
2590
|
super().set_random_state(seed, state)
|
2591
|
+
self.sim_lowres_tfm.set_random_state(seed, state)
|
2588
2592
|
return self
|
2589
2593
|
|
2590
2594
|
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
@@ -2611,6 +2615,61 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
|
|
2611
2615
|
return d
|
2612
2616
|
|
2613
2617
|
|
2618
|
+
class ConvertBoxToPointsd(MapTransform):
|
2619
|
+
"""
|
2620
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
|
2621
|
+
"""
|
2622
|
+
|
2623
|
+
backend = ConvertBoxToPoints.backend
|
2624
|
+
|
2625
|
+
def __init__(
|
2626
|
+
self,
|
2627
|
+
keys: KeysCollection,
|
2628
|
+
point_key="points",
|
2629
|
+
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
|
2630
|
+
allow_missing_keys: bool = False,
|
2631
|
+
):
|
2632
|
+
"""
|
2633
|
+
Args:
|
2634
|
+
keys: keys of the corresponding items to be transformed.
|
2635
|
+
point_key: key to store the point data.
|
2636
|
+
mode: the mode of the input boxes. Defaults to StandardMode.
|
2637
|
+
allow_missing_keys: don't raise exception if key is missing.
|
2638
|
+
"""
|
2639
|
+
super().__init__(keys, allow_missing_keys)
|
2640
|
+
self.point_key = point_key
|
2641
|
+
self.converter = ConvertBoxToPoints(mode=mode)
|
2642
|
+
|
2643
|
+
def __call__(self, data):
|
2644
|
+
d = dict(data)
|
2645
|
+
for key in self.key_iterator(d):
|
2646
|
+
data[self.point_key] = self.converter(d[key])
|
2647
|
+
return data
|
2648
|
+
|
2649
|
+
|
2650
|
+
class ConvertPointsToBoxesd(MapTransform):
|
2651
|
+
"""
|
2652
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
|
2653
|
+
"""
|
2654
|
+
|
2655
|
+
def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
|
2656
|
+
"""
|
2657
|
+
Args:
|
2658
|
+
keys: keys of the corresponding items to be transformed.
|
2659
|
+
box_key: key to store the box data.
|
2660
|
+
allow_missing_keys: don't raise exception if key is missing.
|
2661
|
+
"""
|
2662
|
+
super().__init__(keys, allow_missing_keys)
|
2663
|
+
self.box_key = box_key
|
2664
|
+
self.converter = ConvertPointsToBoxes()
|
2665
|
+
|
2666
|
+
def __call__(self, data):
|
2667
|
+
d = dict(data)
|
2668
|
+
for key in self.key_iterator(d):
|
2669
|
+
data[self.box_key] = self.converter(d[key])
|
2670
|
+
return data
|
2671
|
+
|
2672
|
+
|
2614
2673
|
SpatialResampleD = SpatialResampleDict = SpatialResampled
|
2615
2674
|
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
|
2616
2675
|
SpacingD = SpacingDict = Spacingd
|
@@ -2635,3 +2694,5 @@ GridSplitD = GridSplitDict = GridSplitd
|
|
2635
2694
|
GridPatchD = GridPatchDict = GridPatchd
|
2636
2695
|
RandGridPatchD = RandGridPatchDict = RandGridPatchd
|
2637
2696
|
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
|
2697
|
+
ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
|
2698
|
+
ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
|
@@ -24,6 +24,7 @@ import torch
|
|
24
24
|
import monai
|
25
25
|
from monai.config import USE_COMPILED
|
26
26
|
from monai.config.type_definitions import NdarrayOrTensor
|
27
|
+
from monai.data.box_utils import get_boxmode
|
27
28
|
from monai.data.meta_obj import get_track_meta
|
28
29
|
from monai.data.meta_tensor import MetaTensor
|
29
30
|
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
|
@@ -32,7 +33,7 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop
|
|
32
33
|
from monai.transforms.intensity.array import GaussianSmooth
|
33
34
|
from monai.transforms.inverse import TraceableTransform
|
34
35
|
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
|
35
|
-
from monai.transforms.utils_pytorch_numpy_unification import allclose
|
36
|
+
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
|
36
37
|
from monai.utils import (
|
37
38
|
LazyAttr,
|
38
39
|
TraceKeys,
|
@@ -610,3 +611,71 @@ def affine_func(
|
|
610
611
|
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
|
611
612
|
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
|
612
613
|
return out if image_only else (out, affine)
|
614
|
+
|
615
|
+
|
616
|
+
def convert_box_to_points(bbox, mode):
|
617
|
+
"""
|
618
|
+
Converts an axis-aligned bounding box to points.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
mode: The mode specifying how to interpret the bounding box.
|
622
|
+
bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
|
623
|
+
for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
|
624
|
+
|
625
|
+
Returns:
|
626
|
+
sequence of points representing the corners of the bounding box.
|
627
|
+
"""
|
628
|
+
|
629
|
+
mode = get_boxmode(mode)
|
630
|
+
|
631
|
+
points_list = []
|
632
|
+
for _num in range(bbox.shape[0]):
|
633
|
+
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
|
634
|
+
if len(corners) == 4:
|
635
|
+
points_list.append(
|
636
|
+
concatenate(
|
637
|
+
[
|
638
|
+
concatenate([corners[0], corners[1]], axis=1),
|
639
|
+
concatenate([corners[2], corners[1]], axis=1),
|
640
|
+
concatenate([corners[2], corners[3]], axis=1),
|
641
|
+
concatenate([corners[0], corners[3]], axis=1),
|
642
|
+
],
|
643
|
+
axis=0,
|
644
|
+
)
|
645
|
+
)
|
646
|
+
else:
|
647
|
+
points_list.append(
|
648
|
+
concatenate(
|
649
|
+
[
|
650
|
+
concatenate([corners[0], corners[1], corners[2]], axis=1),
|
651
|
+
concatenate([corners[3], corners[1], corners[2]], axis=1),
|
652
|
+
concatenate([corners[3], corners[4], corners[2]], axis=1),
|
653
|
+
concatenate([corners[0], corners[4], corners[2]], axis=1),
|
654
|
+
concatenate([corners[0], corners[1], corners[5]], axis=1),
|
655
|
+
concatenate([corners[3], corners[1], corners[5]], axis=1),
|
656
|
+
concatenate([corners[3], corners[4], corners[5]], axis=1),
|
657
|
+
concatenate([corners[0], corners[4], corners[5]], axis=1),
|
658
|
+
],
|
659
|
+
axis=0,
|
660
|
+
)
|
661
|
+
)
|
662
|
+
|
663
|
+
return stack(points_list, dim=0)
|
664
|
+
|
665
|
+
|
666
|
+
def convert_points_to_box(points):
|
667
|
+
"""
|
668
|
+
Converts points to an axis-aligned bounding box.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
|
672
|
+
a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
|
673
|
+
"""
|
674
|
+
from monai.transforms.utils_pytorch_numpy_unification import max, min
|
675
|
+
|
676
|
+
mins = min(points, dim=1)
|
677
|
+
maxs = max(points, dim=1)
|
678
|
+
# Concatenate the min and max values to get the bounding boxes
|
679
|
+
bboxes = concatenate([mins, maxs], axis=1)
|
680
|
+
|
681
|
+
return bboxes
|