radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1182 @@
|
|
|
1
|
+
"""VolumeCollection - organizes volumes with consistent dimensions indexed by obs_id."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Sequence, overload
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import numpy.typing as npt
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import tiledb
|
|
14
|
+
|
|
15
|
+
from radiobject._types import TransformFn
|
|
16
|
+
from radiobject.ctx import ctx as global_ctx
|
|
17
|
+
from radiobject.dataframe import Dataframe
|
|
18
|
+
from radiobject.imaging_metadata import (
|
|
19
|
+
DicomMetadata,
|
|
20
|
+
NiftiMetadata,
|
|
21
|
+
extract_dicom_metadata,
|
|
22
|
+
extract_nifti_metadata,
|
|
23
|
+
infer_series_type,
|
|
24
|
+
)
|
|
25
|
+
from radiobject.indexing import Index
|
|
26
|
+
from radiobject.parallel import WriteResult, create_worker_ctx, map_on_threads
|
|
27
|
+
from radiobject.volume import Volume
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
|
|
31
|
+
from radiobject.query import CollectionQuery
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _normalize_index(idx: int, length: int) -> int:
|
|
35
|
+
"""Convert negative index to positive and validate bounds."""
|
|
36
|
+
if idx < 0:
|
|
37
|
+
idx = length + idx
|
|
38
|
+
if idx < 0 or idx >= length:
|
|
39
|
+
raise IndexError(f"Index {idx} out of range [0, {length})")
|
|
40
|
+
return idx
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def generate_obs_id(obs_subject_id: str, series_type: str) -> str:
|
|
44
|
+
"""Generate a unique obs_id from subject ID and series type."""
|
|
45
|
+
return f"{obs_subject_id}_{series_type}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _write_volumes_parallel(
|
|
49
|
+
write_fn,
|
|
50
|
+
write_args: list,
|
|
51
|
+
progress: bool,
|
|
52
|
+
desc: str,
|
|
53
|
+
) -> list[WriteResult]:
|
|
54
|
+
"""Common helper for parallel volume writes with error handling."""
|
|
55
|
+
results = map_on_threads(write_fn, write_args, progress=progress, desc=desc)
|
|
56
|
+
failures = [r for r in results if not r.success]
|
|
57
|
+
if failures:
|
|
58
|
+
raise RuntimeError(f"Volume write failed: {failures[0].error}")
|
|
59
|
+
return results
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _ILocIndexer:
|
|
63
|
+
"""Integer-location based indexer for VolumeCollection (like pandas .iloc)."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, collection: VolumeCollection):
|
|
66
|
+
self._collection = collection
|
|
67
|
+
|
|
68
|
+
def _get_volume(self, obs_id: str) -> Volume:
|
|
69
|
+
"""Construct Volume on-demand by obs_id."""
|
|
70
|
+
root = self._collection._root
|
|
71
|
+
idx = root._index.get_index(obs_id)
|
|
72
|
+
return Volume(f"{root.uri}/volumes/{idx}", ctx=root._ctx)
|
|
73
|
+
|
|
74
|
+
@overload
|
|
75
|
+
def __getitem__(self, key: int) -> Volume: ...
|
|
76
|
+
@overload
|
|
77
|
+
def __getitem__(self, key: slice) -> VolumeCollection: ...
|
|
78
|
+
@overload
|
|
79
|
+
def __getitem__(self, key: list[int]) -> VolumeCollection: ...
|
|
80
|
+
@overload
|
|
81
|
+
def __getitem__(self, key: npt.NDArray[np.bool_]) -> VolumeCollection: ...
|
|
82
|
+
|
|
83
|
+
def __getitem__(
|
|
84
|
+
self, key: int | slice | list[int] | npt.NDArray[np.bool_]
|
|
85
|
+
) -> Volume | VolumeCollection:
|
|
86
|
+
"""Index by int, slice, list of ints, or boolean mask."""
|
|
87
|
+
obs_ids = self._collection._effective_obs_ids
|
|
88
|
+
n = len(obs_ids)
|
|
89
|
+
if isinstance(key, int):
|
|
90
|
+
idx = _normalize_index(key, n)
|
|
91
|
+
return self._get_volume(obs_ids[idx])
|
|
92
|
+
|
|
93
|
+
elif isinstance(key, slice):
|
|
94
|
+
indices = list(range(*key.indices(n)))
|
|
95
|
+
selected_ids = frozenset(obs_ids[i] for i in indices)
|
|
96
|
+
return self._collection._create_view(volume_ids=selected_ids)
|
|
97
|
+
|
|
98
|
+
elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
|
|
99
|
+
if len(key) != n:
|
|
100
|
+
raise ValueError(f"Boolean mask length {len(key)} != volume count {n}")
|
|
101
|
+
indices = np.where(key)[0]
|
|
102
|
+
selected_ids = frozenset(obs_ids[int(i)] for i in indices)
|
|
103
|
+
return self._collection._create_view(volume_ids=selected_ids)
|
|
104
|
+
|
|
105
|
+
elif isinstance(key, list):
|
|
106
|
+
selected_ids = frozenset(obs_ids[_normalize_index(i, n)] for i in key)
|
|
107
|
+
return self._collection._create_view(volume_ids=selected_ids)
|
|
108
|
+
|
|
109
|
+
raise TypeError(
|
|
110
|
+
f"iloc indices must be int, slice, list[int], or boolean array, got {type(key)}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class _LocIndexer:
|
|
115
|
+
"""Label-based indexer for VolumeCollection (like pandas .loc)."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, collection: VolumeCollection):
|
|
118
|
+
self._collection = collection
|
|
119
|
+
|
|
120
|
+
def _get_volume(self, obs_id: str) -> Volume:
|
|
121
|
+
"""Construct Volume on-demand by obs_id."""
|
|
122
|
+
root = self._collection._root
|
|
123
|
+
idx = root._index.get_index(obs_id)
|
|
124
|
+
return Volume(f"{root.uri}/volumes/{idx}", ctx=root._ctx)
|
|
125
|
+
|
|
126
|
+
@overload
|
|
127
|
+
def __getitem__(self, key: str) -> Volume: ...
|
|
128
|
+
@overload
|
|
129
|
+
def __getitem__(self, key: list[str]) -> VolumeCollection: ...
|
|
130
|
+
|
|
131
|
+
def __getitem__(self, key: str | list[str]) -> Volume | VolumeCollection:
|
|
132
|
+
"""Index by obs_id string or list of obs_id strings."""
|
|
133
|
+
if isinstance(key, str):
|
|
134
|
+
# Validate the key is in the effective set
|
|
135
|
+
if self._collection.is_view and key not in self._collection._volume_ids:
|
|
136
|
+
raise KeyError(f"obs_id '{key}' not in view")
|
|
137
|
+
return self._get_volume(key)
|
|
138
|
+
|
|
139
|
+
elif isinstance(key, list):
|
|
140
|
+
selected_ids = frozenset(key)
|
|
141
|
+
# Validate all keys are in the effective set
|
|
142
|
+
if self._collection.is_view:
|
|
143
|
+
invalid = selected_ids - self._collection._volume_ids
|
|
144
|
+
if invalid:
|
|
145
|
+
raise KeyError(f"obs_ids not in view: {sorted(invalid)[:5]}")
|
|
146
|
+
return self._collection._create_view(volume_ids=selected_ids)
|
|
147
|
+
|
|
148
|
+
raise TypeError(f"loc indices must be str or list[str], got {type(key)}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class VolumeCollection:
|
|
152
|
+
"""TileDB-backed volume collection indexed by obs_id. Supports uniform or heterogeneous shapes."""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
uri: str | None,
|
|
157
|
+
ctx: tiledb.Ctx | None = None,
|
|
158
|
+
*,
|
|
159
|
+
_source: VolumeCollection | None = None,
|
|
160
|
+
_volume_ids: frozenset[str] | None = None,
|
|
161
|
+
):
|
|
162
|
+
self._uri: str | None = uri
|
|
163
|
+
self._ctx: tiledb.Ctx | None = ctx
|
|
164
|
+
self._source: VolumeCollection | None = _source
|
|
165
|
+
self._volume_ids: frozenset[str] | None = _volume_ids
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def uri(self) -> str:
|
|
169
|
+
"""URI of the underlying storage (raises if view without storage)."""
|
|
170
|
+
if self._uri is not None:
|
|
171
|
+
return self._uri
|
|
172
|
+
if self._source is not None:
|
|
173
|
+
return self._source.uri
|
|
174
|
+
raise ValueError("VolumeCollection view has no URI. Call materialize(uri) first.")
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def is_view(self) -> bool:
|
|
178
|
+
"""True if this VolumeCollection is a filtered view of another."""
|
|
179
|
+
return self._source is not None
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def _root(self) -> VolumeCollection:
|
|
183
|
+
"""The original attached VolumeCollection (follows source chain)."""
|
|
184
|
+
return self._source._root if self._source else self
|
|
185
|
+
|
|
186
|
+
def _check_not_view(self, operation: str) -> None:
|
|
187
|
+
"""Raise if this is a view (views are immutable)."""
|
|
188
|
+
if self.is_view:
|
|
189
|
+
raise TypeError(f"Cannot {operation} on a view. Call materialize(uri) first.")
|
|
190
|
+
|
|
191
|
+
def _create_view(self, volume_ids: frozenset[str]) -> VolumeCollection:
|
|
192
|
+
"""Create a view with the given volume IDs, intersecting with current filter."""
|
|
193
|
+
if self._volume_ids is not None:
|
|
194
|
+
volume_ids = self._volume_ids & volume_ids
|
|
195
|
+
return VolumeCollection(
|
|
196
|
+
uri=None,
|
|
197
|
+
ctx=self._ctx,
|
|
198
|
+
_source=self._root,
|
|
199
|
+
_volume_ids=volume_ids,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def _effective_obs_ids(self) -> list[str]:
|
|
204
|
+
"""Get the list of obs_ids for this collection (filtered if view)."""
|
|
205
|
+
root = self._root
|
|
206
|
+
all_ids = list(root._index.keys)
|
|
207
|
+
if self._volume_ids is not None:
|
|
208
|
+
return [obs_id for obs_id in all_ids if obs_id in self._volume_ids]
|
|
209
|
+
return all_ids
|
|
210
|
+
|
|
211
|
+
def _effective_ctx(self) -> tiledb.Ctx:
|
|
212
|
+
return self._ctx if self._ctx else global_ctx()
|
|
213
|
+
|
|
214
|
+
def copy(self) -> VolumeCollection:
|
|
215
|
+
"""Create detached copy of this collection (views remain views)."""
|
|
216
|
+
if self.is_view:
|
|
217
|
+
return VolumeCollection(
|
|
218
|
+
uri=None,
|
|
219
|
+
ctx=self._ctx,
|
|
220
|
+
_source=self._root,
|
|
221
|
+
_volume_ids=self._volume_ids,
|
|
222
|
+
)
|
|
223
|
+
return VolumeCollection(self._uri, ctx=self._ctx)
|
|
224
|
+
|
|
225
|
+
@cached_property
|
|
226
|
+
def iloc(self) -> _ILocIndexer:
|
|
227
|
+
"""Integer-location based indexing for selecting volumes by position."""
|
|
228
|
+
return _ILocIndexer(self)
|
|
229
|
+
|
|
230
|
+
@cached_property
|
|
231
|
+
def loc(self) -> _LocIndexer:
|
|
232
|
+
"""Label-based indexing for selecting volumes by obs_id."""
|
|
233
|
+
return _LocIndexer(self)
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def obs(self) -> Dataframe:
|
|
237
|
+
"""Observational metadata per volume."""
|
|
238
|
+
obs_uri = f"{self._root.uri}/obs"
|
|
239
|
+
return Dataframe(uri=obs_uri, ctx=self._ctx)
|
|
240
|
+
|
|
241
|
+
@cached_property
|
|
242
|
+
def _metadata(self) -> dict:
|
|
243
|
+
"""Cached group metadata."""
|
|
244
|
+
with tiledb.Group(self._root.uri, "r", ctx=self._effective_ctx()) as grp:
|
|
245
|
+
return dict(grp.meta)
|
|
246
|
+
|
|
247
|
+
@cached_property
|
|
248
|
+
def _index(self) -> Index:
|
|
249
|
+
"""Cached bidirectional index for obs_id lookups."""
|
|
250
|
+
n = self._metadata["n_volumes"]
|
|
251
|
+
if n == 0:
|
|
252
|
+
return Index.build([])
|
|
253
|
+
obs_data = self.obs.read()
|
|
254
|
+
return Index.build(list(obs_data["obs_id"]))
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def index(self) -> Index:
|
|
258
|
+
"""Volume index for bidirectional ID/position lookups."""
|
|
259
|
+
return self._index
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def name(self) -> str | None:
|
|
263
|
+
"""Collection name (if set during creation)."""
|
|
264
|
+
return self._metadata.get("name")
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def shape(self) -> tuple[int, int, int] | None:
|
|
268
|
+
"""Volume dimensions (X, Y, Z) if uniform, None if heterogeneous."""
|
|
269
|
+
m = self._metadata
|
|
270
|
+
if "x_dim" not in m or "y_dim" not in m or "z_dim" not in m:
|
|
271
|
+
return None
|
|
272
|
+
return (int(m["x_dim"]), int(m["y_dim"]), int(m["z_dim"]))
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def is_uniform(self) -> bool:
|
|
276
|
+
"""Whether all volumes in this collection have the same shape."""
|
|
277
|
+
return self.shape is not None
|
|
278
|
+
|
|
279
|
+
def __len__(self) -> int:
|
|
280
|
+
"""Number of volumes in collection (respects view filter)."""
|
|
281
|
+
if self._volume_ids is not None:
|
|
282
|
+
return len(self._volume_ids)
|
|
283
|
+
return int(self._metadata["n_volumes"])
|
|
284
|
+
|
|
285
|
+
def __iter__(self):
|
|
286
|
+
"""Iterate over volumes in index order (respects view filter)."""
|
|
287
|
+
for obs_id in self._effective_obs_ids:
|
|
288
|
+
yield self.loc[obs_id]
|
|
289
|
+
|
|
290
|
+
def __repr__(self) -> str:
|
|
291
|
+
"""Concise representation of the VolumeCollection."""
|
|
292
|
+
shape = self.shape
|
|
293
|
+
shape_str = "x".join(str(d) for d in shape) if shape else "heterogeneous"
|
|
294
|
+
name_part = f"'{self.name}', " if self.name else ""
|
|
295
|
+
view_part = ", view" if self.is_view else ""
|
|
296
|
+
return f"VolumeCollection({name_part}{len(self)} volumes, shape={shape_str}{view_part})"
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def obs_ids(self) -> list[str]:
|
|
300
|
+
"""All obs_id values in index order (respects view filter)."""
|
|
301
|
+
return self._effective_obs_ids
|
|
302
|
+
|
|
303
|
+
@property
|
|
304
|
+
def obs_subject_ids(self) -> list[str]:
|
|
305
|
+
"""Get obs_subject_id values for this collection (respects view filter)."""
|
|
306
|
+
obs_df = self.obs.read()
|
|
307
|
+
if self._volume_ids is not None:
|
|
308
|
+
obs_df = obs_df[obs_df["obs_id"].isin(self._volume_ids)]
|
|
309
|
+
# Maintain order consistent with obs_ids
|
|
310
|
+
effective_ids = self._effective_obs_ids
|
|
311
|
+
id_to_subject = dict(zip(obs_df["obs_id"], obs_df["obs_subject_id"]))
|
|
312
|
+
return [id_to_subject[obs_id] for obs_id in effective_ids]
|
|
313
|
+
|
|
314
|
+
def get_obs_row_by_obs_id(self, obs_id: str) -> pd.DataFrame:
|
|
315
|
+
"""Get observation row by obs_id string identifier."""
|
|
316
|
+
df = self.obs.read()
|
|
317
|
+
filtered = df[df["obs_id"] == obs_id].reset_index(drop=True)
|
|
318
|
+
return filtered
|
|
319
|
+
|
|
320
|
+
@overload
|
|
321
|
+
def __getitem__(self, key: int) -> Volume: ...
|
|
322
|
+
@overload
|
|
323
|
+
def __getitem__(self, key: str) -> Volume: ...
|
|
324
|
+
@overload
|
|
325
|
+
def __getitem__(self, key: slice) -> VolumeCollection: ...
|
|
326
|
+
@overload
|
|
327
|
+
def __getitem__(self, key: list[int]) -> VolumeCollection: ...
|
|
328
|
+
@overload
|
|
329
|
+
def __getitem__(self, key: list[str]) -> VolumeCollection: ...
|
|
330
|
+
|
|
331
|
+
def __getitem__(
|
|
332
|
+
self, key: int | str | slice | list[int] | list[str]
|
|
333
|
+
) -> Volume | VolumeCollection:
|
|
334
|
+
"""Index by int, str, slice, or list. Slices/lists return views."""
|
|
335
|
+
if isinstance(key, int):
|
|
336
|
+
return self.iloc[key]
|
|
337
|
+
elif isinstance(key, str):
|
|
338
|
+
return self.loc[key]
|
|
339
|
+
elif isinstance(key, slice):
|
|
340
|
+
return self.iloc[key]
|
|
341
|
+
elif isinstance(key, list):
|
|
342
|
+
if len(key) == 0:
|
|
343
|
+
return self._create_view(volume_ids=frozenset())
|
|
344
|
+
if isinstance(key[0], int):
|
|
345
|
+
return self.iloc[key]
|
|
346
|
+
elif isinstance(key[0], str):
|
|
347
|
+
return self.loc[key]
|
|
348
|
+
raise TypeError(f"Key must be int, str, slice, or list, got {type(key)}")
|
|
349
|
+
|
|
350
|
+
def validate(self) -> None:
|
|
351
|
+
"""Validate internal consistency of obs vs volume metadata."""
|
|
352
|
+
self._check_not_view("validate")
|
|
353
|
+
|
|
354
|
+
obs_data = self.obs.read()
|
|
355
|
+
obs_ids_in_dataframe = set(obs_data["obs_id"])
|
|
356
|
+
|
|
357
|
+
# Check each volume's obs_id against obs dataframe
|
|
358
|
+
obs_ids_in_volumes = set()
|
|
359
|
+
for i in range(len(self)):
|
|
360
|
+
vol = self.iloc[i]
|
|
361
|
+
if vol.obs_id is None:
|
|
362
|
+
raise ValueError(f"Volume at index {i} lacks required obs_id metadata")
|
|
363
|
+
obs_ids_in_volumes.add(vol.obs_id)
|
|
364
|
+
|
|
365
|
+
expected_obs_id = obs_data.iloc[i]["obs_id"]
|
|
366
|
+
if vol.obs_id != expected_obs_id:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"Position mismatch at index {i}: "
|
|
369
|
+
f"volume.obs_id={vol.obs_id}, obs.iloc[{i}]={expected_obs_id}"
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
missing_in_obs = obs_ids_in_volumes - obs_ids_in_dataframe
|
|
373
|
+
if missing_in_obs:
|
|
374
|
+
raise ValueError(f"Volumes without obs rows: {list(missing_in_obs)[:5]}")
|
|
375
|
+
|
|
376
|
+
orphan_obs = obs_ids_in_dataframe - obs_ids_in_volumes
|
|
377
|
+
if orphan_obs:
|
|
378
|
+
raise ValueError(f"Obs rows without volumes: {list(orphan_obs)[:5]}")
|
|
379
|
+
|
|
380
|
+
with tiledb.Group(f"{self._root.uri}/volumes", "r", ctx=self._effective_ctx()) as grp:
|
|
381
|
+
actual_count = len(list(grp))
|
|
382
|
+
if actual_count != self._metadata["n_volumes"]:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"n_volumes mismatch: metadata={self._metadata['n_volumes']}, actual={actual_count}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# ===== Lazy Mode (Transform Pipelines) =====
|
|
388
|
+
|
|
389
|
+
def lazy(self) -> CollectionQuery:
|
|
390
|
+
"""Enter lazy mode for transform pipelines via map()."""
|
|
391
|
+
from radiobject.query import CollectionQuery
|
|
392
|
+
|
|
393
|
+
return CollectionQuery(self._root, volume_ids=self._volume_ids)
|
|
394
|
+
|
|
395
|
+
# ===== Filtering Methods (Return Views) =====
|
|
396
|
+
|
|
397
|
+
def head(self, n: int = 5) -> VolumeCollection:
|
|
398
|
+
"""Return view of first n volumes."""
|
|
399
|
+
return self.iloc[:n]
|
|
400
|
+
|
|
401
|
+
def tail(self, n: int = 5) -> VolumeCollection:
|
|
402
|
+
"""Return view of last n volumes."""
|
|
403
|
+
total = len(self)
|
|
404
|
+
return self.iloc[max(0, total - n) :]
|
|
405
|
+
|
|
406
|
+
def sample(self, n: int = 5, seed: int | None = None) -> VolumeCollection:
|
|
407
|
+
"""Return view of n randomly sampled volumes."""
|
|
408
|
+
rng = np.random.default_rng(seed)
|
|
409
|
+
obs_ids = self._effective_obs_ids
|
|
410
|
+
n = min(n, len(obs_ids))
|
|
411
|
+
sampled = rng.choice(obs_ids, size=n, replace=False)
|
|
412
|
+
return self._create_view(volume_ids=frozenset(sampled))
|
|
413
|
+
|
|
414
|
+
def filter(self, expr: str) -> VolumeCollection:
|
|
415
|
+
"""Filter volumes using TileDB QueryCondition on obs. Returns view."""
|
|
416
|
+
matching_ids = self._resolve_filter(expr)
|
|
417
|
+
return self._create_view(volume_ids=matching_ids)
|
|
418
|
+
|
|
419
|
+
def _resolve_filter(self, expr: str) -> frozenset[str]:
|
|
420
|
+
"""Resolve filter expression to set of matching obs_ids."""
|
|
421
|
+
effective_ctx = self._effective_ctx()
|
|
422
|
+
obs_uri = f"{self._root.uri}/obs"
|
|
423
|
+
|
|
424
|
+
with tiledb.open(obs_uri, "r", ctx=effective_ctx) as arr:
|
|
425
|
+
result = arr.query(cond=expr, dims=["obs_id"])[:]
|
|
426
|
+
obs_ids = result["obs_id"]
|
|
427
|
+
matching = frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in obs_ids)
|
|
428
|
+
|
|
429
|
+
# Intersect with current view filter
|
|
430
|
+
if self._volume_ids is not None:
|
|
431
|
+
matching = matching & self._volume_ids
|
|
432
|
+
|
|
433
|
+
return matching
|
|
434
|
+
|
|
435
|
+
def map(self, fn: TransformFn) -> CollectionQuery:
|
|
436
|
+
"""Apply transform to all volumes during materialization. Returns lazy query."""
|
|
437
|
+
return self.lazy().map(fn)
|
|
438
|
+
|
|
439
|
+
# ===== Materialization =====
|
|
440
|
+
|
|
441
|
+
def materialize(
|
|
442
|
+
self,
|
|
443
|
+
uri: str,
|
|
444
|
+
name: str | None = None,
|
|
445
|
+
ctx: tiledb.Ctx | None = None,
|
|
446
|
+
) -> VolumeCollection:
|
|
447
|
+
"""Write this collection (or view) to new storage.
|
|
448
|
+
|
|
449
|
+
Creates a new VolumeCollection at the target URI containing all volumes
|
|
450
|
+
in this view. For views, only the filtered volumes are written.
|
|
451
|
+
"""
|
|
452
|
+
from radiobject.streaming import StreamingWriter
|
|
453
|
+
|
|
454
|
+
obs_ids = self._effective_obs_ids
|
|
455
|
+
if not obs_ids:
|
|
456
|
+
raise ValueError("No volumes to materialize")
|
|
457
|
+
|
|
458
|
+
# Get obs DataFrame for this view
|
|
459
|
+
obs_df = self.obs.read()
|
|
460
|
+
if self._volume_ids is not None:
|
|
461
|
+
obs_df = obs_df[obs_df["obs_id"].isin(self._volume_ids)].reset_index(drop=True)
|
|
462
|
+
|
|
463
|
+
collection_name = name or self._root.name
|
|
464
|
+
|
|
465
|
+
# Build obs schema from source
|
|
466
|
+
obs_schema: dict[str, np.dtype] = {}
|
|
467
|
+
for col in self._root.obs.columns:
|
|
468
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
469
|
+
continue
|
|
470
|
+
obs_schema[col] = self._root.obs.dtypes[col]
|
|
471
|
+
|
|
472
|
+
effective_ctx = ctx if ctx else self._effective_ctx()
|
|
473
|
+
|
|
474
|
+
with StreamingWriter(
|
|
475
|
+
uri=uri,
|
|
476
|
+
shape=self._root.shape,
|
|
477
|
+
obs_schema=obs_schema,
|
|
478
|
+
name=collection_name,
|
|
479
|
+
ctx=effective_ctx,
|
|
480
|
+
) as writer:
|
|
481
|
+
for obs_id in obs_ids:
|
|
482
|
+
vol = self.loc[obs_id]
|
|
483
|
+
data = vol.to_numpy()
|
|
484
|
+
|
|
485
|
+
obs_row = obs_df[obs_df["obs_id"] == obs_id].iloc[0]
|
|
486
|
+
attrs = {k: v for k, v in obs_row.items() if k not in ("obs_id", "obs_subject_id")}
|
|
487
|
+
writer.write_volume(
|
|
488
|
+
data=data,
|
|
489
|
+
obs_id=obs_id,
|
|
490
|
+
obs_subject_id=obs_row["obs_subject_id"],
|
|
491
|
+
**attrs,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
return VolumeCollection(uri, ctx=effective_ctx)
|
|
495
|
+
|
|
496
|
+
# ===== ML Integration =====
|
|
497
|
+
|
|
498
|
+
def to_dataset(
|
|
499
|
+
self,
|
|
500
|
+
patch_size: tuple[int, int, int] | None = None,
|
|
501
|
+
labels: pd.DataFrame | dict | str | None = None,
|
|
502
|
+
transform: Callable[..., Any] | None = None,
|
|
503
|
+
) -> VolumeCollectionDataset:
|
|
504
|
+
"""Create PyTorch Dataset from this collection.
|
|
505
|
+
|
|
506
|
+
Convenience method for ML training integration.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
patch_size: If provided, extract random patches of this size.
|
|
510
|
+
labels: Label source. Can be:
|
|
511
|
+
- str: Column name in this collection's obs DataFrame
|
|
512
|
+
- pd.DataFrame: With obs_id as column/index and label values
|
|
513
|
+
- dict[str, Any]: Mapping from obs_id to label
|
|
514
|
+
- None: No labels
|
|
515
|
+
transform: Transform function applied to each sample.
|
|
516
|
+
MONAI dict transforms (e.g., RandFlipd) work directly.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
VolumeCollectionDataset ready for use with DataLoader.
|
|
520
|
+
|
|
521
|
+
Example::
|
|
522
|
+
|
|
523
|
+
# Full volumes with labels from obs column
|
|
524
|
+
dataset = radi.CT.to_dataset(labels="has_tumor")
|
|
525
|
+
|
|
526
|
+
# Patch extraction
|
|
527
|
+
dataset = radi.CT.to_dataset(patch_size=(64, 64, 64), labels="grade")
|
|
528
|
+
|
|
529
|
+
# With MONAI transforms
|
|
530
|
+
from monai.transforms import NormalizeIntensityd
|
|
531
|
+
dataset = radi.CT.to_dataset(
|
|
532
|
+
labels="has_tumor",
|
|
533
|
+
transform=NormalizeIntensityd(keys="image"),
|
|
534
|
+
)
|
|
535
|
+
"""
|
|
536
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
537
|
+
from radiobject.ml.datasets.collection_dataset import VolumeCollectionDataset
|
|
538
|
+
|
|
539
|
+
loading_mode = LoadingMode.PATCH if patch_size else LoadingMode.FULL_VOLUME
|
|
540
|
+
config = DatasetConfig(loading_mode=loading_mode, patch_size=patch_size)
|
|
541
|
+
|
|
542
|
+
return VolumeCollectionDataset(self, config=config, labels=labels, transform=transform)
|
|
543
|
+
|
|
544
|
+
# ===== Append Operations =====
|
|
545
|
+
|
|
546
|
+
def append(
|
|
547
|
+
self,
|
|
548
|
+
niftis: Sequence[tuple[str | Path, str]] | None = None,
|
|
549
|
+
dicom_dirs: Sequence[tuple[str | Path, str]] | None = None,
|
|
550
|
+
reorient: bool | None = None,
|
|
551
|
+
progress: bool = False,
|
|
552
|
+
) -> None:
|
|
553
|
+
"""Append new volumes atomically.
|
|
554
|
+
|
|
555
|
+
Volume data and obs metadata are written together to maintain consistency.
|
|
556
|
+
Cannot be called on views - use materialize() first.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
niftis: List of (nifti_path, obs_subject_id) tuples
|
|
560
|
+
dicom_dirs: List of (dicom_dir, obs_subject_id) tuples
|
|
561
|
+
reorient: Reorient to canonical orientation (None uses config default)
|
|
562
|
+
progress: Show tqdm progress bar during volume writes
|
|
563
|
+
|
|
564
|
+
Example:
|
|
565
|
+
radi.T1w.append(
|
|
566
|
+
niftis=[
|
|
567
|
+
("sub101_T1w.nii.gz", "sub-101"),
|
|
568
|
+
("sub102_T1w.nii.gz", "sub-102"),
|
|
569
|
+
],
|
|
570
|
+
)
|
|
571
|
+
"""
|
|
572
|
+
self._check_not_view("append")
|
|
573
|
+
|
|
574
|
+
if niftis is None and dicom_dirs is None:
|
|
575
|
+
raise ValueError("Must provide either niftis or dicom_dirs")
|
|
576
|
+
if niftis is not None and dicom_dirs is not None:
|
|
577
|
+
raise ValueError("Cannot provide both niftis and dicom_dirs")
|
|
578
|
+
|
|
579
|
+
effective_ctx = self._effective_ctx()
|
|
580
|
+
current_count = len(self)
|
|
581
|
+
|
|
582
|
+
if niftis is not None:
|
|
583
|
+
self._append_niftis(niftis, reorient, effective_ctx, current_count, progress)
|
|
584
|
+
else:
|
|
585
|
+
self._append_dicoms(dicom_dirs, reorient, effective_ctx, current_count, progress)
|
|
586
|
+
|
|
587
|
+
# Invalidate cached properties
|
|
588
|
+
for prop in ("_index", "_metadata"):
|
|
589
|
+
if prop in self.__dict__:
|
|
590
|
+
del self.__dict__[prop]
|
|
591
|
+
|
|
592
|
+
def _append_niftis(
|
|
593
|
+
self,
|
|
594
|
+
niftis: Sequence[tuple[str | Path, str]],
|
|
595
|
+
reorient: bool | None,
|
|
596
|
+
effective_ctx: tiledb.Ctx,
|
|
597
|
+
start_index: int,
|
|
598
|
+
progress: bool = False,
|
|
599
|
+
) -> None:
|
|
600
|
+
"""Internal: append NIfTI files to this collection."""
|
|
601
|
+
# Extract metadata (no dimension validation for heterogeneous collections)
|
|
602
|
+
metadata_list: list[tuple[Path, str, NiftiMetadata, str]] = []
|
|
603
|
+
|
|
604
|
+
for nifti_path, obs_subject_id in niftis:
|
|
605
|
+
path = Path(nifti_path)
|
|
606
|
+
if not path.exists():
|
|
607
|
+
raise FileNotFoundError(f"NIfTI file not found: {path}")
|
|
608
|
+
|
|
609
|
+
metadata = extract_nifti_metadata(path)
|
|
610
|
+
series_type = infer_series_type(path)
|
|
611
|
+
|
|
612
|
+
# Only validate dimensions if collection has uniform shape requirement
|
|
613
|
+
if self.is_uniform and metadata.dimensions != self.shape:
|
|
614
|
+
raise ValueError(
|
|
615
|
+
f"Dimension mismatch: {path.name} has shape {metadata.dimensions}, "
|
|
616
|
+
f"expected {self.shape}"
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
metadata_list.append((path, obs_subject_id, metadata, series_type))
|
|
620
|
+
|
|
621
|
+
# Check for duplicate obs_ids
|
|
622
|
+
existing_obs_ids = set(self.obs_ids)
|
|
623
|
+
new_obs_ids = {generate_obs_id(sid, st) for _, sid, _, st in metadata_list}
|
|
624
|
+
duplicates = existing_obs_ids & new_obs_ids
|
|
625
|
+
if duplicates:
|
|
626
|
+
raise ValueError(f"obs_ids already exist: {sorted(duplicates)[:5]}")
|
|
627
|
+
|
|
628
|
+
# Write volumes
|
|
629
|
+
def write_volume(args: tuple[int, Path, str, NiftiMetadata, str]) -> WriteResult:
|
|
630
|
+
idx, path, obs_subject_id, metadata, series_type = args
|
|
631
|
+
worker_ctx = create_worker_ctx(self._ctx)
|
|
632
|
+
volume_uri = f"{self.uri}/volumes/{idx}"
|
|
633
|
+
obs_id = generate_obs_id(obs_subject_id, series_type)
|
|
634
|
+
try:
|
|
635
|
+
vol = Volume.from_nifti(volume_uri, path, ctx=worker_ctx, reorient=reorient)
|
|
636
|
+
vol.set_obs_id(obs_id)
|
|
637
|
+
return WriteResult(idx, volume_uri, obs_id, success=True)
|
|
638
|
+
except Exception as e:
|
|
639
|
+
return WriteResult(idx, volume_uri, obs_id, success=False, error=e)
|
|
640
|
+
|
|
641
|
+
write_args = [
|
|
642
|
+
(start_index + i, path, sid, meta, st)
|
|
643
|
+
for i, (path, sid, meta, st) in enumerate(metadata_list)
|
|
644
|
+
]
|
|
645
|
+
results = _write_volumes_parallel(
|
|
646
|
+
write_volume, write_args, progress, f"Writing {self.name or 'volumes'}"
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# Register volumes with group
|
|
650
|
+
with tiledb.Group(f"{self.uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
651
|
+
for result in results:
|
|
652
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
653
|
+
|
|
654
|
+
# Build and write obs rows
|
|
655
|
+
obs_rows: list[dict] = []
|
|
656
|
+
for path, obs_subject_id, metadata, series_type in metadata_list:
|
|
657
|
+
obs_id = generate_obs_id(obs_subject_id, series_type)
|
|
658
|
+
obs_rows.append(metadata.to_obs_dict(obs_id, obs_subject_id, series_type))
|
|
659
|
+
|
|
660
|
+
obs_df = pd.DataFrame(obs_rows)
|
|
661
|
+
obs_uri = f"{self.uri}/obs"
|
|
662
|
+
obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
|
|
663
|
+
obs_ids = obs_df["obs_id"].astype(str).to_numpy()
|
|
664
|
+
|
|
665
|
+
# Only write attributes that exist in the target schema
|
|
666
|
+
existing_columns = set(self.obs.columns)
|
|
667
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
668
|
+
attr_data = {
|
|
669
|
+
col: obs_df[col].to_numpy()
|
|
670
|
+
for col in obs_df.columns
|
|
671
|
+
if col not in ("obs_subject_id", "obs_id") and col in existing_columns
|
|
672
|
+
}
|
|
673
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
674
|
+
|
|
675
|
+
# Update n_volumes metadata
|
|
676
|
+
new_count = start_index + len(niftis)
|
|
677
|
+
with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
|
|
678
|
+
grp.meta["n_volumes"] = new_count
|
|
679
|
+
|
|
680
|
+
def _append_dicoms(
|
|
681
|
+
self,
|
|
682
|
+
dicom_dirs: Sequence[tuple[str | Path, str]],
|
|
683
|
+
reorient: bool | None,
|
|
684
|
+
effective_ctx: tiledb.Ctx,
|
|
685
|
+
start_index: int,
|
|
686
|
+
progress: bool = False,
|
|
687
|
+
) -> None:
|
|
688
|
+
"""Internal: append DICOM series to this collection."""
|
|
689
|
+
metadata_list: list[tuple[Path, str, DicomMetadata]] = []
|
|
690
|
+
|
|
691
|
+
for dicom_dir, obs_subject_id in dicom_dirs:
|
|
692
|
+
path = Path(dicom_dir)
|
|
693
|
+
if not path.exists():
|
|
694
|
+
raise FileNotFoundError(f"DICOM directory not found: {path}")
|
|
695
|
+
|
|
696
|
+
metadata = extract_dicom_metadata(path)
|
|
697
|
+
dims = metadata.dimensions
|
|
698
|
+
shape = (dims[1], dims[0], dims[2])
|
|
699
|
+
|
|
700
|
+
# Only validate dimensions if collection has uniform shape requirement
|
|
701
|
+
if self.is_uniform and shape != self.shape:
|
|
702
|
+
raise ValueError(
|
|
703
|
+
f"Dimension mismatch: {path.name} has shape {shape}, " f"expected {self.shape}"
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
metadata_list.append((path, obs_subject_id, metadata))
|
|
707
|
+
|
|
708
|
+
# Check for duplicate obs_ids
|
|
709
|
+
existing_obs_ids = set(self.obs_ids)
|
|
710
|
+
new_obs_ids = {generate_obs_id(sid, meta.modality) for _, sid, meta in metadata_list}
|
|
711
|
+
duplicates = existing_obs_ids & new_obs_ids
|
|
712
|
+
if duplicates:
|
|
713
|
+
raise ValueError(f"obs_ids already exist: {sorted(duplicates)[:5]}")
|
|
714
|
+
|
|
715
|
+
# Write volumes
|
|
716
|
+
def write_volume(args: tuple[int, Path, str, DicomMetadata]) -> WriteResult:
|
|
717
|
+
idx, path, obs_subject_id, metadata = args
|
|
718
|
+
worker_ctx = create_worker_ctx(self._ctx)
|
|
719
|
+
volume_uri = f"{self.uri}/volumes/{idx}"
|
|
720
|
+
obs_id = generate_obs_id(obs_subject_id, metadata.modality)
|
|
721
|
+
try:
|
|
722
|
+
vol = Volume.from_dicom(volume_uri, path, ctx=worker_ctx, reorient=reorient)
|
|
723
|
+
vol.set_obs_id(obs_id)
|
|
724
|
+
return WriteResult(idx, volume_uri, obs_id, success=True)
|
|
725
|
+
except Exception as e:
|
|
726
|
+
return WriteResult(idx, volume_uri, obs_id, success=False, error=e)
|
|
727
|
+
|
|
728
|
+
write_args = [
|
|
729
|
+
(start_index + i, path, sid, meta) for i, (path, sid, meta) in enumerate(metadata_list)
|
|
730
|
+
]
|
|
731
|
+
results = _write_volumes_parallel(
|
|
732
|
+
write_volume, write_args, progress, f"Writing {self.name or 'volumes'}"
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Register volumes with group
|
|
736
|
+
with tiledb.Group(f"{self.uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
737
|
+
for result in results:
|
|
738
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
739
|
+
|
|
740
|
+
# Build and write obs rows
|
|
741
|
+
obs_rows: list[dict] = []
|
|
742
|
+
for path, obs_subject_id, metadata in metadata_list:
|
|
743
|
+
obs_id = generate_obs_id(obs_subject_id, metadata.modality)
|
|
744
|
+
obs_rows.append(metadata.to_obs_dict(obs_id, obs_subject_id))
|
|
745
|
+
|
|
746
|
+
obs_df = pd.DataFrame(obs_rows)
|
|
747
|
+
for col in ["kvp", "exposure", "repetition_time", "echo_time", "magnetic_field_strength"]:
|
|
748
|
+
if col in obs_df.columns:
|
|
749
|
+
obs_df[col] = obs_df[col].fillna(np.nan)
|
|
750
|
+
|
|
751
|
+
obs_uri = f"{self.uri}/obs"
|
|
752
|
+
obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
|
|
753
|
+
obs_ids = obs_df["obs_id"].astype(str).to_numpy()
|
|
754
|
+
|
|
755
|
+
# Only write attributes that exist in the target schema
|
|
756
|
+
existing_columns = set(self.obs.columns)
|
|
757
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
758
|
+
attr_data = {
|
|
759
|
+
col: obs_df[col].to_numpy()
|
|
760
|
+
for col in obs_df.columns
|
|
761
|
+
if col not in ("obs_subject_id", "obs_id") and col in existing_columns
|
|
762
|
+
}
|
|
763
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
764
|
+
|
|
765
|
+
# Update n_volumes metadata
|
|
766
|
+
new_count = start_index + len(dicom_dirs)
|
|
767
|
+
with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
|
|
768
|
+
grp.meta["n_volumes"] = new_count
|
|
769
|
+
|
|
770
|
+
@classmethod
|
|
771
|
+
def _create(
|
|
772
|
+
cls,
|
|
773
|
+
uri: str,
|
|
774
|
+
shape: tuple[int, int, int] | None = None,
|
|
775
|
+
obs_schema: dict[str, np.dtype] | None = None,
|
|
776
|
+
n_volumes: int = 0,
|
|
777
|
+
name: str | None = None,
|
|
778
|
+
ctx: tiledb.Ctx | None = None,
|
|
779
|
+
) -> VolumeCollection:
|
|
780
|
+
"""Internal: create empty collection with optional uniform dimensions.
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
uri: Target URI for the collection
|
|
784
|
+
shape: If provided, enforces uniform dimensions. If None, allows heterogeneous shapes.
|
|
785
|
+
obs_schema: Schema for volume-level obs attributes
|
|
786
|
+
n_volumes: Initial volume count (usually 0)
|
|
787
|
+
name: Collection name
|
|
788
|
+
ctx: TileDB context
|
|
789
|
+
"""
|
|
790
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
791
|
+
|
|
792
|
+
tiledb.Group.create(uri, ctx=effective_ctx)
|
|
793
|
+
|
|
794
|
+
volumes_uri = f"{uri}/volumes"
|
|
795
|
+
tiledb.Group.create(volumes_uri, ctx=effective_ctx)
|
|
796
|
+
|
|
797
|
+
obs_uri = f"{uri}/obs"
|
|
798
|
+
Dataframe.create(obs_uri, schema=obs_schema or {}, ctx=ctx)
|
|
799
|
+
|
|
800
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
801
|
+
if shape is not None:
|
|
802
|
+
grp.meta["x_dim"] = shape[0]
|
|
803
|
+
grp.meta["y_dim"] = shape[1]
|
|
804
|
+
grp.meta["z_dim"] = shape[2]
|
|
805
|
+
grp.meta["n_volumes"] = n_volumes
|
|
806
|
+
if name is not None:
|
|
807
|
+
grp.meta["name"] = name
|
|
808
|
+
grp.add(volumes_uri, name="volumes")
|
|
809
|
+
grp.add(obs_uri, name="obs")
|
|
810
|
+
|
|
811
|
+
return cls(uri, ctx=ctx)
|
|
812
|
+
|
|
813
|
+
@classmethod
|
|
814
|
+
def _from_volumes(
|
|
815
|
+
cls,
|
|
816
|
+
uri: str,
|
|
817
|
+
volumes: Sequence[tuple[str, Volume]],
|
|
818
|
+
obs_data: pd.DataFrame | None = None,
|
|
819
|
+
name: str | None = None,
|
|
820
|
+
ctx: tiledb.Ctx | None = None,
|
|
821
|
+
) -> VolumeCollection:
|
|
822
|
+
"""Internal: create collection from existing volumes (write-once)."""
|
|
823
|
+
if not volumes:
|
|
824
|
+
raise ValueError("At least one volume is required")
|
|
825
|
+
|
|
826
|
+
first_shape = volumes[0][1].shape[:3]
|
|
827
|
+
for obs_id, vol in volumes:
|
|
828
|
+
if vol.shape[:3] != first_shape:
|
|
829
|
+
raise ValueError(
|
|
830
|
+
f"Volume '{obs_id}' has shape {vol.shape[:3]}, expected {first_shape}"
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
834
|
+
|
|
835
|
+
obs_schema = None
|
|
836
|
+
if obs_data is not None:
|
|
837
|
+
obs_schema = {}
|
|
838
|
+
for col in obs_data.columns:
|
|
839
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
840
|
+
continue
|
|
841
|
+
dtype = obs_data[col].to_numpy().dtype
|
|
842
|
+
if dtype == np.dtype("O"):
|
|
843
|
+
dtype = np.dtype("U64")
|
|
844
|
+
obs_schema[col] = dtype
|
|
845
|
+
|
|
846
|
+
cls._create(
|
|
847
|
+
uri,
|
|
848
|
+
shape=first_shape,
|
|
849
|
+
obs_schema=obs_schema,
|
|
850
|
+
n_volumes=len(volumes),
|
|
851
|
+
name=name,
|
|
852
|
+
ctx=ctx,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
856
|
+
grp.meta["n_volumes"] = len(volumes)
|
|
857
|
+
|
|
858
|
+
def write_volume(args: tuple[int, str, Volume]) -> WriteResult:
|
|
859
|
+
idx, obs_id, vol = args
|
|
860
|
+
worker_ctx = create_worker_ctx(ctx)
|
|
861
|
+
volume_uri = f"{uri}/volumes/{idx}"
|
|
862
|
+
try:
|
|
863
|
+
data = vol.to_numpy()
|
|
864
|
+
new_vol = Volume.from_numpy(volume_uri, data, ctx=worker_ctx)
|
|
865
|
+
new_vol.set_obs_id(obs_id)
|
|
866
|
+
return WriteResult(idx, volume_uri, obs_id, success=True)
|
|
867
|
+
except Exception as e:
|
|
868
|
+
return WriteResult(idx, volume_uri, obs_id, success=False, error=e)
|
|
869
|
+
|
|
870
|
+
write_args = [(idx, obs_id, vol) for idx, (obs_id, vol) in enumerate(volumes)]
|
|
871
|
+
results = _write_volumes_parallel(
|
|
872
|
+
write_volume, write_args, progress=False, desc="Writing volumes"
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
with tiledb.Group(f"{uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
876
|
+
for result in results:
|
|
877
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
878
|
+
|
|
879
|
+
obs_ids = np.array([obs_id for obs_id, _ in volumes])
|
|
880
|
+
if obs_data is not None and "obs_subject_id" in obs_data.columns:
|
|
881
|
+
obs_subject_ids = obs_data["obs_subject_id"].astype(str).to_numpy()
|
|
882
|
+
else:
|
|
883
|
+
obs_subject_ids = obs_ids.copy()
|
|
884
|
+
|
|
885
|
+
obs_uri = f"{uri}/obs"
|
|
886
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
887
|
+
attr_data = {}
|
|
888
|
+
if obs_data is not None:
|
|
889
|
+
for col in obs_data.columns:
|
|
890
|
+
if col not in ("obs_id", "obs_subject_id"):
|
|
891
|
+
attr_data[col] = obs_data[col].to_numpy()
|
|
892
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
893
|
+
|
|
894
|
+
return cls(uri, ctx=ctx)
|
|
895
|
+
|
|
896
|
+
@classmethod
|
|
897
|
+
def from_niftis(
|
|
898
|
+
cls,
|
|
899
|
+
uri: str,
|
|
900
|
+
niftis: Sequence[tuple[str | Path, str]],
|
|
901
|
+
reorient: bool | None = None,
|
|
902
|
+
validate_dimensions: bool = True,
|
|
903
|
+
valid_subject_ids: set[str] | None = None,
|
|
904
|
+
name: str | None = None,
|
|
905
|
+
ctx: tiledb.Ctx | None = None,
|
|
906
|
+
progress: bool = False,
|
|
907
|
+
) -> VolumeCollection:
|
|
908
|
+
"""Create VolumeCollection from NIfTI files with full metadata capture.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
uri: Target URI for the VolumeCollection
|
|
912
|
+
niftis: List of (nifti_path, obs_subject_id) tuples
|
|
913
|
+
reorient: Reorient to canonical orientation (None uses config default)
|
|
914
|
+
validate_dimensions: Raise if dimensions are inconsistent
|
|
915
|
+
valid_subject_ids: Optional whitelist for FK validation
|
|
916
|
+
name: Collection name (stored in metadata)
|
|
917
|
+
ctx: TileDB context
|
|
918
|
+
progress: Show tqdm progress bar during volume writes
|
|
919
|
+
|
|
920
|
+
Returns:
|
|
921
|
+
VolumeCollection with obs containing NIfTI metadata
|
|
922
|
+
"""
|
|
923
|
+
if not niftis:
|
|
924
|
+
raise ValueError("At least one NIfTI file is required")
|
|
925
|
+
|
|
926
|
+
# Validate subject IDs if whitelist provided
|
|
927
|
+
if valid_subject_ids is not None:
|
|
928
|
+
nifti_subject_ids = {sid for _, sid in niftis}
|
|
929
|
+
invalid = nifti_subject_ids - valid_subject_ids
|
|
930
|
+
if invalid:
|
|
931
|
+
raise ValueError(f"Invalid obs_subject_ids: {sorted(invalid)[:5]}")
|
|
932
|
+
|
|
933
|
+
# Extract metadata and validate dimensions
|
|
934
|
+
metadata_list: list[tuple[Path, str, NiftiMetadata, str]] = []
|
|
935
|
+
first_shape: tuple[int, int, int] | None = None
|
|
936
|
+
|
|
937
|
+
for nifti_path, obs_subject_id in niftis:
|
|
938
|
+
path = Path(nifti_path)
|
|
939
|
+
if not path.exists():
|
|
940
|
+
raise FileNotFoundError(f"NIfTI file not found: {path}")
|
|
941
|
+
|
|
942
|
+
metadata = extract_nifti_metadata(path)
|
|
943
|
+
series_type = infer_series_type(path)
|
|
944
|
+
|
|
945
|
+
shape = metadata.dimensions
|
|
946
|
+
if first_shape is None:
|
|
947
|
+
first_shape = shape
|
|
948
|
+
all_same_shape = True
|
|
949
|
+
elif shape != first_shape:
|
|
950
|
+
if validate_dimensions:
|
|
951
|
+
raise ValueError(
|
|
952
|
+
f"Dimension mismatch: {path.name} has shape {shape}, "
|
|
953
|
+
f"expected {first_shape}"
|
|
954
|
+
)
|
|
955
|
+
all_same_shape = False
|
|
956
|
+
|
|
957
|
+
metadata_list.append((path, obs_subject_id, metadata, series_type))
|
|
958
|
+
|
|
959
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
960
|
+
# Only set uniform shape if all volumes have same dimensions
|
|
961
|
+
collection_shape = first_shape if all_same_shape else None
|
|
962
|
+
|
|
963
|
+
# Build obs schema from NiftiMetadata fields (tuples serialized as strings)
|
|
964
|
+
obs_schema: dict[str, np.dtype] = {
|
|
965
|
+
"series_type": np.dtype("U32"),
|
|
966
|
+
"voxel_spacing": np.dtype("U64"), # Tuple serialized as string
|
|
967
|
+
"dimensions": np.dtype("U64"), # Tuple serialized as string
|
|
968
|
+
"datatype": np.int32,
|
|
969
|
+
"bitpix": np.int32,
|
|
970
|
+
"scl_slope": np.float64,
|
|
971
|
+
"scl_inter": np.float64,
|
|
972
|
+
"xyzt_units": np.int32,
|
|
973
|
+
"spatial_units": np.dtype("U16"),
|
|
974
|
+
"qform_code": np.int32,
|
|
975
|
+
"sform_code": np.int32,
|
|
976
|
+
"axcodes": np.dtype("U8"),
|
|
977
|
+
"affine_json": np.dtype("U512"),
|
|
978
|
+
"orientation_source": np.dtype("U32"),
|
|
979
|
+
"source_path": np.dtype("U512"),
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
# Create collection
|
|
983
|
+
cls._create(
|
|
984
|
+
uri,
|
|
985
|
+
shape=collection_shape,
|
|
986
|
+
obs_schema=obs_schema,
|
|
987
|
+
n_volumes=len(niftis),
|
|
988
|
+
name=name,
|
|
989
|
+
ctx=ctx,
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
# Write volumes in parallel
|
|
993
|
+
def write_volume(args: tuple[int, Path, str, NiftiMetadata, str]) -> WriteResult:
|
|
994
|
+
idx, path, obs_subject_id, metadata, series_type = args
|
|
995
|
+
worker_ctx = create_worker_ctx(ctx)
|
|
996
|
+
volume_uri = f"{uri}/volumes/{idx}"
|
|
997
|
+
obs_id = generate_obs_id(obs_subject_id, series_type)
|
|
998
|
+
try:
|
|
999
|
+
vol = Volume.from_nifti(volume_uri, path, ctx=worker_ctx, reorient=reorient)
|
|
1000
|
+
vol.set_obs_id(obs_id)
|
|
1001
|
+
return WriteResult(idx, volume_uri, obs_id, success=True)
|
|
1002
|
+
except Exception as e:
|
|
1003
|
+
return WriteResult(idx, volume_uri, obs_id, success=False, error=e)
|
|
1004
|
+
|
|
1005
|
+
write_args = [
|
|
1006
|
+
(idx, path, sid, meta, st) for idx, (path, sid, meta, st) in enumerate(metadata_list)
|
|
1007
|
+
]
|
|
1008
|
+
results = _write_volumes_parallel(
|
|
1009
|
+
write_volume, write_args, progress, f"Writing {name or 'volumes'}"
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
# Register volumes with group
|
|
1013
|
+
with tiledb.Group(f"{uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
1014
|
+
for result in results:
|
|
1015
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
1016
|
+
|
|
1017
|
+
# Build obs DataFrame rows
|
|
1018
|
+
obs_rows: list[dict] = []
|
|
1019
|
+
for path, obs_subject_id, metadata, series_type in metadata_list:
|
|
1020
|
+
obs_id = generate_obs_id(obs_subject_id, series_type)
|
|
1021
|
+
obs_rows.append(metadata.to_obs_dict(obs_id, obs_subject_id, series_type))
|
|
1022
|
+
|
|
1023
|
+
obs_df = pd.DataFrame(obs_rows)
|
|
1024
|
+
|
|
1025
|
+
# Write obs data
|
|
1026
|
+
obs_uri = f"{uri}/obs"
|
|
1027
|
+
obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
|
|
1028
|
+
obs_ids = obs_df["obs_id"].astype(str).to_numpy()
|
|
1029
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
1030
|
+
attr_data = {
|
|
1031
|
+
col: obs_df[col].to_numpy()
|
|
1032
|
+
for col in obs_df.columns
|
|
1033
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
1034
|
+
}
|
|
1035
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1036
|
+
|
|
1037
|
+
return cls(uri, ctx=ctx)
|
|
1038
|
+
|
|
1039
|
+
@classmethod
|
|
1040
|
+
def from_dicoms(
|
|
1041
|
+
cls,
|
|
1042
|
+
uri: str,
|
|
1043
|
+
dicom_dirs: Sequence[tuple[str | Path, str]],
|
|
1044
|
+
reorient: bool | None = None,
|
|
1045
|
+
validate_dimensions: bool = True,
|
|
1046
|
+
valid_subject_ids: set[str] | None = None,
|
|
1047
|
+
name: str | None = None,
|
|
1048
|
+
ctx: tiledb.Ctx | None = None,
|
|
1049
|
+
progress: bool = False,
|
|
1050
|
+
) -> VolumeCollection:
|
|
1051
|
+
"""Create VolumeCollection from DICOM series with full metadata capture.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
uri: Target URI for the VolumeCollection
|
|
1055
|
+
dicom_dirs: List of (dicom_dir, obs_subject_id) tuples
|
|
1056
|
+
reorient: Reorient to canonical orientation (None uses config default)
|
|
1057
|
+
validate_dimensions: Raise if dimensions are inconsistent
|
|
1058
|
+
valid_subject_ids: Optional whitelist for FK validation
|
|
1059
|
+
name: Collection name (stored in metadata)
|
|
1060
|
+
ctx: TileDB context
|
|
1061
|
+
progress: Show tqdm progress bar during volume writes
|
|
1062
|
+
|
|
1063
|
+
Returns:
|
|
1064
|
+
VolumeCollection with obs containing DICOM metadata
|
|
1065
|
+
"""
|
|
1066
|
+
if not dicom_dirs:
|
|
1067
|
+
raise ValueError("At least one DICOM directory is required")
|
|
1068
|
+
|
|
1069
|
+
# Validate subject IDs if whitelist provided
|
|
1070
|
+
if valid_subject_ids is not None:
|
|
1071
|
+
dicom_subject_ids = {sid for _, sid in dicom_dirs}
|
|
1072
|
+
invalid = dicom_subject_ids - valid_subject_ids
|
|
1073
|
+
if invalid:
|
|
1074
|
+
raise ValueError(f"Invalid obs_subject_ids: {sorted(invalid)[:5]}")
|
|
1075
|
+
|
|
1076
|
+
# Extract metadata and validate dimensions
|
|
1077
|
+
metadata_list: list[tuple[Path, str, DicomMetadata]] = []
|
|
1078
|
+
first_shape: tuple[int, int, int] | None = None
|
|
1079
|
+
|
|
1080
|
+
for dicom_dir, obs_subject_id in dicom_dirs:
|
|
1081
|
+
path = Path(dicom_dir)
|
|
1082
|
+
if not path.exists():
|
|
1083
|
+
raise FileNotFoundError(f"DICOM directory not found: {path}")
|
|
1084
|
+
|
|
1085
|
+
metadata = extract_dicom_metadata(path)
|
|
1086
|
+
|
|
1087
|
+
# DICOM dimensions tuple is (rows, columns, n_slices)
|
|
1088
|
+
# Swap to (columns, rows, n_slices) to match X, Y, Z convention
|
|
1089
|
+
dims = metadata.dimensions
|
|
1090
|
+
shape = (dims[1], dims[0], dims[2])
|
|
1091
|
+
if first_shape is None:
|
|
1092
|
+
first_shape = shape
|
|
1093
|
+
all_same_shape = True
|
|
1094
|
+
elif shape != first_shape:
|
|
1095
|
+
if validate_dimensions:
|
|
1096
|
+
raise ValueError(
|
|
1097
|
+
f"Dimension mismatch: {path.name} has shape {shape}, "
|
|
1098
|
+
f"expected {first_shape}"
|
|
1099
|
+
)
|
|
1100
|
+
all_same_shape = False
|
|
1101
|
+
|
|
1102
|
+
metadata_list.append((path, obs_subject_id, metadata))
|
|
1103
|
+
|
|
1104
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1105
|
+
# Only set uniform shape if all volumes have same dimensions
|
|
1106
|
+
collection_shape = first_shape if all_same_shape else None
|
|
1107
|
+
|
|
1108
|
+
# Build obs schema from DicomMetadata fields (tuples serialized as strings)
|
|
1109
|
+
obs_schema: dict[str, np.dtype] = {
|
|
1110
|
+
"voxel_spacing": np.dtype("U64"), # Tuple serialized as string
|
|
1111
|
+
"dimensions": np.dtype("U64"), # Tuple serialized as string
|
|
1112
|
+
"modality": np.dtype("U16"),
|
|
1113
|
+
"series_description": np.dtype("U256"),
|
|
1114
|
+
"kvp": np.float64,
|
|
1115
|
+
"exposure": np.float64,
|
|
1116
|
+
"repetition_time": np.float64,
|
|
1117
|
+
"echo_time": np.float64,
|
|
1118
|
+
"magnetic_field_strength": np.float64,
|
|
1119
|
+
"axcodes": np.dtype("U8"),
|
|
1120
|
+
"affine_json": np.dtype("U512"),
|
|
1121
|
+
"orientation_source": np.dtype("U32"),
|
|
1122
|
+
"source_path": np.dtype("U512"),
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
# Create collection
|
|
1126
|
+
cls._create(
|
|
1127
|
+
uri,
|
|
1128
|
+
shape=collection_shape,
|
|
1129
|
+
obs_schema=obs_schema,
|
|
1130
|
+
n_volumes=len(dicom_dirs),
|
|
1131
|
+
name=name,
|
|
1132
|
+
ctx=ctx,
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
# Write volumes in parallel
|
|
1136
|
+
def write_volume(args: tuple[int, Path, str, DicomMetadata]) -> WriteResult:
|
|
1137
|
+
idx, path, obs_subject_id, metadata = args
|
|
1138
|
+
worker_ctx = create_worker_ctx(ctx)
|
|
1139
|
+
volume_uri = f"{uri}/volumes/{idx}"
|
|
1140
|
+
obs_id = generate_obs_id(obs_subject_id, metadata.modality)
|
|
1141
|
+
try:
|
|
1142
|
+
vol = Volume.from_dicom(volume_uri, path, ctx=worker_ctx, reorient=reorient)
|
|
1143
|
+
vol.set_obs_id(obs_id)
|
|
1144
|
+
return WriteResult(idx, volume_uri, obs_id, success=True)
|
|
1145
|
+
except Exception as e:
|
|
1146
|
+
return WriteResult(idx, volume_uri, obs_id, success=False, error=e)
|
|
1147
|
+
|
|
1148
|
+
write_args = [(idx, path, sid, meta) for idx, (path, sid, meta) in enumerate(metadata_list)]
|
|
1149
|
+
results = _write_volumes_parallel(
|
|
1150
|
+
write_volume, write_args, progress, f"Writing {name or 'volumes'}"
|
|
1151
|
+
)
|
|
1152
|
+
|
|
1153
|
+
# Register volumes with group
|
|
1154
|
+
with tiledb.Group(f"{uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
1155
|
+
for result in results:
|
|
1156
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
1157
|
+
|
|
1158
|
+
# Build obs DataFrame rows
|
|
1159
|
+
obs_rows: list[dict] = []
|
|
1160
|
+
for path, obs_subject_id, metadata in metadata_list:
|
|
1161
|
+
obs_id = generate_obs_id(obs_subject_id, metadata.modality)
|
|
1162
|
+
obs_rows.append(metadata.to_obs_dict(obs_id, obs_subject_id))
|
|
1163
|
+
|
|
1164
|
+
obs_df = pd.DataFrame(obs_rows)
|
|
1165
|
+
|
|
1166
|
+
# Handle None values for optional fields
|
|
1167
|
+
for col in ["kvp", "exposure", "repetition_time", "echo_time", "magnetic_field_strength"]:
|
|
1168
|
+
obs_df[col] = obs_df[col].fillna(np.nan)
|
|
1169
|
+
|
|
1170
|
+
# Write obs data
|
|
1171
|
+
obs_uri = f"{uri}/obs"
|
|
1172
|
+
obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
|
|
1173
|
+
obs_ids = obs_df["obs_id"].astype(str).to_numpy()
|
|
1174
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
1175
|
+
attr_data = {
|
|
1176
|
+
col: obs_df[col].to_numpy()
|
|
1177
|
+
for col in obs_df.columns
|
|
1178
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
1179
|
+
}
|
|
1180
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1181
|
+
|
|
1182
|
+
return cls(uri, ctx=ctx)
|