monai-weekly 1.4.dev2434__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +177 -0
- monai/apps/vista3d/sampler.py +179 -0
- monai/apps/vista3d/transforms.py +224 -0
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +2 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/inferers/utils.py +1 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/layers/filtering.py +6 -2
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +53 -11
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +225 -41
- monai/transforms/__init__.py +24 -2
- monai/transforms/io/array.py +58 -2
- monai/transforms/io/dictionary.py +29 -2
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +105 -3
- monai/transforms/utils.py +83 -10
- monai/utils/__init__.py +1 -0
- monai/utils/enums.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +4 -1
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +36 -31
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
- /monai/apps/{generation/maisi/utils → vista3d}/__init__.py +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2434.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ import torch
|
|
26
26
|
|
27
27
|
from monai.config import DtypeLike, KeysCollection, SequenceStr
|
28
28
|
from monai.config.type_definitions import NdarrayOrTensor
|
29
|
+
from monai.data.box_utils import BoxMode, StandardMode
|
29
30
|
from monai.data.meta_obj import get_track_meta
|
30
31
|
from monai.data.meta_tensor import MetaTensor
|
31
32
|
from monai.networks.layers.simplelayers import GaussianFilter
|
@@ -33,6 +34,8 @@ from monai.transforms.croppad.array import CenterSpatialCrop
|
|
33
34
|
from monai.transforms.inverse import InvertibleTransform
|
34
35
|
from monai.transforms.spatial.array import (
|
35
36
|
Affine,
|
37
|
+
ConvertBoxToPoints,
|
38
|
+
ConvertPointsToBoxes,
|
36
39
|
Flip,
|
37
40
|
GridDistortion,
|
38
41
|
GridPatch,
|
@@ -2585,6 +2588,7 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
|
|
2585
2588
|
self, seed: int | None = None, state: np.random.RandomState | None = None
|
2586
2589
|
) -> RandSimulateLowResolutiond:
|
2587
2590
|
super().set_random_state(seed, state)
|
2591
|
+
self.sim_lowres_tfm.set_random_state(seed, state)
|
2588
2592
|
return self
|
2589
2593
|
|
2590
2594
|
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
@@ -2611,6 +2615,61 @@ class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
|
|
2611
2615
|
return d
|
2612
2616
|
|
2613
2617
|
|
2618
|
+
class ConvertBoxToPointsd(MapTransform):
|
2619
|
+
"""
|
2620
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
|
2621
|
+
"""
|
2622
|
+
|
2623
|
+
backend = ConvertBoxToPoints.backend
|
2624
|
+
|
2625
|
+
def __init__(
|
2626
|
+
self,
|
2627
|
+
keys: KeysCollection,
|
2628
|
+
point_key="points",
|
2629
|
+
mode: str | BoxMode | type[BoxMode] | None = StandardMode,
|
2630
|
+
allow_missing_keys: bool = False,
|
2631
|
+
):
|
2632
|
+
"""
|
2633
|
+
Args:
|
2634
|
+
keys: keys of the corresponding items to be transformed.
|
2635
|
+
point_key: key to store the point data.
|
2636
|
+
mode: the mode of the input boxes. Defaults to StandardMode.
|
2637
|
+
allow_missing_keys: don't raise exception if key is missing.
|
2638
|
+
"""
|
2639
|
+
super().__init__(keys, allow_missing_keys)
|
2640
|
+
self.point_key = point_key
|
2641
|
+
self.converter = ConvertBoxToPoints(mode=mode)
|
2642
|
+
|
2643
|
+
def __call__(self, data):
|
2644
|
+
d = dict(data)
|
2645
|
+
for key in self.key_iterator(d):
|
2646
|
+
data[self.point_key] = self.converter(d[key])
|
2647
|
+
return data
|
2648
|
+
|
2649
|
+
|
2650
|
+
class ConvertPointsToBoxesd(MapTransform):
|
2651
|
+
"""
|
2652
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
|
2653
|
+
"""
|
2654
|
+
|
2655
|
+
def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
|
2656
|
+
"""
|
2657
|
+
Args:
|
2658
|
+
keys: keys of the corresponding items to be transformed.
|
2659
|
+
box_key: key to store the box data.
|
2660
|
+
allow_missing_keys: don't raise exception if key is missing.
|
2661
|
+
"""
|
2662
|
+
super().__init__(keys, allow_missing_keys)
|
2663
|
+
self.box_key = box_key
|
2664
|
+
self.converter = ConvertPointsToBoxes()
|
2665
|
+
|
2666
|
+
def __call__(self, data):
|
2667
|
+
d = dict(data)
|
2668
|
+
for key in self.key_iterator(d):
|
2669
|
+
data[self.box_key] = self.converter(d[key])
|
2670
|
+
return data
|
2671
|
+
|
2672
|
+
|
2614
2673
|
SpatialResampleD = SpatialResampleDict = SpatialResampled
|
2615
2674
|
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
|
2616
2675
|
SpacingD = SpacingDict = Spacingd
|
@@ -2635,3 +2694,5 @@ GridSplitD = GridSplitDict = GridSplitd
|
|
2635
2694
|
GridPatchD = GridPatchDict = GridPatchd
|
2636
2695
|
RandGridPatchD = RandGridPatchDict = RandGridPatchd
|
2637
2696
|
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
|
2697
|
+
ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
|
2698
|
+
ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
|
@@ -24,6 +24,7 @@ import torch
|
|
24
24
|
import monai
|
25
25
|
from monai.config import USE_COMPILED
|
26
26
|
from monai.config.type_definitions import NdarrayOrTensor
|
27
|
+
from monai.data.box_utils import get_boxmode
|
27
28
|
from monai.data.meta_obj import get_track_meta
|
28
29
|
from monai.data.meta_tensor import MetaTensor
|
29
30
|
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
|
@@ -32,7 +33,7 @@ from monai.transforms.croppad.array import ResizeWithPadOrCrop
|
|
32
33
|
from monai.transforms.intensity.array import GaussianSmooth
|
33
34
|
from monai.transforms.inverse import TraceableTransform
|
34
35
|
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
|
35
|
-
from monai.transforms.utils_pytorch_numpy_unification import allclose
|
36
|
+
from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
|
36
37
|
from monai.utils import (
|
37
38
|
LazyAttr,
|
38
39
|
TraceKeys,
|
@@ -610,3 +611,71 @@ def affine_func(
|
|
610
611
|
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
|
611
612
|
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
|
612
613
|
return out if image_only else (out, affine)
|
614
|
+
|
615
|
+
|
616
|
+
def convert_box_to_points(bbox, mode):
|
617
|
+
"""
|
618
|
+
Converts an axis-aligned bounding box to points.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
mode: The mode specifying how to interpret the bounding box.
|
622
|
+
bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
|
623
|
+
for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
|
624
|
+
|
625
|
+
Returns:
|
626
|
+
sequence of points representing the corners of the bounding box.
|
627
|
+
"""
|
628
|
+
|
629
|
+
mode = get_boxmode(mode)
|
630
|
+
|
631
|
+
points_list = []
|
632
|
+
for _num in range(bbox.shape[0]):
|
633
|
+
corners = mode.boxes_to_corners(bbox[_num : _num + 1])
|
634
|
+
if len(corners) == 4:
|
635
|
+
points_list.append(
|
636
|
+
concatenate(
|
637
|
+
[
|
638
|
+
concatenate([corners[0], corners[1]], axis=1),
|
639
|
+
concatenate([corners[2], corners[1]], axis=1),
|
640
|
+
concatenate([corners[2], corners[3]], axis=1),
|
641
|
+
concatenate([corners[0], corners[3]], axis=1),
|
642
|
+
],
|
643
|
+
axis=0,
|
644
|
+
)
|
645
|
+
)
|
646
|
+
else:
|
647
|
+
points_list.append(
|
648
|
+
concatenate(
|
649
|
+
[
|
650
|
+
concatenate([corners[0], corners[1], corners[2]], axis=1),
|
651
|
+
concatenate([corners[3], corners[1], corners[2]], axis=1),
|
652
|
+
concatenate([corners[3], corners[4], corners[2]], axis=1),
|
653
|
+
concatenate([corners[0], corners[4], corners[2]], axis=1),
|
654
|
+
concatenate([corners[0], corners[1], corners[5]], axis=1),
|
655
|
+
concatenate([corners[3], corners[1], corners[5]], axis=1),
|
656
|
+
concatenate([corners[3], corners[4], corners[5]], axis=1),
|
657
|
+
concatenate([corners[0], corners[4], corners[5]], axis=1),
|
658
|
+
],
|
659
|
+
axis=0,
|
660
|
+
)
|
661
|
+
)
|
662
|
+
|
663
|
+
return stack(points_list, dim=0)
|
664
|
+
|
665
|
+
|
666
|
+
def convert_points_to_box(points):
|
667
|
+
"""
|
668
|
+
Converts points to an axis-aligned bounding box.
|
669
|
+
|
670
|
+
Args:
|
671
|
+
points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
|
672
|
+
a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
|
673
|
+
"""
|
674
|
+
from monai.transforms.utils_pytorch_numpy_unification import max, min
|
675
|
+
|
676
|
+
mins = min(points, dim=1)
|
677
|
+
maxs = max(points, dim=1)
|
678
|
+
# Concatenate the min and max values to get the bounding boxes
|
679
|
+
bboxes = concatenate([mins, maxs], axis=1)
|
680
|
+
|
681
|
+
return bboxes
|
@@ -31,7 +31,7 @@ from monai.config import DtypeLike
|
|
31
31
|
from monai.config.type_definitions import NdarrayOrTensor
|
32
32
|
from monai.data.meta_obj import get_track_meta
|
33
33
|
from monai.data.meta_tensor import MetaTensor
|
34
|
-
from monai.data.utils import is_no_channel, no_collation
|
34
|
+
from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
|
35
35
|
from monai.networks.layers.simplelayers import (
|
36
36
|
ApplyFilter,
|
37
37
|
EllipticalFilter,
|
@@ -42,16 +42,17 @@ from monai.networks.layers.simplelayers import (
|
|
42
42
|
SharpenFilter,
|
43
43
|
median_filter,
|
44
44
|
)
|
45
|
-
from monai.transforms.inverse import InvertibleTransform
|
45
|
+
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
|
46
46
|
from monai.transforms.traits import MultiSampleTrait
|
47
47
|
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
|
48
48
|
from monai.transforms.utils import (
|
49
|
+
apply_affine_to_points,
|
49
50
|
extreme_points_to_image,
|
50
51
|
get_extreme_points,
|
51
52
|
map_binary_to_indices,
|
52
53
|
map_classes_to_indices,
|
53
54
|
)
|
54
|
-
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
|
55
|
+
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
|
55
56
|
from monai.utils import (
|
56
57
|
MetaKeys,
|
57
58
|
TraceKeys,
|
@@ -66,7 +67,7 @@ from monai.utils import (
|
|
66
67
|
)
|
67
68
|
from monai.utils.enums import TransformBackends
|
68
69
|
from monai.utils.misc import is_module_ver_at_least
|
69
|
-
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
|
70
|
+
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
|
70
71
|
|
71
72
|
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
|
72
73
|
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
|
@@ -106,6 +107,7 @@ __all__ = [
|
|
106
107
|
"ToCupy",
|
107
108
|
"ImageFilter",
|
108
109
|
"RandImageFilter",
|
110
|
+
"ApplyTransformToPoints",
|
109
111
|
]
|
110
112
|
|
111
113
|
|
@@ -654,6 +656,7 @@ class DataStats(Transform):
|
|
654
656
|
data_shape: bool = True,
|
655
657
|
value_range: bool = True,
|
656
658
|
data_value: bool = False,
|
659
|
+
meta_info: bool = False,
|
657
660
|
additional_info: Callable | None = None,
|
658
661
|
name: str = "DataStats",
|
659
662
|
) -> None:
|
@@ -665,6 +668,7 @@ class DataStats(Transform):
|
|
665
668
|
value_range: whether to show the value range of input data.
|
666
669
|
data_value: whether to show the raw value of input data.
|
667
670
|
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
|
671
|
+
meta_info: whether to show the data of MetaTensor.
|
668
672
|
additional_info: user can define callable function to extract additional info from input data.
|
669
673
|
name: identifier of `logging.logger` to use, defaulting to "DataStats".
|
670
674
|
|
@@ -679,6 +683,7 @@ class DataStats(Transform):
|
|
679
683
|
self.data_shape = data_shape
|
680
684
|
self.value_range = value_range
|
681
685
|
self.data_value = data_value
|
686
|
+
self.meta_info = meta_info
|
682
687
|
if additional_info is not None and not callable(additional_info):
|
683
688
|
raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.")
|
684
689
|
self.additional_info = additional_info
|
@@ -705,6 +710,7 @@ class DataStats(Transform):
|
|
705
710
|
data_shape: bool | None = None,
|
706
711
|
value_range: bool | None = None,
|
707
712
|
data_value: bool | None = None,
|
713
|
+
meta_info: bool | None = None,
|
708
714
|
additional_info: Callable | None = None,
|
709
715
|
) -> NdarrayOrTensor:
|
710
716
|
"""
|
@@ -725,6 +731,9 @@ class DataStats(Transform):
|
|
725
731
|
lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})")
|
726
732
|
if self.data_value if data_value is None else data_value:
|
727
733
|
lines.append(f"Value: {img}")
|
734
|
+
if self.meta_info if meta_info is None else meta_info:
|
735
|
+
metadata = getattr(img, "meta", "(input is not a MetaTensor)")
|
736
|
+
lines.append(f"Meta info: {repr(metadata)}")
|
728
737
|
additional_info = self.additional_info if additional_info is None else additional_info
|
729
738
|
if additional_info is not None:
|
730
739
|
lines.append(f"Additional info: {additional_info(img)}")
|
@@ -1715,3 +1724,143 @@ class RandImageFilter(RandomizableTransform):
|
|
1715
1724
|
if self._do_transform:
|
1716
1725
|
img = self.filter(img)
|
1717
1726
|
return img
|
1727
|
+
|
1728
|
+
|
1729
|
+
class ApplyTransformToPoints(InvertibleTransform, Transform):
|
1730
|
+
"""
|
1731
|
+
Transform points between image coordinates and world coordinates.
|
1732
|
+
The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels
|
1733
|
+
and N denotes the number of points. It will return a tensor with the same shape as the input.
|
1734
|
+
|
1735
|
+
Args:
|
1736
|
+
dtype: The desired data type for the output.
|
1737
|
+
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
|
1738
|
+
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
|
1739
|
+
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
|
1740
|
+
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
|
1741
|
+
The matrix is always converted to float64 for computation, which can be computationally
|
1742
|
+
expensive when applied to a large number of points.
|
1743
|
+
If None, will try to use the affine matrix from the input data.
|
1744
|
+
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
|
1745
|
+
Typically, the affine matrix is derived from an image and represents its location in world space,
|
1746
|
+
while the points are in world coordinates. A value of ``True`` represents transforming these
|
1747
|
+
world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.
|
1748
|
+
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
|
1749
|
+
or you're using `ITKReader` with `affine_lps_to_ras=True`.
|
1750
|
+
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
|
1751
|
+
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
|
1752
|
+
matrix are in the same coordinate system.
|
1753
|
+
|
1754
|
+
Use Cases:
|
1755
|
+
- Transforming points between world space and image space, and vice versa.
|
1756
|
+
- Automatically handling inverse transformations between image space and world space.
|
1757
|
+
- If points have an existing affine transformation, the class computes and
|
1758
|
+
applies the required delta affine transformation.
|
1759
|
+
|
1760
|
+
"""
|
1761
|
+
|
1762
|
+
def __init__(
|
1763
|
+
self,
|
1764
|
+
dtype: DtypeLike | torch.dtype | None = None,
|
1765
|
+
affine: torch.Tensor | None = None,
|
1766
|
+
invert_affine: bool = True,
|
1767
|
+
affine_lps_to_ras: bool = False,
|
1768
|
+
) -> None:
|
1769
|
+
self.dtype = dtype
|
1770
|
+
self.affine = affine
|
1771
|
+
self.invert_affine = invert_affine
|
1772
|
+
self.affine_lps_to_ras = affine_lps_to_ras
|
1773
|
+
|
1774
|
+
def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
|
1775
|
+
"""
|
1776
|
+
Compute the final affine transformation matrix to apply to the point data.
|
1777
|
+
|
1778
|
+
Args:
|
1779
|
+
data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
|
1780
|
+
affine: 3x3 or 4x4 affine transformation matrix.
|
1781
|
+
|
1782
|
+
Returns:
|
1783
|
+
Final affine transformation matrix.
|
1784
|
+
"""
|
1785
|
+
|
1786
|
+
affine = convert_data_type(affine, dtype=torch.float64)[0]
|
1787
|
+
|
1788
|
+
if self.affine_lps_to_ras:
|
1789
|
+
affine = orientation_ras_lps(affine)
|
1790
|
+
|
1791
|
+
if self.invert_affine:
|
1792
|
+
affine = linalg_inv(affine)
|
1793
|
+
if applied_affine is not None:
|
1794
|
+
affine = affine @ applied_affine
|
1795
|
+
|
1796
|
+
return affine
|
1797
|
+
|
1798
|
+
def transform_coordinates(
|
1799
|
+
self, data: torch.Tensor, affine: torch.Tensor | None = None
|
1800
|
+
) -> tuple[torch.Tensor, dict]:
|
1801
|
+
"""
|
1802
|
+
Transform coordinates using an affine transformation matrix.
|
1803
|
+
|
1804
|
+
Args:
|
1805
|
+
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1806
|
+
where C represents the number of channels and N denotes the number of points.
|
1807
|
+
affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,
|
1808
|
+
which can be computationally expensive when applied to a large number of points.
|
1809
|
+
|
1810
|
+
Returns:
|
1811
|
+
Transformed coordinates.
|
1812
|
+
"""
|
1813
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
1814
|
+
if affine is None and self.invert_affine:
|
1815
|
+
raise ValueError("affine must be provided when invert_affine is True.")
|
1816
|
+
# applied_affine is the affine transformation matrix that has already been applied to the point data
|
1817
|
+
applied_affine: torch.Tensor | None = getattr(data, "affine", None)
|
1818
|
+
affine = applied_affine if affine is None else affine
|
1819
|
+
if affine is None:
|
1820
|
+
raise ValueError("affine must be provided if data does not have an affine matrix.")
|
1821
|
+
|
1822
|
+
final_affine = self._compute_final_affine(affine, applied_affine)
|
1823
|
+
out = apply_affine_to_points(data, final_affine, dtype=self.dtype)
|
1824
|
+
|
1825
|
+
extra_info = {
|
1826
|
+
"invert_affine": self.invert_affine,
|
1827
|
+
"dtype": get_dtype_string(self.dtype),
|
1828
|
+
"image_affine": affine,
|
1829
|
+
"affine_lps_to_ras": self.affine_lps_to_ras,
|
1830
|
+
}
|
1831
|
+
|
1832
|
+
xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
|
1833
|
+
meta_info = TraceableTransform.track_transform_meta(
|
1834
|
+
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
|
1835
|
+
)
|
1836
|
+
|
1837
|
+
return out, meta_info
|
1838
|
+
|
1839
|
+
def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
|
1840
|
+
"""
|
1841
|
+
Args:
|
1842
|
+
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1843
|
+
where C represents the number of channels and N denotes the number of points.
|
1844
|
+
affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
|
1845
|
+
"""
|
1846
|
+
if data.ndim != 3 or data.shape[-1] not in (2, 3):
|
1847
|
+
raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
|
1848
|
+
affine = self.affine if affine is None else affine
|
1849
|
+
if affine is not None and affine.shape not in ((3, 3), (4, 4)):
|
1850
|
+
raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")
|
1851
|
+
|
1852
|
+
out, meta_info = self.transform_coordinates(data, affine)
|
1853
|
+
|
1854
|
+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
|
1855
|
+
|
1856
|
+
def inverse(self, data: torch.Tensor) -> torch.Tensor:
|
1857
|
+
transform = self.pop_transform(data)
|
1858
|
+
inverse_transform = ApplyTransformToPoints(
|
1859
|
+
dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
|
1860
|
+
invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
|
1861
|
+
affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
|
1862
|
+
)
|
1863
|
+
with inverse_transform.trace_transform(False):
|
1864
|
+
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
|
1865
|
+
|
1866
|
+
return data
|
@@ -35,6 +35,7 @@ from monai.transforms.transform import MapTransform, Randomizable, RandomizableT
|
|
35
35
|
from monai.transforms.utility.array import (
|
36
36
|
AddCoordinateChannels,
|
37
37
|
AddExtremePointsChannel,
|
38
|
+
ApplyTransformToPoints,
|
38
39
|
AsChannelLast,
|
39
40
|
CastToType,
|
40
41
|
ClassesToIndices,
|
@@ -180,6 +181,9 @@ __all__ = [
|
|
180
181
|
"ClassesToIndicesd",
|
181
182
|
"ClassesToIndicesD",
|
182
183
|
"ClassesToIndicesDict",
|
184
|
+
"ApplyTransformToPointsd",
|
185
|
+
"ApplyTransformToPointsD",
|
186
|
+
"ApplyTransformToPointsDict",
|
183
187
|
]
|
184
188
|
|
185
189
|
DEFAULT_POST_FIX = PostFix.meta()
|
@@ -789,6 +793,7 @@ class DataStatsd(MapTransform):
|
|
789
793
|
data_shape: Sequence[bool] | bool = True,
|
790
794
|
value_range: Sequence[bool] | bool = True,
|
791
795
|
data_value: Sequence[bool] | bool = False,
|
796
|
+
meta_info: Sequence[bool] | bool = False,
|
792
797
|
additional_info: Sequence[Callable] | Callable | None = None,
|
793
798
|
name: str = "DataStats",
|
794
799
|
allow_missing_keys: bool = False,
|
@@ -808,6 +813,8 @@ class DataStatsd(MapTransform):
|
|
808
813
|
data_value: whether to show the raw value of input data.
|
809
814
|
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
|
810
815
|
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
|
816
|
+
meta_info: whether to show the data of MetaTensor.
|
817
|
+
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
|
811
818
|
additional_info: user can define callable function to extract
|
812
819
|
additional info from input data. it also can be a sequence of string, each element
|
813
820
|
corresponds to a key in ``keys``.
|
@@ -821,15 +828,34 @@ class DataStatsd(MapTransform):
|
|
821
828
|
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
|
822
829
|
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
|
823
830
|
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
|
831
|
+
self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))
|
824
832
|
self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))
|
825
833
|
self.printer = DataStats(name=name)
|
826
834
|
|
827
835
|
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
828
836
|
d = dict(data)
|
829
|
-
for
|
830
|
-
|
837
|
+
for (
|
838
|
+
key,
|
839
|
+
prefix,
|
840
|
+
data_type,
|
841
|
+
data_shape,
|
842
|
+
value_range,
|
843
|
+
data_value,
|
844
|
+
meta_info,
|
845
|
+
additional_info,
|
846
|
+
) in self.key_iterator(
|
847
|
+
d,
|
848
|
+
self.prefix,
|
849
|
+
self.data_type,
|
850
|
+
self.data_shape,
|
851
|
+
self.value_range,
|
852
|
+
self.data_value,
|
853
|
+
self.meta_info,
|
854
|
+
self.additional_info,
|
831
855
|
):
|
832
|
-
d[key] = self.printer(
|
856
|
+
d[key] = self.printer(
|
857
|
+
d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info
|
858
|
+
)
|
833
859
|
return d
|
834
860
|
|
835
861
|
|
@@ -1714,6 +1740,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
|
|
1714
1740
|
Probability the transform is applied to the data
|
1715
1741
|
allow_missing_keys:
|
1716
1742
|
Don't raise exception if key is missing.
|
1743
|
+
|
1744
|
+
Note:
|
1745
|
+
- This transform does not scale output image values automatically to match the range of the input.
|
1746
|
+
The output should be scaled by later transforms to match the input if this is desired.
|
1717
1747
|
"""
|
1718
1748
|
|
1719
1749
|
backend = ImageFilter.backend
|
@@ -1740,6 +1770,77 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
|
|
1740
1770
|
return d
|
1741
1771
|
|
1742
1772
|
|
1773
|
+
class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
|
1774
|
+
"""
|
1775
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.
|
1776
|
+
The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1777
|
+
where C represents the number of channels and N denotes the number of points.
|
1778
|
+
The output has the same shape as the input.
|
1779
|
+
|
1780
|
+
Args:
|
1781
|
+
keys: keys of the corresponding items to be transformed.
|
1782
|
+
See also: monai.transforms.MapTransform
|
1783
|
+
refer_keys: The key of the reference item used for transformation.
|
1784
|
+
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
|
1785
|
+
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
|
1786
|
+
dtype: The desired data type for the output.
|
1787
|
+
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
|
1788
|
+
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
|
1789
|
+
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
|
1790
|
+
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
|
1791
|
+
The matrix is always converted to float64 for computation, which can be computationally
|
1792
|
+
expensive when applied to a large number of points.
|
1793
|
+
If None, will try to use the affine matrix from the refer data.
|
1794
|
+
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
|
1795
|
+
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
|
1796
|
+
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
|
1797
|
+
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
|
1798
|
+
or you're using `ITKReader` with `affine_lps_to_ras=True`.
|
1799
|
+
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
|
1800
|
+
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
|
1801
|
+
matrix are in the same coordinate system.
|
1802
|
+
allow_missing_keys: Don't raise exception if key is missing.
|
1803
|
+
"""
|
1804
|
+
|
1805
|
+
def __init__(
|
1806
|
+
self,
|
1807
|
+
keys: KeysCollection,
|
1808
|
+
refer_keys: KeysCollection | None = None,
|
1809
|
+
dtype: DtypeLike | torch.dtype = torch.float64,
|
1810
|
+
affine: torch.Tensor | None = None,
|
1811
|
+
invert_affine: bool = True,
|
1812
|
+
affine_lps_to_ras: bool = False,
|
1813
|
+
allow_missing_keys: bool = False,
|
1814
|
+
):
|
1815
|
+
MapTransform.__init__(self, keys, allow_missing_keys)
|
1816
|
+
self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
|
1817
|
+
self.converter = ApplyTransformToPoints(
|
1818
|
+
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
|
1819
|
+
)
|
1820
|
+
|
1821
|
+
def __call__(self, data: Mapping[Hashable, torch.Tensor]):
|
1822
|
+
d = dict(data)
|
1823
|
+
for key, refer_key in self.key_iterator(d, self.refer_keys):
|
1824
|
+
coords = d[key]
|
1825
|
+
affine = None # represents using affine given in constructor
|
1826
|
+
if refer_key is not None:
|
1827
|
+
if refer_key in d:
|
1828
|
+
refer_data = d[refer_key]
|
1829
|
+
else:
|
1830
|
+
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")
|
1831
|
+
|
1832
|
+
# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
|
1833
|
+
affine = getattr(refer_data, "affine", refer_data)
|
1834
|
+
d[key] = self.converter(coords, affine)
|
1835
|
+
return d
|
1836
|
+
|
1837
|
+
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
|
1838
|
+
d = dict(data)
|
1839
|
+
for key in self.key_iterator(d):
|
1840
|
+
d[key] = self.converter.inverse(d[key])
|
1841
|
+
return d
|
1842
|
+
|
1843
|
+
|
1743
1844
|
RandImageFilterD = RandImageFilterDict = RandImageFilterd
|
1744
1845
|
ImageFilterD = ImageFilterDict = ImageFilterd
|
1745
1846
|
IdentityD = IdentityDict = Identityd
|
@@ -1780,3 +1881,4 @@ CuCIMD = CuCIMDict = CuCIMd
|
|
1780
1881
|
RandCuCIMD = RandCuCIMDict = RandCuCIMd
|
1781
1882
|
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
|
1782
1883
|
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
|
1884
|
+
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
|