monai-weekly 1.4.dev2435__py3-none-any.whl → 1.4.dev2436__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +44 -2
- monai/_version.py +3 -3
- monai/apps/vista3d/inferer.py +1 -1
- monai/bundle/config_parser.py +5 -3
- monai/bundle/scripts.py +2 -2
- monai/bundle/utils.py +35 -1
- monai/handlers/__init__.py +1 -0
- monai/handlers/trt_handler.py +61 -0
- monai/metrics/generalized_dice.py +77 -48
- monai/networks/__init__.py +2 -0
- monai/networks/nets/swin_unetr.py +4 -4
- monai/networks/nets/vista3d.py +10 -6
- monai/networks/trt_compiler.py +569 -0
- monai/networks/utils.py +224 -40
- monai/transforms/__init__.py +12 -0
- monai/transforms/spatial/array.py +44 -0
- monai/transforms/spatial/dictionary.py +61 -0
- monai/transforms/spatial/functional.py +70 -1
- monai/transforms/utility/array.py +153 -4
- monai/transforms/utility/dictionary.py +101 -3
- monai/transforms/utils.py +31 -4
- monai/utils/__init__.py +1 -0
- monai/utils/type_conversion.py +8 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/METADATA +3 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/RECORD +28 -26
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2435.dist-info → monai_weekly-1.4.dev2436.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from monai.config import DtypeLike
|
|
31
31
|
from monai.config.type_definitions import NdarrayOrTensor
|
32
32
|
from monai.data.meta_obj import get_track_meta
|
33
33
|
from monai.data.meta_tensor import MetaTensor
|
34
|
-
from monai.data.utils import is_no_channel, no_collation
|
34
|
+
from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
|
35
35
|
from monai.networks.layers.simplelayers import (
|
36
36
|
ApplyFilter,
|
37
37
|
EllipticalFilter,
|
@@ -42,16 +42,17 @@ from monai.networks.layers.simplelayers import (
|
|
42
42
|
SharpenFilter,
|
43
43
|
median_filter,
|
44
44
|
)
|
45
|
-
from monai.transforms.inverse import InvertibleTransform
|
45
|
+
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
|
46
46
|
from monai.transforms.traits import MultiSampleTrait
|
47
47
|
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
|
48
48
|
from monai.transforms.utils import (
|
49
|
+
apply_affine_to_points,
|
49
50
|
extreme_points_to_image,
|
50
51
|
get_extreme_points,
|
51
52
|
map_binary_to_indices,
|
52
53
|
map_classes_to_indices,
|
53
54
|
)
|
54
|
-
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
|
55
|
+
from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
|
55
56
|
from monai.utils import (
|
56
57
|
MetaKeys,
|
57
58
|
TraceKeys,
|
@@ -66,7 +67,7 @@ from monai.utils import (
|
|
66
67
|
)
|
67
68
|
from monai.utils.enums import TransformBackends
|
68
69
|
from monai.utils.misc import is_module_ver_at_least
|
69
|
-
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
|
70
|
+
from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
|
70
71
|
|
71
72
|
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
|
72
73
|
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
|
@@ -106,6 +107,7 @@ __all__ = [
|
|
106
107
|
"ToCupy",
|
107
108
|
"ImageFilter",
|
108
109
|
"RandImageFilter",
|
110
|
+
"ApplyTransformToPoints",
|
109
111
|
]
|
110
112
|
|
111
113
|
|
@@ -654,6 +656,7 @@ class DataStats(Transform):
|
|
654
656
|
data_shape: bool = True,
|
655
657
|
value_range: bool = True,
|
656
658
|
data_value: bool = False,
|
659
|
+
meta_info: bool = False,
|
657
660
|
additional_info: Callable | None = None,
|
658
661
|
name: str = "DataStats",
|
659
662
|
) -> None:
|
@@ -665,6 +668,7 @@ class DataStats(Transform):
|
|
665
668
|
value_range: whether to show the value range of input data.
|
666
669
|
data_value: whether to show the raw value of input data.
|
667
670
|
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
|
671
|
+
meta_info: whether to show the data of MetaTensor.
|
668
672
|
additional_info: user can define callable function to extract additional info from input data.
|
669
673
|
name: identifier of `logging.logger` to use, defaulting to "DataStats".
|
670
674
|
|
@@ -679,6 +683,7 @@ class DataStats(Transform):
|
|
679
683
|
self.data_shape = data_shape
|
680
684
|
self.value_range = value_range
|
681
685
|
self.data_value = data_value
|
686
|
+
self.meta_info = meta_info
|
682
687
|
if additional_info is not None and not callable(additional_info):
|
683
688
|
raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.")
|
684
689
|
self.additional_info = additional_info
|
@@ -705,6 +710,7 @@ class DataStats(Transform):
|
|
705
710
|
data_shape: bool | None = None,
|
706
711
|
value_range: bool | None = None,
|
707
712
|
data_value: bool | None = None,
|
713
|
+
meta_info: bool | None = None,
|
708
714
|
additional_info: Callable | None = None,
|
709
715
|
) -> NdarrayOrTensor:
|
710
716
|
"""
|
@@ -725,6 +731,9 @@ class DataStats(Transform):
|
|
725
731
|
lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})")
|
726
732
|
if self.data_value if data_value is None else data_value:
|
727
733
|
lines.append(f"Value: {img}")
|
734
|
+
if self.meta_info if meta_info is None else meta_info:
|
735
|
+
metadata = getattr(img, "meta", "(input is not a MetaTensor)")
|
736
|
+
lines.append(f"Meta info: {repr(metadata)}")
|
728
737
|
additional_info = self.additional_info if additional_info is None else additional_info
|
729
738
|
if additional_info is not None:
|
730
739
|
lines.append(f"Additional info: {additional_info(img)}")
|
@@ -1715,3 +1724,143 @@ class RandImageFilter(RandomizableTransform):
|
|
1715
1724
|
if self._do_transform:
|
1716
1725
|
img = self.filter(img)
|
1717
1726
|
return img
|
1727
|
+
|
1728
|
+
|
1729
|
+
class ApplyTransformToPoints(InvertibleTransform, Transform):
|
1730
|
+
"""
|
1731
|
+
Transform points between image coordinates and world coordinates.
|
1732
|
+
The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels
|
1733
|
+
and N denotes the number of points. It will return a tensor with the same shape as the input.
|
1734
|
+
|
1735
|
+
Args:
|
1736
|
+
dtype: The desired data type for the output.
|
1737
|
+
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
|
1738
|
+
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
|
1739
|
+
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
|
1740
|
+
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
|
1741
|
+
The matrix is always converted to float64 for computation, which can be computationally
|
1742
|
+
expensive when applied to a large number of points.
|
1743
|
+
If None, will try to use the affine matrix from the input data.
|
1744
|
+
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
|
1745
|
+
Typically, the affine matrix is derived from an image and represents its location in world space,
|
1746
|
+
while the points are in world coordinates. A value of ``True`` represents transforming these
|
1747
|
+
world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.
|
1748
|
+
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
|
1749
|
+
or you're using `ITKReader` with `affine_lps_to_ras=True`.
|
1750
|
+
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
|
1751
|
+
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
|
1752
|
+
matrix are in the same coordinate system.
|
1753
|
+
|
1754
|
+
Use Cases:
|
1755
|
+
- Transforming points between world space and image space, and vice versa.
|
1756
|
+
- Automatically handling inverse transformations between image space and world space.
|
1757
|
+
- If points have an existing affine transformation, the class computes and
|
1758
|
+
applies the required delta affine transformation.
|
1759
|
+
|
1760
|
+
"""
|
1761
|
+
|
1762
|
+
def __init__(
|
1763
|
+
self,
|
1764
|
+
dtype: DtypeLike | torch.dtype | None = None,
|
1765
|
+
affine: torch.Tensor | None = None,
|
1766
|
+
invert_affine: bool = True,
|
1767
|
+
affine_lps_to_ras: bool = False,
|
1768
|
+
) -> None:
|
1769
|
+
self.dtype = dtype
|
1770
|
+
self.affine = affine
|
1771
|
+
self.invert_affine = invert_affine
|
1772
|
+
self.affine_lps_to_ras = affine_lps_to_ras
|
1773
|
+
|
1774
|
+
def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
|
1775
|
+
"""
|
1776
|
+
Compute the final affine transformation matrix to apply to the point data.
|
1777
|
+
|
1778
|
+
Args:
|
1779
|
+
data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
|
1780
|
+
affine: 3x3 or 4x4 affine transformation matrix.
|
1781
|
+
|
1782
|
+
Returns:
|
1783
|
+
Final affine transformation matrix.
|
1784
|
+
"""
|
1785
|
+
|
1786
|
+
affine = convert_data_type(affine, dtype=torch.float64)[0]
|
1787
|
+
|
1788
|
+
if self.affine_lps_to_ras:
|
1789
|
+
affine = orientation_ras_lps(affine)
|
1790
|
+
|
1791
|
+
if self.invert_affine:
|
1792
|
+
affine = linalg_inv(affine)
|
1793
|
+
if applied_affine is not None:
|
1794
|
+
affine = affine @ applied_affine
|
1795
|
+
|
1796
|
+
return affine
|
1797
|
+
|
1798
|
+
def transform_coordinates(
|
1799
|
+
self, data: torch.Tensor, affine: torch.Tensor | None = None
|
1800
|
+
) -> tuple[torch.Tensor, dict]:
|
1801
|
+
"""
|
1802
|
+
Transform coordinates using an affine transformation matrix.
|
1803
|
+
|
1804
|
+
Args:
|
1805
|
+
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1806
|
+
where C represents the number of channels and N denotes the number of points.
|
1807
|
+
affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,
|
1808
|
+
which can be computationally expensive when applied to a large number of points.
|
1809
|
+
|
1810
|
+
Returns:
|
1811
|
+
Transformed coordinates.
|
1812
|
+
"""
|
1813
|
+
data = convert_to_tensor(data, track_meta=get_track_meta())
|
1814
|
+
if affine is None and self.invert_affine:
|
1815
|
+
raise ValueError("affine must be provided when invert_affine is True.")
|
1816
|
+
# applied_affine is the affine transformation matrix that has already been applied to the point data
|
1817
|
+
applied_affine: torch.Tensor | None = getattr(data, "affine", None)
|
1818
|
+
affine = applied_affine if affine is None else affine
|
1819
|
+
if affine is None:
|
1820
|
+
raise ValueError("affine must be provided if data does not have an affine matrix.")
|
1821
|
+
|
1822
|
+
final_affine = self._compute_final_affine(affine, applied_affine)
|
1823
|
+
out = apply_affine_to_points(data, final_affine, dtype=self.dtype)
|
1824
|
+
|
1825
|
+
extra_info = {
|
1826
|
+
"invert_affine": self.invert_affine,
|
1827
|
+
"dtype": get_dtype_string(self.dtype),
|
1828
|
+
"image_affine": affine,
|
1829
|
+
"affine_lps_to_ras": self.affine_lps_to_ras,
|
1830
|
+
}
|
1831
|
+
|
1832
|
+
xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
|
1833
|
+
meta_info = TraceableTransform.track_transform_meta(
|
1834
|
+
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
|
1835
|
+
)
|
1836
|
+
|
1837
|
+
return out, meta_info
|
1838
|
+
|
1839
|
+
def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
|
1840
|
+
"""
|
1841
|
+
Args:
|
1842
|
+
data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1843
|
+
where C represents the number of channels and N denotes the number of points.
|
1844
|
+
affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
|
1845
|
+
"""
|
1846
|
+
if data.ndim != 3 or data.shape[-1] not in (2, 3):
|
1847
|
+
raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
|
1848
|
+
affine = self.affine if affine is None else affine
|
1849
|
+
if affine is not None and affine.shape not in ((3, 3), (4, 4)):
|
1850
|
+
raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")
|
1851
|
+
|
1852
|
+
out, meta_info = self.transform_coordinates(data, affine)
|
1853
|
+
|
1854
|
+
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
|
1855
|
+
|
1856
|
+
def inverse(self, data: torch.Tensor) -> torch.Tensor:
|
1857
|
+
transform = self.pop_transform(data)
|
1858
|
+
inverse_transform = ApplyTransformToPoints(
|
1859
|
+
dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
|
1860
|
+
invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
|
1861
|
+
affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
|
1862
|
+
)
|
1863
|
+
with inverse_transform.trace_transform(False):
|
1864
|
+
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
|
1865
|
+
|
1866
|
+
return data
|
@@ -35,6 +35,7 @@ from monai.transforms.transform import MapTransform, Randomizable, RandomizableT
|
|
35
35
|
from monai.transforms.utility.array import (
|
36
36
|
AddCoordinateChannels,
|
37
37
|
AddExtremePointsChannel,
|
38
|
+
ApplyTransformToPoints,
|
38
39
|
AsChannelLast,
|
39
40
|
CastToType,
|
40
41
|
ClassesToIndices,
|
@@ -180,6 +181,9 @@ __all__ = [
|
|
180
181
|
"ClassesToIndicesd",
|
181
182
|
"ClassesToIndicesD",
|
182
183
|
"ClassesToIndicesDict",
|
184
|
+
"ApplyTransformToPointsd",
|
185
|
+
"ApplyTransformToPointsD",
|
186
|
+
"ApplyTransformToPointsDict",
|
183
187
|
]
|
184
188
|
|
185
189
|
DEFAULT_POST_FIX = PostFix.meta()
|
@@ -789,6 +793,7 @@ class DataStatsd(MapTransform):
|
|
789
793
|
data_shape: Sequence[bool] | bool = True,
|
790
794
|
value_range: Sequence[bool] | bool = True,
|
791
795
|
data_value: Sequence[bool] | bool = False,
|
796
|
+
meta_info: Sequence[bool] | bool = False,
|
792
797
|
additional_info: Sequence[Callable] | Callable | None = None,
|
793
798
|
name: str = "DataStats",
|
794
799
|
allow_missing_keys: bool = False,
|
@@ -808,6 +813,8 @@ class DataStatsd(MapTransform):
|
|
808
813
|
data_value: whether to show the raw value of input data.
|
809
814
|
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
|
810
815
|
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
|
816
|
+
meta_info: whether to show the data of MetaTensor.
|
817
|
+
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
|
811
818
|
additional_info: user can define callable function to extract
|
812
819
|
additional info from input data. it also can be a sequence of string, each element
|
813
820
|
corresponds to a key in ``keys``.
|
@@ -821,15 +828,34 @@ class DataStatsd(MapTransform):
|
|
821
828
|
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
|
822
829
|
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
|
823
830
|
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
|
831
|
+
self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))
|
824
832
|
self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))
|
825
833
|
self.printer = DataStats(name=name)
|
826
834
|
|
827
835
|
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
|
828
836
|
d = dict(data)
|
829
|
-
for
|
830
|
-
|
837
|
+
for (
|
838
|
+
key,
|
839
|
+
prefix,
|
840
|
+
data_type,
|
841
|
+
data_shape,
|
842
|
+
value_range,
|
843
|
+
data_value,
|
844
|
+
meta_info,
|
845
|
+
additional_info,
|
846
|
+
) in self.key_iterator(
|
847
|
+
d,
|
848
|
+
self.prefix,
|
849
|
+
self.data_type,
|
850
|
+
self.data_shape,
|
851
|
+
self.value_range,
|
852
|
+
self.data_value,
|
853
|
+
self.meta_info,
|
854
|
+
self.additional_info,
|
831
855
|
):
|
832
|
-
d[key] = self.printer(
|
856
|
+
d[key] = self.printer(
|
857
|
+
d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info
|
858
|
+
)
|
833
859
|
return d
|
834
860
|
|
835
861
|
|
@@ -1744,6 +1770,77 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
|
|
1744
1770
|
return d
|
1745
1771
|
|
1746
1772
|
|
1773
|
+
class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
|
1774
|
+
"""
|
1775
|
+
Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.
|
1776
|
+
The input coordinates are assumed to be in the shape (C, N, 2 or 3),
|
1777
|
+
where C represents the number of channels and N denotes the number of points.
|
1778
|
+
The output has the same shape as the input.
|
1779
|
+
|
1780
|
+
Args:
|
1781
|
+
keys: keys of the corresponding items to be transformed.
|
1782
|
+
See also: monai.transforms.MapTransform
|
1783
|
+
refer_keys: The key of the reference item used for transformation.
|
1784
|
+
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
|
1785
|
+
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
|
1786
|
+
dtype: The desired data type for the output.
|
1787
|
+
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
|
1788
|
+
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
|
1789
|
+
Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
|
1790
|
+
applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
|
1791
|
+
The matrix is always converted to float64 for computation, which can be computationally
|
1792
|
+
expensive when applied to a large number of points.
|
1793
|
+
If None, will try to use the affine matrix from the refer data.
|
1794
|
+
invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
|
1795
|
+
Typically, the affine matrix is derived from the image, while the points are in world coordinates.
|
1796
|
+
If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
|
1797
|
+
affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
|
1798
|
+
or you're using `ITKReader` with `affine_lps_to_ras=True`.
|
1799
|
+
This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
|
1800
|
+
and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
|
1801
|
+
matrix are in the same coordinate system.
|
1802
|
+
allow_missing_keys: Don't raise exception if key is missing.
|
1803
|
+
"""
|
1804
|
+
|
1805
|
+
def __init__(
|
1806
|
+
self,
|
1807
|
+
keys: KeysCollection,
|
1808
|
+
refer_keys: KeysCollection | None = None,
|
1809
|
+
dtype: DtypeLike | torch.dtype = torch.float64,
|
1810
|
+
affine: torch.Tensor | None = None,
|
1811
|
+
invert_affine: bool = True,
|
1812
|
+
affine_lps_to_ras: bool = False,
|
1813
|
+
allow_missing_keys: bool = False,
|
1814
|
+
):
|
1815
|
+
MapTransform.__init__(self, keys, allow_missing_keys)
|
1816
|
+
self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
|
1817
|
+
self.converter = ApplyTransformToPoints(
|
1818
|
+
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
|
1819
|
+
)
|
1820
|
+
|
1821
|
+
def __call__(self, data: Mapping[Hashable, torch.Tensor]):
|
1822
|
+
d = dict(data)
|
1823
|
+
for key, refer_key in self.key_iterator(d, self.refer_keys):
|
1824
|
+
coords = d[key]
|
1825
|
+
affine = None # represents using affine given in constructor
|
1826
|
+
if refer_key is not None:
|
1827
|
+
if refer_key in d:
|
1828
|
+
refer_data = d[refer_key]
|
1829
|
+
else:
|
1830
|
+
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")
|
1831
|
+
|
1832
|
+
# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
|
1833
|
+
affine = getattr(refer_data, "affine", refer_data)
|
1834
|
+
d[key] = self.converter(coords, affine)
|
1835
|
+
return d
|
1836
|
+
|
1837
|
+
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
|
1838
|
+
d = dict(data)
|
1839
|
+
for key in self.key_iterator(d):
|
1840
|
+
d[key] = self.converter.inverse(d[key])
|
1841
|
+
return d
|
1842
|
+
|
1843
|
+
|
1747
1844
|
RandImageFilterD = RandImageFilterDict = RandImageFilterd
|
1748
1845
|
ImageFilterD = ImageFilterDict = ImageFilterd
|
1749
1846
|
IdentityD = IdentityDict = Identityd
|
@@ -1784,3 +1881,4 @@ CuCIMD = CuCIMDict = CuCIMd
|
|
1784
1881
|
RandCuCIMD = RandCuCIMDict = RandCuCIMd
|
1785
1882
|
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
|
1786
1883
|
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
|
1884
|
+
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
|
monai/transforms/utils.py
CHANGED
@@ -27,6 +27,7 @@ from torch import Tensor
|
|
27
27
|
import monai
|
28
28
|
from monai.config import DtypeLike, IndexSelection
|
29
29
|
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
|
30
|
+
from monai.data.utils import to_affine_nd
|
30
31
|
from monai.networks.layers import GaussianFilter
|
31
32
|
from monai.networks.utils import meshgrid_ij
|
32
33
|
from monai.transforms.compose import Compose
|
@@ -35,6 +36,7 @@ from monai.transforms.utils_morphological_ops import erode
|
|
35
36
|
from monai.transforms.utils_pytorch_numpy_unification import (
|
36
37
|
any_np_pt,
|
37
38
|
ascontiguousarray,
|
39
|
+
concatenate,
|
38
40
|
cumsum,
|
39
41
|
isfinite,
|
40
42
|
nonzero,
|
@@ -1861,7 +1863,7 @@ class Fourier:
|
|
1861
1863
|
"""
|
1862
1864
|
|
1863
1865
|
@staticmethod
|
1864
|
-
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
|
1866
|
+
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:
|
1865
1867
|
"""
|
1866
1868
|
Applies fourier transform and shifts the zero-frequency component to the
|
1867
1869
|
center of the spectrum. Only the spatial dimensions get transformed.
|
@@ -1869,6 +1871,7 @@ class Fourier:
|
|
1869
1871
|
Args:
|
1870
1872
|
x: Image to transform.
|
1871
1873
|
spatial_dims: Number of spatial dimensions.
|
1874
|
+
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
|
1872
1875
|
|
1873
1876
|
Returns
|
1874
1877
|
k: K-space data.
|
@@ -1883,10 +1886,12 @@ class Fourier:
|
|
1883
1886
|
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
|
1884
1887
|
else:
|
1885
1888
|
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
|
1886
|
-
return k
|
1889
|
+
return ascontiguousarray(k) if as_contiguous else k
|
1887
1890
|
|
1888
1891
|
@staticmethod
|
1889
|
-
def inv_shift_fourier(
|
1892
|
+
def inv_shift_fourier(
|
1893
|
+
k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False
|
1894
|
+
) -> NdarrayOrTensor:
|
1890
1895
|
"""
|
1891
1896
|
Applies inverse shift and fourier transform. Only the spatial
|
1892
1897
|
dimensions are transformed.
|
@@ -1894,6 +1899,7 @@ class Fourier:
|
|
1894
1899
|
Args:
|
1895
1900
|
k: K-space data.
|
1896
1901
|
spatial_dims: Number of spatial dimensions.
|
1902
|
+
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
|
1897
1903
|
|
1898
1904
|
Returns:
|
1899
1905
|
x: Tensor in image space.
|
@@ -1908,7 +1914,7 @@ class Fourier:
|
|
1908
1914
|
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
|
1909
1915
|
else:
|
1910
1916
|
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
|
1911
|
-
return out
|
1917
|
+
return ascontiguousarray(out) if as_contiguous else out
|
1912
1918
|
|
1913
1919
|
|
1914
1920
|
def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:
|
@@ -2555,5 +2561,26 @@ def distance_transform_edt(
|
|
2555
2561
|
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]
|
2556
2562
|
|
2557
2563
|
|
2564
|
+
def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None):
|
2565
|
+
"""
|
2566
|
+
apply affine transformation to a set of points.
|
2567
|
+
|
2568
|
+
Args:
|
2569
|
+
data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3),
|
2570
|
+
where C represents the number of channels and N denotes the number of points.
|
2571
|
+
affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).
|
2572
|
+
dtype: output data dtype.
|
2573
|
+
"""
|
2574
|
+
data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)
|
2575
|
+
affine = to_affine_nd(data_.shape[-1], affine)
|
2576
|
+
|
2577
|
+
homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore
|
2578
|
+
transformed_homogeneous = torch.matmul(homogeneous, affine.T)
|
2579
|
+
transformed_coordinates = transformed_homogeneous[:, :, :-1]
|
2580
|
+
out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)
|
2581
|
+
|
2582
|
+
return out
|
2583
|
+
|
2584
|
+
|
2558
2585
|
if __name__ == "__main__":
|
2559
2586
|
print_transform_backends()
|
monai/utils/__init__.py
CHANGED
monai/utils/type_conversion.py
CHANGED
@@ -33,6 +33,7 @@ __all__ = [
|
|
33
33
|
"get_equivalent_dtype",
|
34
34
|
"convert_data_type",
|
35
35
|
"get_dtype",
|
36
|
+
"get_dtype_string",
|
36
37
|
"convert_to_cupy",
|
37
38
|
"convert_to_numpy",
|
38
39
|
"convert_to_tensor",
|
@@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype:
|
|
102
103
|
return type(data)
|
103
104
|
|
104
105
|
|
106
|
+
def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:
|
107
|
+
"""Get a string representation of the dtype."""
|
108
|
+
if isinstance(dtype, torch.dtype):
|
109
|
+
return str(dtype)[6:]
|
110
|
+
return str(dtype)[3:]
|
111
|
+
|
112
|
+
|
105
113
|
def convert_to_tensor(
|
106
114
|
data: Any,
|
107
115
|
dtype: DtypeLike | torch.dtype = None,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.4.
|
3
|
+
Version: 1.4.dev2436
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -120,6 +120,8 @@ Provides-Extra: pandas
|
|
120
120
|
Requires-Dist: pandas; extra == "pandas"
|
121
121
|
Provides-Extra: pillow
|
122
122
|
Requires-Dist: pillow!=8.3.0; extra == "pillow"
|
123
|
+
Provides-Extra: polygraphy
|
124
|
+
Requires-Dist: polygraphy; extra == "polygraphy"
|
123
125
|
Provides-Extra: psutil
|
124
126
|
Requires-Dist: psutil; extra == "psutil"
|
125
127
|
Provides-Extra: pyamg
|