monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +2 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/inferers/utils.py +1 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +53 -11
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +225 -41
- monai/transforms/__init__.py +24 -2
- monai/transforms/io/array.py +58 -2
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +105 -3
- monai/transforms/utils.py +83 -10
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +4 -1
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +36 -31
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
monai/networks/utils.py
CHANGED
@@ -36,6 +36,8 @@ from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
|
|
36
36
|
onnx, _ = optional_import("onnx")
|
37
37
|
onnxreference, _ = optional_import("onnx.reference")
|
38
38
|
onnxruntime, _ = optional_import("onnxruntime")
|
39
|
+
polygraphy, polygraphy_imported = optional_import("polygraphy")
|
40
|
+
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
|
39
41
|
|
40
42
|
__all__ = [
|
41
43
|
"one_hot",
|
@@ -61,6 +63,7 @@ __all__ = [
|
|
61
63
|
"look_up_named_module",
|
62
64
|
"set_named_module",
|
63
65
|
"has_nvfuser_instance_norm",
|
66
|
+
"get_profile_shapes",
|
64
67
|
]
|
65
68
|
|
66
69
|
logger = get_logger(module_name=__name__)
|
@@ -68,6 +71,26 @@ logger = get_logger(module_name=__name__)
|
|
68
71
|
_has_nvfuser = None
|
69
72
|
|
70
73
|
|
74
|
+
def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
|
75
|
+
"""
|
76
|
+
Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
|
80
|
+
scale_shape = [*input_shape]
|
81
|
+
scale_shape[0] = scale_num
|
82
|
+
return scale_shape
|
83
|
+
|
84
|
+
# Use the dynamic batchsize range to generate the min, opt and max model input shape
|
85
|
+
if dynamic_batchsize:
|
86
|
+
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
|
87
|
+
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
|
88
|
+
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
|
89
|
+
else:
|
90
|
+
min_input_shape = opt_input_shape = max_input_shape = input_shape
|
91
|
+
return min_input_shape, opt_input_shape, max_input_shape
|
92
|
+
|
93
|
+
|
71
94
|
def has_nvfuser_instance_norm():
|
72
95
|
"""whether the current environment has InstanceNorm3dNVFuser
|
73
96
|
https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
|
@@ -606,6 +629,9 @@ def convert_to_onnx(
|
|
606
629
|
rtol: float = 1e-4,
|
607
630
|
atol: float = 0.0,
|
608
631
|
use_trace: bool = True,
|
632
|
+
do_constant_folding: bool = True,
|
633
|
+
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
|
634
|
+
dynamo=False,
|
609
635
|
**kwargs,
|
610
636
|
):
|
611
637
|
"""
|
@@ -632,7 +658,10 @@ def convert_to_onnx(
|
|
632
658
|
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
|
633
659
|
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
|
634
660
|
use_trace: whether to use `torch.jit.trace` to export the torchscript model.
|
635
|
-
|
661
|
+
do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
|
662
|
+
constant_size_threshold: passed to polygrapy conatant forling, default = 16M
|
663
|
+
kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
|
664
|
+
else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
|
636
665
|
https://pytorch.org/docs/master/generated/torch.jit.script.html.
|
637
666
|
|
638
667
|
"""
|
@@ -642,6 +671,7 @@ def convert_to_onnx(
|
|
642
671
|
if use_trace:
|
643
672
|
# let torch.onnx.export to trace the model.
|
644
673
|
mode_to_export = model
|
674
|
+
torch_versioned_kwargs = kwargs
|
645
675
|
else:
|
646
676
|
if not pytorch_after(1, 10):
|
647
677
|
if "example_outputs" not in kwargs:
|
@@ -654,32 +684,37 @@ def convert_to_onnx(
|
|
654
684
|
del kwargs["example_outputs"]
|
655
685
|
mode_to_export = torch.jit.script(model, **kwargs)
|
656
686
|
|
687
|
+
if torch.is_tensor(inputs) or isinstance(inputs, dict):
|
688
|
+
onnx_inputs = (inputs,)
|
689
|
+
else:
|
690
|
+
onnx_inputs = tuple(inputs)
|
691
|
+
|
657
692
|
if filename is None:
|
658
693
|
f = io.BytesIO()
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
694
|
+
else:
|
695
|
+
f = filename
|
696
|
+
|
697
|
+
torch.onnx.export(
|
698
|
+
mode_to_export,
|
699
|
+
onnx_inputs,
|
700
|
+
f=f,
|
701
|
+
input_names=input_names,
|
702
|
+
output_names=output_names,
|
703
|
+
dynamic_axes=dynamic_axes,
|
704
|
+
opset_version=opset_version,
|
705
|
+
do_constant_folding=do_constant_folding,
|
706
|
+
**torch_versioned_kwargs,
|
707
|
+
)
|
708
|
+
if filename is None:
|
669
709
|
onnx_model = onnx.load_model_from_string(f.getvalue())
|
670
710
|
else:
|
671
|
-
torch.onnx.export(
|
672
|
-
mode_to_export,
|
673
|
-
tuple(inputs),
|
674
|
-
f=filename,
|
675
|
-
input_names=input_names,
|
676
|
-
output_names=output_names,
|
677
|
-
dynamic_axes=dynamic_axes,
|
678
|
-
opset_version=opset_version,
|
679
|
-
**torch_versioned_kwargs,
|
680
|
-
)
|
681
711
|
onnx_model = onnx.load(filename)
|
682
712
|
|
713
|
+
if do_constant_folding and polygraphy_imported:
|
714
|
+
from polygraphy.backend.onnx.loader import fold_constants
|
715
|
+
|
716
|
+
fold_constants(onnx_model, size_threshold=constant_size_threshold)
|
717
|
+
|
683
718
|
if verify:
|
684
719
|
if device is None:
|
685
720
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -814,7 +849,6 @@ def _onnx_trt_compile(
|
|
814
849
|
|
815
850
|
"""
|
816
851
|
trt, _ = optional_import("tensorrt", "8.5.3")
|
817
|
-
torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
|
818
852
|
|
819
853
|
input_shapes = (min_shape, opt_shape, max_shape)
|
820
854
|
# default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
|
@@ -851,7 +885,7 @@ def _onnx_trt_compile(
|
|
851
885
|
# wrap the serialized TensorRT engine back to a TorchScript module.
|
852
886
|
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
|
853
887
|
f.getvalue(),
|
854
|
-
device=
|
888
|
+
device=torch_tensorrt.Device(f"cuda:{device}"),
|
855
889
|
input_binding_names=input_names,
|
856
890
|
output_binding_names=output_names,
|
857
891
|
)
|
@@ -916,8 +950,6 @@ def convert_to_trt(
|
|
916
950
|
to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
|
917
951
|
"""
|
918
952
|
|
919
|
-
torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0")
|
920
|
-
|
921
953
|
if not torch.cuda.is_available():
|
922
954
|
raise Exception("Cannot find any GPU devices.")
|
923
955
|
|
@@ -935,23 +967,9 @@ def convert_to_trt(
|
|
935
967
|
convert_precision = torch.float32 if precision == "fp32" else torch.half
|
936
968
|
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
|
937
969
|
|
938
|
-
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
|
939
|
-
scale_shape = [*input_shape]
|
940
|
-
scale_shape[0] *= scale_num
|
941
|
-
return scale_shape
|
942
|
-
|
943
|
-
# Use the dynamic batchsize range to generate the min, opt and max model input shape
|
944
|
-
if dynamic_batchsize:
|
945
|
-
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
|
946
|
-
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
|
947
|
-
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
|
948
|
-
else:
|
949
|
-
min_input_shape = opt_input_shape = max_input_shape = input_shape
|
950
|
-
|
951
970
|
# convert the torch model to a TorchScript model on target device
|
952
971
|
model = model.eval().to(target_device)
|
953
|
-
|
954
|
-
ir_model.eval()
|
972
|
+
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
|
955
973
|
|
956
974
|
if use_onnx:
|
957
975
|
# set the batch dim as dynamic
|
@@ -960,7 +978,6 @@ def convert_to_trt(
|
|
960
978
|
ir_model = convert_to_onnx(
|
961
979
|
model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes
|
962
980
|
)
|
963
|
-
|
964
981
|
# convert the model through the ONNX-TensorRT way
|
965
982
|
trt_model = _onnx_trt_compile(
|
966
983
|
ir_model,
|
@@ -973,6 +990,8 @@ def convert_to_trt(
|
|
973
990
|
output_names=onnx_output_names,
|
974
991
|
)
|
975
992
|
else:
|
993
|
+
ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
|
994
|
+
ir_model.eval()
|
976
995
|
# convert the model through the Torch-TensorRT way
|
977
996
|
ir_model.to(target_device)
|
978
997
|
with torch.no_grad():
|
@@ -1189,3 +1208,168 @@ class CastTempType(nn.Module):
|
|
1189
1208
|
if dtype == self.initial_type:
|
1190
1209
|
x = x.to(self.initial_type)
|
1191
1210
|
return x
|
1211
|
+
|
1212
|
+
|
1213
|
+
def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
|
1214
|
+
"""
|
1215
|
+
Utility function to cast a single tensor from from_dtype to to_dtype
|
1216
|
+
"""
|
1217
|
+
return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
|
1218
|
+
|
1219
|
+
|
1220
|
+
def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
|
1221
|
+
"""
|
1222
|
+
Utility function to cast all tensors in a tuple from from_dtype to to_dtype
|
1223
|
+
"""
|
1224
|
+
if isinstance(x, torch.Tensor):
|
1225
|
+
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
|
1226
|
+
else:
|
1227
|
+
if isinstance(x, dict):
|
1228
|
+
new_dict = {}
|
1229
|
+
for k in x.keys():
|
1230
|
+
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
|
1231
|
+
return new_dict
|
1232
|
+
elif isinstance(x, tuple):
|
1233
|
+
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
|
1234
|
+
|
1235
|
+
|
1236
|
+
class CastToFloat(torch.nn.Module):
|
1237
|
+
"""
|
1238
|
+
Class used to add autocast protection for ONNX export
|
1239
|
+
for forward methods with single return vaue
|
1240
|
+
"""
|
1241
|
+
|
1242
|
+
def __init__(self, mod):
|
1243
|
+
super().__init__()
|
1244
|
+
self.mod = mod
|
1245
|
+
|
1246
|
+
def forward(self, x):
|
1247
|
+
dtype = x.dtype
|
1248
|
+
with torch.amp.autocast("cuda", enabled=False):
|
1249
|
+
ret = self.mod.forward(x.to(torch.float32)).to(dtype)
|
1250
|
+
return ret
|
1251
|
+
|
1252
|
+
|
1253
|
+
class CastToFloatAll(torch.nn.Module):
|
1254
|
+
"""
|
1255
|
+
Class used to add autocast protection for ONNX export
|
1256
|
+
for forward methods with multiple return values
|
1257
|
+
"""
|
1258
|
+
|
1259
|
+
def __init__(self, mod):
|
1260
|
+
super().__init__()
|
1261
|
+
self.mod = mod
|
1262
|
+
|
1263
|
+
def forward(self, *args):
|
1264
|
+
from_dtype = args[0].dtype
|
1265
|
+
with torch.amp.autocast("cuda", enabled=False):
|
1266
|
+
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
|
1267
|
+
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
|
1268
|
+
|
1269
|
+
|
1270
|
+
def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
|
1271
|
+
"""
|
1272
|
+
Generic function generator to replace base_t module with dest_t wrapper.
|
1273
|
+
Args:
|
1274
|
+
base_t : module type to replace
|
1275
|
+
dest_t : destination module type
|
1276
|
+
Returns:
|
1277
|
+
swap function to replace base_t module with dest_t
|
1278
|
+
"""
|
1279
|
+
|
1280
|
+
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1281
|
+
out = dest_t(mod)
|
1282
|
+
return out
|
1283
|
+
|
1284
|
+
return expansion_fn
|
1285
|
+
|
1286
|
+
|
1287
|
+
def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
|
1288
|
+
"""
|
1289
|
+
Generic function generator to replace base_t module with dest_t.
|
1290
|
+
base_t and dest_t should have same atrributes. No weights are copied.
|
1291
|
+
Args:
|
1292
|
+
base_t : module type to replace
|
1293
|
+
dest_t : destination module type
|
1294
|
+
Returns:
|
1295
|
+
swap function to replace base_t module with dest_t
|
1296
|
+
"""
|
1297
|
+
|
1298
|
+
def expansion_fn(mod: nn.Module) -> nn.Module | None:
|
1299
|
+
if not isinstance(mod, base_t):
|
1300
|
+
return None
|
1301
|
+
args = [getattr(mod, name, None) for name in mod.__constants__]
|
1302
|
+
out = dest_t(*args)
|
1303
|
+
return out
|
1304
|
+
|
1305
|
+
return expansion_fn
|
1306
|
+
|
1307
|
+
|
1308
|
+
def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:
|
1309
|
+
"""
|
1310
|
+
This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
|
1311
|
+
for swapping nested modules through arbitrary levels if children
|
1312
|
+
|
1313
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1314
|
+
|
1315
|
+
"""
|
1316
|
+
for path, new_mod in mapping.items():
|
1317
|
+
expanded_path = path.split(".")
|
1318
|
+
parent_mod = model
|
1319
|
+
for sub_path in expanded_path[:-1]:
|
1320
|
+
submod = parent_mod._modules[sub_path]
|
1321
|
+
if submod is None:
|
1322
|
+
break
|
1323
|
+
else:
|
1324
|
+
parent_mod = submod
|
1325
|
+
parent_mod._modules[expanded_path[-1]] = new_mod
|
1326
|
+
|
1327
|
+
return model
|
1328
|
+
|
1329
|
+
|
1330
|
+
def replace_modules_by_type(
|
1331
|
+
model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]
|
1332
|
+
) -> nn.Module:
|
1333
|
+
"""
|
1334
|
+
Top-level function to replace modules in model, specified by class name with a desired replacement.
|
1335
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1336
|
+
Args:
|
1337
|
+
model : top level module
|
1338
|
+
expansions : replacement dictionary: module class name -> replacement function generator
|
1339
|
+
Returns:
|
1340
|
+
model, possibly modified in-place
|
1341
|
+
"""
|
1342
|
+
mapping: dict[str, nn.Module] = {}
|
1343
|
+
for name, m in model.named_modules():
|
1344
|
+
m_type = type(m).__name__
|
1345
|
+
if m_type in expansions:
|
1346
|
+
# print (f"Found {m_type} in expansions ...")
|
1347
|
+
swapped = expansions[m_type](m)
|
1348
|
+
if swapped:
|
1349
|
+
mapping[name] = swapped
|
1350
|
+
|
1351
|
+
print(f"Swapped {len(mapping)} modules")
|
1352
|
+
_swap_modules(model, mapping)
|
1353
|
+
return model
|
1354
|
+
|
1355
|
+
|
1356
|
+
def add_casts_around_norms(model: nn.Module) -> nn.Module:
|
1357
|
+
"""
|
1358
|
+
Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
|
1359
|
+
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
|
1360
|
+
Args:
|
1361
|
+
model : top level module
|
1362
|
+
Returns:
|
1363
|
+
model, possibly modified in-place
|
1364
|
+
"""
|
1365
|
+
print("Adding casts around norms...")
|
1366
|
+
cast_replacements = {
|
1367
|
+
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
|
1368
|
+
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
|
1369
|
+
"BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat),
|
1370
|
+
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
|
1371
|
+
"InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
|
1372
|
+
"InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
|
1373
|
+
}
|
1374
|
+
replace_modules_by_type(model, cast_replacements)
|
1375
|
+
return model
|
monai/transforms/__init__.py
CHANGED
@@ -238,8 +238,18 @@ from .intensity.dictionary import (
|
|
238
238
|
)
|
239
239
|
from .inverse import InvertibleTransform, TraceableTransform
|
240
240
|
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
|
241
|
-
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
|
242
|
-
from .io.dictionary import
|
241
|
+
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
|
242
|
+
from .io.dictionary import (
|
243
|
+
LoadImaged,
|
244
|
+
LoadImageD,
|
245
|
+
LoadImageDict,
|
246
|
+
SaveImaged,
|
247
|
+
SaveImageD,
|
248
|
+
SaveImageDict,
|
249
|
+
WriteFileMappingd,
|
250
|
+
WriteFileMappingD,
|
251
|
+
WriteFileMappingDict,
|
252
|
+
)
|
243
253
|
from .lazy.array import ApplyPending
|
244
254
|
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
|
245
255
|
from .lazy.functional import apply_pending
|
@@ -386,6 +396,8 @@ from .smooth_field.dictionary import (
|
|
386
396
|
from .spatial.array import (
|
387
397
|
Affine,
|
388
398
|
AffineGrid,
|
399
|
+
ConvertBoxToPoints,
|
400
|
+
ConvertPointsToBoxes,
|
389
401
|
Flip,
|
390
402
|
GridDistortion,
|
391
403
|
GridPatch,
|
@@ -417,6 +429,12 @@ from .spatial.dictionary import (
|
|
417
429
|
Affined,
|
418
430
|
AffineD,
|
419
431
|
AffineDict,
|
432
|
+
ConvertBoxToPointsd,
|
433
|
+
ConvertBoxToPointsD,
|
434
|
+
ConvertBoxToPointsDict,
|
435
|
+
ConvertPointsToBoxesd,
|
436
|
+
ConvertPointsToBoxesD,
|
437
|
+
ConvertPointsToBoxesDict,
|
420
438
|
Flipd,
|
421
439
|
FlipD,
|
422
440
|
FlipDict,
|
@@ -493,6 +511,7 @@ from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTr
|
|
493
511
|
from .utility.array import (
|
494
512
|
AddCoordinateChannels,
|
495
513
|
AddExtremePointsChannel,
|
514
|
+
ApplyTransformToPoints,
|
496
515
|
AsChannelLast,
|
497
516
|
CastToType,
|
498
517
|
ClassesToIndices,
|
@@ -532,6 +551,9 @@ from .utility.dictionary import (
|
|
532
551
|
AddExtremePointsChanneld,
|
533
552
|
AddExtremePointsChannelD,
|
534
553
|
AddExtremePointsChannelDict,
|
554
|
+
ApplyTransformToPointsd,
|
555
|
+
ApplyTransformToPointsD,
|
556
|
+
ApplyTransformToPointsDict,
|
535
557
|
AsChannelLastd,
|
536
558
|
AsChannelLastD,
|
537
559
|
AsChannelLastDict,
|
monai/transforms/io/array.py
CHANGED
@@ -15,6 +15,7 @@ A collection of "vanilla" transforms for IO functions.
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import inspect
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import sys
|
20
21
|
import traceback
|
@@ -45,11 +46,19 @@ from monai.transforms.transform import Transform
|
|
45
46
|
from monai.transforms.utility.array import EnsureChannelFirst
|
46
47
|
from monai.utils import GridSamplePadMode
|
47
48
|
from monai.utils import ImageMetaKey as Key
|
48
|
-
from monai.utils import
|
49
|
+
from monai.utils import (
|
50
|
+
MetaKeys,
|
51
|
+
OptionalImportError,
|
52
|
+
convert_to_dst_type,
|
53
|
+
ensure_tuple,
|
54
|
+
look_up_option,
|
55
|
+
optional_import,
|
56
|
+
)
|
49
57
|
|
50
58
|
nib, _ = optional_import("nibabel")
|
51
59
|
Image, _ = optional_import("PIL.Image")
|
52
60
|
nrrd, _ = optional_import("nrrd")
|
61
|
+
FileLock, has_filelock = optional_import("filelock", name="FileLock")
|
53
62
|
|
54
63
|
__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
|
55
64
|
|
@@ -505,7 +514,7 @@ class SaveImage(Transform):
|
|
505
514
|
else:
|
506
515
|
self._data_index += 1
|
507
516
|
if self.savepath_in_metadict and meta_data is not None:
|
508
|
-
meta_data[
|
517
|
+
meta_data[MetaKeys.SAVED_TO] = filename
|
509
518
|
return img
|
510
519
|
msg = "\n".join([f"{e}" for e in err])
|
511
520
|
raise RuntimeError(
|
@@ -514,3 +523,50 @@ class SaveImage(Transform):
|
|
514
523
|
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
|
515
524
|
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
|
516
525
|
)
|
526
|
+
|
527
|
+
|
528
|
+
class WriteFileMapping(Transform):
|
529
|
+
"""
|
530
|
+
Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
|
531
|
+
This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
|
535
|
+
"""
|
536
|
+
|
537
|
+
def __init__(self, mapping_file_path: Path | str = "mapping.json"):
|
538
|
+
self.mapping_file_path = Path(mapping_file_path)
|
539
|
+
|
540
|
+
def __call__(self, img: NdarrayOrTensor):
|
541
|
+
"""
|
542
|
+
Args:
|
543
|
+
img: The input image with metadata.
|
544
|
+
"""
|
545
|
+
if isinstance(img, MetaTensor):
|
546
|
+
meta_data = img.meta
|
547
|
+
|
548
|
+
if MetaKeys.SAVED_TO not in meta_data:
|
549
|
+
raise KeyError(
|
550
|
+
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
|
551
|
+
)
|
552
|
+
|
553
|
+
input_path = meta_data[Key.FILENAME_OR_OBJ]
|
554
|
+
output_path = meta_data[MetaKeys.SAVED_TO]
|
555
|
+
log_data = {"input": input_path, "output": output_path}
|
556
|
+
|
557
|
+
if has_filelock:
|
558
|
+
with FileLock(str(self.mapping_file_path) + ".lock"):
|
559
|
+
self._write_to_file(log_data)
|
560
|
+
else:
|
561
|
+
self._write_to_file(log_data)
|
562
|
+
return img
|
563
|
+
|
564
|
+
def _write_to_file(self, log_data):
|
565
|
+
try:
|
566
|
+
with self.mapping_file_path.open("r") as f:
|
567
|
+
existing_log_data = json.load(f)
|
568
|
+
except (FileNotFoundError, json.JSONDecodeError):
|
569
|
+
existing_log_data = []
|
570
|
+
existing_log_data.append(log_data)
|
571
|
+
with self.mapping_file_path.open("w") as f:
|
572
|
+
json.dump(existing_log_data, f, indent=4)
|
@@ -17,16 +17,17 @@ Class names are ended with 'd' to denote dictionary-based transforms.
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from collections.abc import Hashable, Mapping
|
20
21
|
from pathlib import Path
|
21
22
|
from typing import Callable
|
22
23
|
|
23
24
|
import numpy as np
|
24
25
|
|
25
26
|
import monai
|
26
|
-
from monai.config import DtypeLike, KeysCollection
|
27
|
+
from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
|
27
28
|
from monai.data import image_writer
|
28
29
|
from monai.data.image_reader import ImageReader
|
29
|
-
from monai.transforms.io.array import LoadImage, SaveImage
|
30
|
+
from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
|
30
31
|
from monai.transforms.transform import MapTransform, Transform
|
31
32
|
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
|
32
33
|
from monai.utils.enums import PostFix
|
@@ -320,5 +321,31 @@ class SaveImaged(MapTransform):
|
|
320
321
|
return d
|
321
322
|
|
322
323
|
|
324
|
+
class WriteFileMappingd(MapTransform):
|
325
|
+
"""
|
326
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
keys: keys of the corresponding items to be transformed.
|
330
|
+
See also: :py:class:`monai.transforms.compose.MapTransform`
|
331
|
+
mapping_file_path: Path to the JSON file where the mappings will be saved.
|
332
|
+
Defaults to "mapping.json".
|
333
|
+
allow_missing_keys: don't raise exception if key is missing.
|
334
|
+
"""
|
335
|
+
|
336
|
+
def __init__(
|
337
|
+
self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
|
338
|
+
) -> None:
|
339
|
+
super().__init__(keys, allow_missing_keys)
|
340
|
+
self.mapping = WriteFileMapping(mapping_file_path)
|
341
|
+
|
342
|
+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
343
|
+
d = dict(data)
|
344
|
+
for key in self.key_iterator(d):
|
345
|
+
d[key] = self.mapping(d[key])
|
346
|
+
return d
|
347
|
+
|
348
|
+
|
323
349
|
LoadImageD = LoadImageDict = LoadImaged
|
324
350
|
SaveImageD = SaveImageDict = SaveImaged
|
351
|
+
WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
|
@@ -25,6 +25,7 @@ import torch
|
|
25
25
|
|
26
26
|
from monai.config import USE_COMPILED, DtypeLike
|
27
27
|
from monai.config.type_definitions import NdarrayOrTensor
|
28
|
+
from monai.data.box_utils import BoxMode, StandardMode
|
28
29
|
from monai.data.meta_obj import get_track_meta, set_track_meta
|
29
30
|
from monai.data.meta_tensor import MetaTensor
|
30
31
|
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
|
@@ -34,6 +35,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCro
|
|
34
35
|
from monai.transforms.inverse import InvertibleTransform
|
35
36
|
from monai.transforms.spatial.functional import (
|
36
37
|
affine_func,
|
38
|
+
convert_box_to_points,
|
39
|
+
convert_points_to_box,
|
37
40
|
flip,
|
38
41
|
orientation,
|
39
42
|
resize,
|
@@ -3544,3 +3547,44 @@ class RandSimulateLowResolution(RandomizableTransform):
|
|
3544
3547
|
|
3545
3548
|
else:
|
3546
3549
|
return img
|
3550
|
+
|
3551
|
+
|
3552
|
+
class ConvertBoxToPoints(Transform):
|
3553
|
+
"""
|
3554
|
+
Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
|
3555
|
+
Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
|
3556
|
+
Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
|
3557
|
+
"""
|
3558
|
+
|
3559
|
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
3560
|
+
|
3561
|
+
def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
|
3562
|
+
"""
|
3563
|
+
Args:
|
3564
|
+
mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
|
3565
|
+
"""
|
3566
|
+
super().__init__()
|
3567
|
+
self.mode = StandardMode if mode is None else mode
|
3568
|
+
|
3569
|
+
def __call__(self, data: Any):
|
3570
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
3571
|
+
points = convert_box_to_points(data, mode=self.mode)
|
3572
|
+
return convert_to_dst_type(points, data)[0]
|
3573
|
+
|
3574
|
+
|
3575
|
+
class ConvertPointsToBoxes(Transform):
|
3576
|
+
"""
|
3577
|
+
Converts points to an axis-aligned bounding box.
|
3578
|
+
Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
|
3579
|
+
(N, 4, 2) for the 4 corners of a 2D rectangle.
|
3580
|
+
"""
|
3581
|
+
|
3582
|
+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
|
3583
|
+
|
3584
|
+
def __init__(self) -> None:
|
3585
|
+
super().__init__()
|
3586
|
+
|
3587
|
+
def __call__(self, data: Any):
|
3588
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
3589
|
+
box = convert_points_to_box(data)
|
3590
|
+
return convert_to_dst_type(box, data)[0]
|