radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""SegmentationDataset - specialized dataset for image/mask segmentation training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
13
|
+
from radiobject.ml.utils.validation import validate_collection_alignment, validate_uniform_shapes
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from radiobject.volume_collection import VolumeCollection
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SegmentationDataset(Dataset):
|
|
20
|
+
"""PyTorch Dataset for segmentation training with explicit image/mask separation."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
image: VolumeCollection,
|
|
25
|
+
mask: VolumeCollection,
|
|
26
|
+
config: DatasetConfig | None = None,
|
|
27
|
+
image_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
28
|
+
spatial_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
29
|
+
foreground_sampling: bool = False,
|
|
30
|
+
foreground_threshold: float = 0.01,
|
|
31
|
+
foreground_max_retries: int = 10,
|
|
32
|
+
):
|
|
33
|
+
"""Initialize segmentation dataset.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
image: VolumeCollection containing input images (CT, MRI, etc.).
|
|
37
|
+
mask: VolumeCollection containing segmentation masks.
|
|
38
|
+
config: Dataset configuration (loading mode, patch size, etc.).
|
|
39
|
+
image_transform: Transform applied to image only (e.g., normalization).
|
|
40
|
+
Should operate on keys=["image"].
|
|
41
|
+
spatial_transform: Transform applied to both image and mask (e.g., flips).
|
|
42
|
+
Should operate on keys=["image", "mask"].
|
|
43
|
+
foreground_sampling: If True, bias patch sampling toward regions with
|
|
44
|
+
foreground (non-zero mask values).
|
|
45
|
+
foreground_threshold: Minimum fraction of foreground voxels in patch
|
|
46
|
+
when foreground_sampling is enabled.
|
|
47
|
+
foreground_max_retries: Maximum random attempts before accepting any patch.
|
|
48
|
+
"""
|
|
49
|
+
self._config = config or DatasetConfig()
|
|
50
|
+
self._image_transform = image_transform
|
|
51
|
+
self._spatial_transform = spatial_transform
|
|
52
|
+
self._foreground_sampling = foreground_sampling
|
|
53
|
+
self._foreground_threshold = foreground_threshold
|
|
54
|
+
self._foreground_max_retries = foreground_max_retries
|
|
55
|
+
|
|
56
|
+
# Store collections directly
|
|
57
|
+
self._image = image
|
|
58
|
+
self._mask = mask
|
|
59
|
+
|
|
60
|
+
# Cache obs_ids and obs_subject_ids for fast access
|
|
61
|
+
self._obs_ids = image.obs_ids
|
|
62
|
+
self._obs_subject_ids = image.obs_subject_ids
|
|
63
|
+
|
|
64
|
+
# Validate alignment between image and mask collections
|
|
65
|
+
collections = {"image": self._image, "mask": self._mask}
|
|
66
|
+
validate_collection_alignment(collections)
|
|
67
|
+
|
|
68
|
+
# Validate uniform shapes
|
|
69
|
+
self._volume_shape = validate_uniform_shapes(collections)
|
|
70
|
+
self._n_volumes = len(self._image)
|
|
71
|
+
|
|
72
|
+
# Compute dataset length
|
|
73
|
+
if self._config.loading_mode == LoadingMode.PATCH:
|
|
74
|
+
self._length = self._n_volumes * self._config.patches_per_volume
|
|
75
|
+
elif self._config.loading_mode == LoadingMode.SLICE_2D:
|
|
76
|
+
self._length = self._n_volumes * self._volume_shape[2]
|
|
77
|
+
else:
|
|
78
|
+
self._length = self._n_volumes
|
|
79
|
+
|
|
80
|
+
def __len__(self) -> int:
|
|
81
|
+
return self._length
|
|
82
|
+
|
|
83
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
84
|
+
if self._config.loading_mode == LoadingMode.PATCH:
|
|
85
|
+
return self._get_patch_item(idx)
|
|
86
|
+
elif self._config.loading_mode == LoadingMode.SLICE_2D:
|
|
87
|
+
return self._get_slice_item(idx)
|
|
88
|
+
else:
|
|
89
|
+
return self._get_full_volume_item(idx)
|
|
90
|
+
|
|
91
|
+
def _get_full_volume_item(self, idx: int) -> dict[str, Any]:
|
|
92
|
+
"""Load full volume for image and mask."""
|
|
93
|
+
image_data = self._image.iloc[idx].to_numpy()
|
|
94
|
+
mask_data = self._mask.iloc[idx].to_numpy()
|
|
95
|
+
|
|
96
|
+
result: dict[str, Any] = {
|
|
97
|
+
"image": torch.from_numpy(image_data).unsqueeze(0),
|
|
98
|
+
"mask": torch.from_numpy(mask_data).unsqueeze(0),
|
|
99
|
+
"idx": idx,
|
|
100
|
+
"obs_id": self._obs_ids[idx],
|
|
101
|
+
"obs_subject_id": self._obs_subject_ids[idx],
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
return self._apply_transforms(result)
|
|
105
|
+
|
|
106
|
+
def _get_patch_item(self, idx: int) -> dict[str, Any]:
|
|
107
|
+
"""Load a random patch from image and mask."""
|
|
108
|
+
volume_idx = idx // self._config.patches_per_volume
|
|
109
|
+
patch_idx = idx % self._config.patches_per_volume
|
|
110
|
+
|
|
111
|
+
patch_size = self._config.patch_size
|
|
112
|
+
assert patch_size is not None
|
|
113
|
+
|
|
114
|
+
max_start = tuple(max(0, self._volume_shape[i] - patch_size[i]) for i in range(3))
|
|
115
|
+
|
|
116
|
+
if self._foreground_sampling:
|
|
117
|
+
# Try to find a patch with sufficient foreground
|
|
118
|
+
start = self._sample_foreground_patch(volume_idx, max_start, patch_size, idx)
|
|
119
|
+
else:
|
|
120
|
+
rng = np.random.default_rng(seed=idx)
|
|
121
|
+
start = tuple(
|
|
122
|
+
rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
image_vol = self._image.iloc[volume_idx]
|
|
126
|
+
mask_vol = self._mask.iloc[volume_idx]
|
|
127
|
+
|
|
128
|
+
image_data = image_vol.slice(
|
|
129
|
+
slice(start[0], start[0] + patch_size[0]),
|
|
130
|
+
slice(start[1], start[1] + patch_size[1]),
|
|
131
|
+
slice(start[2], start[2] + patch_size[2]),
|
|
132
|
+
)
|
|
133
|
+
mask_data = mask_vol.slice(
|
|
134
|
+
slice(start[0], start[0] + patch_size[0]),
|
|
135
|
+
slice(start[1], start[1] + patch_size[1]),
|
|
136
|
+
slice(start[2], start[2] + patch_size[2]),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
result: dict[str, Any] = {
|
|
140
|
+
"image": torch.from_numpy(image_data).unsqueeze(0),
|
|
141
|
+
"mask": torch.from_numpy(mask_data).unsqueeze(0),
|
|
142
|
+
"idx": volume_idx,
|
|
143
|
+
"patch_idx": patch_idx,
|
|
144
|
+
"patch_start": start,
|
|
145
|
+
"obs_id": self._obs_ids[volume_idx],
|
|
146
|
+
"obs_subject_id": self._obs_subject_ids[volume_idx],
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
return self._apply_transforms(result)
|
|
150
|
+
|
|
151
|
+
def _sample_foreground_patch(
|
|
152
|
+
self,
|
|
153
|
+
volume_idx: int,
|
|
154
|
+
max_start: tuple[int, ...],
|
|
155
|
+
patch_size: tuple[int, int, int],
|
|
156
|
+
seed: int,
|
|
157
|
+
) -> tuple[int, int, int]:
|
|
158
|
+
"""Sample a patch position biased toward foreground regions."""
|
|
159
|
+
rng = np.random.default_rng(seed=seed)
|
|
160
|
+
mask_vol = self._mask.iloc[volume_idx]
|
|
161
|
+
|
|
162
|
+
for attempt in range(self._foreground_max_retries):
|
|
163
|
+
start = tuple(
|
|
164
|
+
rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
mask_patch = mask_vol.slice(
|
|
168
|
+
slice(start[0], start[0] + patch_size[0]),
|
|
169
|
+
slice(start[1], start[1] + patch_size[1]),
|
|
170
|
+
slice(start[2], start[2] + patch_size[2]),
|
|
171
|
+
)
|
|
172
|
+
foreground_ratio = np.count_nonzero(mask_patch) / mask_patch.size
|
|
173
|
+
|
|
174
|
+
if foreground_ratio >= self._foreground_threshold:
|
|
175
|
+
return start # type: ignore[return-value]
|
|
176
|
+
|
|
177
|
+
# Fallback: return last sampled position
|
|
178
|
+
return start # type: ignore[return-value]
|
|
179
|
+
|
|
180
|
+
def _get_slice_item(self, idx: int) -> dict[str, Any]:
|
|
181
|
+
"""Load a 2D slice from image and mask."""
|
|
182
|
+
volume_idx = idx // self._volume_shape[2]
|
|
183
|
+
slice_idx = idx % self._volume_shape[2]
|
|
184
|
+
|
|
185
|
+
image_data = self._image.iloc[volume_idx].axial(slice_idx)
|
|
186
|
+
mask_data = self._mask.iloc[volume_idx].axial(slice_idx)
|
|
187
|
+
|
|
188
|
+
result: dict[str, Any] = {
|
|
189
|
+
"image": torch.from_numpy(image_data).unsqueeze(0),
|
|
190
|
+
"mask": torch.from_numpy(mask_data).unsqueeze(0),
|
|
191
|
+
"idx": volume_idx,
|
|
192
|
+
"slice_idx": slice_idx,
|
|
193
|
+
"obs_id": self._obs_ids[volume_idx],
|
|
194
|
+
"obs_subject_id": self._obs_subject_ids[volume_idx],
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
return self._apply_transforms(result)
|
|
198
|
+
|
|
199
|
+
def _apply_transforms(self, result: dict[str, Any]) -> dict[str, Any]:
|
|
200
|
+
"""Apply spatial and image transforms to sample dict."""
|
|
201
|
+
# Spatial transform affects both image and mask
|
|
202
|
+
if self._spatial_transform is not None:
|
|
203
|
+
result = self._spatial_transform(result)
|
|
204
|
+
|
|
205
|
+
# Image transform affects only image
|
|
206
|
+
if self._image_transform is not None:
|
|
207
|
+
result = self._image_transform(result)
|
|
208
|
+
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def volume_shape(self) -> tuple[int, int, int]:
|
|
213
|
+
"""Shape of each volume (X, Y, Z)."""
|
|
214
|
+
return self._volume_shape
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def n_volumes(self) -> int:
|
|
218
|
+
"""Number of image/mask pairs."""
|
|
219
|
+
return self._n_volumes
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""Core RadiObjectDataset implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from radiobject.radi_object import RadiObject
|
|
15
|
+
from radiobject.volume_collection import VolumeCollection
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RadiObjectDataset(Dataset):
|
|
19
|
+
"""PyTorch Dataset for loading volumes from RadiObject via TileDB."""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
radi_object: RadiObject,
|
|
24
|
+
config: DatasetConfig,
|
|
25
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
26
|
+
):
|
|
27
|
+
self._config = config
|
|
28
|
+
self._transform = transform
|
|
29
|
+
|
|
30
|
+
modalities = config.modalities or list(radi_object.collection_names)
|
|
31
|
+
if not modalities:
|
|
32
|
+
raise ValueError("No modalities specified and RadiObject has no collections")
|
|
33
|
+
|
|
34
|
+
self._modalities = modalities
|
|
35
|
+
self._collections: dict[str, VolumeCollection] = {
|
|
36
|
+
mod: radi_object.collection(mod) for mod in modalities
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
# Validate all collections have uniform shapes for batched loading
|
|
40
|
+
for mod in modalities:
|
|
41
|
+
coll = self._collections[mod]
|
|
42
|
+
if not coll.is_uniform:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Collection '{mod}' has heterogeneous shapes. "
|
|
45
|
+
f"Call collection.resample_to() to normalize dimensions before ML training."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
first_coll = self._collections[modalities[0]]
|
|
49
|
+
self._n_volumes = len(first_coll)
|
|
50
|
+
self._volume_shape = first_coll.shape # Guaranteed non-None after uniform check
|
|
51
|
+
|
|
52
|
+
if len(modalities) > 1:
|
|
53
|
+
self._validate_subject_alignment()
|
|
54
|
+
|
|
55
|
+
self._labels: dict[int, int | float] | None = None
|
|
56
|
+
if config.label_column:
|
|
57
|
+
self._load_labels(radi_object, config.label_column, config.value_filter)
|
|
58
|
+
|
|
59
|
+
if config.loading_mode == LoadingMode.PATCH:
|
|
60
|
+
self._length = self._n_volumes * config.patches_per_volume
|
|
61
|
+
elif config.loading_mode == LoadingMode.SLICE_2D:
|
|
62
|
+
self._length = self._n_volumes * self._volume_shape[2]
|
|
63
|
+
else:
|
|
64
|
+
self._length = self._n_volumes
|
|
65
|
+
|
|
66
|
+
def _validate_subject_alignment(self) -> None:
|
|
67
|
+
"""Validate that all modalities have matching subjects."""
|
|
68
|
+
first_mod = self._modalities[0]
|
|
69
|
+
first_coll = self._collections[first_mod]
|
|
70
|
+
|
|
71
|
+
first_subjects = set(first_coll.obs_subject_ids)
|
|
72
|
+
|
|
73
|
+
for mod in self._modalities[1:]:
|
|
74
|
+
coll = self._collections[mod]
|
|
75
|
+
if len(coll) != self._n_volumes:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"Modality '{mod}' has {len(coll)} volumes, expected {self._n_volumes}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
mod_subjects = set(coll.obs_subject_ids)
|
|
81
|
+
|
|
82
|
+
if mod_subjects != first_subjects:
|
|
83
|
+
missing = first_subjects - mod_subjects
|
|
84
|
+
extra = mod_subjects - first_subjects
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Subject mismatch for modality '{mod}': "
|
|
87
|
+
f"missing={list(missing)[:3]}, extra={list(extra)[:3]}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _load_labels(
|
|
91
|
+
self,
|
|
92
|
+
radi_object: RadiObject,
|
|
93
|
+
label_column: str,
|
|
94
|
+
value_filter: str | None,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Load labels from obs_meta dataframe."""
|
|
97
|
+
obs_meta = radi_object.obs_meta.read(value_filter=value_filter)
|
|
98
|
+
if label_column not in obs_meta.columns:
|
|
99
|
+
raise ValueError(f"Label column '{label_column}' not found in obs_meta")
|
|
100
|
+
|
|
101
|
+
first_coll = self._collections[self._modalities[0]]
|
|
102
|
+
obs_ids = first_coll.obs_ids
|
|
103
|
+
self._labels = {}
|
|
104
|
+
for idx in range(self._n_volumes):
|
|
105
|
+
obs_id = obs_ids[idx]
|
|
106
|
+
# Try matching by obs_id first (exact match)
|
|
107
|
+
match = obs_meta[obs_meta["obs_id"] == obs_id]
|
|
108
|
+
if len(match) == 0:
|
|
109
|
+
# Fall back to obs_subject_id matching
|
|
110
|
+
match = obs_meta[obs_meta["obs_subject_id"] == obs_id]
|
|
111
|
+
if len(match) == 0:
|
|
112
|
+
# Legacy: try parsing obs_id as subject_id + suffix
|
|
113
|
+
parts = obs_id.rsplit("_", 1)
|
|
114
|
+
if len(parts) > 1:
|
|
115
|
+
subject_id = parts[0]
|
|
116
|
+
match = obs_meta[obs_meta["obs_subject_id"] == subject_id]
|
|
117
|
+
if len(match) > 0:
|
|
118
|
+
self._labels[idx] = match[label_column].iloc[0]
|
|
119
|
+
|
|
120
|
+
def __len__(self) -> int:
|
|
121
|
+
return self._length
|
|
122
|
+
|
|
123
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
124
|
+
if self._config.loading_mode == LoadingMode.PATCH:
|
|
125
|
+
return self._get_patch_item(idx)
|
|
126
|
+
elif self._config.loading_mode == LoadingMode.SLICE_2D:
|
|
127
|
+
return self._get_slice_item(idx)
|
|
128
|
+
else:
|
|
129
|
+
return self._get_full_volume_item(idx)
|
|
130
|
+
|
|
131
|
+
def _get_full_volume_item(self, idx: int) -> dict[str, torch.Tensor]:
|
|
132
|
+
"""Load full volume for all modalities."""
|
|
133
|
+
volumes = [self._collections[mod].iloc[idx].to_numpy() for mod in self._modalities]
|
|
134
|
+
|
|
135
|
+
stacked = np.stack(volumes, axis=0)
|
|
136
|
+
result: dict[str, Any] = {
|
|
137
|
+
"image": torch.from_numpy(stacked),
|
|
138
|
+
"idx": idx,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
if self._labels is not None and idx in self._labels:
|
|
142
|
+
label = self._labels[idx]
|
|
143
|
+
if isinstance(label, (int, float, np.integer, np.floating)):
|
|
144
|
+
result["label"] = torch.tensor(label)
|
|
145
|
+
else:
|
|
146
|
+
result["label"] = label
|
|
147
|
+
|
|
148
|
+
if self._transform is not None:
|
|
149
|
+
result = self._transform(result)
|
|
150
|
+
|
|
151
|
+
return result
|
|
152
|
+
|
|
153
|
+
def _get_patch_item(self, idx: int) -> dict[str, torch.Tensor]:
|
|
154
|
+
"""Load a random patch from the volume."""
|
|
155
|
+
volume_idx = idx // self._config.patches_per_volume
|
|
156
|
+
patch_idx = idx % self._config.patches_per_volume
|
|
157
|
+
|
|
158
|
+
rng = np.random.default_rng(seed=idx)
|
|
159
|
+
patch_size = self._config.patch_size
|
|
160
|
+
assert patch_size is not None
|
|
161
|
+
|
|
162
|
+
max_start = tuple(max(0, self._volume_shape[i] - patch_size[i]) for i in range(3))
|
|
163
|
+
start = tuple(
|
|
164
|
+
rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
volumes = []
|
|
168
|
+
for mod in self._modalities:
|
|
169
|
+
vol = self._collections[mod].iloc[volume_idx]
|
|
170
|
+
patch = vol.slice(
|
|
171
|
+
slice(start[0], start[0] + patch_size[0]),
|
|
172
|
+
slice(start[1], start[1] + patch_size[1]),
|
|
173
|
+
slice(start[2], start[2] + patch_size[2]),
|
|
174
|
+
)
|
|
175
|
+
volumes.append(patch)
|
|
176
|
+
|
|
177
|
+
stacked = np.stack(volumes, axis=0)
|
|
178
|
+
result: dict[str, Any] = {
|
|
179
|
+
"image": torch.from_numpy(stacked),
|
|
180
|
+
"idx": volume_idx,
|
|
181
|
+
"patch_idx": patch_idx,
|
|
182
|
+
"patch_start": start,
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
if self._labels is not None and volume_idx in self._labels:
|
|
186
|
+
label = self._labels[volume_idx]
|
|
187
|
+
if isinstance(label, (int, float, np.integer, np.floating)):
|
|
188
|
+
result["label"] = torch.tensor(label)
|
|
189
|
+
else:
|
|
190
|
+
result["label"] = label
|
|
191
|
+
|
|
192
|
+
if self._transform is not None:
|
|
193
|
+
result = self._transform(result)
|
|
194
|
+
|
|
195
|
+
return result
|
|
196
|
+
|
|
197
|
+
def _get_slice_item(self, idx: int) -> dict[str, torch.Tensor]:
|
|
198
|
+
"""Load a 2D slice from the volume."""
|
|
199
|
+
volume_idx = idx // self._volume_shape[2]
|
|
200
|
+
slice_idx = idx % self._volume_shape[2]
|
|
201
|
+
|
|
202
|
+
slices = [
|
|
203
|
+
self._collections[mod].iloc[volume_idx].axial(slice_idx) for mod in self._modalities
|
|
204
|
+
]
|
|
205
|
+
|
|
206
|
+
stacked = np.stack(slices, axis=0)
|
|
207
|
+
result: dict[str, Any] = {
|
|
208
|
+
"image": torch.from_numpy(stacked),
|
|
209
|
+
"idx": volume_idx,
|
|
210
|
+
"slice_idx": slice_idx,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
if self._labels is not None and volume_idx in self._labels:
|
|
214
|
+
label = self._labels[volume_idx]
|
|
215
|
+
if isinstance(label, (int, float, np.integer, np.floating)):
|
|
216
|
+
result["label"] = torch.tensor(label)
|
|
217
|
+
else:
|
|
218
|
+
result["label"] = label
|
|
219
|
+
|
|
220
|
+
if self._transform is not None:
|
|
221
|
+
result = self._transform(result)
|
|
222
|
+
|
|
223
|
+
return result
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def modalities(self) -> list[str]:
|
|
227
|
+
"""List of modalities being loaded."""
|
|
228
|
+
return self._modalities
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def volume_shape(self) -> tuple[int, int, int]:
|
|
232
|
+
"""Shape of each volume (X, Y, Z)."""
|
|
233
|
+
return self._volume_shape
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Distributed training utilities for DDP."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
|
7
|
+
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
10
|
+
|
|
11
|
+
from radiobject._types import LabelSource
|
|
12
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
13
|
+
from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
|
|
14
|
+
from radiobject.ml.utils.worker_init import worker_init_fn
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from radiobject.volume_collection import VolumeCollection
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_distributed_dataloader(
|
|
21
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
22
|
+
rank: int,
|
|
23
|
+
world_size: int,
|
|
24
|
+
labels: LabelSource = None,
|
|
25
|
+
batch_size: int = 4,
|
|
26
|
+
patch_size: tuple[int, int, int] | None = None,
|
|
27
|
+
num_workers: int = 4,
|
|
28
|
+
pin_memory: bool = True,
|
|
29
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
30
|
+
) -> DataLoader:
|
|
31
|
+
"""Create a DataLoader for distributed training with DDP.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
collections: Single VolumeCollection or list for multi-modal training.
|
|
35
|
+
rank: Current process rank.
|
|
36
|
+
world_size: Total number of processes.
|
|
37
|
+
labels: Label source (see create_training_dataloader for options).
|
|
38
|
+
batch_size: Samples per batch per GPU.
|
|
39
|
+
patch_size: If provided, extract random patches.
|
|
40
|
+
num_workers: DataLoader worker processes.
|
|
41
|
+
pin_memory: Pin tensors to CUDA memory.
|
|
42
|
+
transform: Transform function.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
DataLoader with DistributedSampler configured.
|
|
46
|
+
"""
|
|
47
|
+
loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
|
|
48
|
+
|
|
49
|
+
config = DatasetConfig(
|
|
50
|
+
loading_mode=loading_mode,
|
|
51
|
+
patch_size=patch_size,
|
|
52
|
+
patches_per_volume=1,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
dataset = VolumeCollectionDataset(
|
|
56
|
+
collections, config=config, labels=labels, transform=transform
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
sampler = DistributedSampler(
|
|
60
|
+
dataset,
|
|
61
|
+
num_replicas=world_size,
|
|
62
|
+
rank=rank,
|
|
63
|
+
shuffle=True,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
effective_workers = num_workers if num_workers > 0 else 0
|
|
67
|
+
|
|
68
|
+
return DataLoader(
|
|
69
|
+
dataset,
|
|
70
|
+
batch_size=batch_size,
|
|
71
|
+
sampler=sampler,
|
|
72
|
+
num_workers=effective_workers,
|
|
73
|
+
pin_memory=pin_memory and effective_workers > 0,
|
|
74
|
+
worker_init_fn=worker_init_fn if effective_workers > 0 else None,
|
|
75
|
+
drop_last=True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def set_epoch(dataloader: DataLoader, epoch: int) -> None:
|
|
80
|
+
"""Set epoch for DistributedSampler to ensure proper shuffling."""
|
|
81
|
+
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, DistributedSampler):
|
|
82
|
+
dataloader.sampler.set_epoch(epoch)
|