medaugmentx 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- medaugmentx/__init__.py +22 -0
- medaugmentx/core/__init__.py +16 -0
- medaugmentx/core/base.py +81 -0
- medaugmentx/core/compose.py +195 -0
- medaugmentx/core/utils.py +87 -0
- medaugmentx/core/volume.py +117 -0
- medaugmentx/io/__init__.py +18 -0
- medaugmentx/io/dicom.py +195 -0
- medaugmentx/io/nifti.py +101 -0
- medaugmentx/presets.py +226 -0
- medaugmentx/serialization.py +267 -0
- medaugmentx/transforms/__init__.py +54 -0
- medaugmentx/transforms/intensity/__init__.py +18 -0
- medaugmentx/transforms/intensity/bias_field.py +107 -0
- medaugmentx/transforms/intensity/blur.py +165 -0
- medaugmentx/transforms/intensity/brightness_contrast.py +91 -0
- medaugmentx/transforms/intensity/contrast.py +79 -0
- medaugmentx/transforms/intensity/noise.py +130 -0
- medaugmentx/transforms/intensity/window_level.py +116 -0
- medaugmentx/transforms/modality/__init__.py +22 -0
- medaugmentx/transforms/modality/ct/__init__.py +4 -0
- medaugmentx/transforms/modality/ct/beam_hardening.py +108 -0
- medaugmentx/transforms/modality/mri/__init__.py +5 -0
- medaugmentx/transforms/modality/mri/ghosting.py +112 -0
- medaugmentx/transforms/modality/mri/kspace.py +105 -0
- medaugmentx/transforms/modality/tomosynthesis/__init__.py +12 -0
- medaugmentx/transforms/modality/tomosynthesis/blur.py +89 -0
- medaugmentx/transforms/modality/tomosynthesis/dropout.py +82 -0
- medaugmentx/transforms/modality/tomosynthesis/elastic.py +70 -0
- medaugmentx/transforms/modality/tomosynthesis/slab.py +89 -0
- medaugmentx/transforms/spatial/__init__.py +7 -0
- medaugmentx/transforms/spatial/affine.py +187 -0
- medaugmentx/transforms/spatial/crop.py +112 -0
- medaugmentx/transforms/spatial/elastic.py +133 -0
- medaugmentx/transforms/spatial/flip.py +75 -0
- medaugmentx-0.2.0.dist-info/METADATA +330 -0
- medaugmentx-0.2.0.dist-info/RECORD +40 -0
- medaugmentx-0.2.0.dist-info/WHEEL +5 -0
- medaugmentx-0.2.0.dist-info/licenses/LICENSE +21 -0
- medaugmentx-0.2.0.dist-info/top_level.txt +1 -0
medaugmentx/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""MedAugment — clinically-aware medical image augmentation.
|
|
2
|
+
|
|
3
|
+
Public surface for Phase 2.
|
|
4
|
+
"""
|
|
5
|
+
from medaugmentx.core import (
|
|
6
|
+
Compose,
|
|
7
|
+
MedVolume,
|
|
8
|
+
OneOf,
|
|
9
|
+
SomeOf,
|
|
10
|
+
Transform,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__version__ = "0.2.0"
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"__version__",
|
|
17
|
+
"MedVolume",
|
|
18
|
+
"Transform",
|
|
19
|
+
"Compose",
|
|
20
|
+
"OneOf",
|
|
21
|
+
"SomeOf",
|
|
22
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Core data model and pipeline primitives."""
|
|
2
|
+
from medaugmentx.core.base import Transform
|
|
3
|
+
from medaugmentx.core.compose import Compose, OneOf, SomeOf
|
|
4
|
+
from medaugmentx.core.utils import as_float32, derive_rng, resolve_rng
|
|
5
|
+
from medaugmentx.core.volume import MedVolume
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"MedVolume",
|
|
9
|
+
"Transform",
|
|
10
|
+
"Compose",
|
|
11
|
+
"OneOf",
|
|
12
|
+
"SomeOf",
|
|
13
|
+
"as_float32",
|
|
14
|
+
"derive_rng",
|
|
15
|
+
"resolve_rng",
|
|
16
|
+
]
|
medaugmentx/core/base.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Abstract base class that every MedAugment transform inherits from."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from medaugmentx.core.utils import SeedLike, resolve_rng
|
|
10
|
+
from medaugmentx.core.volume import MedVolume
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Transform(ABC):
|
|
14
|
+
"""Base class for all augmentations.
|
|
15
|
+
|
|
16
|
+
Subclasses must override :meth:`apply` and accept ``p`` and ``seed``
|
|
17
|
+
through ``super().__init__``. The base class handles probabilistic
|
|
18
|
+
application, seeding, and serialisation.
|
|
19
|
+
|
|
20
|
+
Probabilistic application is gated on ``self.rng`` so that two transforms
|
|
21
|
+
in the same :class:`Compose` with the same seed do not share a random
|
|
22
|
+
stream — see :func:`medaugmentx.core.utils.derive_rng`.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
|
|
26
|
+
class MyShift(Transform):
|
|
27
|
+
def __init__(self, max_shift=0.1, p=1.0, seed=None):
|
|
28
|
+
super().__init__(p=p, seed=seed)
|
|
29
|
+
self.max_shift = max_shift
|
|
30
|
+
|
|
31
|
+
def apply(self, volume):
|
|
32
|
+
delta = self.rng.uniform(-self.max_shift, self.max_shift)
|
|
33
|
+
return volume.replace(image=volume.image + delta)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, p: float = 1.0, seed: SeedLike = None) -> None:
|
|
37
|
+
if not 0.0 <= float(p) <= 1.0:
|
|
38
|
+
raise ValueError(f"p must be in [0, 1], got {p}")
|
|
39
|
+
self.p: float = float(p)
|
|
40
|
+
# Store the seed for serialisation (int or None only; Generator can't round-trip).
|
|
41
|
+
self._seed: int | None = seed if isinstance(seed, (int, type(None))) else None
|
|
42
|
+
self.rng: np.random.Generator = resolve_rng(seed)
|
|
43
|
+
|
|
44
|
+
def __call__(self, volume: MedVolume) -> MedVolume:
|
|
45
|
+
if not isinstance(volume, MedVolume):
|
|
46
|
+
raise TypeError(f"Transform expects a MedVolume, got {type(volume).__name__}")
|
|
47
|
+
if self.p < 1.0 and self.rng.random() >= self.p:
|
|
48
|
+
return volume
|
|
49
|
+
return self.apply(volume)
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def apply(self, volume: MedVolume) -> MedVolume:
|
|
53
|
+
"""Perform the transform unconditionally — already past probability gate."""
|
|
54
|
+
|
|
55
|
+
def set_rng(self, rng: np.random.Generator) -> None:
|
|
56
|
+
"""Reseed this transform with a specific :class:`numpy.random.Generator`.
|
|
57
|
+
|
|
58
|
+
Used by :class:`Compose` to give each child its own deterministic stream.
|
|
59
|
+
"""
|
|
60
|
+
if not isinstance(rng, np.random.Generator):
|
|
61
|
+
raise TypeError("rng must be a numpy.random.Generator")
|
|
62
|
+
self.rng = rng
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
attrs = ", ".join(
|
|
66
|
+
f"{k}={v!r}"
|
|
67
|
+
for k, v in self.__dict__.items()
|
|
68
|
+
if k != "rng" and not k.startswith("_")
|
|
69
|
+
)
|
|
70
|
+
return f"{self.__class__.__name__}({attrs})"
|
|
71
|
+
|
|
72
|
+
def to_dict(self) -> dict[str, Any]:
|
|
73
|
+
"""Best-effort dictionary form of this transform's parameters.
|
|
74
|
+
|
|
75
|
+
Phase 1 ships only this introspection helper; full YAML serialisation
|
|
76
|
+
and round-tripping arrive in Phase 2.
|
|
77
|
+
"""
|
|
78
|
+
params = {
|
|
79
|
+
k: v for k, v in self.__dict__.items() if k != "rng" and not k.startswith("_")
|
|
80
|
+
}
|
|
81
|
+
return {"name": self.__class__.__name__, "params": params}
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Pipeline builders: Compose, OneOf, SomeOf."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections.abc import Iterable, Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from medaugmentx.core.base import Transform
|
|
10
|
+
from medaugmentx.core.utils import SeedLike, derive_rng
|
|
11
|
+
from medaugmentx.core.volume import MedVolume
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Compose(Transform):
|
|
15
|
+
"""Apply transforms sequentially.
|
|
16
|
+
|
|
17
|
+
All children share a deterministic seeding chain derived from the
|
|
18
|
+
top-level seed, so ``Compose([...], seed=42)`` produces the same output
|
|
19
|
+
every time, on every machine, for the same NumPy version.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
transforms: Iterable[Transform],
|
|
25
|
+
p: float = 1.0,
|
|
26
|
+
seed: SeedLike = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__(p=p, seed=seed)
|
|
29
|
+
self.transforms: list[Transform] = list(transforms)
|
|
30
|
+
for t in self.transforms:
|
|
31
|
+
if not isinstance(t, Transform):
|
|
32
|
+
raise TypeError(
|
|
33
|
+
f"Compose expected Transform instances, got {type(t).__name__}"
|
|
34
|
+
)
|
|
35
|
+
self._reseed_children()
|
|
36
|
+
|
|
37
|
+
def _reseed_children(self) -> None:
|
|
38
|
+
if not self.transforms:
|
|
39
|
+
return
|
|
40
|
+
for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
|
|
41
|
+
t.set_rng(child_rng)
|
|
42
|
+
|
|
43
|
+
def set_rng(self, rng: np.random.Generator) -> None:
|
|
44
|
+
super().set_rng(rng)
|
|
45
|
+
self._reseed_children()
|
|
46
|
+
|
|
47
|
+
def apply(self, volume: MedVolume) -> MedVolume:
|
|
48
|
+
out = volume
|
|
49
|
+
for t in self.transforms:
|
|
50
|
+
out = t(out)
|
|
51
|
+
return out
|
|
52
|
+
|
|
53
|
+
def __len__(self) -> int:
|
|
54
|
+
return len(self.transforms)
|
|
55
|
+
|
|
56
|
+
def __iter__(self):
|
|
57
|
+
return iter(self.transforms)
|
|
58
|
+
|
|
59
|
+
def to_dict(self) -> dict[str, Any]:
|
|
60
|
+
return {
|
|
61
|
+
"name": self.__class__.__name__,
|
|
62
|
+
"params": {
|
|
63
|
+
"transforms": [t.to_dict() for t in self.transforms],
|
|
64
|
+
"p": self.p,
|
|
65
|
+
"seed": self._seed,
|
|
66
|
+
},
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class OneOf(Transform):
|
|
71
|
+
"""Pick exactly one child uniformly at random and apply it.
|
|
72
|
+
|
|
73
|
+
The container's ``p`` controls whether *any* child runs at all. When
|
|
74
|
+
weights are provided they are normalised; otherwise the choice is uniform.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
transforms: Sequence[Transform],
|
|
80
|
+
weights: Sequence[float] | None = None,
|
|
81
|
+
p: float = 1.0,
|
|
82
|
+
seed: SeedLike = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
super().__init__(p=p, seed=seed)
|
|
85
|
+
self.transforms: list[Transform] = list(transforms)
|
|
86
|
+
if not self.transforms:
|
|
87
|
+
raise ValueError("OneOf requires at least one transform")
|
|
88
|
+
for t in self.transforms:
|
|
89
|
+
if not isinstance(t, Transform):
|
|
90
|
+
raise TypeError(f"OneOf expected Transform, got {type(t).__name__}")
|
|
91
|
+
|
|
92
|
+
if weights is None:
|
|
93
|
+
self.weights = np.full(len(self.transforms), 1.0 / len(self.transforms))
|
|
94
|
+
else:
|
|
95
|
+
w = np.asarray(weights, dtype=np.float64)
|
|
96
|
+
if w.shape != (len(self.transforms),):
|
|
97
|
+
raise ValueError("weights length must match number of transforms")
|
|
98
|
+
if (w < 0).any() or w.sum() <= 0:
|
|
99
|
+
raise ValueError("weights must be non-negative and sum to > 0")
|
|
100
|
+
self.weights = w / w.sum()
|
|
101
|
+
|
|
102
|
+
self._reseed_children()
|
|
103
|
+
|
|
104
|
+
def _reseed_children(self) -> None:
|
|
105
|
+
for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
|
|
106
|
+
t.set_rng(child_rng)
|
|
107
|
+
|
|
108
|
+
def set_rng(self, rng: np.random.Generator) -> None:
|
|
109
|
+
super().set_rng(rng)
|
|
110
|
+
self._reseed_children()
|
|
111
|
+
|
|
112
|
+
def apply(self, volume: MedVolume) -> MedVolume:
|
|
113
|
+
idx = int(self.rng.choice(len(self.transforms), p=self.weights))
|
|
114
|
+
# Force the chosen child to run regardless of its own ``p``.
|
|
115
|
+
return self.transforms[idx].apply(volume)
|
|
116
|
+
|
|
117
|
+
def to_dict(self) -> dict[str, Any]:
|
|
118
|
+
return {
|
|
119
|
+
"name": self.__class__.__name__,
|
|
120
|
+
"params": {
|
|
121
|
+
"transforms": [t.to_dict() for t in self.transforms],
|
|
122
|
+
"weights": self.weights.tolist(),
|
|
123
|
+
"p": self.p,
|
|
124
|
+
"seed": self._seed,
|
|
125
|
+
},
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class SomeOf(Transform):
|
|
130
|
+
"""Pick ``n`` children at random (without replacement) and apply them in order.
|
|
131
|
+
|
|
132
|
+
``n`` may be an int or a ``(low, high)`` inclusive range — when a range,
|
|
133
|
+
a value is sampled per call.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
transforms: Sequence[Transform],
|
|
139
|
+
n: int | tuple[int, int] = 1,
|
|
140
|
+
p: float = 1.0,
|
|
141
|
+
seed: SeedLike = None,
|
|
142
|
+
) -> None:
|
|
143
|
+
super().__init__(p=p, seed=seed)
|
|
144
|
+
self.transforms: list[Transform] = list(transforms)
|
|
145
|
+
if not self.transforms:
|
|
146
|
+
raise ValueError("SomeOf requires at least one transform")
|
|
147
|
+
for t in self.transforms:
|
|
148
|
+
if not isinstance(t, Transform):
|
|
149
|
+
raise TypeError(f"SomeOf expected Transform, got {type(t).__name__}")
|
|
150
|
+
|
|
151
|
+
if isinstance(n, int):
|
|
152
|
+
lo, hi = n, n
|
|
153
|
+
else:
|
|
154
|
+
lo, hi = int(n[0]), int(n[1])
|
|
155
|
+
if not 0 <= lo <= hi <= len(self.transforms):
|
|
156
|
+
raise ValueError(f"n={n} invalid for {len(self.transforms)} transforms")
|
|
157
|
+
self.n_range: tuple[int, int] = (lo, hi)
|
|
158
|
+
|
|
159
|
+
self._reseed_children()
|
|
160
|
+
|
|
161
|
+
def _reseed_children(self) -> None:
|
|
162
|
+
for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
|
|
163
|
+
t.set_rng(child_rng)
|
|
164
|
+
|
|
165
|
+
def set_rng(self, rng: np.random.Generator) -> None:
|
|
166
|
+
super().set_rng(rng)
|
|
167
|
+
self._reseed_children()
|
|
168
|
+
|
|
169
|
+
def apply(self, volume: MedVolume) -> MedVolume:
|
|
170
|
+
lo, hi = self.n_range
|
|
171
|
+
n = int(self.rng.integers(lo, hi + 1))
|
|
172
|
+
if n == 0:
|
|
173
|
+
return volume
|
|
174
|
+
idxs = self.rng.choice(len(self.transforms), size=n, replace=False)
|
|
175
|
+
idxs.sort()
|
|
176
|
+
out = volume
|
|
177
|
+
for i in idxs:
|
|
178
|
+
out = self.transforms[int(i)].apply(out)
|
|
179
|
+
return out
|
|
180
|
+
|
|
181
|
+
def to_dict(self) -> dict[str, Any]:
|
|
182
|
+
lo, hi = self.n_range
|
|
183
|
+
n: Any = lo if lo == hi else list(self.n_range)
|
|
184
|
+
return {
|
|
185
|
+
"name": self.__class__.__name__,
|
|
186
|
+
"params": {
|
|
187
|
+
"transforms": [t.to_dict() for t in self.transforms],
|
|
188
|
+
"n": n,
|
|
189
|
+
"p": self.p,
|
|
190
|
+
"seed": self._seed,
|
|
191
|
+
},
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
__all__ = ["Compose", "OneOf", "SomeOf"]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Small helpers shared across the library."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
SeedLike = Union[int, np.random.Generator, None]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def resolve_rng(seed: SeedLike) -> np.random.Generator:
|
|
12
|
+
"""Return a ``numpy.random.Generator`` from any accepted seed input.
|
|
13
|
+
|
|
14
|
+
Accepts ``None`` (fresh entropy), an ``int`` seed, or an existing
|
|
15
|
+
``Generator`` (returned as-is for chaining).
|
|
16
|
+
"""
|
|
17
|
+
if isinstance(seed, np.random.Generator):
|
|
18
|
+
return seed
|
|
19
|
+
return np.random.default_rng(seed)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def derive_rng(rng: np.random.Generator, n: int) -> list[np.random.Generator]:
|
|
23
|
+
"""Spawn ``n`` independent generators from ``rng`` deterministically.
|
|
24
|
+
|
|
25
|
+
Used by :class:`~medaugmentx.core.compose.Compose` to give each child
|
|
26
|
+
transform its own stream while keeping the whole pipeline reproducible
|
|
27
|
+
from a single top-level seed.
|
|
28
|
+
"""
|
|
29
|
+
seeds = rng.integers(0, np.iinfo(np.uint64).max, size=n, dtype=np.uint64, endpoint=False)
|
|
30
|
+
return [np.random.default_rng(int(s)) for s in seeds]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def as_float32(image: np.ndarray) -> np.ndarray:
|
|
34
|
+
"""Cast to ``float32`` only when needed; cheap no-op otherwise."""
|
|
35
|
+
if image.dtype == np.float32:
|
|
36
|
+
return image
|
|
37
|
+
return image.astype(np.float32, copy=False)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def normalize_axes(axes: int | tuple | list | None, ndim: int) -> tuple:
|
|
41
|
+
"""Normalise an ``axes`` argument to a sorted tuple of non-negative ints.
|
|
42
|
+
|
|
43
|
+
``None`` expands to all axes. Negative axes are wrapped relative to ndim.
|
|
44
|
+
"""
|
|
45
|
+
if axes is None:
|
|
46
|
+
return tuple(range(ndim))
|
|
47
|
+
if isinstance(axes, int):
|
|
48
|
+
axes = (axes,)
|
|
49
|
+
out: list[int] = []
|
|
50
|
+
for a in axes:
|
|
51
|
+
ax = int(a)
|
|
52
|
+
if ax < 0:
|
|
53
|
+
ax += ndim
|
|
54
|
+
if not 0 <= ax < ndim:
|
|
55
|
+
raise ValueError(f"axis {a} out of range for ndim={ndim}")
|
|
56
|
+
out.append(ax)
|
|
57
|
+
return tuple(sorted(set(out)))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def axis_label_to_index(label: str, ndim: int) -> int:
|
|
61
|
+
"""Map a friendly axis label (``"x"``, ``"y"``, ``"z"``) to a NumPy axis.
|
|
62
|
+
|
|
63
|
+
Convention used throughout the library:
|
|
64
|
+
|
|
65
|
+
- 3D arrays are stored as ``(D, H, W)`` — i.e. ``(z, y, x)``.
|
|
66
|
+
- 2D arrays are stored as ``(H, W)`` — i.e. ``(y, x)``.
|
|
67
|
+
|
|
68
|
+
So for 3D ``"z"`` -> 0, ``"y"`` -> 1, ``"x"`` -> 2; for 2D ``"y"`` -> 0,
|
|
69
|
+
``"x"`` -> 1. ``"z"`` is invalid for 2D arrays.
|
|
70
|
+
"""
|
|
71
|
+
label = label.lower()
|
|
72
|
+
if ndim == 3:
|
|
73
|
+
mapping = {"z": 0, "y": 1, "x": 2}
|
|
74
|
+
elif ndim == 2:
|
|
75
|
+
mapping = {"y": 0, "x": 1}
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Only 2D or 3D supported, got ndim={ndim}")
|
|
78
|
+
if label not in mapping:
|
|
79
|
+
raise ValueError(f"Unknown axis label {label!r} for ndim={ndim}")
|
|
80
|
+
return mapping[label]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def clip_intensity(image: np.ndarray, lo: float | None = None, hi: float | None = None) -> np.ndarray:
|
|
84
|
+
"""Clip in-place if writeable, otherwise return a clipped copy."""
|
|
85
|
+
if lo is None and hi is None:
|
|
86
|
+
return image
|
|
87
|
+
return np.clip(image, lo, hi)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""The MedVolume container — image + optional mask + spacing + metadata."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field, replace
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class MedVolume:
|
|
12
|
+
"""A single medical image (2D or 3D) with optional segmentation mask.
|
|
13
|
+
|
|
14
|
+
All transforms in MedAugment operate on this container so that masks and
|
|
15
|
+
metadata stay in lockstep with the image array.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
image: 2D ``(H, W)`` or 3D ``(D, H, W)`` array. Recommended dtype is
|
|
19
|
+
``float32``; integer inputs are accepted but will be cast where
|
|
20
|
+
arithmetic is required.
|
|
21
|
+
mask: Optional integer label map with the same shape as ``image``.
|
|
22
|
+
spacing: Voxel size in millimetres, one entry per spatial axis.
|
|
23
|
+
For 3D volumes this is ``(slice_thickness, row_mm, col_mm)``.
|
|
24
|
+
metadata: Free-form dictionary. Conventional keys: ``modality``
|
|
25
|
+
(``"MR" | "CT" | "DX" | "DBT"``), ``vendor``, ``patient_id``,
|
|
26
|
+
``original_dtype``.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
image: np.ndarray
|
|
30
|
+
mask: np.ndarray | None = None
|
|
31
|
+
spacing: tuple[float, ...] = ()
|
|
32
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
def __post_init__(self) -> None:
|
|
35
|
+
if not isinstance(self.image, np.ndarray):
|
|
36
|
+
raise TypeError(f"image must be a numpy.ndarray, got {type(self.image).__name__}")
|
|
37
|
+
if self.image.ndim not in (2, 3):
|
|
38
|
+
raise ValueError(f"image must be 2D or 3D; got shape {self.image.shape}")
|
|
39
|
+
|
|
40
|
+
if self.mask is not None:
|
|
41
|
+
if not isinstance(self.mask, np.ndarray):
|
|
42
|
+
raise TypeError("mask must be a numpy.ndarray or None")
|
|
43
|
+
if self.mask.shape != self.image.shape:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"mask shape {self.mask.shape} does not match image shape {self.image.shape}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if self.spacing:
|
|
49
|
+
if len(self.spacing) != self.image.ndim:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"spacing has {len(self.spacing)} entries but image is {self.image.ndim}D"
|
|
52
|
+
)
|
|
53
|
+
self.spacing = tuple(float(s) for s in self.spacing)
|
|
54
|
+
else:
|
|
55
|
+
self.spacing = tuple(1.0 for _ in range(self.image.ndim))
|
|
56
|
+
|
|
57
|
+
if not isinstance(self.metadata, dict):
|
|
58
|
+
raise TypeError("metadata must be a dict")
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def ndim(self) -> int:
|
|
62
|
+
return int(self.image.ndim)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def shape(self) -> tuple[int, ...]:
|
|
66
|
+
return tuple(self.image.shape)
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def is_3d(self) -> bool:
|
|
70
|
+
return self.image.ndim == 3
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def has_mask(self) -> bool:
|
|
74
|
+
return self.mask is not None
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def modality(self) -> str | None:
|
|
78
|
+
return self.metadata.get("modality")
|
|
79
|
+
|
|
80
|
+
def replace(
|
|
81
|
+
self,
|
|
82
|
+
*,
|
|
83
|
+
image: np.ndarray | None = None,
|
|
84
|
+
mask: np.ndarray | None = None,
|
|
85
|
+
spacing: tuple[float, ...] | None = None,
|
|
86
|
+
metadata: dict[str, Any] | None = None,
|
|
87
|
+
) -> MedVolume:
|
|
88
|
+
"""Return a new MedVolume with selected fields swapped out.
|
|
89
|
+
|
|
90
|
+
Use ``mask=...`` only to provide a new mask; pass ``mask=None`` and
|
|
91
|
+
rely on the existing one by omitting the keyword. Metadata is shallow-
|
|
92
|
+
copied to avoid silent aliasing across volumes.
|
|
93
|
+
"""
|
|
94
|
+
return replace(
|
|
95
|
+
self,
|
|
96
|
+
image=self.image if image is None else image,
|
|
97
|
+
mask=self.mask if mask is None else mask,
|
|
98
|
+
spacing=self.spacing if spacing is None else tuple(float(s) for s in spacing),
|
|
99
|
+
metadata=dict(self.metadata if metadata is None else metadata),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def copy(self) -> MedVolume:
|
|
103
|
+
"""Deep copy of the underlying arrays and metadata."""
|
|
104
|
+
return MedVolume(
|
|
105
|
+
image=self.image.copy(),
|
|
106
|
+
mask=None if self.mask is None else self.mask.copy(),
|
|
107
|
+
spacing=tuple(self.spacing),
|
|
108
|
+
metadata=dict(self.metadata),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def __repr__(self) -> str:
|
|
112
|
+
mask_repr = "None" if self.mask is None else f"shape={self.mask.shape}, dtype={self.mask.dtype}"
|
|
113
|
+
return (
|
|
114
|
+
f"MedVolume(image=shape={self.image.shape}, dtype={self.image.dtype}, "
|
|
115
|
+
f"mask={mask_repr}, spacing={self.spacing}, "
|
|
116
|
+
f"modality={self.modality!r})"
|
|
117
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Unified I/O for medical image formats.
|
|
2
|
+
|
|
3
|
+
Each loader returns a :class:`~medaugmentx.core.volume.MedVolume` with
|
|
4
|
+
``spacing`` populated in millimetres and ``metadata`` carrying the
|
|
5
|
+
modality and any vendor-specific information that callers may need.
|
|
6
|
+
|
|
7
|
+
Optional dependencies:
|
|
8
|
+
|
|
9
|
+
- DICOM I/O requires ``pydicom`` (``pip install medaugmentx[dicom]``).
|
|
10
|
+
- NIfTI I/O requires ``nibabel`` (``pip install medaugmentx[nifti]``).
|
|
11
|
+
|
|
12
|
+
If a backend is missing the loader raises a clear :class:`ImportError`
|
|
13
|
+
when called, not at import time.
|
|
14
|
+
"""
|
|
15
|
+
from medaugmentx.io.dicom import load_dicom_series
|
|
16
|
+
from medaugmentx.io.nifti import load_nifti, save_nifti
|
|
17
|
+
|
|
18
|
+
__all__ = ["load_dicom_series", "load_nifti", "save_nifti"]
|