radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""TorchIO integration for RadiObject."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch.utils.data import Dataset
|
|
9
|
+
|
|
10
|
+
from radiobject._types import LabelSource
|
|
11
|
+
from radiobject.ml.utils.labels import load_labels
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from radiobject.volume_collection import VolumeCollection
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import torchio as tio
|
|
18
|
+
|
|
19
|
+
HAS_TORCHIO = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
HAS_TORCHIO = False
|
|
22
|
+
tio = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _require_torchio() -> None:
|
|
26
|
+
"""Raise ImportError if TorchIO not installed."""
|
|
27
|
+
if not HAS_TORCHIO:
|
|
28
|
+
raise ImportError("TorchIO required. Install with: pip install radiobject[torchio]")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VolumeCollectionSubjectsDataset(Dataset):
|
|
32
|
+
"""TorchIO-compatible dataset yielding Subject objects from VolumeCollection(s)."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
37
|
+
labels: LabelSource = None,
|
|
38
|
+
transform: Any | None = None,
|
|
39
|
+
):
|
|
40
|
+
"""Initialize TorchIO-compatible dataset.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
collections: Single VolumeCollection or sequence of collections.
|
|
44
|
+
Each collection becomes a separate image in the Subject.
|
|
45
|
+
labels: Label source. Can be:
|
|
46
|
+
- str: Column name in collection's obs DataFrame
|
|
47
|
+
- pd.DataFrame: With obs_id as column/index and label values
|
|
48
|
+
- dict[str, Any]: Mapping from obs_id to label
|
|
49
|
+
- Callable[[str], Any]: Function taking obs_id, returning label
|
|
50
|
+
- None: No labels
|
|
51
|
+
transform: TorchIO transform (e.g., tio.Compose) applied to each Subject.
|
|
52
|
+
"""
|
|
53
|
+
_require_torchio()
|
|
54
|
+
|
|
55
|
+
# Normalize to list
|
|
56
|
+
if not isinstance(collections, Sequence):
|
|
57
|
+
collections = [collections]
|
|
58
|
+
if not collections:
|
|
59
|
+
raise ValueError("At least one collection required")
|
|
60
|
+
|
|
61
|
+
self._collections = list(collections)
|
|
62
|
+
self._collection_names = [c.name or f"collection_{i}" for i, c in enumerate(collections)]
|
|
63
|
+
self._transform = transform
|
|
64
|
+
|
|
65
|
+
first_coll = self._collections[0]
|
|
66
|
+
self._n_subjects = len(first_coll)
|
|
67
|
+
|
|
68
|
+
# Load labels from first collection's obs
|
|
69
|
+
self._labels: dict[int, Any] | None = None
|
|
70
|
+
if labels is not None:
|
|
71
|
+
obs_df = first_coll.obs.read() if isinstance(labels, str) else None
|
|
72
|
+
self._labels = load_labels(first_coll, labels, obs_df)
|
|
73
|
+
|
|
74
|
+
def __len__(self) -> int:
|
|
75
|
+
return self._n_subjects
|
|
76
|
+
|
|
77
|
+
def __getitem__(self, idx: int) -> "tio.Subject":
|
|
78
|
+
"""Return TorchIO Subject with images for all collections."""
|
|
79
|
+
subject_dict: dict[str, Any] = {}
|
|
80
|
+
|
|
81
|
+
for name, coll in zip(self._collection_names, self._collections):
|
|
82
|
+
data = coll.iloc[idx].to_numpy()
|
|
83
|
+
tensor = torch.from_numpy(data).unsqueeze(0).float()
|
|
84
|
+
subject_dict[name] = tio.ScalarImage(tensor=tensor)
|
|
85
|
+
|
|
86
|
+
if self._labels is not None and idx in self._labels:
|
|
87
|
+
subject_dict["label"] = self._labels[idx]
|
|
88
|
+
|
|
89
|
+
subject = tio.Subject(subject_dict)
|
|
90
|
+
|
|
91
|
+
if self._transform:
|
|
92
|
+
subject = self._transform(subject)
|
|
93
|
+
|
|
94
|
+
return subject
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def collection_names(self) -> list[str]:
|
|
98
|
+
"""Names of collections in each Subject."""
|
|
99
|
+
return self._collection_names
|
radiobject/ml/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Configuration models for ML training pipeline."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Self
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, field_validator, model_validator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LoadingMode(str, Enum):
|
|
10
|
+
"""Volume loading strategy."""
|
|
11
|
+
|
|
12
|
+
FULL_VOLUME = "full_volume"
|
|
13
|
+
PATCH = "patch"
|
|
14
|
+
SLICE_2D = "slice_2d"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DatasetConfig(BaseModel):
|
|
18
|
+
"""Configuration for RadiObjectDataset."""
|
|
19
|
+
|
|
20
|
+
loading_mode: LoadingMode = LoadingMode.FULL_VOLUME
|
|
21
|
+
patch_size: tuple[int, int, int] | None = None
|
|
22
|
+
patches_per_volume: int = 1
|
|
23
|
+
modalities: list[str] | None = None
|
|
24
|
+
label_column: str | None = None
|
|
25
|
+
value_filter: str | None = None
|
|
26
|
+
|
|
27
|
+
@model_validator(mode="after")
|
|
28
|
+
def validate_patch_config(self) -> Self:
|
|
29
|
+
"""Validate patch configuration consistency."""
|
|
30
|
+
if self.loading_mode == LoadingMode.PATCH and self.patch_size is None:
|
|
31
|
+
raise ValueError("patch_size required when loading_mode is PATCH")
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
@field_validator("patches_per_volume")
|
|
35
|
+
@classmethod
|
|
36
|
+
def validate_patches_per_volume(cls, v: int) -> int:
|
|
37
|
+
"""Ensure patches_per_volume is positive."""
|
|
38
|
+
if v < 1:
|
|
39
|
+
raise ValueError("patches_per_volume must be >= 1")
|
|
40
|
+
return v
|
|
41
|
+
|
|
42
|
+
model_config = {"frozen": True}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""PyTorch Dataset implementations for RadiObject."""
|
|
2
|
+
|
|
3
|
+
from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
|
|
4
|
+
from radiobject.ml.datasets.patch_dataset import GridPatchDataset, PatchVolumeDataset
|
|
5
|
+
from radiobject.ml.datasets.segmentation_dataset import SegmentationDataset
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"VolumeCollectionDataset",
|
|
9
|
+
"GridPatchDataset",
|
|
10
|
+
"PatchVolumeDataset",
|
|
11
|
+
"SegmentationDataset",
|
|
12
|
+
]
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""VolumeCollectionDataset - primary PyTorch Dataset for VolumeCollection(s)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from radiobject._types import LabelSource
|
|
13
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
14
|
+
from radiobject.ml.utils.labels import load_labels
|
|
15
|
+
from radiobject.ml.utils.validation import validate_collection_alignment, validate_uniform_shapes
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from radiobject.volume_collection import VolumeCollection
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class VolumeCollectionDataset(Dataset):
|
|
22
|
+
"""PyTorch Dataset for VolumeCollection(s) - primary ML interface."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
27
|
+
config: DatasetConfig | None = None,
|
|
28
|
+
labels: LabelSource = None,
|
|
29
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
30
|
+
):
|
|
31
|
+
"""Initialize dataset from VolumeCollection(s).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
collections: Single VolumeCollection or sequence of collections.
|
|
35
|
+
Multiple collections are stacked along channel dimension.
|
|
36
|
+
config: Dataset configuration (loading mode, patch size, etc.).
|
|
37
|
+
If None, uses full volume mode.
|
|
38
|
+
labels: Label source. Can be:
|
|
39
|
+
- str: Column name in collection's obs DataFrame
|
|
40
|
+
- pd.DataFrame: With obs_id as column/index and label values
|
|
41
|
+
- dict[str, Any]: Mapping from obs_id to label
|
|
42
|
+
- Callable[[str], Any]: Function taking obs_id, returning label
|
|
43
|
+
- None: No labels
|
|
44
|
+
transform: Transform function applied to each sample dict.
|
|
45
|
+
MONAI dict transforms (e.g., RandFlipd) work directly.
|
|
46
|
+
"""
|
|
47
|
+
self._config = config or DatasetConfig()
|
|
48
|
+
self._transform = transform
|
|
49
|
+
|
|
50
|
+
# Normalize to list
|
|
51
|
+
if not isinstance(collections, Sequence):
|
|
52
|
+
collections = [collections]
|
|
53
|
+
if not collections:
|
|
54
|
+
raise ValueError("At least one collection required")
|
|
55
|
+
|
|
56
|
+
self._collections = list(collections)
|
|
57
|
+
self._collection_names = [c.name or f"collection_{i}" for i, c in enumerate(collections)]
|
|
58
|
+
|
|
59
|
+
# Build dict for validation
|
|
60
|
+
collections_dict: dict[str, VolumeCollection] = {}
|
|
61
|
+
for name, coll in zip(self._collection_names, self._collections):
|
|
62
|
+
collections_dict[name] = coll
|
|
63
|
+
|
|
64
|
+
# Validate alignment if multi-modal
|
|
65
|
+
if len(collections_dict) > 1:
|
|
66
|
+
validate_collection_alignment(collections_dict)
|
|
67
|
+
|
|
68
|
+
# Validate uniform shapes (required for batched loading)
|
|
69
|
+
self._volume_shape = validate_uniform_shapes(collections_dict)
|
|
70
|
+
|
|
71
|
+
first_coll = self._collections[0]
|
|
72
|
+
self._n_volumes = len(first_coll)
|
|
73
|
+
|
|
74
|
+
# Load labels from first collection's obs
|
|
75
|
+
self._labels: dict[int, Any] | None = None
|
|
76
|
+
if labels is not None:
|
|
77
|
+
obs_df = first_coll.obs.read() if isinstance(labels, str) else None
|
|
78
|
+
self._labels = load_labels(first_coll, labels, obs_df)
|
|
79
|
+
|
|
80
|
+
# Compute dataset length based on loading mode
|
|
81
|
+
if self._config.loading_mode == LoadingMode.PATCH:
|
|
82
|
+
self._length = self._n_volumes * self._config.patches_per_volume
|
|
83
|
+
elif self._config.loading_mode == LoadingMode.SLICE_2D:
|
|
84
|
+
self._length = self._n_volumes * self._volume_shape[2]
|
|
85
|
+
else:
|
|
86
|
+
self._length = self._n_volumes
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return self._length
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
92
|
+
if self._config.loading_mode == LoadingMode.PATCH:
|
|
93
|
+
return self._get_patch_item(idx)
|
|
94
|
+
elif self._config.loading_mode == LoadingMode.SLICE_2D:
|
|
95
|
+
return self._get_slice_item(idx)
|
|
96
|
+
else:
|
|
97
|
+
return self._get_full_volume_item(idx)
|
|
98
|
+
|
|
99
|
+
def _get_full_volume_item(self, idx: int) -> dict[str, Any]:
|
|
100
|
+
"""Load full volume for all collections."""
|
|
101
|
+
volumes = [coll.iloc[idx].to_numpy() for coll in self._collections]
|
|
102
|
+
|
|
103
|
+
stacked = np.stack(volumes, axis=0)
|
|
104
|
+
result: dict[str, Any] = {
|
|
105
|
+
"image": torch.from_numpy(stacked),
|
|
106
|
+
"idx": idx,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
self._add_label(result, idx)
|
|
110
|
+
|
|
111
|
+
if self._transform is not None:
|
|
112
|
+
result = self._transform(result)
|
|
113
|
+
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
def _get_patch_item(self, idx: int) -> dict[str, Any]:
|
|
117
|
+
"""Load a random patch from the volume."""
|
|
118
|
+
volume_idx = idx // self._config.patches_per_volume
|
|
119
|
+
patch_idx = idx % self._config.patches_per_volume
|
|
120
|
+
|
|
121
|
+
rng = np.random.default_rng(seed=idx)
|
|
122
|
+
patch_size = self._config.patch_size
|
|
123
|
+
assert patch_size is not None
|
|
124
|
+
|
|
125
|
+
max_start = tuple(max(0, self._volume_shape[i] - patch_size[i]) for i in range(3))
|
|
126
|
+
start = tuple(
|
|
127
|
+
rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
volumes = []
|
|
131
|
+
for coll in self._collections:
|
|
132
|
+
vol = coll.iloc[volume_idx]
|
|
133
|
+
patch = vol.slice(
|
|
134
|
+
slice(start[0], start[0] + patch_size[0]),
|
|
135
|
+
slice(start[1], start[1] + patch_size[1]),
|
|
136
|
+
slice(start[2], start[2] + patch_size[2]),
|
|
137
|
+
)
|
|
138
|
+
volumes.append(patch)
|
|
139
|
+
|
|
140
|
+
stacked = np.stack(volumes, axis=0)
|
|
141
|
+
result: dict[str, Any] = {
|
|
142
|
+
"image": torch.from_numpy(stacked),
|
|
143
|
+
"idx": volume_idx,
|
|
144
|
+
"patch_idx": patch_idx,
|
|
145
|
+
"patch_start": start,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
self._add_label(result, volume_idx)
|
|
149
|
+
|
|
150
|
+
if self._transform is not None:
|
|
151
|
+
result = self._transform(result)
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
def _get_slice_item(self, idx: int) -> dict[str, Any]:
|
|
156
|
+
"""Load a 2D slice from the volume."""
|
|
157
|
+
volume_idx = idx // self._volume_shape[2]
|
|
158
|
+
slice_idx = idx % self._volume_shape[2]
|
|
159
|
+
|
|
160
|
+
slices = [coll.iloc[volume_idx].axial(slice_idx) for coll in self._collections]
|
|
161
|
+
|
|
162
|
+
stacked = np.stack(slices, axis=0)
|
|
163
|
+
result: dict[str, Any] = {
|
|
164
|
+
"image": torch.from_numpy(stacked),
|
|
165
|
+
"idx": volume_idx,
|
|
166
|
+
"slice_idx": slice_idx,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
self._add_label(result, volume_idx)
|
|
170
|
+
|
|
171
|
+
if self._transform is not None:
|
|
172
|
+
result = self._transform(result)
|
|
173
|
+
|
|
174
|
+
return result
|
|
175
|
+
|
|
176
|
+
def _add_label(self, result: dict[str, Any], volume_idx: int) -> None:
|
|
177
|
+
"""Add label column to sample dict from label source."""
|
|
178
|
+
if self._labels is not None and volume_idx in self._labels:
|
|
179
|
+
label = self._labels[volume_idx]
|
|
180
|
+
if isinstance(label, (int, float, np.integer, np.floating)):
|
|
181
|
+
result["label"] = torch.tensor(label)
|
|
182
|
+
else:
|
|
183
|
+
result["label"] = label
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def collection_names(self) -> list[str]:
|
|
187
|
+
"""Names of collections being loaded (channel order)."""
|
|
188
|
+
return self._collection_names
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def volume_shape(self) -> tuple[int, int, int]:
|
|
192
|
+
"""Shape of each volume (X, Y, Z)."""
|
|
193
|
+
return self._volume_shape
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def n_channels(self) -> int:
|
|
197
|
+
"""Number of channels (collections) in output tensors."""
|
|
198
|
+
return len(self._collections)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Multi-modal dataset for loading aligned volumes from multiple VolumeCollections."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from radiobject.radi_object import RadiObject
|
|
13
|
+
from radiobject.volume_collection import VolumeCollection
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MultiModalDataset(Dataset):
|
|
17
|
+
"""Dataset for loading aligned volumes from multiple VolumeCollections."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
radi_object: RadiObject,
|
|
22
|
+
modalities: list[str],
|
|
23
|
+
label_column: str | None = None,
|
|
24
|
+
value_filter: str | None = None,
|
|
25
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
26
|
+
):
|
|
27
|
+
if not modalities:
|
|
28
|
+
raise ValueError("At least one modality required")
|
|
29
|
+
|
|
30
|
+
self._modalities = modalities
|
|
31
|
+
self._transform = transform
|
|
32
|
+
self._radi_object = radi_object
|
|
33
|
+
|
|
34
|
+
self._collections: dict[str, VolumeCollection] = {
|
|
35
|
+
mod: radi_object.collection(mod) for mod in modalities
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
first_coll = self._collections[modalities[0]]
|
|
39
|
+
self._n_volumes = len(first_coll)
|
|
40
|
+
self._volume_shape = first_coll.shape
|
|
41
|
+
|
|
42
|
+
self._validate_alignment()
|
|
43
|
+
|
|
44
|
+
self._labels: dict[int, Any] | None = None
|
|
45
|
+
if label_column:
|
|
46
|
+
self._load_labels(radi_object, label_column, value_filter)
|
|
47
|
+
|
|
48
|
+
def _validate_alignment(self) -> None:
|
|
49
|
+
"""Validate that all modalities have matching subjects."""
|
|
50
|
+
first_mod = self._modalities[0]
|
|
51
|
+
first_coll = self._collections[first_mod]
|
|
52
|
+
|
|
53
|
+
first_subjects = set(first_coll.obs_subject_ids)
|
|
54
|
+
|
|
55
|
+
for mod in self._modalities[1:]:
|
|
56
|
+
coll = self._collections[mod]
|
|
57
|
+
if len(coll) != self._n_volumes:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Modality '{mod}' has {len(coll)} volumes, expected {self._n_volumes}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
mod_subjects = set(coll.obs_subject_ids)
|
|
63
|
+
|
|
64
|
+
if mod_subjects != first_subjects:
|
|
65
|
+
missing = first_subjects - mod_subjects
|
|
66
|
+
extra = mod_subjects - first_subjects
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Subject mismatch for modality '{mod}': "
|
|
69
|
+
f"missing={list(missing)[:3]}, extra={list(extra)[:3]}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _load_labels(
|
|
73
|
+
self,
|
|
74
|
+
radi_object: RadiObject,
|
|
75
|
+
label_column: str,
|
|
76
|
+
value_filter: str | None,
|
|
77
|
+
) -> None:
|
|
78
|
+
"""Load labels from obs_meta."""
|
|
79
|
+
obs_meta = radi_object.obs_meta.read(value_filter=value_filter)
|
|
80
|
+
if label_column not in obs_meta.columns:
|
|
81
|
+
raise ValueError(f"Label column '{label_column}' not found")
|
|
82
|
+
|
|
83
|
+
first_coll = self._collections[self._modalities[0]]
|
|
84
|
+
obs_subject_ids = first_coll.obs_subject_ids
|
|
85
|
+
self._labels = {}
|
|
86
|
+
for idx in range(self._n_volumes):
|
|
87
|
+
subject_id = obs_subject_ids[idx]
|
|
88
|
+
match = obs_meta[obs_meta["obs_subject_id"] == subject_id]
|
|
89
|
+
if len(match) > 0:
|
|
90
|
+
self._labels[idx] = match[label_column].iloc[0]
|
|
91
|
+
|
|
92
|
+
def __len__(self) -> int:
|
|
93
|
+
return self._n_volumes
|
|
94
|
+
|
|
95
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
96
|
+
volumes = [self._collections[mod].iloc[idx].to_numpy() for mod in self._modalities]
|
|
97
|
+
|
|
98
|
+
stacked = np.stack(volumes, axis=0)
|
|
99
|
+
|
|
100
|
+
first_coll = self._collections[self._modalities[0]]
|
|
101
|
+
obs_id = first_coll.obs_ids[idx]
|
|
102
|
+
|
|
103
|
+
result: dict[str, Any] = {
|
|
104
|
+
"image": torch.from_numpy(stacked),
|
|
105
|
+
"idx": idx,
|
|
106
|
+
"obs_id": obs_id,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
if self._labels is not None and idx in self._labels:
|
|
110
|
+
label = self._labels[idx]
|
|
111
|
+
if isinstance(label, (int, float, np.integer, np.floating)):
|
|
112
|
+
result["label"] = torch.tensor(label)
|
|
113
|
+
else:
|
|
114
|
+
result["label"] = label
|
|
115
|
+
|
|
116
|
+
if self._transform is not None:
|
|
117
|
+
result = self._transform(result)
|
|
118
|
+
|
|
119
|
+
return result
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def modalities(self) -> list[str]:
|
|
123
|
+
"""List of modalities."""
|
|
124
|
+
return self._modalities
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def volume_shape(self) -> tuple[int, int, int]:
|
|
128
|
+
"""Volume dimensions."""
|
|
129
|
+
return self._volume_shape
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Specialized patch extraction dataset."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from radiobject.volume_collection import VolumeCollection
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PatchVolumeDataset(Dataset):
|
|
17
|
+
"""Dataset for extracting patches from a single VolumeCollection."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
collection: VolumeCollection,
|
|
22
|
+
patch_size: tuple[int, int, int],
|
|
23
|
+
patches_per_volume: int = 1,
|
|
24
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
25
|
+
):
|
|
26
|
+
self._collection = collection
|
|
27
|
+
self._patch_size = patch_size
|
|
28
|
+
self._patches_per_volume = patches_per_volume
|
|
29
|
+
self._transform = transform
|
|
30
|
+
|
|
31
|
+
self._obs_ids = collection.obs_ids
|
|
32
|
+
self._n_volumes = len(self._obs_ids)
|
|
33
|
+
self._volume_shape = collection.shape
|
|
34
|
+
self._length = self._n_volumes * patches_per_volume
|
|
35
|
+
|
|
36
|
+
if self._volume_shape is None:
|
|
37
|
+
raise ValueError("Collection must have uniform shape for patch extraction")
|
|
38
|
+
|
|
39
|
+
for i, dim in enumerate(patch_size):
|
|
40
|
+
if dim > self._volume_shape[i]:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"Patch dimension {i} ({dim}) exceeds volume dimension ({self._volume_shape[i]})"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def __len__(self) -> int:
|
|
46
|
+
return self._length
|
|
47
|
+
|
|
48
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
49
|
+
volume_idx = idx // self._patches_per_volume
|
|
50
|
+
patch_idx = idx % self._patches_per_volume
|
|
51
|
+
|
|
52
|
+
rng = np.random.default_rng(seed=idx)
|
|
53
|
+
|
|
54
|
+
max_start = tuple(max(0, self._volume_shape[i] - self._patch_size[i]) for i in range(3))
|
|
55
|
+
start = tuple(
|
|
56
|
+
rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
vol = self._collection.iloc[volume_idx]
|
|
60
|
+
data = vol.slice(
|
|
61
|
+
slice(start[0], start[0] + self._patch_size[0]),
|
|
62
|
+
slice(start[1], start[1] + self._patch_size[1]),
|
|
63
|
+
slice(start[2], start[2] + self._patch_size[2]),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
result: dict[str, Any] = {
|
|
67
|
+
"image": torch.from_numpy(data).unsqueeze(0),
|
|
68
|
+
"idx": volume_idx,
|
|
69
|
+
"patch_idx": patch_idx,
|
|
70
|
+
"patch_start": start,
|
|
71
|
+
"obs_id": self._obs_ids[volume_idx],
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
if self._transform is not None:
|
|
75
|
+
result = self._transform(result)
|
|
76
|
+
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def volume_shape(self) -> tuple[int, int, int]:
|
|
81
|
+
"""Shape of each volume."""
|
|
82
|
+
return self._volume_shape
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def patch_size(self) -> tuple[int, int, int]:
|
|
86
|
+
"""Patch dimensions."""
|
|
87
|
+
return self._patch_size
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class GridPatchDataset(Dataset):
|
|
91
|
+
"""Dataset for extracting patches on a regular grid (for inference)."""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
collection: VolumeCollection,
|
|
96
|
+
patch_size: tuple[int, int, int],
|
|
97
|
+
stride: tuple[int, int, int] | None = None,
|
|
98
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
99
|
+
):
|
|
100
|
+
self._collection = collection
|
|
101
|
+
self._patch_size = patch_size
|
|
102
|
+
self._stride = stride or patch_size
|
|
103
|
+
self._transform = transform
|
|
104
|
+
|
|
105
|
+
self._obs_ids = collection.obs_ids
|
|
106
|
+
self._n_volumes = len(self._obs_ids)
|
|
107
|
+
self._volume_shape = collection.shape
|
|
108
|
+
|
|
109
|
+
if self._volume_shape is None:
|
|
110
|
+
raise ValueError("Collection must have uniform shape for grid patch extraction")
|
|
111
|
+
|
|
112
|
+
self._grid_positions = self._compute_grid_positions()
|
|
113
|
+
self._patches_per_volume = len(self._grid_positions)
|
|
114
|
+
self._length = self._n_volumes * self._patches_per_volume
|
|
115
|
+
|
|
116
|
+
def _compute_grid_positions(self) -> list[tuple[int, int, int]]:
|
|
117
|
+
"""Compute grid patch positions for inference."""
|
|
118
|
+
positions = []
|
|
119
|
+
for x in range(0, self._volume_shape[0] - self._patch_size[0] + 1, self._stride[0]):
|
|
120
|
+
for y in range(0, self._volume_shape[1] - self._patch_size[1] + 1, self._stride[1]):
|
|
121
|
+
for z in range(0, self._volume_shape[2] - self._patch_size[2] + 1, self._stride[2]):
|
|
122
|
+
positions.append((x, y, z))
|
|
123
|
+
if not positions:
|
|
124
|
+
positions.append((0, 0, 0))
|
|
125
|
+
return positions
|
|
126
|
+
|
|
127
|
+
def __len__(self) -> int:
|
|
128
|
+
return self._length
|
|
129
|
+
|
|
130
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
131
|
+
volume_idx = idx // self._patches_per_volume
|
|
132
|
+
patch_idx = idx % self._patches_per_volume
|
|
133
|
+
start = self._grid_positions[patch_idx]
|
|
134
|
+
|
|
135
|
+
vol = self._collection.iloc[volume_idx]
|
|
136
|
+
data = vol.slice(
|
|
137
|
+
slice(start[0], start[0] + self._patch_size[0]),
|
|
138
|
+
slice(start[1], start[1] + self._patch_size[1]),
|
|
139
|
+
slice(start[2], start[2] + self._patch_size[2]),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
result: dict[str, Any] = {
|
|
143
|
+
"image": torch.from_numpy(data).unsqueeze(0),
|
|
144
|
+
"idx": volume_idx,
|
|
145
|
+
"patch_idx": patch_idx,
|
|
146
|
+
"patch_start": start,
|
|
147
|
+
"obs_id": self._obs_ids[volume_idx],
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if self._transform is not None:
|
|
151
|
+
result = self._transform(result)
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def grid_positions(self) -> list[tuple[int, int, int]]:
|
|
157
|
+
"""All patch start positions in the grid."""
|
|
158
|
+
return self._grid_positions
|