medaugmentx 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. medaugmentx/__init__.py +22 -0
  2. medaugmentx/core/__init__.py +16 -0
  3. medaugmentx/core/base.py +81 -0
  4. medaugmentx/core/compose.py +195 -0
  5. medaugmentx/core/utils.py +87 -0
  6. medaugmentx/core/volume.py +117 -0
  7. medaugmentx/io/__init__.py +18 -0
  8. medaugmentx/io/dicom.py +195 -0
  9. medaugmentx/io/nifti.py +101 -0
  10. medaugmentx/presets.py +226 -0
  11. medaugmentx/serialization.py +267 -0
  12. medaugmentx/transforms/__init__.py +54 -0
  13. medaugmentx/transforms/intensity/__init__.py +18 -0
  14. medaugmentx/transforms/intensity/bias_field.py +107 -0
  15. medaugmentx/transforms/intensity/blur.py +165 -0
  16. medaugmentx/transforms/intensity/brightness_contrast.py +91 -0
  17. medaugmentx/transforms/intensity/contrast.py +79 -0
  18. medaugmentx/transforms/intensity/noise.py +130 -0
  19. medaugmentx/transforms/intensity/window_level.py +116 -0
  20. medaugmentx/transforms/modality/__init__.py +22 -0
  21. medaugmentx/transforms/modality/ct/__init__.py +4 -0
  22. medaugmentx/transforms/modality/ct/beam_hardening.py +108 -0
  23. medaugmentx/transforms/modality/mri/__init__.py +5 -0
  24. medaugmentx/transforms/modality/mri/ghosting.py +112 -0
  25. medaugmentx/transforms/modality/mri/kspace.py +105 -0
  26. medaugmentx/transforms/modality/tomosynthesis/__init__.py +12 -0
  27. medaugmentx/transforms/modality/tomosynthesis/blur.py +89 -0
  28. medaugmentx/transforms/modality/tomosynthesis/dropout.py +82 -0
  29. medaugmentx/transforms/modality/tomosynthesis/elastic.py +70 -0
  30. medaugmentx/transforms/modality/tomosynthesis/slab.py +89 -0
  31. medaugmentx/transforms/spatial/__init__.py +7 -0
  32. medaugmentx/transforms/spatial/affine.py +187 -0
  33. medaugmentx/transforms/spatial/crop.py +112 -0
  34. medaugmentx/transforms/spatial/elastic.py +133 -0
  35. medaugmentx/transforms/spatial/flip.py +75 -0
  36. medaugmentx-0.2.0.dist-info/METADATA +330 -0
  37. medaugmentx-0.2.0.dist-info/RECORD +40 -0
  38. medaugmentx-0.2.0.dist-info/WHEEL +5 -0
  39. medaugmentx-0.2.0.dist-info/licenses/LICENSE +21 -0
  40. medaugmentx-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
1
+ """MedAugment — clinically-aware medical image augmentation.
2
+
3
+ Public surface for Phase 2.
4
+ """
5
+ from medaugmentx.core import (
6
+ Compose,
7
+ MedVolume,
8
+ OneOf,
9
+ SomeOf,
10
+ Transform,
11
+ )
12
+
13
+ __version__ = "0.2.0"
14
+
15
+ __all__ = [
16
+ "__version__",
17
+ "MedVolume",
18
+ "Transform",
19
+ "Compose",
20
+ "OneOf",
21
+ "SomeOf",
22
+ ]
@@ -0,0 +1,16 @@
1
+ """Core data model and pipeline primitives."""
2
+ from medaugmentx.core.base import Transform
3
+ from medaugmentx.core.compose import Compose, OneOf, SomeOf
4
+ from medaugmentx.core.utils import as_float32, derive_rng, resolve_rng
5
+ from medaugmentx.core.volume import MedVolume
6
+
7
+ __all__ = [
8
+ "MedVolume",
9
+ "Transform",
10
+ "Compose",
11
+ "OneOf",
12
+ "SomeOf",
13
+ "as_float32",
14
+ "derive_rng",
15
+ "resolve_rng",
16
+ ]
@@ -0,0 +1,81 @@
1
+ """Abstract base class that every MedAugment transform inherits from."""
2
+ from __future__ import annotations
3
+
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from medaugmentx.core.utils import SeedLike, resolve_rng
10
+ from medaugmentx.core.volume import MedVolume
11
+
12
+
13
+ class Transform(ABC):
14
+ """Base class for all augmentations.
15
+
16
+ Subclasses must override :meth:`apply` and accept ``p`` and ``seed``
17
+ through ``super().__init__``. The base class handles probabilistic
18
+ application, seeding, and serialisation.
19
+
20
+ Probabilistic application is gated on ``self.rng`` so that two transforms
21
+ in the same :class:`Compose` with the same seed do not share a random
22
+ stream — see :func:`medaugmentx.core.utils.derive_rng`.
23
+
24
+ Example:
25
+
26
+ class MyShift(Transform):
27
+ def __init__(self, max_shift=0.1, p=1.0, seed=None):
28
+ super().__init__(p=p, seed=seed)
29
+ self.max_shift = max_shift
30
+
31
+ def apply(self, volume):
32
+ delta = self.rng.uniform(-self.max_shift, self.max_shift)
33
+ return volume.replace(image=volume.image + delta)
34
+ """
35
+
36
+ def __init__(self, p: float = 1.0, seed: SeedLike = None) -> None:
37
+ if not 0.0 <= float(p) <= 1.0:
38
+ raise ValueError(f"p must be in [0, 1], got {p}")
39
+ self.p: float = float(p)
40
+ # Store the seed for serialisation (int or None only; Generator can't round-trip).
41
+ self._seed: int | None = seed if isinstance(seed, (int, type(None))) else None
42
+ self.rng: np.random.Generator = resolve_rng(seed)
43
+
44
+ def __call__(self, volume: MedVolume) -> MedVolume:
45
+ if not isinstance(volume, MedVolume):
46
+ raise TypeError(f"Transform expects a MedVolume, got {type(volume).__name__}")
47
+ if self.p < 1.0 and self.rng.random() >= self.p:
48
+ return volume
49
+ return self.apply(volume)
50
+
51
+ @abstractmethod
52
+ def apply(self, volume: MedVolume) -> MedVolume:
53
+ """Perform the transform unconditionally — already past probability gate."""
54
+
55
+ def set_rng(self, rng: np.random.Generator) -> None:
56
+ """Reseed this transform with a specific :class:`numpy.random.Generator`.
57
+
58
+ Used by :class:`Compose` to give each child its own deterministic stream.
59
+ """
60
+ if not isinstance(rng, np.random.Generator):
61
+ raise TypeError("rng must be a numpy.random.Generator")
62
+ self.rng = rng
63
+
64
+ def __repr__(self) -> str:
65
+ attrs = ", ".join(
66
+ f"{k}={v!r}"
67
+ for k, v in self.__dict__.items()
68
+ if k != "rng" and not k.startswith("_")
69
+ )
70
+ return f"{self.__class__.__name__}({attrs})"
71
+
72
+ def to_dict(self) -> dict[str, Any]:
73
+ """Best-effort dictionary form of this transform's parameters.
74
+
75
+ Phase 1 ships only this introspection helper; full YAML serialisation
76
+ and round-tripping arrive in Phase 2.
77
+ """
78
+ params = {
79
+ k: v for k, v in self.__dict__.items() if k != "rng" and not k.startswith("_")
80
+ }
81
+ return {"name": self.__class__.__name__, "params": params}
@@ -0,0 +1,195 @@
1
+ """Pipeline builders: Compose, OneOf, SomeOf."""
2
+ from __future__ import annotations
3
+
4
+ from collections.abc import Iterable, Sequence
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+ from medaugmentx.core.base import Transform
10
+ from medaugmentx.core.utils import SeedLike, derive_rng
11
+ from medaugmentx.core.volume import MedVolume
12
+
13
+
14
+ class Compose(Transform):
15
+ """Apply transforms sequentially.
16
+
17
+ All children share a deterministic seeding chain derived from the
18
+ top-level seed, so ``Compose([...], seed=42)`` produces the same output
19
+ every time, on every machine, for the same NumPy version.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ transforms: Iterable[Transform],
25
+ p: float = 1.0,
26
+ seed: SeedLike = None,
27
+ ) -> None:
28
+ super().__init__(p=p, seed=seed)
29
+ self.transforms: list[Transform] = list(transforms)
30
+ for t in self.transforms:
31
+ if not isinstance(t, Transform):
32
+ raise TypeError(
33
+ f"Compose expected Transform instances, got {type(t).__name__}"
34
+ )
35
+ self._reseed_children()
36
+
37
+ def _reseed_children(self) -> None:
38
+ if not self.transforms:
39
+ return
40
+ for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
41
+ t.set_rng(child_rng)
42
+
43
+ def set_rng(self, rng: np.random.Generator) -> None:
44
+ super().set_rng(rng)
45
+ self._reseed_children()
46
+
47
+ def apply(self, volume: MedVolume) -> MedVolume:
48
+ out = volume
49
+ for t in self.transforms:
50
+ out = t(out)
51
+ return out
52
+
53
+ def __len__(self) -> int:
54
+ return len(self.transforms)
55
+
56
+ def __iter__(self):
57
+ return iter(self.transforms)
58
+
59
+ def to_dict(self) -> dict[str, Any]:
60
+ return {
61
+ "name": self.__class__.__name__,
62
+ "params": {
63
+ "transforms": [t.to_dict() for t in self.transforms],
64
+ "p": self.p,
65
+ "seed": self._seed,
66
+ },
67
+ }
68
+
69
+
70
+ class OneOf(Transform):
71
+ """Pick exactly one child uniformly at random and apply it.
72
+
73
+ The container's ``p`` controls whether *any* child runs at all. When
74
+ weights are provided they are normalised; otherwise the choice is uniform.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ transforms: Sequence[Transform],
80
+ weights: Sequence[float] | None = None,
81
+ p: float = 1.0,
82
+ seed: SeedLike = None,
83
+ ) -> None:
84
+ super().__init__(p=p, seed=seed)
85
+ self.transforms: list[Transform] = list(transforms)
86
+ if not self.transforms:
87
+ raise ValueError("OneOf requires at least one transform")
88
+ for t in self.transforms:
89
+ if not isinstance(t, Transform):
90
+ raise TypeError(f"OneOf expected Transform, got {type(t).__name__}")
91
+
92
+ if weights is None:
93
+ self.weights = np.full(len(self.transforms), 1.0 / len(self.transforms))
94
+ else:
95
+ w = np.asarray(weights, dtype=np.float64)
96
+ if w.shape != (len(self.transforms),):
97
+ raise ValueError("weights length must match number of transforms")
98
+ if (w < 0).any() or w.sum() <= 0:
99
+ raise ValueError("weights must be non-negative and sum to > 0")
100
+ self.weights = w / w.sum()
101
+
102
+ self._reseed_children()
103
+
104
+ def _reseed_children(self) -> None:
105
+ for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
106
+ t.set_rng(child_rng)
107
+
108
+ def set_rng(self, rng: np.random.Generator) -> None:
109
+ super().set_rng(rng)
110
+ self._reseed_children()
111
+
112
+ def apply(self, volume: MedVolume) -> MedVolume:
113
+ idx = int(self.rng.choice(len(self.transforms), p=self.weights))
114
+ # Force the chosen child to run regardless of its own ``p``.
115
+ return self.transforms[idx].apply(volume)
116
+
117
+ def to_dict(self) -> dict[str, Any]:
118
+ return {
119
+ "name": self.__class__.__name__,
120
+ "params": {
121
+ "transforms": [t.to_dict() for t in self.transforms],
122
+ "weights": self.weights.tolist(),
123
+ "p": self.p,
124
+ "seed": self._seed,
125
+ },
126
+ }
127
+
128
+
129
+ class SomeOf(Transform):
130
+ """Pick ``n`` children at random (without replacement) and apply them in order.
131
+
132
+ ``n`` may be an int or a ``(low, high)`` inclusive range — when a range,
133
+ a value is sampled per call.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ transforms: Sequence[Transform],
139
+ n: int | tuple[int, int] = 1,
140
+ p: float = 1.0,
141
+ seed: SeedLike = None,
142
+ ) -> None:
143
+ super().__init__(p=p, seed=seed)
144
+ self.transforms: list[Transform] = list(transforms)
145
+ if not self.transforms:
146
+ raise ValueError("SomeOf requires at least one transform")
147
+ for t in self.transforms:
148
+ if not isinstance(t, Transform):
149
+ raise TypeError(f"SomeOf expected Transform, got {type(t).__name__}")
150
+
151
+ if isinstance(n, int):
152
+ lo, hi = n, n
153
+ else:
154
+ lo, hi = int(n[0]), int(n[1])
155
+ if not 0 <= lo <= hi <= len(self.transforms):
156
+ raise ValueError(f"n={n} invalid for {len(self.transforms)} transforms")
157
+ self.n_range: tuple[int, int] = (lo, hi)
158
+
159
+ self._reseed_children()
160
+
161
+ def _reseed_children(self) -> None:
162
+ for t, child_rng in zip(self.transforms, derive_rng(self.rng, len(self.transforms))):
163
+ t.set_rng(child_rng)
164
+
165
+ def set_rng(self, rng: np.random.Generator) -> None:
166
+ super().set_rng(rng)
167
+ self._reseed_children()
168
+
169
+ def apply(self, volume: MedVolume) -> MedVolume:
170
+ lo, hi = self.n_range
171
+ n = int(self.rng.integers(lo, hi + 1))
172
+ if n == 0:
173
+ return volume
174
+ idxs = self.rng.choice(len(self.transforms), size=n, replace=False)
175
+ idxs.sort()
176
+ out = volume
177
+ for i in idxs:
178
+ out = self.transforms[int(i)].apply(out)
179
+ return out
180
+
181
+ def to_dict(self) -> dict[str, Any]:
182
+ lo, hi = self.n_range
183
+ n: Any = lo if lo == hi else list(self.n_range)
184
+ return {
185
+ "name": self.__class__.__name__,
186
+ "params": {
187
+ "transforms": [t.to_dict() for t in self.transforms],
188
+ "n": n,
189
+ "p": self.p,
190
+ "seed": self._seed,
191
+ },
192
+ }
193
+
194
+
195
+ __all__ = ["Compose", "OneOf", "SomeOf"]
@@ -0,0 +1,87 @@
1
+ """Small helpers shared across the library."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+
8
+ SeedLike = Union[int, np.random.Generator, None]
9
+
10
+
11
+ def resolve_rng(seed: SeedLike) -> np.random.Generator:
12
+ """Return a ``numpy.random.Generator`` from any accepted seed input.
13
+
14
+ Accepts ``None`` (fresh entropy), an ``int`` seed, or an existing
15
+ ``Generator`` (returned as-is for chaining).
16
+ """
17
+ if isinstance(seed, np.random.Generator):
18
+ return seed
19
+ return np.random.default_rng(seed)
20
+
21
+
22
+ def derive_rng(rng: np.random.Generator, n: int) -> list[np.random.Generator]:
23
+ """Spawn ``n`` independent generators from ``rng`` deterministically.
24
+
25
+ Used by :class:`~medaugmentx.core.compose.Compose` to give each child
26
+ transform its own stream while keeping the whole pipeline reproducible
27
+ from a single top-level seed.
28
+ """
29
+ seeds = rng.integers(0, np.iinfo(np.uint64).max, size=n, dtype=np.uint64, endpoint=False)
30
+ return [np.random.default_rng(int(s)) for s in seeds]
31
+
32
+
33
+ def as_float32(image: np.ndarray) -> np.ndarray:
34
+ """Cast to ``float32`` only when needed; cheap no-op otherwise."""
35
+ if image.dtype == np.float32:
36
+ return image
37
+ return image.astype(np.float32, copy=False)
38
+
39
+
40
+ def normalize_axes(axes: int | tuple | list | None, ndim: int) -> tuple:
41
+ """Normalise an ``axes`` argument to a sorted tuple of non-negative ints.
42
+
43
+ ``None`` expands to all axes. Negative axes are wrapped relative to ndim.
44
+ """
45
+ if axes is None:
46
+ return tuple(range(ndim))
47
+ if isinstance(axes, int):
48
+ axes = (axes,)
49
+ out: list[int] = []
50
+ for a in axes:
51
+ ax = int(a)
52
+ if ax < 0:
53
+ ax += ndim
54
+ if not 0 <= ax < ndim:
55
+ raise ValueError(f"axis {a} out of range for ndim={ndim}")
56
+ out.append(ax)
57
+ return tuple(sorted(set(out)))
58
+
59
+
60
+ def axis_label_to_index(label: str, ndim: int) -> int:
61
+ """Map a friendly axis label (``"x"``, ``"y"``, ``"z"``) to a NumPy axis.
62
+
63
+ Convention used throughout the library:
64
+
65
+ - 3D arrays are stored as ``(D, H, W)`` — i.e. ``(z, y, x)``.
66
+ - 2D arrays are stored as ``(H, W)`` — i.e. ``(y, x)``.
67
+
68
+ So for 3D ``"z"`` -> 0, ``"y"`` -> 1, ``"x"`` -> 2; for 2D ``"y"`` -> 0,
69
+ ``"x"`` -> 1. ``"z"`` is invalid for 2D arrays.
70
+ """
71
+ label = label.lower()
72
+ if ndim == 3:
73
+ mapping = {"z": 0, "y": 1, "x": 2}
74
+ elif ndim == 2:
75
+ mapping = {"y": 0, "x": 1}
76
+ else:
77
+ raise ValueError(f"Only 2D or 3D supported, got ndim={ndim}")
78
+ if label not in mapping:
79
+ raise ValueError(f"Unknown axis label {label!r} for ndim={ndim}")
80
+ return mapping[label]
81
+
82
+
83
+ def clip_intensity(image: np.ndarray, lo: float | None = None, hi: float | None = None) -> np.ndarray:
84
+ """Clip in-place if writeable, otherwise return a clipped copy."""
85
+ if lo is None and hi is None:
86
+ return image
87
+ return np.clip(image, lo, hi)
@@ -0,0 +1,117 @@
1
+ """The MedVolume container — image + optional mask + spacing + metadata."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+
9
+
10
+ @dataclass
11
+ class MedVolume:
12
+ """A single medical image (2D or 3D) with optional segmentation mask.
13
+
14
+ All transforms in MedAugment operate on this container so that masks and
15
+ metadata stay in lockstep with the image array.
16
+
17
+ Attributes:
18
+ image: 2D ``(H, W)`` or 3D ``(D, H, W)`` array. Recommended dtype is
19
+ ``float32``; integer inputs are accepted but will be cast where
20
+ arithmetic is required.
21
+ mask: Optional integer label map with the same shape as ``image``.
22
+ spacing: Voxel size in millimetres, one entry per spatial axis.
23
+ For 3D volumes this is ``(slice_thickness, row_mm, col_mm)``.
24
+ metadata: Free-form dictionary. Conventional keys: ``modality``
25
+ (``"MR" | "CT" | "DX" | "DBT"``), ``vendor``, ``patient_id``,
26
+ ``original_dtype``.
27
+ """
28
+
29
+ image: np.ndarray
30
+ mask: np.ndarray | None = None
31
+ spacing: tuple[float, ...] = ()
32
+ metadata: dict[str, Any] = field(default_factory=dict)
33
+
34
+ def __post_init__(self) -> None:
35
+ if not isinstance(self.image, np.ndarray):
36
+ raise TypeError(f"image must be a numpy.ndarray, got {type(self.image).__name__}")
37
+ if self.image.ndim not in (2, 3):
38
+ raise ValueError(f"image must be 2D or 3D; got shape {self.image.shape}")
39
+
40
+ if self.mask is not None:
41
+ if not isinstance(self.mask, np.ndarray):
42
+ raise TypeError("mask must be a numpy.ndarray or None")
43
+ if self.mask.shape != self.image.shape:
44
+ raise ValueError(
45
+ f"mask shape {self.mask.shape} does not match image shape {self.image.shape}"
46
+ )
47
+
48
+ if self.spacing:
49
+ if len(self.spacing) != self.image.ndim:
50
+ raise ValueError(
51
+ f"spacing has {len(self.spacing)} entries but image is {self.image.ndim}D"
52
+ )
53
+ self.spacing = tuple(float(s) for s in self.spacing)
54
+ else:
55
+ self.spacing = tuple(1.0 for _ in range(self.image.ndim))
56
+
57
+ if not isinstance(self.metadata, dict):
58
+ raise TypeError("metadata must be a dict")
59
+
60
+ @property
61
+ def ndim(self) -> int:
62
+ return int(self.image.ndim)
63
+
64
+ @property
65
+ def shape(self) -> tuple[int, ...]:
66
+ return tuple(self.image.shape)
67
+
68
+ @property
69
+ def is_3d(self) -> bool:
70
+ return self.image.ndim == 3
71
+
72
+ @property
73
+ def has_mask(self) -> bool:
74
+ return self.mask is not None
75
+
76
+ @property
77
+ def modality(self) -> str | None:
78
+ return self.metadata.get("modality")
79
+
80
+ def replace(
81
+ self,
82
+ *,
83
+ image: np.ndarray | None = None,
84
+ mask: np.ndarray | None = None,
85
+ spacing: tuple[float, ...] | None = None,
86
+ metadata: dict[str, Any] | None = None,
87
+ ) -> MedVolume:
88
+ """Return a new MedVolume with selected fields swapped out.
89
+
90
+ Use ``mask=...`` only to provide a new mask; pass ``mask=None`` and
91
+ rely on the existing one by omitting the keyword. Metadata is shallow-
92
+ copied to avoid silent aliasing across volumes.
93
+ """
94
+ return replace(
95
+ self,
96
+ image=self.image if image is None else image,
97
+ mask=self.mask if mask is None else mask,
98
+ spacing=self.spacing if spacing is None else tuple(float(s) for s in spacing),
99
+ metadata=dict(self.metadata if metadata is None else metadata),
100
+ )
101
+
102
+ def copy(self) -> MedVolume:
103
+ """Deep copy of the underlying arrays and metadata."""
104
+ return MedVolume(
105
+ image=self.image.copy(),
106
+ mask=None if self.mask is None else self.mask.copy(),
107
+ spacing=tuple(self.spacing),
108
+ metadata=dict(self.metadata),
109
+ )
110
+
111
+ def __repr__(self) -> str:
112
+ mask_repr = "None" if self.mask is None else f"shape={self.mask.shape}, dtype={self.mask.dtype}"
113
+ return (
114
+ f"MedVolume(image=shape={self.image.shape}, dtype={self.image.dtype}, "
115
+ f"mask={mask_repr}, spacing={self.spacing}, "
116
+ f"modality={self.modality!r})"
117
+ )
@@ -0,0 +1,18 @@
1
+ """Unified I/O for medical image formats.
2
+
3
+ Each loader returns a :class:`~medaugmentx.core.volume.MedVolume` with
4
+ ``spacing`` populated in millimetres and ``metadata`` carrying the
5
+ modality and any vendor-specific information that callers may need.
6
+
7
+ Optional dependencies:
8
+
9
+ - DICOM I/O requires ``pydicom`` (``pip install medaugmentx[dicom]``).
10
+ - NIfTI I/O requires ``nibabel`` (``pip install medaugmentx[nifti]``).
11
+
12
+ If a backend is missing the loader raises a clear :class:`ImportError`
13
+ when called, not at import time.
14
+ """
15
+ from medaugmentx.io.dicom import load_dicom_series
16
+ from medaugmentx.io.nifti import load_nifti, save_nifti
17
+
18
+ __all__ = ["load_dicom_series", "load_nifti", "save_nifti"]