radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1665 @@
|
|
|
1
|
+
"""RadiObject - top-level container for multi-collection radiology data."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import Iterator
|
|
7
|
+
from functools import cached_property
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Sequence, overload
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import numpy.typing as npt
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import tiledb
|
|
15
|
+
|
|
16
|
+
from radiobject.ctx import ctx as global_ctx
|
|
17
|
+
from radiobject.dataframe import Dataframe
|
|
18
|
+
from radiobject.imaging_metadata import (
|
|
19
|
+
extract_dicom_metadata,
|
|
20
|
+
extract_nifti_metadata,
|
|
21
|
+
infer_series_type,
|
|
22
|
+
)
|
|
23
|
+
from radiobject.indexing import Index
|
|
24
|
+
from radiobject.parallel import WriteResult, create_worker_ctx
|
|
25
|
+
from radiobject.volume import Volume
|
|
26
|
+
from radiobject.volume_collection import (
|
|
27
|
+
VolumeCollection,
|
|
28
|
+
_normalize_index,
|
|
29
|
+
_write_volumes_parallel,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from radiobject.query import Query
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class _SubjectILocIndexer:
|
|
37
|
+
"""Integer-location based indexer for RadiObject subjects."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, radi_object: RadiObject):
|
|
40
|
+
self._radi_object = radi_object
|
|
41
|
+
|
|
42
|
+
@overload
|
|
43
|
+
def __getitem__(self, key: int) -> RadiObject: ...
|
|
44
|
+
@overload
|
|
45
|
+
def __getitem__(self, key: slice) -> RadiObject: ...
|
|
46
|
+
@overload
|
|
47
|
+
def __getitem__(self, key: list[int]) -> RadiObject: ...
|
|
48
|
+
@overload
|
|
49
|
+
def __getitem__(self, key: npt.NDArray[np.bool_]) -> RadiObject: ...
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> RadiObject:
|
|
52
|
+
"""Returns a RadiObject view filtered to selected subject indices."""
|
|
53
|
+
n = len(self._radi_object)
|
|
54
|
+
if isinstance(key, int):
|
|
55
|
+
idx = _normalize_index(key, n)
|
|
56
|
+
return self._radi_object._filter_by_indices([idx])
|
|
57
|
+
elif isinstance(key, slice):
|
|
58
|
+
indices = list(range(*key.indices(n)))
|
|
59
|
+
return self._radi_object._filter_by_indices(indices)
|
|
60
|
+
elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
|
|
61
|
+
if len(key) != n:
|
|
62
|
+
raise ValueError(f"Boolean mask length {len(key)} != subject count {n}")
|
|
63
|
+
indices = list(np.where(key)[0])
|
|
64
|
+
return self._radi_object._filter_by_indices(indices)
|
|
65
|
+
elif isinstance(key, list):
|
|
66
|
+
indices = [_normalize_index(i, n) for i in key]
|
|
67
|
+
return self._radi_object._filter_by_indices(indices)
|
|
68
|
+
raise TypeError(
|
|
69
|
+
f"iloc indices must be int, slice, list[int], or boolean array, got {type(key)}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _SubjectLocIndexer:
|
|
74
|
+
"""Label-based indexer for RadiObject subjects."""
|
|
75
|
+
|
|
76
|
+
def __init__(self, radi_object: RadiObject):
|
|
77
|
+
self._radi_object = radi_object
|
|
78
|
+
|
|
79
|
+
@overload
|
|
80
|
+
def __getitem__(self, key: str) -> RadiObject: ...
|
|
81
|
+
@overload
|
|
82
|
+
def __getitem__(self, key: list[str]) -> RadiObject: ...
|
|
83
|
+
|
|
84
|
+
def __getitem__(self, key: str | list[str]) -> RadiObject:
|
|
85
|
+
"""Returns a RadiObject view filtered to selected obs_subject_ids."""
|
|
86
|
+
if isinstance(key, str):
|
|
87
|
+
return self._radi_object._filter_by_subject_ids([key])
|
|
88
|
+
elif isinstance(key, list):
|
|
89
|
+
return self._radi_object._filter_by_subject_ids(key)
|
|
90
|
+
raise TypeError(f"loc indices must be str or list[str], got {type(key)}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class RadiObject:
|
|
94
|
+
"""Top-level container for multi-collection radiology data with subject metadata.
|
|
95
|
+
|
|
96
|
+
RadiObject can be either "attached" (backed by storage at a URI) or a "view"
|
|
97
|
+
(filtered subset referencing a source RadiObject). Views are created by
|
|
98
|
+
filtering operations and read data from their source with filters applied.
|
|
99
|
+
|
|
100
|
+
Attached (has URI):
|
|
101
|
+
radi = RadiObject("s3://bucket/dataset")
|
|
102
|
+
radi.is_view # False
|
|
103
|
+
radi.uri # "s3://bucket/dataset"
|
|
104
|
+
|
|
105
|
+
View (filtered, no URI):
|
|
106
|
+
subset = radi.filter("age > 40")
|
|
107
|
+
subset.is_view # True
|
|
108
|
+
subset.uri # None
|
|
109
|
+
subset._root # Original RadiObject
|
|
110
|
+
|
|
111
|
+
Views are immutable. To persist a view, use materialize(uri).
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
uri: str | None,
|
|
117
|
+
ctx: tiledb.Ctx | None = None,
|
|
118
|
+
*,
|
|
119
|
+
# View state (internal use only)
|
|
120
|
+
_source: RadiObject | None = None,
|
|
121
|
+
_subject_ids: frozenset[str] | None = None,
|
|
122
|
+
_collection_names: frozenset[str] | None = None,
|
|
123
|
+
):
|
|
124
|
+
self._uri: str | None = uri
|
|
125
|
+
self._ctx: tiledb.Ctx | None = ctx
|
|
126
|
+
# View state
|
|
127
|
+
self._source: RadiObject | None = _source
|
|
128
|
+
self._subject_ids: frozenset[str] | None = _subject_ids
|
|
129
|
+
self._collection_names_filter: frozenset[str] | None = _collection_names
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def uri(self) -> str | None:
|
|
133
|
+
"""URI of this RadiObject, or None if this is a view."""
|
|
134
|
+
return self._uri
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def is_view(self) -> bool:
|
|
138
|
+
"""True if this RadiObject is a filtered view of another."""
|
|
139
|
+
return self._source is not None
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def _root(self) -> RadiObject:
|
|
143
|
+
"""The original attached RadiObject (follows source chain)."""
|
|
144
|
+
if self._source is None:
|
|
145
|
+
return self
|
|
146
|
+
return self._source._root
|
|
147
|
+
|
|
148
|
+
def _effective_ctx(self) -> tiledb.Ctx:
|
|
149
|
+
if self._source is not None:
|
|
150
|
+
return self._source._effective_ctx()
|
|
151
|
+
return self._ctx if self._ctx else global_ctx()
|
|
152
|
+
|
|
153
|
+
def _effective_uri(self) -> str:
|
|
154
|
+
"""Get the storage URI (from root if this is a view)."""
|
|
155
|
+
if self._source is not None:
|
|
156
|
+
return self._source._effective_uri()
|
|
157
|
+
if self._uri is None:
|
|
158
|
+
raise ValueError("RadiObject has no URI")
|
|
159
|
+
return self._uri
|
|
160
|
+
|
|
161
|
+
# ===== View Factory =====
|
|
162
|
+
|
|
163
|
+
def _create_view(
|
|
164
|
+
self,
|
|
165
|
+
subject_ids: frozenset[str] | None = None,
|
|
166
|
+
collection_names: frozenset[str] | None = None,
|
|
167
|
+
) -> RadiObject:
|
|
168
|
+
"""Create a view with specified filters, intersecting with current filters."""
|
|
169
|
+
# Intersect subject_ids with current filter
|
|
170
|
+
if self._subject_ids is not None and subject_ids is not None:
|
|
171
|
+
subject_ids = self._subject_ids & subject_ids
|
|
172
|
+
elif self._subject_ids is not None:
|
|
173
|
+
subject_ids = self._subject_ids
|
|
174
|
+
# subject_ids stays as passed if self._subject_ids is None
|
|
175
|
+
|
|
176
|
+
# Intersect collection_names with current filter
|
|
177
|
+
if self._collection_names_filter is not None and collection_names is not None:
|
|
178
|
+
collection_names = self._collection_names_filter & collection_names
|
|
179
|
+
elif self._collection_names_filter is not None:
|
|
180
|
+
collection_names = self._collection_names_filter
|
|
181
|
+
# collection_names stays as passed if self._collection_names_filter is None
|
|
182
|
+
|
|
183
|
+
return RadiObject(
|
|
184
|
+
uri=None,
|
|
185
|
+
ctx=self._ctx,
|
|
186
|
+
_source=self._root, # Always point to root to avoid deep chains
|
|
187
|
+
_subject_ids=subject_ids,
|
|
188
|
+
_collection_names=collection_names,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# ===== Subject Indexing =====
|
|
192
|
+
|
|
193
|
+
@cached_property
|
|
194
|
+
def iloc(self) -> _SubjectILocIndexer:
|
|
195
|
+
"""Integer-location based indexing for selecting subjects by position."""
|
|
196
|
+
return _SubjectILocIndexer(self)
|
|
197
|
+
|
|
198
|
+
@cached_property
|
|
199
|
+
def loc(self) -> _SubjectLocIndexer:
|
|
200
|
+
"""Label-based indexing for selecting subjects by obs_subject_id."""
|
|
201
|
+
return _SubjectLocIndexer(self)
|
|
202
|
+
|
|
203
|
+
# ===== ObsMeta (Subject Metadata) =====
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def obs_meta(self) -> pd.DataFrame | Dataframe:
|
|
207
|
+
"""Subject-level observational metadata.
|
|
208
|
+
|
|
209
|
+
Returns Dataframe for attached RadiObject, pd.DataFrame for views.
|
|
210
|
+
"""
|
|
211
|
+
if self.is_view:
|
|
212
|
+
# Return filtered DataFrame
|
|
213
|
+
full_obs_meta = self._root.obs_meta.read()
|
|
214
|
+
if self._subject_ids is not None:
|
|
215
|
+
return full_obs_meta[
|
|
216
|
+
full_obs_meta["obs_subject_id"].isin(self._subject_ids)
|
|
217
|
+
].reset_index(drop=True)
|
|
218
|
+
return full_obs_meta
|
|
219
|
+
obs_meta_uri = f"{self._effective_uri()}/obs_meta"
|
|
220
|
+
return Dataframe(uri=obs_meta_uri, ctx=self._ctx)
|
|
221
|
+
|
|
222
|
+
@cached_property
|
|
223
|
+
def _index(self) -> Index:
|
|
224
|
+
"""Cached bidirectional index for obs_subject_id lookups."""
|
|
225
|
+
if self.is_view:
|
|
226
|
+
# Build index from filtered subject_ids
|
|
227
|
+
if self._subject_ids is not None:
|
|
228
|
+
# Preserve order from root
|
|
229
|
+
root_ids = self._root.obs_subject_ids
|
|
230
|
+
filtered = [sid for sid in root_ids if sid in self._subject_ids]
|
|
231
|
+
return Index.build(filtered)
|
|
232
|
+
return self._root._index
|
|
233
|
+
n = self._metadata.get("subject_count", 0)
|
|
234
|
+
if n == 0:
|
|
235
|
+
return Index.build([])
|
|
236
|
+
# Only load the index column for efficiency
|
|
237
|
+
data = self.obs_meta.read(columns=["obs_subject_id"])
|
|
238
|
+
return Index.build(list(data["obs_subject_id"]))
|
|
239
|
+
|
|
240
|
+
@property
|
|
241
|
+
def index(self) -> Index:
|
|
242
|
+
"""Subject index for bidirectional ID/position lookups."""
|
|
243
|
+
return self._index
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def obs_subject_ids(self) -> list[str]:
|
|
247
|
+
"""All obs_subject_id values in index order."""
|
|
248
|
+
return list(self._index.keys)
|
|
249
|
+
|
|
250
|
+
def get_obs_row_by_obs_subject_id(self, obs_subject_id: str) -> pd.DataFrame:
|
|
251
|
+
"""Get obs_meta row by obs_subject_id string identifier."""
|
|
252
|
+
if self.is_view:
|
|
253
|
+
obs_meta_df = self.obs_meta
|
|
254
|
+
return obs_meta_df[obs_meta_df["obs_subject_id"] == obs_subject_id].reset_index(
|
|
255
|
+
drop=True
|
|
256
|
+
)
|
|
257
|
+
df = self.obs_meta.read()
|
|
258
|
+
filtered = df[df["obs_subject_id"] == obs_subject_id].reset_index(drop=True)
|
|
259
|
+
return filtered
|
|
260
|
+
|
|
261
|
+
# ===== Volume Access Across Collections =====
|
|
262
|
+
|
|
263
|
+
@cached_property
|
|
264
|
+
def all_obs_ids(self) -> list[str]:
|
|
265
|
+
"""All obs_ids across all collections (for uniqueness checks)."""
|
|
266
|
+
obs_ids = []
|
|
267
|
+
for name in self.collection_names:
|
|
268
|
+
obs_ids.extend(self.collection(name).obs_ids)
|
|
269
|
+
return obs_ids
|
|
270
|
+
|
|
271
|
+
def get_volume(self, obs_id: str) -> Volume:
|
|
272
|
+
"""Get a volume by obs_id from any collection."""
|
|
273
|
+
for name in self.collection_names:
|
|
274
|
+
coll = self.collection(name)
|
|
275
|
+
if obs_id in coll.index:
|
|
276
|
+
return coll.loc[obs_id]
|
|
277
|
+
raise KeyError(f"obs_id '{obs_id}' not found in any collection")
|
|
278
|
+
|
|
279
|
+
# ===== VolumeCollections =====
|
|
280
|
+
|
|
281
|
+
@cached_property
|
|
282
|
+
def _metadata(self) -> dict:
|
|
283
|
+
"""Cached group metadata."""
|
|
284
|
+
uri = self._effective_uri()
|
|
285
|
+
with tiledb.Group(uri, "r", ctx=self._effective_ctx()) as grp:
|
|
286
|
+
return dict(grp.meta)
|
|
287
|
+
|
|
288
|
+
@cached_property
|
|
289
|
+
def collection_names(self) -> tuple[str, ...]:
|
|
290
|
+
"""Names of all VolumeCollections."""
|
|
291
|
+
if self.is_view and self._collection_names_filter is not None:
|
|
292
|
+
# Return filtered collection names (preserving root order)
|
|
293
|
+
root_names = self._root.collection_names
|
|
294
|
+
return tuple(name for name in root_names if name in self._collection_names_filter)
|
|
295
|
+
uri = self._effective_uri()
|
|
296
|
+
collections_uri = f"{uri}/collections"
|
|
297
|
+
with tiledb.Group(collections_uri, "r", ctx=self._effective_ctx()) as grp:
|
|
298
|
+
return tuple(obj.name for obj in grp)
|
|
299
|
+
|
|
300
|
+
def collection(self, name: str) -> VolumeCollection:
|
|
301
|
+
"""Get a VolumeCollection by name."""
|
|
302
|
+
if name not in self.collection_names:
|
|
303
|
+
raise KeyError(f"Collection '{name}' not found. Available: {self.collection_names}")
|
|
304
|
+
uri = self._effective_uri()
|
|
305
|
+
collection_uri = f"{uri}/collections/{name}"
|
|
306
|
+
return VolumeCollection(collection_uri, ctx=self._ctx)
|
|
307
|
+
|
|
308
|
+
def __getattr__(self, name: str) -> VolumeCollection:
|
|
309
|
+
"""Attribute access to collections (e.g., radi.T1w)."""
|
|
310
|
+
if name.startswith("_"):
|
|
311
|
+
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
|
|
312
|
+
try:
|
|
313
|
+
return self.collection(name)
|
|
314
|
+
except KeyError:
|
|
315
|
+
raise AttributeError(f"'{type(self).__name__}' has no collection '{name}'")
|
|
316
|
+
|
|
317
|
+
def rename_collection(self, old_name: str, new_name: str) -> None:
|
|
318
|
+
"""Rename a collection."""
|
|
319
|
+
self._check_not_view("rename_collection")
|
|
320
|
+
if old_name not in self.collection_names:
|
|
321
|
+
raise KeyError(f"Collection '{old_name}' not found")
|
|
322
|
+
if new_name in self.collection_names:
|
|
323
|
+
raise ValueError(f"Collection '{new_name}' already exists")
|
|
324
|
+
|
|
325
|
+
effective_ctx = self._effective_ctx()
|
|
326
|
+
uri = self._effective_uri()
|
|
327
|
+
collections_uri = f"{uri}/collections"
|
|
328
|
+
old_uri = f"{collections_uri}/{old_name}"
|
|
329
|
+
|
|
330
|
+
with tiledb.Group(old_uri, "w", ctx=effective_ctx) as grp:
|
|
331
|
+
grp.meta["name"] = new_name
|
|
332
|
+
|
|
333
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
334
|
+
grp.remove(old_name)
|
|
335
|
+
grp.add(old_uri, name=new_name)
|
|
336
|
+
|
|
337
|
+
if "collection_names" in self.__dict__:
|
|
338
|
+
del self.__dict__["collection_names"]
|
|
339
|
+
|
|
340
|
+
# ===== Length / Iteration =====
|
|
341
|
+
|
|
342
|
+
def __len__(self) -> int:
|
|
343
|
+
"""Number of subjects."""
|
|
344
|
+
if self.is_view:
|
|
345
|
+
return len(self._index)
|
|
346
|
+
return int(self._metadata.get("subject_count", 0))
|
|
347
|
+
|
|
348
|
+
@property
|
|
349
|
+
def n_collections(self) -> int:
|
|
350
|
+
"""Number of VolumeCollections."""
|
|
351
|
+
return len(self.collection_names)
|
|
352
|
+
|
|
353
|
+
def __iter__(self) -> Iterator[str]:
|
|
354
|
+
"""Iterate over collection names."""
|
|
355
|
+
return iter(self.collection_names)
|
|
356
|
+
|
|
357
|
+
@overload
|
|
358
|
+
def __getitem__(self, key: str) -> RadiObject: ...
|
|
359
|
+
@overload
|
|
360
|
+
def __getitem__(self, key: list[str]) -> RadiObject: ...
|
|
361
|
+
|
|
362
|
+
def __getitem__(self, key: str | list[str]) -> RadiObject:
|
|
363
|
+
"""Bracket indexing for subjects by obs_subject_id.
|
|
364
|
+
|
|
365
|
+
Alias for .loc[] - allows `radi["BraTS001"]` as shorthand for `radi.loc["BraTS001"]`.
|
|
366
|
+
"""
|
|
367
|
+
return self.loc[key]
|
|
368
|
+
|
|
369
|
+
def __repr__(self) -> str:
|
|
370
|
+
"""Concise representation of the RadiObject."""
|
|
371
|
+
collections = ", ".join(self.collection_names) if self.collection_names else "none"
|
|
372
|
+
view_indicator = " (view)" if self.is_view else ""
|
|
373
|
+
return (
|
|
374
|
+
f"RadiObject({len(self)} subjects, {self.n_collections} collections: "
|
|
375
|
+
f"[{collections}]){view_indicator}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def describe(self) -> str:
|
|
379
|
+
"""Return a summary: subjects, collections, shapes, and label distributions."""
|
|
380
|
+
lines = [
|
|
381
|
+
"RadiObject Summary",
|
|
382
|
+
"==================",
|
|
383
|
+
f"URI: {self.uri or '(view)'}",
|
|
384
|
+
f"Subjects: {len(self)}",
|
|
385
|
+
f"Collections: {self.n_collections}",
|
|
386
|
+
"",
|
|
387
|
+
"Collections:",
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
for name in self.collection_names:
|
|
391
|
+
coll = self.collection(name)
|
|
392
|
+
shape = coll.shape
|
|
393
|
+
shape_str = "x".join(str(d) for d in shape) if shape else "heterogeneous"
|
|
394
|
+
uniform_str = "" if coll.is_uniform else " (mixed shapes)"
|
|
395
|
+
lines.append(f" - {name}: {len(coll)} volumes, shape={shape_str}{uniform_str}")
|
|
396
|
+
|
|
397
|
+
# Find label columns
|
|
398
|
+
obs_meta_df = self.obs_meta if self.is_view else self.obs_meta.read()
|
|
399
|
+
label_cols = []
|
|
400
|
+
for col in obs_meta_df.columns:
|
|
401
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
402
|
+
continue
|
|
403
|
+
dtype = obs_meta_df[col].dtype
|
|
404
|
+
if dtype in (np.int64, np.int32, np.float64, np.float32, object):
|
|
405
|
+
n_unique = obs_meta_df[col].nunique()
|
|
406
|
+
if n_unique <= 10:
|
|
407
|
+
label_cols.append(col)
|
|
408
|
+
|
|
409
|
+
if label_cols:
|
|
410
|
+
lines.append("")
|
|
411
|
+
lines.append("Label Columns:")
|
|
412
|
+
for col in label_cols:
|
|
413
|
+
value_counts = obs_meta_df[col].value_counts().to_dict()
|
|
414
|
+
lines.append(f" - {col}: {value_counts}")
|
|
415
|
+
|
|
416
|
+
return "\n".join(lines)
|
|
417
|
+
|
|
418
|
+
# ===== Lazy Mode (returns Query for transform pipelines) =====
|
|
419
|
+
|
|
420
|
+
def lazy(self) -> Query:
|
|
421
|
+
"""Enter lazy mode for transform pipelines.
|
|
422
|
+
|
|
423
|
+
Returns a Query that accumulates transforms without executing them.
|
|
424
|
+
Use this when you need to apply transforms via .map().
|
|
425
|
+
|
|
426
|
+
Example:
|
|
427
|
+
normalized = (
|
|
428
|
+
radi.CT
|
|
429
|
+
.lazy()
|
|
430
|
+
.filter("quality == 'good'")
|
|
431
|
+
.map(normalize_intensity)
|
|
432
|
+
.materialize("./normalized")
|
|
433
|
+
)
|
|
434
|
+
"""
|
|
435
|
+
from radiobject.query import Query
|
|
436
|
+
|
|
437
|
+
return Query(
|
|
438
|
+
self._root,
|
|
439
|
+
subject_ids=self._subject_ids,
|
|
440
|
+
output_collections=self._collection_names_filter,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# ===== Immutability Check =====
|
|
444
|
+
|
|
445
|
+
def _check_not_view(self, operation: str) -> None:
|
|
446
|
+
"""Raise if attempting to modify a view."""
|
|
447
|
+
if self.is_view:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"Cannot {operation} on a view. Call materialize(uri) first to create "
|
|
450
|
+
"an attached RadiObject."
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# ===== Filtering (returns RadiObject view) =====
|
|
454
|
+
|
|
455
|
+
def _filter_by_indices(self, indices: list[int]) -> RadiObject:
|
|
456
|
+
"""Create a view filtered to specific subject indices."""
|
|
457
|
+
subject_ids = frozenset(self._index.get_key(i) for i in indices)
|
|
458
|
+
return self._create_view(subject_ids=subject_ids)
|
|
459
|
+
|
|
460
|
+
def _filter_by_subject_ids(self, obs_subject_ids: list[str]) -> RadiObject:
|
|
461
|
+
"""Create a view filtered to specific obs_subject_ids."""
|
|
462
|
+
current_ids = set(self._index.keys)
|
|
463
|
+
for sid in obs_subject_ids:
|
|
464
|
+
if sid not in current_ids:
|
|
465
|
+
raise KeyError(f"obs_subject_id '{sid}' not found")
|
|
466
|
+
return self._create_view(subject_ids=frozenset(obs_subject_ids))
|
|
467
|
+
|
|
468
|
+
def select_collections(self, names: list[str]) -> RadiObject:
|
|
469
|
+
"""Create a view with only specified collections."""
|
|
470
|
+
current_names = set(self.collection_names)
|
|
471
|
+
for name in names:
|
|
472
|
+
if name not in current_names:
|
|
473
|
+
raise KeyError(f"Collection '{name}' not found")
|
|
474
|
+
return self._create_view(collection_names=frozenset(names))
|
|
475
|
+
|
|
476
|
+
def filter(self, expr: str) -> RadiObject:
|
|
477
|
+
"""Filter subjects using a query expression on obs_meta.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
expr: TileDB QueryCondition string (e.g., "tumor_grade == 'HGG' and age > 40")
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
RadiObject view filtered to matching subjects
|
|
484
|
+
"""
|
|
485
|
+
if self.is_view:
|
|
486
|
+
# Filter from the obs_meta DataFrame
|
|
487
|
+
obs_meta_df = self.obs_meta
|
|
488
|
+
# Use pandas query for view filtering
|
|
489
|
+
filtered = obs_meta_df.query(expr)
|
|
490
|
+
subject_ids = frozenset(filtered["obs_subject_id"])
|
|
491
|
+
else:
|
|
492
|
+
# Use TileDB QueryCondition for attached RadiObject
|
|
493
|
+
filtered = self.obs_meta.read(value_filter=expr)
|
|
494
|
+
subject_ids = frozenset(filtered["obs_subject_id"])
|
|
495
|
+
return self._create_view(subject_ids=subject_ids)
|
|
496
|
+
|
|
497
|
+
def head(self, n: int = 5) -> RadiObject:
|
|
498
|
+
"""Return view of first n subjects."""
|
|
499
|
+
n = min(n, len(self))
|
|
500
|
+
return self._filter_by_indices(list(range(n)))
|
|
501
|
+
|
|
502
|
+
def tail(self, n: int = 5) -> RadiObject:
|
|
503
|
+
"""Return view of last n subjects."""
|
|
504
|
+
total = len(self)
|
|
505
|
+
n = min(n, total)
|
|
506
|
+
return self._filter_by_indices(list(range(total - n, total)))
|
|
507
|
+
|
|
508
|
+
def sample(self, n: int = 5, seed: int | None = None) -> RadiObject:
|
|
509
|
+
"""Return view of n randomly sampled subjects."""
|
|
510
|
+
rng = np.random.default_rng(seed)
|
|
511
|
+
total = len(self)
|
|
512
|
+
n = min(n, total)
|
|
513
|
+
indices = list(rng.choice(total, size=n, replace=False))
|
|
514
|
+
return self._filter_by_indices(sorted(indices))
|
|
515
|
+
|
|
516
|
+
# ===== Materialization =====
|
|
517
|
+
|
|
518
|
+
def materialize(
|
|
519
|
+
self,
|
|
520
|
+
uri: str,
|
|
521
|
+
streaming: bool = True,
|
|
522
|
+
ctx: tiledb.Ctx | None = None,
|
|
523
|
+
) -> RadiObject:
|
|
524
|
+
"""Write this RadiObject (or view) to storage.
|
|
525
|
+
|
|
526
|
+
For attached RadiObjects, this copies the entire dataset.
|
|
527
|
+
For views, this writes only the filtered subset.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
uri: Target URI for the new RadiObject
|
|
531
|
+
streaming: Use streaming writer for memory efficiency (default: True)
|
|
532
|
+
ctx: TileDB context
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
New attached RadiObject at the target URI
|
|
536
|
+
"""
|
|
537
|
+
# Get filtered obs_meta
|
|
538
|
+
if self.is_view:
|
|
539
|
+
filtered_obs_meta = self.obs_meta # Already filtered DataFrame
|
|
540
|
+
else:
|
|
541
|
+
filtered_obs_meta = self.obs_meta.read()
|
|
542
|
+
|
|
543
|
+
# Build obs_meta schema
|
|
544
|
+
obs_meta_schema: dict[str, np.dtype] = {}
|
|
545
|
+
for col in filtered_obs_meta.columns:
|
|
546
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
547
|
+
continue
|
|
548
|
+
dtype = filtered_obs_meta[col].to_numpy().dtype
|
|
549
|
+
if dtype == np.dtype("O"):
|
|
550
|
+
dtype = np.dtype("U64")
|
|
551
|
+
obs_meta_schema[col] = dtype
|
|
552
|
+
|
|
553
|
+
if streaming:
|
|
554
|
+
return self._materialize_streaming(uri, filtered_obs_meta, obs_meta_schema, ctx)
|
|
555
|
+
return self._materialize_batch(uri, filtered_obs_meta, obs_meta_schema, ctx)
|
|
556
|
+
|
|
557
|
+
def _materialize_streaming(
|
|
558
|
+
self,
|
|
559
|
+
uri: str,
|
|
560
|
+
obs_meta_df: pd.DataFrame,
|
|
561
|
+
obs_meta_schema: dict[str, np.dtype],
|
|
562
|
+
ctx: tiledb.Ctx | None,
|
|
563
|
+
) -> RadiObject:
|
|
564
|
+
"""Materialize view to storage using streaming writer."""
|
|
565
|
+
from radiobject.streaming import RadiObjectWriter
|
|
566
|
+
|
|
567
|
+
subject_ids = set(obs_meta_df["obs_subject_id"])
|
|
568
|
+
|
|
569
|
+
with RadiObjectWriter(uri, obs_meta_schema=obs_meta_schema, ctx=ctx) as writer:
|
|
570
|
+
writer.write_obs_meta(obs_meta_df)
|
|
571
|
+
|
|
572
|
+
for coll_name in self.collection_names:
|
|
573
|
+
src_collection = self.collection(coll_name)
|
|
574
|
+
obs_df = src_collection.obs.read()
|
|
575
|
+
filtered_obs = obs_df[obs_df["obs_subject_id"].isin(subject_ids)]
|
|
576
|
+
|
|
577
|
+
if len(filtered_obs) == 0:
|
|
578
|
+
continue
|
|
579
|
+
|
|
580
|
+
# Extract obs schema
|
|
581
|
+
obs_schema: dict[str, np.dtype] = {}
|
|
582
|
+
for col in src_collection.obs.columns:
|
|
583
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
584
|
+
continue
|
|
585
|
+
obs_schema[col] = src_collection.obs.dtypes[col]
|
|
586
|
+
|
|
587
|
+
with writer.add_collection(
|
|
588
|
+
coll_name, src_collection.shape, obs_schema
|
|
589
|
+
) as coll_writer:
|
|
590
|
+
for _, row in filtered_obs.iterrows():
|
|
591
|
+
obs_id = row["obs_id"]
|
|
592
|
+
vol = src_collection.loc[obs_id]
|
|
593
|
+
attrs = {
|
|
594
|
+
k: v for k, v in row.items() if k not in ("obs_id", "obs_subject_id")
|
|
595
|
+
}
|
|
596
|
+
coll_writer.write_volume(
|
|
597
|
+
data=vol.to_numpy(),
|
|
598
|
+
obs_id=obs_id,
|
|
599
|
+
obs_subject_id=row["obs_subject_id"],
|
|
600
|
+
**attrs,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
return RadiObject(uri, ctx=ctx)
|
|
604
|
+
|
|
605
|
+
def _materialize_batch(
|
|
606
|
+
self,
|
|
607
|
+
uri: str,
|
|
608
|
+
obs_meta_df: pd.DataFrame,
|
|
609
|
+
obs_meta_schema: dict[str, np.dtype],
|
|
610
|
+
ctx: tiledb.Ctx | None,
|
|
611
|
+
) -> RadiObject:
|
|
612
|
+
"""Materialize view to storage using batch writer."""
|
|
613
|
+
effective_ctx = ctx if ctx else self._effective_ctx()
|
|
614
|
+
subject_ids = list(obs_meta_df["obs_subject_id"])
|
|
615
|
+
|
|
616
|
+
RadiObject._create(
|
|
617
|
+
uri,
|
|
618
|
+
obs_meta_schema=obs_meta_schema,
|
|
619
|
+
n_subjects=len(subject_ids),
|
|
620
|
+
ctx=ctx,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# Write obs_meta
|
|
624
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
625
|
+
obs_subject_ids_arr = obs_meta_df["obs_subject_id"].astype(str).to_numpy()
|
|
626
|
+
obs_ids_arr = (
|
|
627
|
+
obs_meta_df["obs_id"].astype(str).to_numpy()
|
|
628
|
+
if "obs_id" in obs_meta_df.columns
|
|
629
|
+
else obs_subject_ids_arr
|
|
630
|
+
)
|
|
631
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
632
|
+
attr_data = {
|
|
633
|
+
col: obs_meta_df[col].to_numpy()
|
|
634
|
+
for col in obs_meta_df.columns
|
|
635
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
636
|
+
}
|
|
637
|
+
arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
|
|
638
|
+
|
|
639
|
+
# Copy collections
|
|
640
|
+
collections_uri = f"{uri}/collections"
|
|
641
|
+
for coll_name in self.collection_names:
|
|
642
|
+
src_collection = self.collection(coll_name)
|
|
643
|
+
new_vc_uri = f"{collections_uri}/{coll_name}"
|
|
644
|
+
|
|
645
|
+
_copy_filtered_volume_collection(
|
|
646
|
+
src_collection,
|
|
647
|
+
new_vc_uri,
|
|
648
|
+
obs_subject_ids=subject_ids,
|
|
649
|
+
name=coll_name,
|
|
650
|
+
ctx=ctx,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
654
|
+
grp.add(new_vc_uri, name=coll_name)
|
|
655
|
+
|
|
656
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
657
|
+
grp.meta["n_collections"] = len(self.collection_names)
|
|
658
|
+
grp.meta["subject_count"] = len(subject_ids)
|
|
659
|
+
|
|
660
|
+
return RadiObject(uri, ctx=ctx)
|
|
661
|
+
|
|
662
|
+
def copy(self) -> RadiObject:
|
|
663
|
+
"""Create an independent in-memory copy, detached from the view chain.
|
|
664
|
+
|
|
665
|
+
Useful when you want to break the reference to the source RadiObject.
|
|
666
|
+
Note: This does NOT persist data. Call materialize(uri) to write to storage.
|
|
667
|
+
"""
|
|
668
|
+
if not self.is_view:
|
|
669
|
+
# For attached RadiObject, just return self (already independent)
|
|
670
|
+
return self
|
|
671
|
+
# Create a new view with the same filters but mark it as "detached"
|
|
672
|
+
# In practice, since we always point to _root, this is already independent
|
|
673
|
+
return RadiObject(
|
|
674
|
+
uri=None,
|
|
675
|
+
ctx=self._ctx,
|
|
676
|
+
_source=self._root,
|
|
677
|
+
_subject_ids=self._subject_ids,
|
|
678
|
+
_collection_names=self._collection_names_filter,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# ===== Append Operations (Mutations) =====
|
|
682
|
+
|
|
683
|
+
def append(
|
|
684
|
+
self,
|
|
685
|
+
niftis: Sequence[tuple[str | Path, str]] | None = None,
|
|
686
|
+
dicom_dirs: Sequence[tuple[str | Path, str]] | None = None,
|
|
687
|
+
obs_meta: pd.DataFrame | None = None,
|
|
688
|
+
reorient: bool | None = None,
|
|
689
|
+
progress: bool = False,
|
|
690
|
+
) -> None:
|
|
691
|
+
"""Append new subjects and their volumes atomically."""
|
|
692
|
+
self._check_not_view("append")
|
|
693
|
+
|
|
694
|
+
if niftis is None and dicom_dirs is None:
|
|
695
|
+
raise ValueError("Must provide either niftis or dicom_dirs")
|
|
696
|
+
if niftis is not None and dicom_dirs is not None:
|
|
697
|
+
raise ValueError("Cannot provide both niftis and dicom_dirs")
|
|
698
|
+
|
|
699
|
+
effective_ctx = self._effective_ctx()
|
|
700
|
+
uri = self._effective_uri()
|
|
701
|
+
|
|
702
|
+
# Collect all subject IDs from input
|
|
703
|
+
if niftis is not None:
|
|
704
|
+
input_subject_ids = {sid for _, sid in niftis}
|
|
705
|
+
else:
|
|
706
|
+
input_subject_ids = {sid for _, sid in dicom_dirs}
|
|
707
|
+
|
|
708
|
+
existing_subject_ids = set(self.obs_subject_ids)
|
|
709
|
+
new_subject_ids = input_subject_ids - existing_subject_ids
|
|
710
|
+
|
|
711
|
+
# Validate obs_meta
|
|
712
|
+
if new_subject_ids:
|
|
713
|
+
if obs_meta is None:
|
|
714
|
+
raise ValueError(
|
|
715
|
+
f"obs_meta required for new subjects: {sorted(new_subject_ids)[:5]}"
|
|
716
|
+
)
|
|
717
|
+
if "obs_subject_id" not in obs_meta.columns:
|
|
718
|
+
raise ValueError("obs_meta must contain 'obs_subject_id' column")
|
|
719
|
+
obs_meta_ids = set(obs_meta["obs_subject_id"])
|
|
720
|
+
missing = new_subject_ids - obs_meta_ids
|
|
721
|
+
if missing:
|
|
722
|
+
raise ValueError(f"obs_meta missing entries for: {sorted(missing)[:5]}")
|
|
723
|
+
obs_meta = obs_meta[obs_meta["obs_subject_id"].isin(new_subject_ids)]
|
|
724
|
+
|
|
725
|
+
# Append obs_meta for new subjects
|
|
726
|
+
if obs_meta is not None and len(obs_meta) > 0:
|
|
727
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
728
|
+
obs_subject_ids_arr = obs_meta["obs_subject_id"].astype(str).to_numpy()
|
|
729
|
+
obs_ids_arr = (
|
|
730
|
+
obs_meta["obs_id"].astype(str).to_numpy()
|
|
731
|
+
if "obs_id" in obs_meta.columns
|
|
732
|
+
else obs_subject_ids_arr
|
|
733
|
+
)
|
|
734
|
+
existing_columns = set(self._root.obs_meta.columns)
|
|
735
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
736
|
+
attr_data = {
|
|
737
|
+
col: obs_meta[col].to_numpy()
|
|
738
|
+
for col in obs_meta.columns
|
|
739
|
+
if col not in ("obs_subject_id", "obs_id") and col in existing_columns
|
|
740
|
+
}
|
|
741
|
+
arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
|
|
742
|
+
|
|
743
|
+
new_count = len(self) + len(obs_meta)
|
|
744
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
745
|
+
grp.meta["subject_count"] = new_count
|
|
746
|
+
|
|
747
|
+
# Process and group input files
|
|
748
|
+
if niftis is not None:
|
|
749
|
+
self._append_niftis(niftis, reorient, effective_ctx, progress)
|
|
750
|
+
else:
|
|
751
|
+
self._append_dicoms(dicom_dirs, reorient, effective_ctx, progress)
|
|
752
|
+
|
|
753
|
+
# Invalidate cached properties
|
|
754
|
+
for prop in ("_index", "_metadata", "collection_names"):
|
|
755
|
+
if prop in self.__dict__:
|
|
756
|
+
del self.__dict__[prop]
|
|
757
|
+
|
|
758
|
+
def _append_niftis(
|
|
759
|
+
self,
|
|
760
|
+
niftis: Sequence[tuple[str | Path, str]],
|
|
761
|
+
reorient: bool | None,
|
|
762
|
+
effective_ctx: tiledb.Ctx,
|
|
763
|
+
progress: bool = False,
|
|
764
|
+
) -> None:
|
|
765
|
+
"""Internal: append NIfTI files to existing collections or create new ones."""
|
|
766
|
+
uri = self._effective_uri()
|
|
767
|
+
file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
|
|
768
|
+
for nifti_path, obs_subject_id in niftis:
|
|
769
|
+
path = Path(nifti_path)
|
|
770
|
+
if not path.exists():
|
|
771
|
+
raise FileNotFoundError(f"NIfTI file not found: {path}")
|
|
772
|
+
metadata = extract_nifti_metadata(path)
|
|
773
|
+
series_type = infer_series_type(path)
|
|
774
|
+
file_info.append((path, obs_subject_id, metadata.dimensions, series_type))
|
|
775
|
+
|
|
776
|
+
groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
|
|
777
|
+
for path, subject_id, shape, series_type in file_info:
|
|
778
|
+
groups[(shape, series_type)].append((path, subject_id))
|
|
779
|
+
|
|
780
|
+
collections_uri = f"{uri}/collections"
|
|
781
|
+
existing_collections = set(self.collection_names)
|
|
782
|
+
|
|
783
|
+
groups_iter = groups.items()
|
|
784
|
+
if progress:
|
|
785
|
+
from tqdm.auto import tqdm
|
|
786
|
+
|
|
787
|
+
groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
|
|
788
|
+
|
|
789
|
+
for (shape, series_type), items in groups_iter:
|
|
790
|
+
collection_name = series_type
|
|
791
|
+
if collection_name in existing_collections:
|
|
792
|
+
vc = self.collection(collection_name)
|
|
793
|
+
if vc.shape != shape:
|
|
794
|
+
collection_name = f"{series_type}_{shape[0]}x{shape[1]}x{shape[2]}"
|
|
795
|
+
|
|
796
|
+
if collection_name in existing_collections:
|
|
797
|
+
vc = self.collection(collection_name)
|
|
798
|
+
vc.append(niftis=items, reorient=reorient, progress=progress)
|
|
799
|
+
else:
|
|
800
|
+
vc_uri = f"{collections_uri}/{collection_name}"
|
|
801
|
+
VolumeCollection.from_niftis(
|
|
802
|
+
uri=vc_uri,
|
|
803
|
+
niftis=items,
|
|
804
|
+
reorient=reorient,
|
|
805
|
+
validate_dimensions=True,
|
|
806
|
+
name=collection_name,
|
|
807
|
+
ctx=self._ctx,
|
|
808
|
+
progress=progress,
|
|
809
|
+
)
|
|
810
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
811
|
+
grp.add(vc_uri, name=collection_name)
|
|
812
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
813
|
+
grp.meta["n_collections"] = self.n_collections + 1
|
|
814
|
+
existing_collections.add(collection_name)
|
|
815
|
+
|
|
816
|
+
def _append_dicoms(
|
|
817
|
+
self,
|
|
818
|
+
dicom_dirs: Sequence[tuple[str | Path, str]],
|
|
819
|
+
reorient: bool | None,
|
|
820
|
+
effective_ctx: tiledb.Ctx,
|
|
821
|
+
progress: bool = False,
|
|
822
|
+
) -> None:
|
|
823
|
+
"""Internal: append DICOM series to existing collections or create new ones."""
|
|
824
|
+
uri = self._effective_uri()
|
|
825
|
+
file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
|
|
826
|
+
for dicom_dir, obs_subject_id in dicom_dirs:
|
|
827
|
+
path = Path(dicom_dir)
|
|
828
|
+
if not path.exists():
|
|
829
|
+
raise FileNotFoundError(f"DICOM directory not found: {path}")
|
|
830
|
+
metadata = extract_dicom_metadata(path)
|
|
831
|
+
dims = metadata.dimensions
|
|
832
|
+
shape = (dims[1], dims[0], dims[2])
|
|
833
|
+
file_info.append((path, obs_subject_id, shape, metadata.modality))
|
|
834
|
+
|
|
835
|
+
groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
|
|
836
|
+
for path, subject_id, shape, modality in file_info:
|
|
837
|
+
groups[(shape, modality)].append((path, subject_id))
|
|
838
|
+
|
|
839
|
+
collections_uri = f"{uri}/collections"
|
|
840
|
+
existing_collections = set(self.collection_names)
|
|
841
|
+
|
|
842
|
+
groups_iter = groups.items()
|
|
843
|
+
if progress:
|
|
844
|
+
from tqdm.auto import tqdm
|
|
845
|
+
|
|
846
|
+
groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
|
|
847
|
+
|
|
848
|
+
for (shape, modality), items in groups_iter:
|
|
849
|
+
collection_name = modality
|
|
850
|
+
if collection_name in existing_collections:
|
|
851
|
+
vc = self.collection(collection_name)
|
|
852
|
+
if vc.shape != shape:
|
|
853
|
+
collection_name = f"{modality}_{shape[0]}x{shape[1]}x{shape[2]}"
|
|
854
|
+
|
|
855
|
+
if collection_name in existing_collections:
|
|
856
|
+
vc = self.collection(collection_name)
|
|
857
|
+
vc.append(dicom_dirs=items, reorient=reorient, progress=progress)
|
|
858
|
+
else:
|
|
859
|
+
vc_uri = f"{collections_uri}/{collection_name}"
|
|
860
|
+
VolumeCollection.from_dicoms(
|
|
861
|
+
uri=vc_uri,
|
|
862
|
+
dicom_dirs=items,
|
|
863
|
+
reorient=reorient,
|
|
864
|
+
validate_dimensions=True,
|
|
865
|
+
name=collection_name,
|
|
866
|
+
ctx=self._ctx,
|
|
867
|
+
progress=progress,
|
|
868
|
+
)
|
|
869
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
870
|
+
grp.add(vc_uri, name=collection_name)
|
|
871
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
872
|
+
grp.meta["n_collections"] = self.n_collections + 1
|
|
873
|
+
existing_collections.add(collection_name)
|
|
874
|
+
|
|
875
|
+
# ===== Validation =====
|
|
876
|
+
|
|
877
|
+
def validate(self) -> None:
|
|
878
|
+
"""Validate internal consistency of the RadiObject and all collections."""
|
|
879
|
+
self._check_not_view("validate")
|
|
880
|
+
obs_meta_data = self.obs_meta.read()
|
|
881
|
+
actual_subject_count = len(obs_meta_data)
|
|
882
|
+
stored_subject_count = self._metadata.get("subject_count", 0)
|
|
883
|
+
if actual_subject_count != stored_subject_count:
|
|
884
|
+
raise ValueError(
|
|
885
|
+
f"subject_count mismatch: metadata={stored_subject_count}, actual={actual_subject_count}"
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
actual_n_collections = len(self.collection_names)
|
|
889
|
+
stored_n_collections = self._metadata.get("n_collections", 0)
|
|
890
|
+
if actual_n_collections != stored_n_collections:
|
|
891
|
+
raise ValueError(
|
|
892
|
+
f"n_collections mismatch: metadata={stored_n_collections}, actual={actual_n_collections}"
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
for name in self.collection_names:
|
|
896
|
+
self.collection(name).validate()
|
|
897
|
+
|
|
898
|
+
obs_meta_subject_ids = set(obs_meta_data["obs_subject_id"])
|
|
899
|
+
for name in self.collection_names:
|
|
900
|
+
vc = self.collection(name)
|
|
901
|
+
vc_obs = vc.obs.read()
|
|
902
|
+
vc_subject_ids = set(vc_obs["obs_subject_id"])
|
|
903
|
+
orphan_subjects = vc_subject_ids - obs_meta_subject_ids
|
|
904
|
+
if orphan_subjects:
|
|
905
|
+
raise ValueError(
|
|
906
|
+
f"Collection '{name}' has obs_subject_ids not in obs_meta: "
|
|
907
|
+
f"{sorted(orphan_subjects)[:5]}"
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
seen_obs_ids: dict[str, str] = {}
|
|
911
|
+
for name in self.collection_names:
|
|
912
|
+
vc = self.collection(name)
|
|
913
|
+
for obs_id in vc.obs_ids:
|
|
914
|
+
if obs_id in seen_obs_ids:
|
|
915
|
+
raise ValueError(
|
|
916
|
+
f"obs_id '{obs_id}' is duplicated across collections: "
|
|
917
|
+
f"'{seen_obs_ids[obs_id]}' and '{name}'"
|
|
918
|
+
)
|
|
919
|
+
seen_obs_ids[obs_id] = name
|
|
920
|
+
|
|
921
|
+
# ===== Factory Methods =====
|
|
922
|
+
|
|
923
|
+
@classmethod
|
|
924
|
+
def _create(
|
|
925
|
+
cls,
|
|
926
|
+
uri: str,
|
|
927
|
+
obs_meta_schema: dict[str, np.dtype] | None = None,
|
|
928
|
+
n_subjects: int = 0,
|
|
929
|
+
ctx: tiledb.Ctx | None = None,
|
|
930
|
+
) -> RadiObject:
|
|
931
|
+
"""Internal: create an empty RadiObject with optional obs_meta schema."""
|
|
932
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
933
|
+
|
|
934
|
+
tiledb.Group.create(uri, ctx=effective_ctx)
|
|
935
|
+
|
|
936
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
937
|
+
Dataframe.create(obs_meta_uri, schema=obs_meta_schema or {}, ctx=ctx)
|
|
938
|
+
|
|
939
|
+
collections_uri = f"{uri}/collections"
|
|
940
|
+
tiledb.Group.create(collections_uri, ctx=effective_ctx)
|
|
941
|
+
|
|
942
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
943
|
+
grp.meta["subject_count"] = n_subjects
|
|
944
|
+
grp.meta["n_collections"] = 0
|
|
945
|
+
grp.add(obs_meta_uri, name="obs_meta")
|
|
946
|
+
grp.add(collections_uri, name="collections")
|
|
947
|
+
|
|
948
|
+
return cls(uri, ctx=ctx)
|
|
949
|
+
|
|
950
|
+
@classmethod
|
|
951
|
+
def _from_volume_collections(
|
|
952
|
+
cls,
|
|
953
|
+
uri: str,
|
|
954
|
+
collections: dict[str, VolumeCollection],
|
|
955
|
+
obs_meta: pd.DataFrame | None = None,
|
|
956
|
+
ctx: tiledb.Ctx | None = None,
|
|
957
|
+
) -> RadiObject:
|
|
958
|
+
"""Internal: create RadiObject from existing VolumeCollections."""
|
|
959
|
+
if not collections:
|
|
960
|
+
raise ValueError("At least one VolumeCollection is required")
|
|
961
|
+
|
|
962
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
963
|
+
|
|
964
|
+
n_subjects = len(obs_meta) if obs_meta is not None else 0
|
|
965
|
+
|
|
966
|
+
obs_meta_schema = None
|
|
967
|
+
if obs_meta is not None:
|
|
968
|
+
obs_meta_schema = {}
|
|
969
|
+
for col in obs_meta.columns:
|
|
970
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
971
|
+
continue
|
|
972
|
+
dtype = obs_meta[col].to_numpy().dtype
|
|
973
|
+
if dtype == np.dtype("O"):
|
|
974
|
+
dtype = np.dtype("U64")
|
|
975
|
+
obs_meta_schema[col] = dtype
|
|
976
|
+
|
|
977
|
+
cls._create(uri, obs_meta_schema=obs_meta_schema, n_subjects=n_subjects, ctx=ctx)
|
|
978
|
+
|
|
979
|
+
if obs_meta is not None and len(obs_meta) > 0:
|
|
980
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
981
|
+
obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
|
|
982
|
+
obs_ids = (
|
|
983
|
+
obs_meta["obs_id"].astype(str).to_numpy()
|
|
984
|
+
if "obs_id" in obs_meta.columns
|
|
985
|
+
else obs_subject_ids
|
|
986
|
+
)
|
|
987
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
988
|
+
attr_data = {}
|
|
989
|
+
for col in obs_meta.columns:
|
|
990
|
+
if col not in ("obs_subject_id", "obs_id"):
|
|
991
|
+
attr_data[col] = obs_meta[col].to_numpy()
|
|
992
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
993
|
+
|
|
994
|
+
collections_uri = f"{uri}/collections"
|
|
995
|
+
for coll_name, vc in collections.items():
|
|
996
|
+
new_vc_uri = f"{collections_uri}/{coll_name}"
|
|
997
|
+
_copy_volume_collection(vc, new_vc_uri, name=coll_name, ctx=ctx)
|
|
998
|
+
|
|
999
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
1000
|
+
grp.add(new_vc_uri, name=coll_name)
|
|
1001
|
+
|
|
1002
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
1003
|
+
grp.meta["n_collections"] = len(collections)
|
|
1004
|
+
grp.meta["subject_count"] = n_subjects
|
|
1005
|
+
|
|
1006
|
+
radi_result = cls(uri, ctx=ctx)
|
|
1007
|
+
return radi_result
|
|
1008
|
+
|
|
1009
|
+
@classmethod
|
|
1010
|
+
def from_collections(
|
|
1011
|
+
cls,
|
|
1012
|
+
uri: str,
|
|
1013
|
+
collections: dict[str, VolumeCollection | str],
|
|
1014
|
+
obs_meta: pd.DataFrame | None = None,
|
|
1015
|
+
ctx: tiledb.Ctx | None = None,
|
|
1016
|
+
) -> RadiObject:
|
|
1017
|
+
"""Create RadiObject from existing VolumeCollections.
|
|
1018
|
+
|
|
1019
|
+
Links collections without copying when they're already at expected URIs
|
|
1020
|
+
({uri}/collections/{name}). Copies collections that are elsewhere.
|
|
1021
|
+
|
|
1022
|
+
Args:
|
|
1023
|
+
uri: Target URI for RadiObject
|
|
1024
|
+
collections: Dict mapping collection names to VolumeCollection objects or URIs
|
|
1025
|
+
obs_meta: Optional subject-level metadata. If None, derived from collections.
|
|
1026
|
+
ctx: TileDB context
|
|
1027
|
+
|
|
1028
|
+
Example:
|
|
1029
|
+
# Collections already at expected locations (no copy)
|
|
1030
|
+
ct_vc = radi.CT.lazy().map(transform).materialize(uri=f"{URI}/collections/CT")
|
|
1031
|
+
seg_vc = radi.seg.lazy().map(transform).materialize(uri=f"{URI}/collections/seg")
|
|
1032
|
+
radi = RadiObject.from_collections(
|
|
1033
|
+
uri=URI,
|
|
1034
|
+
collections={"CT": ct_vc, "seg": seg_vc},
|
|
1035
|
+
)
|
|
1036
|
+
|
|
1037
|
+
# Collections from elsewhere (will be copied)
|
|
1038
|
+
radi = RadiObject.from_collections(
|
|
1039
|
+
uri="./new_dataset",
|
|
1040
|
+
collections={"T1w": existing_t1w_collection},
|
|
1041
|
+
)
|
|
1042
|
+
"""
|
|
1043
|
+
if not collections:
|
|
1044
|
+
raise ValueError("At least one collection is required")
|
|
1045
|
+
|
|
1046
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1047
|
+
collections_uri = f"{uri}/collections"
|
|
1048
|
+
|
|
1049
|
+
# Resolve string URIs to VolumeCollection objects
|
|
1050
|
+
resolved: dict[str, VolumeCollection] = {}
|
|
1051
|
+
for name, vc_or_uri in collections.items():
|
|
1052
|
+
if isinstance(vc_or_uri, str):
|
|
1053
|
+
resolved[name] = VolumeCollection(vc_or_uri, ctx=ctx)
|
|
1054
|
+
else:
|
|
1055
|
+
resolved[name] = vc_or_uri
|
|
1056
|
+
|
|
1057
|
+
# Determine which collections need copying vs linking
|
|
1058
|
+
in_place: dict[str, VolumeCollection] = {}
|
|
1059
|
+
to_copy: dict[str, VolumeCollection] = {}
|
|
1060
|
+
|
|
1061
|
+
for name, vc in resolved.items():
|
|
1062
|
+
expected_uri = f"{collections_uri}/{name}"
|
|
1063
|
+
if vc.uri == expected_uri:
|
|
1064
|
+
in_place[name] = vc
|
|
1065
|
+
else:
|
|
1066
|
+
to_copy[name] = vc
|
|
1067
|
+
|
|
1068
|
+
# Check if collections group already exists (from materialize)
|
|
1069
|
+
vfs = tiledb.VFS(ctx=effective_ctx)
|
|
1070
|
+
collections_group_exists = vfs.is_dir(collections_uri)
|
|
1071
|
+
|
|
1072
|
+
# Create root group
|
|
1073
|
+
tiledb.Group.create(uri, ctx=effective_ctx)
|
|
1074
|
+
|
|
1075
|
+
# Create or use existing collections group
|
|
1076
|
+
if not collections_group_exists:
|
|
1077
|
+
tiledb.Group.create(collections_uri, ctx=effective_ctx)
|
|
1078
|
+
|
|
1079
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
1080
|
+
grp.add(collections_uri, name="collections")
|
|
1081
|
+
|
|
1082
|
+
# Link in-place collections (no copy needed)
|
|
1083
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
1084
|
+
for name, vc in in_place.items():
|
|
1085
|
+
grp.add(vc.uri, name=name)
|
|
1086
|
+
|
|
1087
|
+
# Copy external collections
|
|
1088
|
+
for name, vc in to_copy.items():
|
|
1089
|
+
new_uri = f"{collections_uri}/{name}"
|
|
1090
|
+
_copy_volume_collection(vc, new_uri, name=name, ctx=ctx)
|
|
1091
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
1092
|
+
grp.add(new_uri, name=name)
|
|
1093
|
+
|
|
1094
|
+
# Derive obs_meta if not provided
|
|
1095
|
+
if obs_meta is None:
|
|
1096
|
+
all_subject_ids: set[str] = set()
|
|
1097
|
+
for vc in resolved.values():
|
|
1098
|
+
obs_df = vc.obs.read()
|
|
1099
|
+
all_subject_ids.update(obs_df["obs_subject_id"].tolist())
|
|
1100
|
+
sorted_ids = sorted(all_subject_ids)
|
|
1101
|
+
obs_meta = pd.DataFrame(
|
|
1102
|
+
{
|
|
1103
|
+
"obs_subject_id": sorted_ids,
|
|
1104
|
+
"obs_id": sorted_ids,
|
|
1105
|
+
}
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
# Build obs_meta schema
|
|
1109
|
+
n_subjects = len(obs_meta)
|
|
1110
|
+
obs_meta_schema: dict[str, np.dtype] = {}
|
|
1111
|
+
for col in obs_meta.columns:
|
|
1112
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
1113
|
+
continue
|
|
1114
|
+
dtype = obs_meta[col].to_numpy().dtype
|
|
1115
|
+
if dtype == np.dtype("O"):
|
|
1116
|
+
dtype = np.dtype("U64")
|
|
1117
|
+
obs_meta_schema[col] = dtype
|
|
1118
|
+
|
|
1119
|
+
# Create obs_meta
|
|
1120
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
1121
|
+
Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
|
|
1122
|
+
|
|
1123
|
+
if len(obs_meta) > 0:
|
|
1124
|
+
obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
|
|
1125
|
+
obs_ids = (
|
|
1126
|
+
obs_meta["obs_id"].astype(str).to_numpy()
|
|
1127
|
+
if "obs_id" in obs_meta.columns
|
|
1128
|
+
else obs_subject_ids
|
|
1129
|
+
)
|
|
1130
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
1131
|
+
attr_data = {
|
|
1132
|
+
col: obs_meta[col].to_numpy()
|
|
1133
|
+
for col in obs_meta.columns
|
|
1134
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
1135
|
+
}
|
|
1136
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1137
|
+
|
|
1138
|
+
# Link obs_meta to root and set metadata
|
|
1139
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
1140
|
+
grp.add(obs_meta_uri, name="obs_meta")
|
|
1141
|
+
grp.meta["n_collections"] = len(resolved)
|
|
1142
|
+
grp.meta["subject_count"] = n_subjects
|
|
1143
|
+
|
|
1144
|
+
return cls(uri, ctx=ctx)
|
|
1145
|
+
|
|
1146
|
+
@classmethod
|
|
1147
|
+
def from_niftis(
|
|
1148
|
+
cls,
|
|
1149
|
+
uri: str,
|
|
1150
|
+
niftis: Sequence[tuple[str | Path, str]] | None = None,
|
|
1151
|
+
image_dir: str | Path | None = None,
|
|
1152
|
+
collection_name: str | None = None,
|
|
1153
|
+
images: dict[str, str | Path | Sequence[tuple[str | Path, str]]] | None = None,
|
|
1154
|
+
validate_alignment: bool = False,
|
|
1155
|
+
obs_meta: pd.DataFrame | None = None,
|
|
1156
|
+
reorient: bool | None = None,
|
|
1157
|
+
ctx: tiledb.Ctx | None = None,
|
|
1158
|
+
progress: bool = False,
|
|
1159
|
+
) -> RadiObject:
|
|
1160
|
+
"""Create RadiObject from NIfTI files with raw data storage.
|
|
1161
|
+
|
|
1162
|
+
Ingestion stores volumes in their original dimensions without any
|
|
1163
|
+
preprocessing. Use `collection.lazy().map()` for post-hoc transformations.
|
|
1164
|
+
|
|
1165
|
+
Three input modes:
|
|
1166
|
+
1. images: Dict mapping collection names to paths/globs/lists (recommended)
|
|
1167
|
+
2. niftis: List of (path, subject_id) tuples (legacy)
|
|
1168
|
+
3. image_dir: Directory-based discovery (legacy)
|
|
1169
|
+
|
|
1170
|
+
Collection organization:
|
|
1171
|
+
- With images dict: each key becomes a collection
|
|
1172
|
+
- With collection_name: all volumes go to that single collection
|
|
1173
|
+
- Otherwise: auto-group by inferred modality (T1w, FLAIR, CT, etc.)
|
|
1174
|
+
|
|
1175
|
+
Args:
|
|
1176
|
+
uri: Target URI for RadiObject
|
|
1177
|
+
images: Dict mapping collection names to NIfTI sources. Sources can be:
|
|
1178
|
+
- Glob pattern: "./imagesTr/*.nii.gz"
|
|
1179
|
+
- Directory path: "./imagesTr"
|
|
1180
|
+
- Pre-resolved list: [(path, subject_id), ...]
|
|
1181
|
+
niftis: List of (nifti_path, obs_subject_id) tuples (legacy)
|
|
1182
|
+
image_dir: Directory containing image NIfTIs (legacy, mutually exclusive with niftis)
|
|
1183
|
+
collection_name: Explicit name for collection (legacy, all volumes go here)
|
|
1184
|
+
validate_alignment: If True, verify all collections have same subject IDs
|
|
1185
|
+
obs_meta: Subject-level metadata. Must contain obs_subject_id column.
|
|
1186
|
+
reorient: Reorient to canonical orientation (None uses config default)
|
|
1187
|
+
ctx: TileDB context
|
|
1188
|
+
progress: Show tqdm progress bar
|
|
1189
|
+
|
|
1190
|
+
Example (images dict with globs):
|
|
1191
|
+
radi = RadiObject.from_niftis(
|
|
1192
|
+
uri="./dataset",
|
|
1193
|
+
images={
|
|
1194
|
+
"CT": "./imagesTr/*.nii.gz",
|
|
1195
|
+
"seg": "./labelsTr/*.nii.gz",
|
|
1196
|
+
},
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
Example (images dict with directories):
|
|
1200
|
+
radi = RadiObject.from_niftis(
|
|
1201
|
+
uri="./dataset",
|
|
1202
|
+
images={"CT": "./imagesTr", "seg": "./labelsTr"},
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
Example (legacy explicit collection name):
|
|
1206
|
+
radi = RadiObject.from_niftis(
|
|
1207
|
+
uri="s3://bucket/raw",
|
|
1208
|
+
image_dir="./imagesTr",
|
|
1209
|
+
collection_name="lung_ct",
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
Example (legacy auto-group by modality):
|
|
1213
|
+
radi = RadiObject.from_niftis(
|
|
1214
|
+
uri="s3://bucket/raw",
|
|
1215
|
+
niftis=[
|
|
1216
|
+
("sub01_T1w.nii.gz", "sub-01"),
|
|
1217
|
+
("sub01_FLAIR.nii.gz", "sub-01"),
|
|
1218
|
+
],
|
|
1219
|
+
)
|
|
1220
|
+
# Result: radi.T1w, radi.FLAIR collections
|
|
1221
|
+
"""
|
|
1222
|
+
from radiobject.ingest import resolve_nifti_source
|
|
1223
|
+
|
|
1224
|
+
# --- NORMALIZE ALL INPUTS TO images DICT ---
|
|
1225
|
+
|
|
1226
|
+
if images is not None:
|
|
1227
|
+
if niftis is not None or image_dir is not None or collection_name is not None:
|
|
1228
|
+
raise ValueError("Cannot use 'images' with legacy parameters")
|
|
1229
|
+
if not images:
|
|
1230
|
+
raise ValueError("images dict cannot be empty")
|
|
1231
|
+
|
|
1232
|
+
elif image_dir is not None:
|
|
1233
|
+
if niftis is not None:
|
|
1234
|
+
raise ValueError("Cannot specify both 'niftis' and 'image_dir'")
|
|
1235
|
+
from radiobject.ingest import discover_nifti_pairs
|
|
1236
|
+
|
|
1237
|
+
sources = discover_nifti_pairs(image_dir)
|
|
1238
|
+
niftis = [(s.image_path, s.subject_id) for s in sources]
|
|
1239
|
+
|
|
1240
|
+
if collection_name:
|
|
1241
|
+
images = {collection_name: niftis}
|
|
1242
|
+
else:
|
|
1243
|
+
modality_groups: dict[str, list[tuple[str | Path, str]]] = defaultdict(list)
|
|
1244
|
+
for path, sid in niftis:
|
|
1245
|
+
series_type = infer_series_type(Path(path))
|
|
1246
|
+
modality_groups[series_type].append((path, sid))
|
|
1247
|
+
images = dict(modality_groups)
|
|
1248
|
+
|
|
1249
|
+
elif niftis is not None:
|
|
1250
|
+
if collection_name:
|
|
1251
|
+
images = {collection_name: niftis}
|
|
1252
|
+
else:
|
|
1253
|
+
modality_groups: dict[str, list[tuple[str | Path, str]]] = defaultdict(list)
|
|
1254
|
+
for path, sid in niftis:
|
|
1255
|
+
series_type = infer_series_type(Path(path))
|
|
1256
|
+
modality_groups[series_type].append((path, sid))
|
|
1257
|
+
images = dict(modality_groups)
|
|
1258
|
+
else:
|
|
1259
|
+
raise ValueError("Must specify 'images', 'niftis', or 'image_dir'")
|
|
1260
|
+
|
|
1261
|
+
# --- SINGLE CODE PATH: Resolve images dict ---
|
|
1262
|
+
|
|
1263
|
+
groups: dict[str, list[tuple[Path, str]]] = {}
|
|
1264
|
+
for coll_name, source in images.items():
|
|
1265
|
+
groups[coll_name] = resolve_nifti_source(source)
|
|
1266
|
+
|
|
1267
|
+
# Optional alignment validation
|
|
1268
|
+
if validate_alignment and len(groups) > 1:
|
|
1269
|
+
subject_sets = {
|
|
1270
|
+
name: {sid for _, sid in nifti_list} for name, nifti_list in groups.items()
|
|
1271
|
+
}
|
|
1272
|
+
first_name, first_set = next(iter(subject_sets.items()))
|
|
1273
|
+
for name, sid_set in subject_sets.items():
|
|
1274
|
+
if sid_set != first_set:
|
|
1275
|
+
missing_in_first = sid_set - first_set
|
|
1276
|
+
missing_in_other = first_set - sid_set
|
|
1277
|
+
raise ValueError(
|
|
1278
|
+
f"Subject ID mismatch between '{first_name}' and '{name}': "
|
|
1279
|
+
f"missing in '{first_name}': {sorted(missing_in_first)[:3]}, "
|
|
1280
|
+
f"missing in '{name}': {sorted(missing_in_other)[:3]}"
|
|
1281
|
+
)
|
|
1282
|
+
|
|
1283
|
+
# Validate all files exist
|
|
1284
|
+
for coll_name, nifti_list in groups.items():
|
|
1285
|
+
for path, _ in nifti_list:
|
|
1286
|
+
if not path.exists():
|
|
1287
|
+
raise FileNotFoundError(f"NIfTI file not found: {path}")
|
|
1288
|
+
|
|
1289
|
+
# Collect all subject IDs
|
|
1290
|
+
all_subject_ids: set[str] = set()
|
|
1291
|
+
for nifti_list in groups.values():
|
|
1292
|
+
all_subject_ids.update(sid for _, sid in nifti_list)
|
|
1293
|
+
|
|
1294
|
+
# Validate FK constraint if obs_meta provided
|
|
1295
|
+
if obs_meta is not None:
|
|
1296
|
+
if "obs_subject_id" not in obs_meta.columns:
|
|
1297
|
+
raise ValueError("obs_meta must contain 'obs_subject_id' column")
|
|
1298
|
+
obs_meta_subject_ids = set(obs_meta["obs_subject_id"])
|
|
1299
|
+
missing = all_subject_ids - obs_meta_subject_ids
|
|
1300
|
+
if missing:
|
|
1301
|
+
raise ValueError(
|
|
1302
|
+
f"obs_subject_ids in niftis not found in obs_meta: {sorted(missing)[:5]}"
|
|
1303
|
+
)
|
|
1304
|
+
else:
|
|
1305
|
+
sorted_ids = sorted(all_subject_ids)
|
|
1306
|
+
obs_meta = pd.DataFrame(
|
|
1307
|
+
{
|
|
1308
|
+
"obs_subject_id": sorted_ids,
|
|
1309
|
+
"obs_id": sorted_ids,
|
|
1310
|
+
}
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
if not groups or all(len(nifti_list) == 0 for nifti_list in groups.values()):
|
|
1314
|
+
raise ValueError("No NIfTI files found")
|
|
1315
|
+
|
|
1316
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1317
|
+
|
|
1318
|
+
tiledb.Group.create(uri, ctx=effective_ctx)
|
|
1319
|
+
collections_uri = f"{uri}/collections"
|
|
1320
|
+
tiledb.Group.create(collections_uri, ctx=effective_ctx)
|
|
1321
|
+
|
|
1322
|
+
collections: dict[str, VolumeCollection] = {}
|
|
1323
|
+
|
|
1324
|
+
groups_iter = list(groups.items())
|
|
1325
|
+
if progress:
|
|
1326
|
+
from tqdm.auto import tqdm
|
|
1327
|
+
|
|
1328
|
+
groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
|
|
1329
|
+
|
|
1330
|
+
for coll_name, items in groups_iter:
|
|
1331
|
+
vc_uri = f"{collections_uri}/{coll_name}"
|
|
1332
|
+
nifti_list = [(path, subject_id) for path, subject_id in items]
|
|
1333
|
+
|
|
1334
|
+
vc = VolumeCollection.from_niftis(
|
|
1335
|
+
uri=vc_uri,
|
|
1336
|
+
niftis=nifti_list,
|
|
1337
|
+
reorient=reorient,
|
|
1338
|
+
validate_dimensions=False,
|
|
1339
|
+
name=coll_name,
|
|
1340
|
+
ctx=ctx,
|
|
1341
|
+
progress=progress,
|
|
1342
|
+
)
|
|
1343
|
+
collections[coll_name] = vc
|
|
1344
|
+
|
|
1345
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
1346
|
+
grp.add(vc_uri, name=coll_name)
|
|
1347
|
+
|
|
1348
|
+
n_subjects = len(obs_meta)
|
|
1349
|
+
obs_meta_schema: dict[str, np.dtype] = {}
|
|
1350
|
+
for col in obs_meta.columns:
|
|
1351
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
1352
|
+
continue
|
|
1353
|
+
dtype = obs_meta[col].to_numpy().dtype
|
|
1354
|
+
if dtype == np.dtype("O"):
|
|
1355
|
+
dtype = np.dtype("U64")
|
|
1356
|
+
obs_meta_schema[col] = dtype
|
|
1357
|
+
|
|
1358
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
1359
|
+
Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
|
|
1360
|
+
|
|
1361
|
+
if len(obs_meta) > 0:
|
|
1362
|
+
obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
|
|
1363
|
+
obs_ids = (
|
|
1364
|
+
obs_meta["obs_id"].astype(str).to_numpy()
|
|
1365
|
+
if "obs_id" in obs_meta.columns
|
|
1366
|
+
else obs_subject_ids
|
|
1367
|
+
)
|
|
1368
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
1369
|
+
attr_data = {}
|
|
1370
|
+
for col in obs_meta.columns:
|
|
1371
|
+
if col not in ("obs_subject_id", "obs_id"):
|
|
1372
|
+
attr_data[col] = obs_meta[col].to_numpy()
|
|
1373
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1374
|
+
|
|
1375
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
1376
|
+
grp.meta["n_collections"] = len(collections)
|
|
1377
|
+
grp.meta["subject_count"] = n_subjects
|
|
1378
|
+
grp.add(obs_meta_uri, name="obs_meta")
|
|
1379
|
+
grp.add(collections_uri, name="collections")
|
|
1380
|
+
|
|
1381
|
+
return cls(uri, ctx=ctx)
|
|
1382
|
+
|
|
1383
|
+
@classmethod
|
|
1384
|
+
def from_dicoms(
|
|
1385
|
+
cls,
|
|
1386
|
+
uri: str,
|
|
1387
|
+
dicom_dirs: Sequence[tuple[str | Path, str]],
|
|
1388
|
+
obs_meta: pd.DataFrame | None = None,
|
|
1389
|
+
reorient: bool | None = None,
|
|
1390
|
+
ctx: tiledb.Ctx | None = None,
|
|
1391
|
+
progress: bool = False,
|
|
1392
|
+
) -> RadiObject:
|
|
1393
|
+
"""Create RadiObject from DICOM series with automatic grouping.
|
|
1394
|
+
|
|
1395
|
+
Files are automatically grouped into VolumeCollections by:
|
|
1396
|
+
1. Dimensions (rows, columns, n_slices)
|
|
1397
|
+
2. Modality tag (CT, MR) + SeriesDescription
|
|
1398
|
+
|
|
1399
|
+
Args:
|
|
1400
|
+
uri: Target URI for RadiObject
|
|
1401
|
+
dicom_dirs: List of (dicom_dir, obs_subject_id) tuples
|
|
1402
|
+
obs_meta: Subject-level metadata (user-provided). Must contain obs_subject_id column.
|
|
1403
|
+
reorient: Reorient to canonical orientation (None uses config default)
|
|
1404
|
+
ctx: TileDB context
|
|
1405
|
+
progress: Show tqdm progress bar during volume writes
|
|
1406
|
+
|
|
1407
|
+
Example:
|
|
1408
|
+
radi = RadiObject.from_dicoms(
|
|
1409
|
+
uri="/storage/ct_study",
|
|
1410
|
+
dicom_dirs=[
|
|
1411
|
+
("/dicom/sub01/CT_HEAD", "sub-01"),
|
|
1412
|
+
("/dicom/sub01/CT_CHEST", "sub-01"),
|
|
1413
|
+
("/dicom/sub02/CT_HEAD", "sub-02"),
|
|
1414
|
+
],
|
|
1415
|
+
obs_meta=obs_meta_df,
|
|
1416
|
+
)
|
|
1417
|
+
"""
|
|
1418
|
+
if not dicom_dirs:
|
|
1419
|
+
raise ValueError("At least one DICOM directory is required")
|
|
1420
|
+
|
|
1421
|
+
all_subject_ids = {sid for _, sid in dicom_dirs}
|
|
1422
|
+
|
|
1423
|
+
if obs_meta is not None:
|
|
1424
|
+
if "obs_subject_id" not in obs_meta.columns:
|
|
1425
|
+
raise ValueError("obs_meta must contain 'obs_subject_id' column")
|
|
1426
|
+
obs_meta_subject_ids = set(obs_meta["obs_subject_id"])
|
|
1427
|
+
missing = all_subject_ids - obs_meta_subject_ids
|
|
1428
|
+
if missing:
|
|
1429
|
+
raise ValueError(
|
|
1430
|
+
f"obs_subject_ids in dicom_dirs not found in obs_meta: {sorted(missing)[:5]}"
|
|
1431
|
+
)
|
|
1432
|
+
else:
|
|
1433
|
+
sorted_ids = sorted(all_subject_ids)
|
|
1434
|
+
obs_meta = pd.DataFrame(
|
|
1435
|
+
{
|
|
1436
|
+
"obs_subject_id": sorted_ids,
|
|
1437
|
+
"obs_id": sorted_ids,
|
|
1438
|
+
}
|
|
1439
|
+
)
|
|
1440
|
+
|
|
1441
|
+
file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
|
|
1442
|
+
for dicom_dir, obs_subject_id in dicom_dirs:
|
|
1443
|
+
path = Path(dicom_dir)
|
|
1444
|
+
if not path.exists():
|
|
1445
|
+
raise FileNotFoundError(f"DICOM directory not found: {path}")
|
|
1446
|
+
|
|
1447
|
+
metadata = extract_dicom_metadata(path)
|
|
1448
|
+
dims = metadata.dimensions
|
|
1449
|
+
shape = (dims[1], dims[0], dims[2])
|
|
1450
|
+
group_key = metadata.modality
|
|
1451
|
+
file_info.append((path, obs_subject_id, shape, group_key))
|
|
1452
|
+
|
|
1453
|
+
groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
|
|
1454
|
+
for path, subject_id, shape, group_key in file_info:
|
|
1455
|
+
key = (shape, group_key)
|
|
1456
|
+
groups[key].append((path, subject_id))
|
|
1457
|
+
|
|
1458
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1459
|
+
|
|
1460
|
+
tiledb.Group.create(uri, ctx=effective_ctx)
|
|
1461
|
+
collections_uri = f"{uri}/collections"
|
|
1462
|
+
tiledb.Group.create(collections_uri, ctx=effective_ctx)
|
|
1463
|
+
|
|
1464
|
+
collections: dict[str, VolumeCollection] = {}
|
|
1465
|
+
used_names: set[str] = set()
|
|
1466
|
+
|
|
1467
|
+
groups_iter = groups.items()
|
|
1468
|
+
if progress:
|
|
1469
|
+
from tqdm.auto import tqdm
|
|
1470
|
+
|
|
1471
|
+
groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
|
|
1472
|
+
|
|
1473
|
+
for (shape, modality), items in groups_iter:
|
|
1474
|
+
coll_name = modality
|
|
1475
|
+
if coll_name in used_names:
|
|
1476
|
+
coll_name = f"{modality}_{shape[0]}x{shape[1]}x{shape[2]}"
|
|
1477
|
+
used_names.add(coll_name)
|
|
1478
|
+
|
|
1479
|
+
vc_uri = f"{collections_uri}/{coll_name}"
|
|
1480
|
+
dicom_list = [(path, subject_id) for path, subject_id in items]
|
|
1481
|
+
|
|
1482
|
+
vc = VolumeCollection.from_dicoms(
|
|
1483
|
+
uri=vc_uri,
|
|
1484
|
+
dicom_dirs=dicom_list,
|
|
1485
|
+
reorient=reorient,
|
|
1486
|
+
validate_dimensions=True,
|
|
1487
|
+
name=coll_name,
|
|
1488
|
+
ctx=ctx,
|
|
1489
|
+
progress=progress,
|
|
1490
|
+
)
|
|
1491
|
+
collections[coll_name] = vc
|
|
1492
|
+
|
|
1493
|
+
with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
|
|
1494
|
+
grp.add(vc_uri, name=coll_name)
|
|
1495
|
+
|
|
1496
|
+
n_subjects = len(obs_meta)
|
|
1497
|
+
obs_meta_schema: dict[str, np.dtype] = {}
|
|
1498
|
+
for col in obs_meta.columns:
|
|
1499
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
1500
|
+
continue
|
|
1501
|
+
dtype = obs_meta[col].to_numpy().dtype
|
|
1502
|
+
if dtype == np.dtype("O"):
|
|
1503
|
+
dtype = np.dtype("U64")
|
|
1504
|
+
obs_meta_schema[col] = dtype
|
|
1505
|
+
|
|
1506
|
+
obs_meta_uri = f"{uri}/obs_meta"
|
|
1507
|
+
Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
|
|
1508
|
+
|
|
1509
|
+
if len(obs_meta) > 0:
|
|
1510
|
+
obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
|
|
1511
|
+
obs_ids = (
|
|
1512
|
+
obs_meta["obs_id"].astype(str).to_numpy()
|
|
1513
|
+
if "obs_id" in obs_meta.columns
|
|
1514
|
+
else obs_subject_ids
|
|
1515
|
+
)
|
|
1516
|
+
with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
|
|
1517
|
+
attr_data = {}
|
|
1518
|
+
for col in obs_meta.columns:
|
|
1519
|
+
if col not in ("obs_subject_id", "obs_id"):
|
|
1520
|
+
attr_data[col] = obs_meta[col].to_numpy()
|
|
1521
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1522
|
+
|
|
1523
|
+
with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
|
|
1524
|
+
grp.meta["n_collections"] = len(collections)
|
|
1525
|
+
grp.meta["subject_count"] = n_subjects
|
|
1526
|
+
grp.add(obs_meta_uri, name="obs_meta")
|
|
1527
|
+
grp.add(collections_uri, name="collections")
|
|
1528
|
+
|
|
1529
|
+
return cls(uri, ctx=ctx)
|
|
1530
|
+
|
|
1531
|
+
|
|
1532
|
+
# ===== Helper Functions =====
|
|
1533
|
+
|
|
1534
|
+
|
|
1535
|
+
def _extract_obs_schema(obs: Dataframe) -> dict[str, np.dtype]:
|
|
1536
|
+
"""Extract schema from an obs Dataframe (excluding obs_id and obs_subject_id)."""
|
|
1537
|
+
schema = {}
|
|
1538
|
+
for col in obs.columns:
|
|
1539
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
1540
|
+
continue
|
|
1541
|
+
schema[col] = obs.dtypes[col]
|
|
1542
|
+
return schema
|
|
1543
|
+
|
|
1544
|
+
|
|
1545
|
+
def _copy_volume_collection(
|
|
1546
|
+
src: VolumeCollection,
|
|
1547
|
+
dst_uri: str,
|
|
1548
|
+
name: str | None = None,
|
|
1549
|
+
ctx: tiledb.Ctx | None = None,
|
|
1550
|
+
) -> None:
|
|
1551
|
+
"""Copy a VolumeCollection to a new URI."""
|
|
1552
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1553
|
+
|
|
1554
|
+
collection_name = name if name is not None else src.name
|
|
1555
|
+
|
|
1556
|
+
VolumeCollection._create(
|
|
1557
|
+
dst_uri,
|
|
1558
|
+
shape=src.shape,
|
|
1559
|
+
obs_schema=_extract_obs_schema(src.obs),
|
|
1560
|
+
n_volumes=len(src),
|
|
1561
|
+
name=collection_name,
|
|
1562
|
+
ctx=ctx,
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
obs_df = src.obs.read()
|
|
1566
|
+
obs_uri = f"{dst_uri}/obs"
|
|
1567
|
+
obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
|
|
1568
|
+
obs_ids = obs_df["obs_id"].astype(str).to_numpy()
|
|
1569
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
1570
|
+
attr_data = {
|
|
1571
|
+
col: obs_df[col].to_numpy()
|
|
1572
|
+
for col in obs_df.columns
|
|
1573
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
1574
|
+
}
|
|
1575
|
+
arr[obs_subject_ids, obs_ids] = attr_data
|
|
1576
|
+
|
|
1577
|
+
def write_volume(args: tuple[int, str, Volume]) -> WriteResult:
|
|
1578
|
+
idx, obs_id, vol = args
|
|
1579
|
+
worker_ctx = create_worker_ctx(ctx)
|
|
1580
|
+
new_vol_uri = f"{dst_uri}/volumes/{idx}"
|
|
1581
|
+
try:
|
|
1582
|
+
data = vol.to_numpy()
|
|
1583
|
+
new_vol = Volume.from_numpy(new_vol_uri, data, ctx=worker_ctx)
|
|
1584
|
+
new_vol.set_obs_id(obs_id)
|
|
1585
|
+
return WriteResult(idx, new_vol_uri, obs_id, success=True)
|
|
1586
|
+
except Exception as e:
|
|
1587
|
+
return WriteResult(idx, new_vol_uri, obs_id, success=False, error=e)
|
|
1588
|
+
|
|
1589
|
+
write_args = [(idx, obs_id, src.iloc[idx]) for idx, obs_id in enumerate(src.obs_ids)]
|
|
1590
|
+
results = _write_volumes_parallel(
|
|
1591
|
+
write_volume, write_args, progress=False, desc="Copying volumes"
|
|
1592
|
+
)
|
|
1593
|
+
|
|
1594
|
+
with tiledb.Group(f"{dst_uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
1595
|
+
for result in results:
|
|
1596
|
+
vol_grp.add(result.uri, name=str(result.index))
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def _copy_filtered_volume_collection(
|
|
1600
|
+
src: VolumeCollection,
|
|
1601
|
+
dst_uri: str,
|
|
1602
|
+
obs_subject_ids: list[str],
|
|
1603
|
+
name: str | None = None,
|
|
1604
|
+
ctx: tiledb.Ctx | None = None,
|
|
1605
|
+
) -> None:
|
|
1606
|
+
"""Copy a VolumeCollection, filtering to volumes matching obs_subject_ids."""
|
|
1607
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
1608
|
+
|
|
1609
|
+
collection_name = name if name is not None else src.name
|
|
1610
|
+
|
|
1611
|
+
obs_df = src.obs.read()
|
|
1612
|
+
subject_id_set = set(obs_subject_ids)
|
|
1613
|
+
|
|
1614
|
+
filtered_obs = obs_df[obs_df["obs_subject_id"].isin(subject_id_set)].reset_index(drop=True)
|
|
1615
|
+
|
|
1616
|
+
if len(filtered_obs) == 0:
|
|
1617
|
+
raise ValueError("No volumes match the specified obs_subject_ids")
|
|
1618
|
+
|
|
1619
|
+
VolumeCollection._create(
|
|
1620
|
+
dst_uri,
|
|
1621
|
+
shape=src.shape,
|
|
1622
|
+
obs_schema=_extract_obs_schema(src.obs),
|
|
1623
|
+
n_volumes=len(filtered_obs),
|
|
1624
|
+
name=collection_name,
|
|
1625
|
+
ctx=ctx,
|
|
1626
|
+
)
|
|
1627
|
+
|
|
1628
|
+
obs_uri = f"{dst_uri}/obs"
|
|
1629
|
+
obs_subject_ids_arr = filtered_obs["obs_subject_id"].astype(str).to_numpy()
|
|
1630
|
+
obs_ids_arr = filtered_obs["obs_id"].astype(str).to_numpy()
|
|
1631
|
+
with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
|
|
1632
|
+
attr_data = {
|
|
1633
|
+
col: filtered_obs[col].to_numpy()
|
|
1634
|
+
for col in filtered_obs.columns
|
|
1635
|
+
if col not in ("obs_subject_id", "obs_id")
|
|
1636
|
+
}
|
|
1637
|
+
arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
|
|
1638
|
+
|
|
1639
|
+
selected_obs_ids = set(filtered_obs["obs_id"])
|
|
1640
|
+
selected_indices = [i for i, oid in enumerate(src.obs_ids) if oid in selected_obs_ids]
|
|
1641
|
+
|
|
1642
|
+
def write_volume(args: tuple[int, int, str]) -> WriteResult:
|
|
1643
|
+
new_idx, orig_idx, obs_id = args
|
|
1644
|
+
worker_ctx = create_worker_ctx(ctx)
|
|
1645
|
+
new_vol_uri = f"{dst_uri}/volumes/{new_idx}"
|
|
1646
|
+
try:
|
|
1647
|
+
vol = src.iloc[orig_idx]
|
|
1648
|
+
data = vol.to_numpy()
|
|
1649
|
+
new_vol = Volume.from_numpy(new_vol_uri, data, ctx=worker_ctx)
|
|
1650
|
+
new_vol.set_obs_id(obs_id)
|
|
1651
|
+
return WriteResult(new_idx, new_vol_uri, obs_id, success=True)
|
|
1652
|
+
except Exception as e:
|
|
1653
|
+
return WriteResult(new_idx, new_vol_uri, obs_id, success=False, error=e)
|
|
1654
|
+
|
|
1655
|
+
write_args = [
|
|
1656
|
+
(new_idx, orig_idx, src.obs_ids[orig_idx])
|
|
1657
|
+
for new_idx, orig_idx in enumerate(selected_indices)
|
|
1658
|
+
]
|
|
1659
|
+
results = _write_volumes_parallel(
|
|
1660
|
+
write_volume, write_args, progress=False, desc="Filtering volumes"
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
with tiledb.Group(f"{dst_uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
|
|
1664
|
+
for result in results:
|
|
1665
|
+
vol_grp.add(result.uri, name=str(result.index))
|