radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ """Factory functions for creating training dataloaders."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import TYPE_CHECKING, Any, Sequence
7
+
8
+ from torch.utils.data import DataLoader
9
+
10
+ from radiobject._types import LabelSource
11
+ from radiobject.ml.config import DatasetConfig, LoadingMode
12
+ from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
13
+ from radiobject.ml.datasets.segmentation_dataset import SegmentationDataset
14
+ from radiobject.ml.utils.worker_init import worker_init_fn
15
+
16
+ if TYPE_CHECKING:
17
+ from radiobject.volume_collection import VolumeCollection
18
+
19
+
20
+ def create_training_dataloader(
21
+ collections: VolumeCollection | Sequence[VolumeCollection],
22
+ labels: LabelSource = None,
23
+ batch_size: int = 4,
24
+ patch_size: tuple[int, int, int] | None = None,
25
+ num_workers: int = 4,
26
+ pin_memory: bool = True,
27
+ persistent_workers: bool = True,
28
+ transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
29
+ patches_per_volume: int = 1,
30
+ ) -> DataLoader:
31
+ """Create a DataLoader configured for training from VolumeCollection(s).
32
+
33
+ Args:
34
+ collections: Single VolumeCollection or list for multi-modal training.
35
+ Multi-modal collections are stacked along channel dimension.
36
+ labels: Label source. Can be:
37
+ - str: Column name in collection's obs DataFrame
38
+ - pd.DataFrame: With obs_id as column/index and label values
39
+ - dict[str, Any]: Mapping from obs_id to label
40
+ - Callable[[str], Any]: Function taking obs_id, returning label
41
+ - None: No labels
42
+ batch_size: Samples per batch.
43
+ patch_size: If provided, extract random patches of this size.
44
+ num_workers: DataLoader worker processes.
45
+ pin_memory: Pin tensors to CUDA memory.
46
+ persistent_workers: Keep workers alive between epochs.
47
+ transform: Transform function applied to each sample.
48
+ MONAI dict transforms (e.g., RandFlipd) work directly.
49
+ patches_per_volume: Number of patches to extract per volume per epoch.
50
+
51
+ Returns:
52
+ DataLoader configured for training with shuffle enabled.
53
+ """
54
+ loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
55
+
56
+ config = DatasetConfig(
57
+ loading_mode=loading_mode,
58
+ patch_size=patch_size,
59
+ patches_per_volume=patches_per_volume,
60
+ )
61
+
62
+ dataset = VolumeCollectionDataset(
63
+ collections, config=config, labels=labels, transform=transform
64
+ )
65
+
66
+ effective_workers = num_workers if num_workers > 0 else 0
67
+ effective_persistent = persistent_workers and effective_workers > 0
68
+
69
+ return DataLoader(
70
+ dataset,
71
+ batch_size=batch_size,
72
+ shuffle=True,
73
+ num_workers=effective_workers,
74
+ pin_memory=pin_memory and effective_workers > 0,
75
+ persistent_workers=effective_persistent,
76
+ worker_init_fn=worker_init_fn if effective_workers > 0 else None,
77
+ drop_last=True,
78
+ )
79
+
80
+
81
+ def create_validation_dataloader(
82
+ collections: VolumeCollection | Sequence[VolumeCollection],
83
+ labels: LabelSource = None,
84
+ batch_size: int = 4,
85
+ patch_size: tuple[int, int, int] | None = None,
86
+ num_workers: int = 4,
87
+ pin_memory: bool = True,
88
+ transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
89
+ ) -> DataLoader:
90
+ """Create a DataLoader configured for validation (no shuffle, no drop_last).
91
+
92
+ Args:
93
+ collections: Single VolumeCollection or list for multi-modal validation.
94
+ labels: Label source (see create_training_dataloader for options).
95
+ batch_size: Samples per batch.
96
+ patch_size: If provided, extract patches of this size.
97
+ num_workers: DataLoader worker processes.
98
+ pin_memory: Pin tensors to CUDA memory.
99
+ transform: Transform function applied to each sample.
100
+ MONAI dict transforms work directly.
101
+
102
+ Returns:
103
+ DataLoader configured for validation.
104
+
105
+ Example::
106
+
107
+ from monai.transforms import Compose, NormalizeIntensityd
108
+
109
+ transform = Compose([NormalizeIntensityd(keys="image")])
110
+ loader = create_validation_dataloader(radi.CT, labels="has_tumor", transform=transform)
111
+ """
112
+ loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
113
+
114
+ config = DatasetConfig(
115
+ loading_mode=loading_mode,
116
+ patch_size=patch_size,
117
+ patches_per_volume=1,
118
+ )
119
+
120
+ dataset = VolumeCollectionDataset(
121
+ collections, config=config, labels=labels, transform=transform
122
+ )
123
+
124
+ effective_workers = num_workers if num_workers > 0 else 0
125
+
126
+ return DataLoader(
127
+ dataset,
128
+ batch_size=batch_size,
129
+ shuffle=False,
130
+ num_workers=effective_workers,
131
+ pin_memory=pin_memory and effective_workers > 0,
132
+ worker_init_fn=worker_init_fn if effective_workers > 0 else None,
133
+ drop_last=False,
134
+ )
135
+
136
+
137
+ def create_inference_dataloader(
138
+ collections: VolumeCollection | Sequence[VolumeCollection],
139
+ batch_size: int = 1,
140
+ num_workers: int = 4,
141
+ pin_memory: bool = True,
142
+ transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
143
+ ) -> DataLoader:
144
+ """Create a DataLoader configured for inference (full volumes, no shuffle).
145
+
146
+ Args:
147
+ collections: Single VolumeCollection or list for multi-modal inference.
148
+ batch_size: Samples per batch.
149
+ num_workers: DataLoader worker processes.
150
+ pin_memory: Pin tensors to CUDA memory.
151
+ transform: Transform function applied to each sample.
152
+
153
+ Returns:
154
+ DataLoader configured for inference.
155
+
156
+ Example::
157
+
158
+ from monai.transforms import NormalizeIntensityd
159
+
160
+ transform = NormalizeIntensityd(keys="image")
161
+ loader = create_inference_dataloader(radi.CT, transform=transform)
162
+ """
163
+ config = DatasetConfig(loading_mode=LoadingMode.FULL_VOLUME)
164
+
165
+ dataset = VolumeCollectionDataset(collections, config=config, transform=transform)
166
+
167
+ effective_workers = num_workers if num_workers > 0 else 0
168
+
169
+ return DataLoader(
170
+ dataset,
171
+ batch_size=batch_size,
172
+ shuffle=False,
173
+ num_workers=effective_workers,
174
+ pin_memory=pin_memory and effective_workers > 0,
175
+ worker_init_fn=worker_init_fn if effective_workers > 0 else None,
176
+ )
177
+
178
+
179
+ def create_segmentation_dataloader(
180
+ image: VolumeCollection,
181
+ mask: VolumeCollection,
182
+ batch_size: int = 4,
183
+ patch_size: tuple[int, int, int] | None = None,
184
+ num_workers: int = 4,
185
+ pin_memory: bool = True,
186
+ persistent_workers: bool = True,
187
+ image_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
188
+ spatial_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
189
+ foreground_sampling: bool = False,
190
+ foreground_threshold: float = 0.01,
191
+ patches_per_volume: int = 1,
192
+ ) -> DataLoader:
193
+ """Create a DataLoader for segmentation training with separate image/mask handling.
194
+
195
+ Unlike create_training_dataloader which stacks collections as channels, this
196
+ returns separate "image" and "mask" tensors. This is cleaner for segmentation
197
+ workflows where different transforms need to be applied to images vs masks.
198
+
199
+ Args:
200
+ image: VolumeCollection containing input images (CT, MRI, etc.).
201
+ mask: VolumeCollection containing segmentation masks.
202
+ batch_size: Samples per batch.
203
+ patch_size: If provided, extract random patches of this size.
204
+ num_workers: DataLoader worker processes.
205
+ pin_memory: Pin tensors to CUDA memory.
206
+ persistent_workers: Keep workers alive between epochs.
207
+ image_transform: Transform applied only to "image" key (e.g., normalization).
208
+ spatial_transform: Transform applied to both "image" and "mask" keys
209
+ (e.g., random flips, rotations).
210
+ foreground_sampling: If True, bias patch sampling toward regions with
211
+ foreground (non-zero mask values).
212
+ foreground_threshold: Minimum fraction of foreground voxels in patch
213
+ when foreground_sampling is enabled.
214
+ patches_per_volume: Number of patches to extract per volume per epoch.
215
+
216
+ Returns:
217
+ DataLoader yielding {"image": (B,1,X,Y,Z), "mask": (B,1,X,Y,Z), ...}
218
+ """
219
+ loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
220
+
221
+ config = DatasetConfig(
222
+ loading_mode=loading_mode,
223
+ patch_size=patch_size,
224
+ patches_per_volume=patches_per_volume,
225
+ )
226
+
227
+ dataset = SegmentationDataset(
228
+ image=image,
229
+ mask=mask,
230
+ config=config,
231
+ image_transform=image_transform,
232
+ spatial_transform=spatial_transform,
233
+ foreground_sampling=foreground_sampling,
234
+ foreground_threshold=foreground_threshold,
235
+ )
236
+
237
+ effective_workers = num_workers if num_workers > 0 else 0
238
+ effective_persistent = persistent_workers and effective_workers > 0
239
+
240
+ return DataLoader(
241
+ dataset,
242
+ batch_size=batch_size,
243
+ shuffle=True,
244
+ num_workers=effective_workers,
245
+ pin_memory=pin_memory and effective_workers > 0,
246
+ persistent_workers=effective_persistent,
247
+ worker_init_fn=worker_init_fn if effective_workers > 0 else None,
248
+ drop_last=True,
249
+ )
@@ -0,0 +1,13 @@
1
+ """ML utilities."""
2
+
3
+ from radiobject.ml.utils.labels import LabelSource, load_labels
4
+ from radiobject.ml.utils.validation import validate_collection_alignment, validate_uniform_shapes
5
+ from radiobject.ml.utils.worker_init import worker_init_fn
6
+
7
+ __all__ = [
8
+ "LabelSource",
9
+ "load_labels",
10
+ "validate_collection_alignment",
11
+ "validate_uniform_shapes",
12
+ "worker_init_fn",
13
+ ]
@@ -0,0 +1,106 @@
1
+ """Label loading utilities for ML datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import pandas as pd
8
+
9
+ from radiobject._types import LabelSource
10
+
11
+ __all__ = ["LabelSource", "load_labels"]
12
+
13
+ if TYPE_CHECKING:
14
+ from radiobject.volume_collection import VolumeCollection
15
+
16
+
17
+ def load_labels(
18
+ collection: VolumeCollection,
19
+ labels: LabelSource,
20
+ obs_df: pd.DataFrame | None = None,
21
+ ) -> dict[int, Any] | None:
22
+ """Load labels from various sources, indexed by volume position.
23
+
24
+ Args:
25
+ collection: VolumeCollection for the primary collection (used to get obs_ids).
26
+ labels: Label source - see LabelSource type for options.
27
+ obs_df: Pre-loaded obs DataFrame from the collection. Required when
28
+ labels is a column name string.
29
+
30
+ Returns:
31
+ Dict mapping volume index to label value, or None if labels is None.
32
+
33
+ Raises:
34
+ ValueError: If label column not found or DataFrame missing required columns.
35
+ """
36
+ if labels is None:
37
+ return None
38
+
39
+ obs_ids = collection.obs_ids
40
+ n_volumes = len(obs_ids)
41
+ result: dict[int, Any] = {}
42
+
43
+ if isinstance(labels, str):
44
+ # Column name in obs DataFrame
45
+ if obs_df is None:
46
+ raise ValueError(
47
+ "obs_df required when labels is a column name. "
48
+ "Pass collection.obs.read() as obs_df."
49
+ )
50
+
51
+ if labels not in obs_df.columns:
52
+ raise ValueError(f"Label column '{labels}' not found in obs DataFrame")
53
+
54
+ # Build lookup by obs_id
55
+ if "obs_id" in obs_df.columns:
56
+ label_lookup = dict(zip(obs_df["obs_id"], obs_df[labels]))
57
+ elif obs_df.index.name == "obs_id":
58
+ label_lookup = obs_df[labels].to_dict()
59
+ else:
60
+ raise ValueError("obs DataFrame must have 'obs_id' column or index")
61
+
62
+ for idx in range(n_volumes):
63
+ obs_id = obs_ids[idx]
64
+ if obs_id in label_lookup:
65
+ result[idx] = label_lookup[obs_id]
66
+
67
+ elif isinstance(labels, pd.DataFrame):
68
+ # DataFrame with obs_id mapping
69
+ if "obs_id" in labels.columns:
70
+ # obs_id as column - use first non-obs_id column as label
71
+ label_cols = [c for c in labels.columns if c != "obs_id"]
72
+ if not label_cols:
73
+ raise ValueError("Labels DataFrame must have at least one label column")
74
+ label_col = label_cols[0]
75
+ label_lookup = dict(zip(labels["obs_id"], labels[label_col]))
76
+ elif labels.index.name == "obs_id" or labels.index.dtype == object:
77
+ # obs_id as index
78
+ label_col = labels.columns[0]
79
+ label_lookup = labels[label_col].to_dict()
80
+ else:
81
+ raise ValueError("Labels DataFrame must have 'obs_id' as column or index")
82
+
83
+ for idx in range(n_volumes):
84
+ obs_id = obs_ids[idx]
85
+ if obs_id in label_lookup:
86
+ result[idx] = label_lookup[obs_id]
87
+
88
+ elif isinstance(labels, dict):
89
+ # Direct mapping from obs_id to label
90
+ for idx in range(n_volumes):
91
+ obs_id = obs_ids[idx]
92
+ if obs_id in labels:
93
+ result[idx] = labels[obs_id]
94
+
95
+ elif callable(labels):
96
+ # Function that takes obs_id and returns label
97
+ for idx in range(n_volumes):
98
+ obs_id = obs_ids[idx]
99
+ result[idx] = labels(obs_id)
100
+
101
+ else:
102
+ raise TypeError(
103
+ f"labels must be str, DataFrame, dict, callable, or None, got {type(labels)}"
104
+ )
105
+
106
+ return result if result else None
@@ -0,0 +1,85 @@
1
+ """Validation utilities for ML datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from radiobject.volume_collection import VolumeCollection
9
+
10
+
11
+ def validate_collection_alignment(collections: dict[str, VolumeCollection]) -> None:
12
+ """Validate all collections have matching subjects by obs_subject_id.
13
+
14
+ For multi-modal training, volumes from different collections must correspond
15
+ to the same subjects. This validates alignment using the obs_subject_id field
16
+ directly (no string parsing).
17
+
18
+ Args:
19
+ collections: Dict mapping collection names to VolumeCollection instances.
20
+
21
+ Raises:
22
+ ValueError: If collections have different volume counts or mismatched subjects.
23
+ """
24
+ if len(collections) < 2:
25
+ return
26
+
27
+ names = list(collections.keys())
28
+ first_name = names[0]
29
+ first_coll = collections[first_name]
30
+ n_volumes = len(first_coll)
31
+
32
+ first_subjects = set(first_coll.obs_subject_ids)
33
+
34
+ for name in names[1:]:
35
+ coll = collections[name]
36
+ if len(coll) != n_volumes:
37
+ raise ValueError(f"Collection '{name}' has {len(coll)} volumes, expected {n_volumes}")
38
+
39
+ mod_subjects = set(coll.obs_subject_ids)
40
+
41
+ if mod_subjects != first_subjects:
42
+ missing = first_subjects - mod_subjects
43
+ extra = mod_subjects - first_subjects
44
+ raise ValueError(
45
+ f"Subject mismatch for collection '{name}': "
46
+ f"missing={list(missing)[:3]}, extra={list(extra)[:3]}"
47
+ )
48
+
49
+
50
+ def validate_uniform_shapes(collections: dict[str, VolumeCollection]) -> tuple[int, int, int]:
51
+ """Validate all collections have uniform shapes and return the common shape.
52
+
53
+ Args:
54
+ collections: Dict mapping collection names to VolumeCollection instances.
55
+
56
+ Returns:
57
+ Common volume shape (X, Y, Z).
58
+
59
+ Raises:
60
+ ValueError: If any collection has non-uniform shapes or shapes don't match.
61
+ """
62
+ shape: tuple[int, int, int] | None = None
63
+
64
+ for name, coll in collections.items():
65
+ if not coll.is_uniform:
66
+ raise ValueError(
67
+ f"Collection '{name}' has heterogeneous shapes. "
68
+ f"Resample to uniform dimensions before ML training."
69
+ )
70
+
71
+ coll_shape = coll.shape
72
+ if coll_shape is None:
73
+ raise ValueError(f"Collection '{name}' has no shape metadata.")
74
+
75
+ if shape is None:
76
+ shape = coll_shape
77
+ elif coll_shape != shape:
78
+ raise ValueError(
79
+ f"Shape mismatch: collection '{name}' has shape {coll_shape}, " f"expected {shape}"
80
+ )
81
+
82
+ if shape is None:
83
+ raise ValueError("No collections provided")
84
+
85
+ return shape
@@ -0,0 +1,10 @@
1
+ """Worker initialization for DataLoader multiprocessing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from radiobject.parallel import create_worker_ctx
6
+
7
+
8
+ def worker_init_fn(worker_id: int) -> None:
9
+ """Initialize TileDB context for each DataLoader worker."""
10
+ _ = create_worker_ctx()