radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
radiobject/ml/factory.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Factory functions for creating training dataloaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
|
7
|
+
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
|
|
10
|
+
from radiobject._types import LabelSource
|
|
11
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
12
|
+
from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
|
|
13
|
+
from radiobject.ml.datasets.segmentation_dataset import SegmentationDataset
|
|
14
|
+
from radiobject.ml.utils.worker_init import worker_init_fn
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from radiobject.volume_collection import VolumeCollection
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_training_dataloader(
|
|
21
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
22
|
+
labels: LabelSource = None,
|
|
23
|
+
batch_size: int = 4,
|
|
24
|
+
patch_size: tuple[int, int, int] | None = None,
|
|
25
|
+
num_workers: int = 4,
|
|
26
|
+
pin_memory: bool = True,
|
|
27
|
+
persistent_workers: bool = True,
|
|
28
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
29
|
+
patches_per_volume: int = 1,
|
|
30
|
+
) -> DataLoader:
|
|
31
|
+
"""Create a DataLoader configured for training from VolumeCollection(s).
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
collections: Single VolumeCollection or list for multi-modal training.
|
|
35
|
+
Multi-modal collections are stacked along channel dimension.
|
|
36
|
+
labels: Label source. Can be:
|
|
37
|
+
- str: Column name in collection's obs DataFrame
|
|
38
|
+
- pd.DataFrame: With obs_id as column/index and label values
|
|
39
|
+
- dict[str, Any]: Mapping from obs_id to label
|
|
40
|
+
- Callable[[str], Any]: Function taking obs_id, returning label
|
|
41
|
+
- None: No labels
|
|
42
|
+
batch_size: Samples per batch.
|
|
43
|
+
patch_size: If provided, extract random patches of this size.
|
|
44
|
+
num_workers: DataLoader worker processes.
|
|
45
|
+
pin_memory: Pin tensors to CUDA memory.
|
|
46
|
+
persistent_workers: Keep workers alive between epochs.
|
|
47
|
+
transform: Transform function applied to each sample.
|
|
48
|
+
MONAI dict transforms (e.g., RandFlipd) work directly.
|
|
49
|
+
patches_per_volume: Number of patches to extract per volume per epoch.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
DataLoader configured for training with shuffle enabled.
|
|
53
|
+
"""
|
|
54
|
+
loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
|
|
55
|
+
|
|
56
|
+
config = DatasetConfig(
|
|
57
|
+
loading_mode=loading_mode,
|
|
58
|
+
patch_size=patch_size,
|
|
59
|
+
patches_per_volume=patches_per_volume,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
dataset = VolumeCollectionDataset(
|
|
63
|
+
collections, config=config, labels=labels, transform=transform
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
effective_workers = num_workers if num_workers > 0 else 0
|
|
67
|
+
effective_persistent = persistent_workers and effective_workers > 0
|
|
68
|
+
|
|
69
|
+
return DataLoader(
|
|
70
|
+
dataset,
|
|
71
|
+
batch_size=batch_size,
|
|
72
|
+
shuffle=True,
|
|
73
|
+
num_workers=effective_workers,
|
|
74
|
+
pin_memory=pin_memory and effective_workers > 0,
|
|
75
|
+
persistent_workers=effective_persistent,
|
|
76
|
+
worker_init_fn=worker_init_fn if effective_workers > 0 else None,
|
|
77
|
+
drop_last=True,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def create_validation_dataloader(
|
|
82
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
83
|
+
labels: LabelSource = None,
|
|
84
|
+
batch_size: int = 4,
|
|
85
|
+
patch_size: tuple[int, int, int] | None = None,
|
|
86
|
+
num_workers: int = 4,
|
|
87
|
+
pin_memory: bool = True,
|
|
88
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
89
|
+
) -> DataLoader:
|
|
90
|
+
"""Create a DataLoader configured for validation (no shuffle, no drop_last).
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
collections: Single VolumeCollection or list for multi-modal validation.
|
|
94
|
+
labels: Label source (see create_training_dataloader for options).
|
|
95
|
+
batch_size: Samples per batch.
|
|
96
|
+
patch_size: If provided, extract patches of this size.
|
|
97
|
+
num_workers: DataLoader worker processes.
|
|
98
|
+
pin_memory: Pin tensors to CUDA memory.
|
|
99
|
+
transform: Transform function applied to each sample.
|
|
100
|
+
MONAI dict transforms work directly.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
DataLoader configured for validation.
|
|
104
|
+
|
|
105
|
+
Example::
|
|
106
|
+
|
|
107
|
+
from monai.transforms import Compose, NormalizeIntensityd
|
|
108
|
+
|
|
109
|
+
transform = Compose([NormalizeIntensityd(keys="image")])
|
|
110
|
+
loader = create_validation_dataloader(radi.CT, labels="has_tumor", transform=transform)
|
|
111
|
+
"""
|
|
112
|
+
loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
|
|
113
|
+
|
|
114
|
+
config = DatasetConfig(
|
|
115
|
+
loading_mode=loading_mode,
|
|
116
|
+
patch_size=patch_size,
|
|
117
|
+
patches_per_volume=1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
dataset = VolumeCollectionDataset(
|
|
121
|
+
collections, config=config, labels=labels, transform=transform
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
effective_workers = num_workers if num_workers > 0 else 0
|
|
125
|
+
|
|
126
|
+
return DataLoader(
|
|
127
|
+
dataset,
|
|
128
|
+
batch_size=batch_size,
|
|
129
|
+
shuffle=False,
|
|
130
|
+
num_workers=effective_workers,
|
|
131
|
+
pin_memory=pin_memory and effective_workers > 0,
|
|
132
|
+
worker_init_fn=worker_init_fn if effective_workers > 0 else None,
|
|
133
|
+
drop_last=False,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def create_inference_dataloader(
|
|
138
|
+
collections: VolumeCollection | Sequence[VolumeCollection],
|
|
139
|
+
batch_size: int = 1,
|
|
140
|
+
num_workers: int = 4,
|
|
141
|
+
pin_memory: bool = True,
|
|
142
|
+
transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
143
|
+
) -> DataLoader:
|
|
144
|
+
"""Create a DataLoader configured for inference (full volumes, no shuffle).
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
collections: Single VolumeCollection or list for multi-modal inference.
|
|
148
|
+
batch_size: Samples per batch.
|
|
149
|
+
num_workers: DataLoader worker processes.
|
|
150
|
+
pin_memory: Pin tensors to CUDA memory.
|
|
151
|
+
transform: Transform function applied to each sample.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
DataLoader configured for inference.
|
|
155
|
+
|
|
156
|
+
Example::
|
|
157
|
+
|
|
158
|
+
from monai.transforms import NormalizeIntensityd
|
|
159
|
+
|
|
160
|
+
transform = NormalizeIntensityd(keys="image")
|
|
161
|
+
loader = create_inference_dataloader(radi.CT, transform=transform)
|
|
162
|
+
"""
|
|
163
|
+
config = DatasetConfig(loading_mode=LoadingMode.FULL_VOLUME)
|
|
164
|
+
|
|
165
|
+
dataset = VolumeCollectionDataset(collections, config=config, transform=transform)
|
|
166
|
+
|
|
167
|
+
effective_workers = num_workers if num_workers > 0 else 0
|
|
168
|
+
|
|
169
|
+
return DataLoader(
|
|
170
|
+
dataset,
|
|
171
|
+
batch_size=batch_size,
|
|
172
|
+
shuffle=False,
|
|
173
|
+
num_workers=effective_workers,
|
|
174
|
+
pin_memory=pin_memory and effective_workers > 0,
|
|
175
|
+
worker_init_fn=worker_init_fn if effective_workers > 0 else None,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def create_segmentation_dataloader(
|
|
180
|
+
image: VolumeCollection,
|
|
181
|
+
mask: VolumeCollection,
|
|
182
|
+
batch_size: int = 4,
|
|
183
|
+
patch_size: tuple[int, int, int] | None = None,
|
|
184
|
+
num_workers: int = 4,
|
|
185
|
+
pin_memory: bool = True,
|
|
186
|
+
persistent_workers: bool = True,
|
|
187
|
+
image_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
188
|
+
spatial_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
|
189
|
+
foreground_sampling: bool = False,
|
|
190
|
+
foreground_threshold: float = 0.01,
|
|
191
|
+
patches_per_volume: int = 1,
|
|
192
|
+
) -> DataLoader:
|
|
193
|
+
"""Create a DataLoader for segmentation training with separate image/mask handling.
|
|
194
|
+
|
|
195
|
+
Unlike create_training_dataloader which stacks collections as channels, this
|
|
196
|
+
returns separate "image" and "mask" tensors. This is cleaner for segmentation
|
|
197
|
+
workflows where different transforms need to be applied to images vs masks.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
image: VolumeCollection containing input images (CT, MRI, etc.).
|
|
201
|
+
mask: VolumeCollection containing segmentation masks.
|
|
202
|
+
batch_size: Samples per batch.
|
|
203
|
+
patch_size: If provided, extract random patches of this size.
|
|
204
|
+
num_workers: DataLoader worker processes.
|
|
205
|
+
pin_memory: Pin tensors to CUDA memory.
|
|
206
|
+
persistent_workers: Keep workers alive between epochs.
|
|
207
|
+
image_transform: Transform applied only to "image" key (e.g., normalization).
|
|
208
|
+
spatial_transform: Transform applied to both "image" and "mask" keys
|
|
209
|
+
(e.g., random flips, rotations).
|
|
210
|
+
foreground_sampling: If True, bias patch sampling toward regions with
|
|
211
|
+
foreground (non-zero mask values).
|
|
212
|
+
foreground_threshold: Minimum fraction of foreground voxels in patch
|
|
213
|
+
when foreground_sampling is enabled.
|
|
214
|
+
patches_per_volume: Number of patches to extract per volume per epoch.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
DataLoader yielding {"image": (B,1,X,Y,Z), "mask": (B,1,X,Y,Z), ...}
|
|
218
|
+
"""
|
|
219
|
+
loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
|
|
220
|
+
|
|
221
|
+
config = DatasetConfig(
|
|
222
|
+
loading_mode=loading_mode,
|
|
223
|
+
patch_size=patch_size,
|
|
224
|
+
patches_per_volume=patches_per_volume,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
dataset = SegmentationDataset(
|
|
228
|
+
image=image,
|
|
229
|
+
mask=mask,
|
|
230
|
+
config=config,
|
|
231
|
+
image_transform=image_transform,
|
|
232
|
+
spatial_transform=spatial_transform,
|
|
233
|
+
foreground_sampling=foreground_sampling,
|
|
234
|
+
foreground_threshold=foreground_threshold,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
effective_workers = num_workers if num_workers > 0 else 0
|
|
238
|
+
effective_persistent = persistent_workers and effective_workers > 0
|
|
239
|
+
|
|
240
|
+
return DataLoader(
|
|
241
|
+
dataset,
|
|
242
|
+
batch_size=batch_size,
|
|
243
|
+
shuffle=True,
|
|
244
|
+
num_workers=effective_workers,
|
|
245
|
+
pin_memory=pin_memory and effective_workers > 0,
|
|
246
|
+
persistent_workers=effective_persistent,
|
|
247
|
+
worker_init_fn=worker_init_fn if effective_workers > 0 else None,
|
|
248
|
+
drop_last=True,
|
|
249
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""ML utilities."""
|
|
2
|
+
|
|
3
|
+
from radiobject.ml.utils.labels import LabelSource, load_labels
|
|
4
|
+
from radiobject.ml.utils.validation import validate_collection_alignment, validate_uniform_shapes
|
|
5
|
+
from radiobject.ml.utils.worker_init import worker_init_fn
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"LabelSource",
|
|
9
|
+
"load_labels",
|
|
10
|
+
"validate_collection_alignment",
|
|
11
|
+
"validate_uniform_shapes",
|
|
12
|
+
"worker_init_fn",
|
|
13
|
+
]
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Label loading utilities for ML datasets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from radiobject._types import LabelSource
|
|
10
|
+
|
|
11
|
+
__all__ = ["LabelSource", "load_labels"]
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from radiobject.volume_collection import VolumeCollection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def load_labels(
|
|
18
|
+
collection: VolumeCollection,
|
|
19
|
+
labels: LabelSource,
|
|
20
|
+
obs_df: pd.DataFrame | None = None,
|
|
21
|
+
) -> dict[int, Any] | None:
|
|
22
|
+
"""Load labels from various sources, indexed by volume position.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
collection: VolumeCollection for the primary collection (used to get obs_ids).
|
|
26
|
+
labels: Label source - see LabelSource type for options.
|
|
27
|
+
obs_df: Pre-loaded obs DataFrame from the collection. Required when
|
|
28
|
+
labels is a column name string.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Dict mapping volume index to label value, or None if labels is None.
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
ValueError: If label column not found or DataFrame missing required columns.
|
|
35
|
+
"""
|
|
36
|
+
if labels is None:
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
obs_ids = collection.obs_ids
|
|
40
|
+
n_volumes = len(obs_ids)
|
|
41
|
+
result: dict[int, Any] = {}
|
|
42
|
+
|
|
43
|
+
if isinstance(labels, str):
|
|
44
|
+
# Column name in obs DataFrame
|
|
45
|
+
if obs_df is None:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"obs_df required when labels is a column name. "
|
|
48
|
+
"Pass collection.obs.read() as obs_df."
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if labels not in obs_df.columns:
|
|
52
|
+
raise ValueError(f"Label column '{labels}' not found in obs DataFrame")
|
|
53
|
+
|
|
54
|
+
# Build lookup by obs_id
|
|
55
|
+
if "obs_id" in obs_df.columns:
|
|
56
|
+
label_lookup = dict(zip(obs_df["obs_id"], obs_df[labels]))
|
|
57
|
+
elif obs_df.index.name == "obs_id":
|
|
58
|
+
label_lookup = obs_df[labels].to_dict()
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError("obs DataFrame must have 'obs_id' column or index")
|
|
61
|
+
|
|
62
|
+
for idx in range(n_volumes):
|
|
63
|
+
obs_id = obs_ids[idx]
|
|
64
|
+
if obs_id in label_lookup:
|
|
65
|
+
result[idx] = label_lookup[obs_id]
|
|
66
|
+
|
|
67
|
+
elif isinstance(labels, pd.DataFrame):
|
|
68
|
+
# DataFrame with obs_id mapping
|
|
69
|
+
if "obs_id" in labels.columns:
|
|
70
|
+
# obs_id as column - use first non-obs_id column as label
|
|
71
|
+
label_cols = [c for c in labels.columns if c != "obs_id"]
|
|
72
|
+
if not label_cols:
|
|
73
|
+
raise ValueError("Labels DataFrame must have at least one label column")
|
|
74
|
+
label_col = label_cols[0]
|
|
75
|
+
label_lookup = dict(zip(labels["obs_id"], labels[label_col]))
|
|
76
|
+
elif labels.index.name == "obs_id" or labels.index.dtype == object:
|
|
77
|
+
# obs_id as index
|
|
78
|
+
label_col = labels.columns[0]
|
|
79
|
+
label_lookup = labels[label_col].to_dict()
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError("Labels DataFrame must have 'obs_id' as column or index")
|
|
82
|
+
|
|
83
|
+
for idx in range(n_volumes):
|
|
84
|
+
obs_id = obs_ids[idx]
|
|
85
|
+
if obs_id in label_lookup:
|
|
86
|
+
result[idx] = label_lookup[obs_id]
|
|
87
|
+
|
|
88
|
+
elif isinstance(labels, dict):
|
|
89
|
+
# Direct mapping from obs_id to label
|
|
90
|
+
for idx in range(n_volumes):
|
|
91
|
+
obs_id = obs_ids[idx]
|
|
92
|
+
if obs_id in labels:
|
|
93
|
+
result[idx] = labels[obs_id]
|
|
94
|
+
|
|
95
|
+
elif callable(labels):
|
|
96
|
+
# Function that takes obs_id and returns label
|
|
97
|
+
for idx in range(n_volumes):
|
|
98
|
+
obs_id = obs_ids[idx]
|
|
99
|
+
result[idx] = labels(obs_id)
|
|
100
|
+
|
|
101
|
+
else:
|
|
102
|
+
raise TypeError(
|
|
103
|
+
f"labels must be str, DataFrame, dict, callable, or None, got {type(labels)}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return result if result else None
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Validation utilities for ML datasets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from radiobject.volume_collection import VolumeCollection
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def validate_collection_alignment(collections: dict[str, VolumeCollection]) -> None:
|
|
12
|
+
"""Validate all collections have matching subjects by obs_subject_id.
|
|
13
|
+
|
|
14
|
+
For multi-modal training, volumes from different collections must correspond
|
|
15
|
+
to the same subjects. This validates alignment using the obs_subject_id field
|
|
16
|
+
directly (no string parsing).
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
collections: Dict mapping collection names to VolumeCollection instances.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
ValueError: If collections have different volume counts or mismatched subjects.
|
|
23
|
+
"""
|
|
24
|
+
if len(collections) < 2:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
names = list(collections.keys())
|
|
28
|
+
first_name = names[0]
|
|
29
|
+
first_coll = collections[first_name]
|
|
30
|
+
n_volumes = len(first_coll)
|
|
31
|
+
|
|
32
|
+
first_subjects = set(first_coll.obs_subject_ids)
|
|
33
|
+
|
|
34
|
+
for name in names[1:]:
|
|
35
|
+
coll = collections[name]
|
|
36
|
+
if len(coll) != n_volumes:
|
|
37
|
+
raise ValueError(f"Collection '{name}' has {len(coll)} volumes, expected {n_volumes}")
|
|
38
|
+
|
|
39
|
+
mod_subjects = set(coll.obs_subject_ids)
|
|
40
|
+
|
|
41
|
+
if mod_subjects != first_subjects:
|
|
42
|
+
missing = first_subjects - mod_subjects
|
|
43
|
+
extra = mod_subjects - first_subjects
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Subject mismatch for collection '{name}': "
|
|
46
|
+
f"missing={list(missing)[:3]}, extra={list(extra)[:3]}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def validate_uniform_shapes(collections: dict[str, VolumeCollection]) -> tuple[int, int, int]:
|
|
51
|
+
"""Validate all collections have uniform shapes and return the common shape.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
collections: Dict mapping collection names to VolumeCollection instances.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Common volume shape (X, Y, Z).
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If any collection has non-uniform shapes or shapes don't match.
|
|
61
|
+
"""
|
|
62
|
+
shape: tuple[int, int, int] | None = None
|
|
63
|
+
|
|
64
|
+
for name, coll in collections.items():
|
|
65
|
+
if not coll.is_uniform:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Collection '{name}' has heterogeneous shapes. "
|
|
68
|
+
f"Resample to uniform dimensions before ML training."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
coll_shape = coll.shape
|
|
72
|
+
if coll_shape is None:
|
|
73
|
+
raise ValueError(f"Collection '{name}' has no shape metadata.")
|
|
74
|
+
|
|
75
|
+
if shape is None:
|
|
76
|
+
shape = coll_shape
|
|
77
|
+
elif coll_shape != shape:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Shape mismatch: collection '{name}' has shape {coll_shape}, " f"expected {shape}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if shape is None:
|
|
83
|
+
raise ValueError("No collections provided")
|
|
84
|
+
|
|
85
|
+
return shape
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Worker initialization for DataLoader multiprocessing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from radiobject.parallel import create_worker_ctx
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def worker_init_fn(worker_id: int) -> None:
|
|
9
|
+
"""Initialize TileDB context for each DataLoader worker."""
|
|
10
|
+
_ = create_worker_ctx()
|