radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ """SegmentationDataset - specialized dataset for image/mask segmentation training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+
12
+ from radiobject.ml.config import DatasetConfig, LoadingMode
13
+ from radiobject.ml.utils.validation import validate_collection_alignment, validate_uniform_shapes
14
+
15
+ if TYPE_CHECKING:
16
+ from radiobject.volume_collection import VolumeCollection
17
+
18
+
19
+ class SegmentationDataset(Dataset):
20
+ """PyTorch Dataset for segmentation training with explicit image/mask separation."""
21
+
22
+ def __init__(
23
+ self,
24
+ image: VolumeCollection,
25
+ mask: VolumeCollection,
26
+ config: DatasetConfig | None = None,
27
+ image_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
28
+ spatial_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
29
+ foreground_sampling: bool = False,
30
+ foreground_threshold: float = 0.01,
31
+ foreground_max_retries: int = 10,
32
+ ):
33
+ """Initialize segmentation dataset.
34
+
35
+ Args:
36
+ image: VolumeCollection containing input images (CT, MRI, etc.).
37
+ mask: VolumeCollection containing segmentation masks.
38
+ config: Dataset configuration (loading mode, patch size, etc.).
39
+ image_transform: Transform applied to image only (e.g., normalization).
40
+ Should operate on keys=["image"].
41
+ spatial_transform: Transform applied to both image and mask (e.g., flips).
42
+ Should operate on keys=["image", "mask"].
43
+ foreground_sampling: If True, bias patch sampling toward regions with
44
+ foreground (non-zero mask values).
45
+ foreground_threshold: Minimum fraction of foreground voxels in patch
46
+ when foreground_sampling is enabled.
47
+ foreground_max_retries: Maximum random attempts before accepting any patch.
48
+ """
49
+ self._config = config or DatasetConfig()
50
+ self._image_transform = image_transform
51
+ self._spatial_transform = spatial_transform
52
+ self._foreground_sampling = foreground_sampling
53
+ self._foreground_threshold = foreground_threshold
54
+ self._foreground_max_retries = foreground_max_retries
55
+
56
+ # Store collections directly
57
+ self._image = image
58
+ self._mask = mask
59
+
60
+ # Cache obs_ids and obs_subject_ids for fast access
61
+ self._obs_ids = image.obs_ids
62
+ self._obs_subject_ids = image.obs_subject_ids
63
+
64
+ # Validate alignment between image and mask collections
65
+ collections = {"image": self._image, "mask": self._mask}
66
+ validate_collection_alignment(collections)
67
+
68
+ # Validate uniform shapes
69
+ self._volume_shape = validate_uniform_shapes(collections)
70
+ self._n_volumes = len(self._image)
71
+
72
+ # Compute dataset length
73
+ if self._config.loading_mode == LoadingMode.PATCH:
74
+ self._length = self._n_volumes * self._config.patches_per_volume
75
+ elif self._config.loading_mode == LoadingMode.SLICE_2D:
76
+ self._length = self._n_volumes * self._volume_shape[2]
77
+ else:
78
+ self._length = self._n_volumes
79
+
80
+ def __len__(self) -> int:
81
+ return self._length
82
+
83
+ def __getitem__(self, idx: int) -> dict[str, Any]:
84
+ if self._config.loading_mode == LoadingMode.PATCH:
85
+ return self._get_patch_item(idx)
86
+ elif self._config.loading_mode == LoadingMode.SLICE_2D:
87
+ return self._get_slice_item(idx)
88
+ else:
89
+ return self._get_full_volume_item(idx)
90
+
91
+ def _get_full_volume_item(self, idx: int) -> dict[str, Any]:
92
+ """Load full volume for image and mask."""
93
+ image_data = self._image.iloc[idx].to_numpy()
94
+ mask_data = self._mask.iloc[idx].to_numpy()
95
+
96
+ result: dict[str, Any] = {
97
+ "image": torch.from_numpy(image_data).unsqueeze(0),
98
+ "mask": torch.from_numpy(mask_data).unsqueeze(0),
99
+ "idx": idx,
100
+ "obs_id": self._obs_ids[idx],
101
+ "obs_subject_id": self._obs_subject_ids[idx],
102
+ }
103
+
104
+ return self._apply_transforms(result)
105
+
106
+ def _get_patch_item(self, idx: int) -> dict[str, Any]:
107
+ """Load a random patch from image and mask."""
108
+ volume_idx = idx // self._config.patches_per_volume
109
+ patch_idx = idx % self._config.patches_per_volume
110
+
111
+ patch_size = self._config.patch_size
112
+ assert patch_size is not None
113
+
114
+ max_start = tuple(max(0, self._volume_shape[i] - patch_size[i]) for i in range(3))
115
+
116
+ if self._foreground_sampling:
117
+ # Try to find a patch with sufficient foreground
118
+ start = self._sample_foreground_patch(volume_idx, max_start, patch_size, idx)
119
+ else:
120
+ rng = np.random.default_rng(seed=idx)
121
+ start = tuple(
122
+ rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
123
+ )
124
+
125
+ image_vol = self._image.iloc[volume_idx]
126
+ mask_vol = self._mask.iloc[volume_idx]
127
+
128
+ image_data = image_vol.slice(
129
+ slice(start[0], start[0] + patch_size[0]),
130
+ slice(start[1], start[1] + patch_size[1]),
131
+ slice(start[2], start[2] + patch_size[2]),
132
+ )
133
+ mask_data = mask_vol.slice(
134
+ slice(start[0], start[0] + patch_size[0]),
135
+ slice(start[1], start[1] + patch_size[1]),
136
+ slice(start[2], start[2] + patch_size[2]),
137
+ )
138
+
139
+ result: dict[str, Any] = {
140
+ "image": torch.from_numpy(image_data).unsqueeze(0),
141
+ "mask": torch.from_numpy(mask_data).unsqueeze(0),
142
+ "idx": volume_idx,
143
+ "patch_idx": patch_idx,
144
+ "patch_start": start,
145
+ "obs_id": self._obs_ids[volume_idx],
146
+ "obs_subject_id": self._obs_subject_ids[volume_idx],
147
+ }
148
+
149
+ return self._apply_transforms(result)
150
+
151
+ def _sample_foreground_patch(
152
+ self,
153
+ volume_idx: int,
154
+ max_start: tuple[int, ...],
155
+ patch_size: tuple[int, int, int],
156
+ seed: int,
157
+ ) -> tuple[int, int, int]:
158
+ """Sample a patch position biased toward foreground regions."""
159
+ rng = np.random.default_rng(seed=seed)
160
+ mask_vol = self._mask.iloc[volume_idx]
161
+
162
+ for attempt in range(self._foreground_max_retries):
163
+ start = tuple(
164
+ rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
165
+ )
166
+
167
+ mask_patch = mask_vol.slice(
168
+ slice(start[0], start[0] + patch_size[0]),
169
+ slice(start[1], start[1] + patch_size[1]),
170
+ slice(start[2], start[2] + patch_size[2]),
171
+ )
172
+ foreground_ratio = np.count_nonzero(mask_patch) / mask_patch.size
173
+
174
+ if foreground_ratio >= self._foreground_threshold:
175
+ return start # type: ignore[return-value]
176
+
177
+ # Fallback: return last sampled position
178
+ return start # type: ignore[return-value]
179
+
180
+ def _get_slice_item(self, idx: int) -> dict[str, Any]:
181
+ """Load a 2D slice from image and mask."""
182
+ volume_idx = idx // self._volume_shape[2]
183
+ slice_idx = idx % self._volume_shape[2]
184
+
185
+ image_data = self._image.iloc[volume_idx].axial(slice_idx)
186
+ mask_data = self._mask.iloc[volume_idx].axial(slice_idx)
187
+
188
+ result: dict[str, Any] = {
189
+ "image": torch.from_numpy(image_data).unsqueeze(0),
190
+ "mask": torch.from_numpy(mask_data).unsqueeze(0),
191
+ "idx": volume_idx,
192
+ "slice_idx": slice_idx,
193
+ "obs_id": self._obs_ids[volume_idx],
194
+ "obs_subject_id": self._obs_subject_ids[volume_idx],
195
+ }
196
+
197
+ return self._apply_transforms(result)
198
+
199
+ def _apply_transforms(self, result: dict[str, Any]) -> dict[str, Any]:
200
+ """Apply spatial and image transforms to sample dict."""
201
+ # Spatial transform affects both image and mask
202
+ if self._spatial_transform is not None:
203
+ result = self._spatial_transform(result)
204
+
205
+ # Image transform affects only image
206
+ if self._image_transform is not None:
207
+ result = self._image_transform(result)
208
+
209
+ return result
210
+
211
+ @property
212
+ def volume_shape(self) -> tuple[int, int, int]:
213
+ """Shape of each volume (X, Y, Z)."""
214
+ return self._volume_shape
215
+
216
+ @property
217
+ def n_volumes(self) -> int:
218
+ """Number of image/mask pairs."""
219
+ return self._n_volumes
@@ -0,0 +1,233 @@
1
+ """Core RadiObjectDataset implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any, Callable
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ from radiobject.ml.config import DatasetConfig, LoadingMode
12
+
13
+ if TYPE_CHECKING:
14
+ from radiobject.radi_object import RadiObject
15
+ from radiobject.volume_collection import VolumeCollection
16
+
17
+
18
+ class RadiObjectDataset(Dataset):
19
+ """PyTorch Dataset for loading volumes from RadiObject via TileDB."""
20
+
21
+ def __init__(
22
+ self,
23
+ radi_object: RadiObject,
24
+ config: DatasetConfig,
25
+ transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
26
+ ):
27
+ self._config = config
28
+ self._transform = transform
29
+
30
+ modalities = config.modalities or list(radi_object.collection_names)
31
+ if not modalities:
32
+ raise ValueError("No modalities specified and RadiObject has no collections")
33
+
34
+ self._modalities = modalities
35
+ self._collections: dict[str, VolumeCollection] = {
36
+ mod: radi_object.collection(mod) for mod in modalities
37
+ }
38
+
39
+ # Validate all collections have uniform shapes for batched loading
40
+ for mod in modalities:
41
+ coll = self._collections[mod]
42
+ if not coll.is_uniform:
43
+ raise ValueError(
44
+ f"Collection '{mod}' has heterogeneous shapes. "
45
+ f"Call collection.resample_to() to normalize dimensions before ML training."
46
+ )
47
+
48
+ first_coll = self._collections[modalities[0]]
49
+ self._n_volumes = len(first_coll)
50
+ self._volume_shape = first_coll.shape # Guaranteed non-None after uniform check
51
+
52
+ if len(modalities) > 1:
53
+ self._validate_subject_alignment()
54
+
55
+ self._labels: dict[int, int | float] | None = None
56
+ if config.label_column:
57
+ self._load_labels(radi_object, config.label_column, config.value_filter)
58
+
59
+ if config.loading_mode == LoadingMode.PATCH:
60
+ self._length = self._n_volumes * config.patches_per_volume
61
+ elif config.loading_mode == LoadingMode.SLICE_2D:
62
+ self._length = self._n_volumes * self._volume_shape[2]
63
+ else:
64
+ self._length = self._n_volumes
65
+
66
+ def _validate_subject_alignment(self) -> None:
67
+ """Validate that all modalities have matching subjects."""
68
+ first_mod = self._modalities[0]
69
+ first_coll = self._collections[first_mod]
70
+
71
+ first_subjects = set(first_coll.obs_subject_ids)
72
+
73
+ for mod in self._modalities[1:]:
74
+ coll = self._collections[mod]
75
+ if len(coll) != self._n_volumes:
76
+ raise ValueError(
77
+ f"Modality '{mod}' has {len(coll)} volumes, expected {self._n_volumes}"
78
+ )
79
+
80
+ mod_subjects = set(coll.obs_subject_ids)
81
+
82
+ if mod_subjects != first_subjects:
83
+ missing = first_subjects - mod_subjects
84
+ extra = mod_subjects - first_subjects
85
+ raise ValueError(
86
+ f"Subject mismatch for modality '{mod}': "
87
+ f"missing={list(missing)[:3]}, extra={list(extra)[:3]}"
88
+ )
89
+
90
+ def _load_labels(
91
+ self,
92
+ radi_object: RadiObject,
93
+ label_column: str,
94
+ value_filter: str | None,
95
+ ) -> None:
96
+ """Load labels from obs_meta dataframe."""
97
+ obs_meta = radi_object.obs_meta.read(value_filter=value_filter)
98
+ if label_column not in obs_meta.columns:
99
+ raise ValueError(f"Label column '{label_column}' not found in obs_meta")
100
+
101
+ first_coll = self._collections[self._modalities[0]]
102
+ obs_ids = first_coll.obs_ids
103
+ self._labels = {}
104
+ for idx in range(self._n_volumes):
105
+ obs_id = obs_ids[idx]
106
+ # Try matching by obs_id first (exact match)
107
+ match = obs_meta[obs_meta["obs_id"] == obs_id]
108
+ if len(match) == 0:
109
+ # Fall back to obs_subject_id matching
110
+ match = obs_meta[obs_meta["obs_subject_id"] == obs_id]
111
+ if len(match) == 0:
112
+ # Legacy: try parsing obs_id as subject_id + suffix
113
+ parts = obs_id.rsplit("_", 1)
114
+ if len(parts) > 1:
115
+ subject_id = parts[0]
116
+ match = obs_meta[obs_meta["obs_subject_id"] == subject_id]
117
+ if len(match) > 0:
118
+ self._labels[idx] = match[label_column].iloc[0]
119
+
120
+ def __len__(self) -> int:
121
+ return self._length
122
+
123
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
124
+ if self._config.loading_mode == LoadingMode.PATCH:
125
+ return self._get_patch_item(idx)
126
+ elif self._config.loading_mode == LoadingMode.SLICE_2D:
127
+ return self._get_slice_item(idx)
128
+ else:
129
+ return self._get_full_volume_item(idx)
130
+
131
+ def _get_full_volume_item(self, idx: int) -> dict[str, torch.Tensor]:
132
+ """Load full volume for all modalities."""
133
+ volumes = [self._collections[mod].iloc[idx].to_numpy() for mod in self._modalities]
134
+
135
+ stacked = np.stack(volumes, axis=0)
136
+ result: dict[str, Any] = {
137
+ "image": torch.from_numpy(stacked),
138
+ "idx": idx,
139
+ }
140
+
141
+ if self._labels is not None and idx in self._labels:
142
+ label = self._labels[idx]
143
+ if isinstance(label, (int, float, np.integer, np.floating)):
144
+ result["label"] = torch.tensor(label)
145
+ else:
146
+ result["label"] = label
147
+
148
+ if self._transform is not None:
149
+ result = self._transform(result)
150
+
151
+ return result
152
+
153
+ def _get_patch_item(self, idx: int) -> dict[str, torch.Tensor]:
154
+ """Load a random patch from the volume."""
155
+ volume_idx = idx // self._config.patches_per_volume
156
+ patch_idx = idx % self._config.patches_per_volume
157
+
158
+ rng = np.random.default_rng(seed=idx)
159
+ patch_size = self._config.patch_size
160
+ assert patch_size is not None
161
+
162
+ max_start = tuple(max(0, self._volume_shape[i] - patch_size[i]) for i in range(3))
163
+ start = tuple(
164
+ rng.integers(0, max_start[i] + 1) if max_start[i] > 0 else 0 for i in range(3)
165
+ )
166
+
167
+ volumes = []
168
+ for mod in self._modalities:
169
+ vol = self._collections[mod].iloc[volume_idx]
170
+ patch = vol.slice(
171
+ slice(start[0], start[0] + patch_size[0]),
172
+ slice(start[1], start[1] + patch_size[1]),
173
+ slice(start[2], start[2] + patch_size[2]),
174
+ )
175
+ volumes.append(patch)
176
+
177
+ stacked = np.stack(volumes, axis=0)
178
+ result: dict[str, Any] = {
179
+ "image": torch.from_numpy(stacked),
180
+ "idx": volume_idx,
181
+ "patch_idx": patch_idx,
182
+ "patch_start": start,
183
+ }
184
+
185
+ if self._labels is not None and volume_idx in self._labels:
186
+ label = self._labels[volume_idx]
187
+ if isinstance(label, (int, float, np.integer, np.floating)):
188
+ result["label"] = torch.tensor(label)
189
+ else:
190
+ result["label"] = label
191
+
192
+ if self._transform is not None:
193
+ result = self._transform(result)
194
+
195
+ return result
196
+
197
+ def _get_slice_item(self, idx: int) -> dict[str, torch.Tensor]:
198
+ """Load a 2D slice from the volume."""
199
+ volume_idx = idx // self._volume_shape[2]
200
+ slice_idx = idx % self._volume_shape[2]
201
+
202
+ slices = [
203
+ self._collections[mod].iloc[volume_idx].axial(slice_idx) for mod in self._modalities
204
+ ]
205
+
206
+ stacked = np.stack(slices, axis=0)
207
+ result: dict[str, Any] = {
208
+ "image": torch.from_numpy(stacked),
209
+ "idx": volume_idx,
210
+ "slice_idx": slice_idx,
211
+ }
212
+
213
+ if self._labels is not None and volume_idx in self._labels:
214
+ label = self._labels[volume_idx]
215
+ if isinstance(label, (int, float, np.integer, np.floating)):
216
+ result["label"] = torch.tensor(label)
217
+ else:
218
+ result["label"] = label
219
+
220
+ if self._transform is not None:
221
+ result = self._transform(result)
222
+
223
+ return result
224
+
225
+ @property
226
+ def modalities(self) -> list[str]:
227
+ """List of modalities being loaded."""
228
+ return self._modalities
229
+
230
+ @property
231
+ def volume_shape(self) -> tuple[int, int, int]:
232
+ """Shape of each volume (X, Y, Z)."""
233
+ return self._volume_shape
@@ -0,0 +1,82 @@
1
+ """Distributed training utilities for DDP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import TYPE_CHECKING, Any, Sequence
7
+
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from radiobject._types import LabelSource
12
+ from radiobject.ml.config import DatasetConfig, LoadingMode
13
+ from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
14
+ from radiobject.ml.utils.worker_init import worker_init_fn
15
+
16
+ if TYPE_CHECKING:
17
+ from radiobject.volume_collection import VolumeCollection
18
+
19
+
20
+ def create_distributed_dataloader(
21
+ collections: VolumeCollection | Sequence[VolumeCollection],
22
+ rank: int,
23
+ world_size: int,
24
+ labels: LabelSource = None,
25
+ batch_size: int = 4,
26
+ patch_size: tuple[int, int, int] | None = None,
27
+ num_workers: int = 4,
28
+ pin_memory: bool = True,
29
+ transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
30
+ ) -> DataLoader:
31
+ """Create a DataLoader for distributed training with DDP.
32
+
33
+ Args:
34
+ collections: Single VolumeCollection or list for multi-modal training.
35
+ rank: Current process rank.
36
+ world_size: Total number of processes.
37
+ labels: Label source (see create_training_dataloader for options).
38
+ batch_size: Samples per batch per GPU.
39
+ patch_size: If provided, extract random patches.
40
+ num_workers: DataLoader worker processes.
41
+ pin_memory: Pin tensors to CUDA memory.
42
+ transform: Transform function.
43
+
44
+ Returns:
45
+ DataLoader with DistributedSampler configured.
46
+ """
47
+ loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
48
+
49
+ config = DatasetConfig(
50
+ loading_mode=loading_mode,
51
+ patch_size=patch_size,
52
+ patches_per_volume=1,
53
+ )
54
+
55
+ dataset = VolumeCollectionDataset(
56
+ collections, config=config, labels=labels, transform=transform
57
+ )
58
+
59
+ sampler = DistributedSampler(
60
+ dataset,
61
+ num_replicas=world_size,
62
+ rank=rank,
63
+ shuffle=True,
64
+ )
65
+
66
+ effective_workers = num_workers if num_workers > 0 else 0
67
+
68
+ return DataLoader(
69
+ dataset,
70
+ batch_size=batch_size,
71
+ sampler=sampler,
72
+ num_workers=effective_workers,
73
+ pin_memory=pin_memory and effective_workers > 0,
74
+ worker_init_fn=worker_init_fn if effective_workers > 0 else None,
75
+ drop_last=True,
76
+ )
77
+
78
+
79
+ def set_epoch(dataloader: DataLoader, epoch: int) -> None:
80
+ """Set epoch for DistributedSampler to ensure proper shuffling."""
81
+ if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, DistributedSampler):
82
+ dataloader.sampler.set_epoch(epoch)