radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
radiobject/query.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
1
|
+
"""Lazy query builder pattern for RadiObject and VolumeCollection filtering.
|
|
2
|
+
|
|
3
|
+
Query Design:
|
|
4
|
+
=============
|
|
5
|
+
Queries work by building up filter conditions and resolving them to masks (sets of IDs).
|
|
6
|
+
The flow is:
|
|
7
|
+
|
|
8
|
+
1. Start with RadiObject.lazy() or VolumeCollection.lazy()
|
|
9
|
+
2. Add filter conditions (obs_meta filters, collection filters)
|
|
10
|
+
3. Add transforms via .map() for compute-intensive operations
|
|
11
|
+
4. Resolve to masks:
|
|
12
|
+
- Subject mask: set of obs_subject_ids matching criteria
|
|
13
|
+
- Volume mask: set of obs_ids within each collection matching criteria
|
|
14
|
+
5. Apply masks via materialization (iter_volumes, materialize, etc.)
|
|
15
|
+
|
|
16
|
+
Key Concepts:
|
|
17
|
+
- obs_meta: Subject-level metadata (indexed by obs_subject_id)
|
|
18
|
+
- obs: Volume-level metadata per collection (indexed by obs_id, contains obs_subject_id FK)
|
|
19
|
+
- A subject matches if it passes obs_meta filter AND has at least one volume passing collection filters
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import TYPE_CHECKING, Iterator, Sequence
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
import numpy.typing as npt
|
|
29
|
+
import pandas as pd
|
|
30
|
+
import tiledb
|
|
31
|
+
|
|
32
|
+
from radiobject._types import TransformFn
|
|
33
|
+
from radiobject.volume import Volume
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from radiobject.radi_object import RadiObject
|
|
37
|
+
from radiobject.volume_collection import VolumeCollection
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class VolumeBatch:
|
|
42
|
+
"""Batch of volumes for ML training with stacked numpy arrays."""
|
|
43
|
+
|
|
44
|
+
volumes: dict[str, npt.NDArray[np.floating]] # collection_name -> (N, X, Y, Z)
|
|
45
|
+
subject_ids: tuple[str, ...]
|
|
46
|
+
obs_ids: dict[str, tuple[str, ...]] # collection_name -> obs_ids
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass(frozen=True)
|
|
50
|
+
class QueryCount:
|
|
51
|
+
"""Count results for a query."""
|
|
52
|
+
|
|
53
|
+
n_subjects: int
|
|
54
|
+
n_volumes: dict[str, int] # collection_name -> volume count
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Query:
|
|
58
|
+
"""Lazy filter builder for RadiObject with explicit materialization.
|
|
59
|
+
|
|
60
|
+
Filters accumulate without data access. Call iter_volumes(), materialize(),
|
|
61
|
+
or count() to materialize results.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
result = (
|
|
65
|
+
radi.lazy()
|
|
66
|
+
.filter("age > 40")
|
|
67
|
+
.filter_collection("T1w", "voxel_spacing == '1.0x1.0x1.0'")
|
|
68
|
+
.map(normalize_intensity)
|
|
69
|
+
.materialize("s3://bucket/subset")
|
|
70
|
+
)
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
source: RadiObject,
|
|
76
|
+
*,
|
|
77
|
+
# Subject-level filters (on obs_meta)
|
|
78
|
+
subject_ids: frozenset[str] | None = None,
|
|
79
|
+
subject_query: str | None = None,
|
|
80
|
+
# Collection-level filters (on each collection's obs)
|
|
81
|
+
collection_filters: dict[str, str] | None = None, # collection_name -> query expr
|
|
82
|
+
# Output scope
|
|
83
|
+
output_collections: frozenset[str] | None = None,
|
|
84
|
+
# Transform function applied during materialization
|
|
85
|
+
transform_fn: TransformFn | None = None,
|
|
86
|
+
):
|
|
87
|
+
self._source = source
|
|
88
|
+
self._subject_ids = subject_ids
|
|
89
|
+
self._subject_query = subject_query
|
|
90
|
+
self._collection_filters = collection_filters or {}
|
|
91
|
+
self._output_collections = output_collections
|
|
92
|
+
self._transform_fn = transform_fn
|
|
93
|
+
|
|
94
|
+
def _copy(self, **kwargs) -> Query:
|
|
95
|
+
"""Create a copy with modified fields."""
|
|
96
|
+
return Query(
|
|
97
|
+
self._source,
|
|
98
|
+
subject_ids=kwargs.get("subject_ids", self._subject_ids),
|
|
99
|
+
subject_query=kwargs.get("subject_query", self._subject_query),
|
|
100
|
+
collection_filters=kwargs.get("collection_filters", self._collection_filters),
|
|
101
|
+
output_collections=kwargs.get("output_collections", self._output_collections),
|
|
102
|
+
transform_fn=kwargs.get("transform_fn", self._transform_fn),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# =========================================================================
|
|
106
|
+
# SUBJECT-LEVEL FILTERS (operate on obs_meta)
|
|
107
|
+
# =========================================================================
|
|
108
|
+
|
|
109
|
+
def filter(self, expr: str) -> Query:
|
|
110
|
+
"""Filter subjects using TileDB QueryCondition on obs_meta.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
expr: TileDB query expression (e.g., "age > 40 and diagnosis == 'tumor'")
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
New Query with filter applied
|
|
117
|
+
"""
|
|
118
|
+
new_query = self._subject_query
|
|
119
|
+
if new_query is None:
|
|
120
|
+
new_query = expr
|
|
121
|
+
else:
|
|
122
|
+
new_query = f"({new_query}) and ({expr})"
|
|
123
|
+
return self._copy(subject_query=new_query)
|
|
124
|
+
|
|
125
|
+
def filter_subjects(self, ids: Sequence[str]) -> Query:
|
|
126
|
+
"""Filter to specific subject IDs.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
ids: List of obs_subject_ids to include
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
New Query with filter applied
|
|
133
|
+
"""
|
|
134
|
+
new_ids = frozenset(ids)
|
|
135
|
+
if self._subject_ids is not None:
|
|
136
|
+
new_ids = self._subject_ids & new_ids
|
|
137
|
+
return self._copy(subject_ids=new_ids)
|
|
138
|
+
|
|
139
|
+
def iloc(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> Query:
|
|
140
|
+
"""Filter subjects by integer position(s)."""
|
|
141
|
+
n = len(self._source)
|
|
142
|
+
if isinstance(key, int):
|
|
143
|
+
idx = key if key >= 0 else n + key
|
|
144
|
+
indices = [idx]
|
|
145
|
+
elif isinstance(key, slice):
|
|
146
|
+
indices = list(range(*key.indices(n)))
|
|
147
|
+
elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
|
|
148
|
+
indices = list(np.where(key)[0])
|
|
149
|
+
elif isinstance(key, list):
|
|
150
|
+
indices = [i if i >= 0 else n + i for i in key]
|
|
151
|
+
else:
|
|
152
|
+
raise TypeError("iloc key must be int, slice, list[int], or bool array")
|
|
153
|
+
|
|
154
|
+
ids = [self._source._index.get_key(i) for i in indices]
|
|
155
|
+
return self.filter_subjects(ids)
|
|
156
|
+
|
|
157
|
+
def loc(self, key: str | Sequence[str]) -> Query:
|
|
158
|
+
"""Filter subjects by obs_subject_id(s)."""
|
|
159
|
+
if isinstance(key, str):
|
|
160
|
+
return self.filter_subjects([key])
|
|
161
|
+
return self.filter_subjects(key)
|
|
162
|
+
|
|
163
|
+
def head(self, n: int = 5) -> Query:
|
|
164
|
+
"""Filter to first n subjects."""
|
|
165
|
+
return self.iloc(slice(0, n))
|
|
166
|
+
|
|
167
|
+
def tail(self, n: int = 5) -> Query:
|
|
168
|
+
"""Filter to last n subjects."""
|
|
169
|
+
total = len(self._source)
|
|
170
|
+
return self.iloc(slice(max(0, total - n), total))
|
|
171
|
+
|
|
172
|
+
def sample(self, n: int = 5, seed: int | None = None) -> Query:
|
|
173
|
+
"""Filter to n randomly sampled subjects."""
|
|
174
|
+
rng = np.random.default_rng(seed)
|
|
175
|
+
resolved = self._resolve_subject_mask()
|
|
176
|
+
subject_list = list(resolved)
|
|
177
|
+
n = min(n, len(subject_list))
|
|
178
|
+
sampled = rng.choice(subject_list, size=n, replace=False)
|
|
179
|
+
return self.filter_subjects(sampled)
|
|
180
|
+
|
|
181
|
+
# =========================================================================
|
|
182
|
+
# COLLECTION-LEVEL FILTERS (operate on collection obs)
|
|
183
|
+
# =========================================================================
|
|
184
|
+
|
|
185
|
+
def filter_collection(self, collection_name: str, expr: str) -> Query:
|
|
186
|
+
"""Filter volumes in a specific collection using TileDB QueryCondition on obs.
|
|
187
|
+
|
|
188
|
+
Only subjects that have at least one volume matching the filter will be included.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
collection_name: Name of the collection to filter
|
|
192
|
+
expr: TileDB query expression on the collection's obs dataframe
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
New Query with filter applied
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
# Only include subjects whose T1w scans have 1mm resolution
|
|
199
|
+
query.filter_collection("T1w", "voxel_spacing == '1.0x1.0x1.0'")
|
|
200
|
+
"""
|
|
201
|
+
new_filters = dict(self._collection_filters)
|
|
202
|
+
if collection_name in new_filters:
|
|
203
|
+
new_filters[collection_name] = f"({new_filters[collection_name]}) and ({expr})"
|
|
204
|
+
else:
|
|
205
|
+
new_filters[collection_name] = expr
|
|
206
|
+
return self._copy(collection_filters=new_filters)
|
|
207
|
+
|
|
208
|
+
def select_collections(self, names: Sequence[str]) -> Query:
|
|
209
|
+
"""Limit output to specific collections.
|
|
210
|
+
|
|
211
|
+
This doesn't affect filtering - it only limits which collections
|
|
212
|
+
appear in the output.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
names: Collection names to include in output
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
New Query with output scope set
|
|
219
|
+
"""
|
|
220
|
+
new_collections = frozenset(names)
|
|
221
|
+
if self._output_collections is not None:
|
|
222
|
+
new_collections = self._output_collections & new_collections
|
|
223
|
+
return self._copy(output_collections=new_collections)
|
|
224
|
+
|
|
225
|
+
# =========================================================================
|
|
226
|
+
# TRANSFORM (applied during materialization)
|
|
227
|
+
# =========================================================================
|
|
228
|
+
|
|
229
|
+
def map(self, fn: TransformFn) -> Query:
|
|
230
|
+
"""Apply transform to each volume during materialization.
|
|
231
|
+
|
|
232
|
+
Multiple map() calls compose: query.map(f1).map(f2) applies f1 then f2.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
fn: Function (X, Y, Z) -> (X', Y', Z'). Can change shape.
|
|
236
|
+
"""
|
|
237
|
+
if self._transform_fn is not None:
|
|
238
|
+
# Compose transforms: apply previous transform, then new one
|
|
239
|
+
prev_fn = self._transform_fn
|
|
240
|
+
|
|
241
|
+
def composed_fn(v: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
|
|
242
|
+
return fn(prev_fn(v))
|
|
243
|
+
|
|
244
|
+
return self._copy(transform_fn=composed_fn)
|
|
245
|
+
return self._copy(transform_fn=fn)
|
|
246
|
+
|
|
247
|
+
# =========================================================================
|
|
248
|
+
# MASK RESOLUTION
|
|
249
|
+
# =========================================================================
|
|
250
|
+
|
|
251
|
+
def _resolve_subject_mask(self) -> frozenset[str]:
|
|
252
|
+
"""Resolve subject-level filters to a set of obs_subject_ids.
|
|
253
|
+
|
|
254
|
+
This applies:
|
|
255
|
+
1. Explicit subject_ids filter
|
|
256
|
+
2. Query expression on obs_meta
|
|
257
|
+
"""
|
|
258
|
+
all_ids = set(self._source.obs_subject_ids)
|
|
259
|
+
|
|
260
|
+
# Apply explicit subject_ids filter
|
|
261
|
+
if self._subject_ids is not None:
|
|
262
|
+
all_ids &= self._subject_ids
|
|
263
|
+
|
|
264
|
+
# Apply query expression on obs_meta
|
|
265
|
+
if self._subject_query is not None:
|
|
266
|
+
filtered = self._source.obs_meta.read(value_filter=self._subject_query)
|
|
267
|
+
query_ids = set(filtered["obs_subject_id"])
|
|
268
|
+
all_ids &= query_ids
|
|
269
|
+
|
|
270
|
+
return frozenset(all_ids)
|
|
271
|
+
|
|
272
|
+
def _resolve_volume_mask(
|
|
273
|
+
self, collection_name: str, subject_mask: frozenset[str]
|
|
274
|
+
) -> frozenset[str]:
|
|
275
|
+
"""Resolve volume-level filters to a set of obs_ids for a collection.
|
|
276
|
+
|
|
277
|
+
Uses TileDB dimension slicing for efficient subject filtering,
|
|
278
|
+
combined with QueryCondition for attribute filters.
|
|
279
|
+
"""
|
|
280
|
+
vc = self._source.collection(collection_name)
|
|
281
|
+
effective_ctx = vc.obs._effective_ctx()
|
|
282
|
+
|
|
283
|
+
# Build query with dimension slicing for subject_mask
|
|
284
|
+
with tiledb.open(vc.obs.uri, "r", ctx=effective_ctx) as arr:
|
|
285
|
+
# Use multi-index for efficient dimension-based filtering
|
|
286
|
+
subject_list = list(subject_mask) if subject_mask else None
|
|
287
|
+
|
|
288
|
+
if subject_list:
|
|
289
|
+
# Query only rows matching subject_mask using dimension slicing
|
|
290
|
+
query = arr.query(dims=["obs_subject_id", "obs_id"])
|
|
291
|
+
if collection_name in self._collection_filters:
|
|
292
|
+
query = query.cond(self._collection_filters[collection_name])
|
|
293
|
+
result = query.multi_index[subject_list, :]
|
|
294
|
+
else:
|
|
295
|
+
# No subject filter - apply attribute filter only
|
|
296
|
+
if collection_name in self._collection_filters:
|
|
297
|
+
result = arr.query(cond=self._collection_filters[collection_name])[:]
|
|
298
|
+
else:
|
|
299
|
+
result = arr.query(attrs=[])[:]["obs_id"]
|
|
300
|
+
return frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in result)
|
|
301
|
+
|
|
302
|
+
obs_ids = result["obs_id"]
|
|
303
|
+
return frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in obs_ids)
|
|
304
|
+
|
|
305
|
+
def _resolve_final_subject_mask(self) -> frozenset[str]:
|
|
306
|
+
"""Resolve to final subject mask after applying all filters.
|
|
307
|
+
|
|
308
|
+
Uses TileDB dimension queries to efficiently find subjects with matching volumes.
|
|
309
|
+
"""
|
|
310
|
+
subject_mask = self._resolve_subject_mask()
|
|
311
|
+
|
|
312
|
+
# If there are collection filters, further filter subjects
|
|
313
|
+
# to only those with matching volumes
|
|
314
|
+
if self._collection_filters:
|
|
315
|
+
subjects_with_matching_volumes = set()
|
|
316
|
+
|
|
317
|
+
for coll_name in self._collection_filters:
|
|
318
|
+
if coll_name not in self._source.collection_names:
|
|
319
|
+
continue
|
|
320
|
+
|
|
321
|
+
vc = self._source.collection(coll_name)
|
|
322
|
+
effective_ctx = vc.obs._effective_ctx()
|
|
323
|
+
expr = self._collection_filters[coll_name]
|
|
324
|
+
|
|
325
|
+
# Query with attribute filter, only request obs_subject_id dimension
|
|
326
|
+
with tiledb.open(vc.obs.uri, "r", ctx=effective_ctx) as arr:
|
|
327
|
+
result = arr.query(cond=expr, dims=["obs_subject_id"], attrs=[])[:]
|
|
328
|
+
subject_ids = result["obs_subject_id"]
|
|
329
|
+
for sid in subject_ids:
|
|
330
|
+
s = sid.decode() if isinstance(sid, bytes) else str(sid)
|
|
331
|
+
if s in subject_mask:
|
|
332
|
+
subjects_with_matching_volumes.add(s)
|
|
333
|
+
|
|
334
|
+
subject_mask = subject_mask & frozenset(subjects_with_matching_volumes)
|
|
335
|
+
|
|
336
|
+
return subject_mask
|
|
337
|
+
|
|
338
|
+
def _resolve_output_collections(self) -> tuple[str, ...]:
|
|
339
|
+
"""Resolve which collections to include in output."""
|
|
340
|
+
if self._output_collections is not None:
|
|
341
|
+
return tuple(
|
|
342
|
+
name for name in self._source.collection_names if name in self._output_collections
|
|
343
|
+
)
|
|
344
|
+
return self._source.collection_names
|
|
345
|
+
|
|
346
|
+
# =========================================================================
|
|
347
|
+
# MATERIALIZATION (triggers data access)
|
|
348
|
+
# =========================================================================
|
|
349
|
+
|
|
350
|
+
def count(self) -> QueryCount:
|
|
351
|
+
"""Count subjects and volumes matching the query without loading volume data."""
|
|
352
|
+
subject_mask = self._resolve_final_subject_mask()
|
|
353
|
+
output_collections = self._resolve_output_collections()
|
|
354
|
+
|
|
355
|
+
volume_counts: dict[str, int] = {}
|
|
356
|
+
for name in output_collections:
|
|
357
|
+
volume_mask = self._resolve_volume_mask(name, subject_mask)
|
|
358
|
+
volume_counts[name] = len(volume_mask)
|
|
359
|
+
|
|
360
|
+
return QueryCount(n_subjects=len(subject_mask), n_volumes=volume_counts)
|
|
361
|
+
|
|
362
|
+
def to_obs_meta(self) -> pd.DataFrame:
|
|
363
|
+
"""Return filtered obs_meta DataFrame."""
|
|
364
|
+
subject_mask = self._resolve_final_subject_mask()
|
|
365
|
+
obs_meta = self._source.obs_meta.read()
|
|
366
|
+
return obs_meta[obs_meta["obs_subject_id"].isin(subject_mask)].reset_index(drop=True)
|
|
367
|
+
|
|
368
|
+
def iter_volumes(self, collection_name: str | None = None) -> Iterator[Volume]:
|
|
369
|
+
"""Iterate over volumes matching the query.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
collection_name: Specific collection to iterate. If None, iterates
|
|
373
|
+
over all output collections.
|
|
374
|
+
"""
|
|
375
|
+
subject_mask = self._resolve_final_subject_mask()
|
|
376
|
+
output_collections = self._resolve_output_collections()
|
|
377
|
+
|
|
378
|
+
if collection_name is not None:
|
|
379
|
+
if collection_name not in output_collections:
|
|
380
|
+
raise ValueError(f"Collection '{collection_name}' not in query scope")
|
|
381
|
+
output_collections = (collection_name,)
|
|
382
|
+
|
|
383
|
+
for coll_name in output_collections:
|
|
384
|
+
vc = self._source.collection(coll_name)
|
|
385
|
+
volume_mask = self._resolve_volume_mask(coll_name, subject_mask)
|
|
386
|
+
|
|
387
|
+
for obs_id in volume_mask:
|
|
388
|
+
yield vc.loc[obs_id]
|
|
389
|
+
|
|
390
|
+
def iter_batches(self, batch_size: int = 32) -> Iterator[VolumeBatch]:
|
|
391
|
+
"""Iterate over batched volumes for ML training.
|
|
392
|
+
|
|
393
|
+
Yields VolumeBatch with stacked numpy arrays for each collection.
|
|
394
|
+
Batches are grouped by subject.
|
|
395
|
+
"""
|
|
396
|
+
subject_mask = self._resolve_final_subject_mask()
|
|
397
|
+
subject_list = sorted(subject_mask)
|
|
398
|
+
output_collections = self._resolve_output_collections()
|
|
399
|
+
|
|
400
|
+
for i in range(0, len(subject_list), batch_size):
|
|
401
|
+
batch_subjects = subject_list[i : i + batch_size]
|
|
402
|
+
batch_subject_set = frozenset(batch_subjects)
|
|
403
|
+
|
|
404
|
+
volumes: dict[str, npt.NDArray[np.floating]] = {}
|
|
405
|
+
obs_ids: dict[str, tuple[str, ...]] = {}
|
|
406
|
+
|
|
407
|
+
for coll_name in output_collections:
|
|
408
|
+
volume_mask = self._resolve_volume_mask(coll_name, batch_subject_set)
|
|
409
|
+
vc = self._source.collection(coll_name)
|
|
410
|
+
|
|
411
|
+
batch_arrays = []
|
|
412
|
+
batch_obs_ids = []
|
|
413
|
+
for obs_id in sorted(volume_mask):
|
|
414
|
+
vol = vc.loc[obs_id]
|
|
415
|
+
batch_arrays.append(vol.to_numpy())
|
|
416
|
+
batch_obs_ids.append(obs_id)
|
|
417
|
+
|
|
418
|
+
if batch_arrays:
|
|
419
|
+
volumes[coll_name] = np.stack(batch_arrays, axis=0)
|
|
420
|
+
obs_ids[coll_name] = tuple(batch_obs_ids)
|
|
421
|
+
|
|
422
|
+
yield VolumeBatch(
|
|
423
|
+
volumes=volumes,
|
|
424
|
+
subject_ids=tuple(batch_subjects),
|
|
425
|
+
obs_ids=obs_ids,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
def materialize(
|
|
429
|
+
self,
|
|
430
|
+
uri: str,
|
|
431
|
+
streaming: bool = True,
|
|
432
|
+
ctx: tiledb.Ctx | None = None,
|
|
433
|
+
) -> RadiObject:
|
|
434
|
+
"""Materialize query results as a new RadiObject.
|
|
435
|
+
|
|
436
|
+
If a transform was set via map(), it is applied to each volume
|
|
437
|
+
during materialization. If the transform changes volume shapes,
|
|
438
|
+
the output collection becomes heterogeneous.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
uri: Target URI for the new RadiObject
|
|
442
|
+
streaming: Use streaming writer for memory efficiency (default: True)
|
|
443
|
+
ctx: TileDB context
|
|
444
|
+
"""
|
|
445
|
+
from radiobject.radi_object import RadiObject
|
|
446
|
+
from radiobject.streaming import RadiObjectWriter
|
|
447
|
+
|
|
448
|
+
subject_mask = self._resolve_final_subject_mask()
|
|
449
|
+
output_collections = self._resolve_output_collections()
|
|
450
|
+
|
|
451
|
+
# If no transform, create a RadiObject view and materialize it
|
|
452
|
+
if self._transform_fn is None:
|
|
453
|
+
view = RadiObject(
|
|
454
|
+
uri=None,
|
|
455
|
+
ctx=ctx,
|
|
456
|
+
_source=self._source,
|
|
457
|
+
_subject_ids=subject_mask,
|
|
458
|
+
_collection_names=frozenset(output_collections),
|
|
459
|
+
)
|
|
460
|
+
return view.materialize(uri, streaming=streaming, ctx=ctx)
|
|
461
|
+
|
|
462
|
+
# With transform: use streaming writer and apply transform to each volume
|
|
463
|
+
obs_meta_df = self._source.obs_meta.read()
|
|
464
|
+
filtered_obs_meta = obs_meta_df[
|
|
465
|
+
obs_meta_df["obs_subject_id"].isin(subject_mask)
|
|
466
|
+
].reset_index(drop=True)
|
|
467
|
+
|
|
468
|
+
obs_meta_schema: dict[str, np.dtype] = {}
|
|
469
|
+
for col in filtered_obs_meta.columns:
|
|
470
|
+
if col in ("obs_subject_id", "obs_id"):
|
|
471
|
+
continue
|
|
472
|
+
dtype = filtered_obs_meta[col].to_numpy().dtype
|
|
473
|
+
if dtype == np.dtype("O"):
|
|
474
|
+
dtype = np.dtype("U64")
|
|
475
|
+
obs_meta_schema[col] = dtype
|
|
476
|
+
|
|
477
|
+
with RadiObjectWriter(uri, obs_meta_schema=obs_meta_schema, ctx=ctx) as writer:
|
|
478
|
+
writer.write_obs_meta(filtered_obs_meta)
|
|
479
|
+
|
|
480
|
+
for coll_name in output_collections:
|
|
481
|
+
src_collection = self._source.collection(coll_name)
|
|
482
|
+
volume_mask = self._resolve_volume_mask(coll_name, subject_mask)
|
|
483
|
+
|
|
484
|
+
if not volume_mask:
|
|
485
|
+
continue
|
|
486
|
+
|
|
487
|
+
obs_df = src_collection.obs.read()
|
|
488
|
+
filtered_obs = obs_df[obs_df["obs_id"].isin(volume_mask)]
|
|
489
|
+
|
|
490
|
+
obs_schema: dict[str, np.dtype] = {}
|
|
491
|
+
for col in src_collection.obs.columns:
|
|
492
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
493
|
+
continue
|
|
494
|
+
obs_schema[col] = src_collection.obs.dtypes[col]
|
|
495
|
+
|
|
496
|
+
# Transform may change shape, so output is heterogeneous (shape=None)
|
|
497
|
+
with writer.add_collection(
|
|
498
|
+
coll_name, shape=None, obs_schema=obs_schema
|
|
499
|
+
) as coll_writer:
|
|
500
|
+
for _, row in filtered_obs.iterrows():
|
|
501
|
+
obs_id = row["obs_id"]
|
|
502
|
+
vol = src_collection.loc[obs_id]
|
|
503
|
+
data = vol.to_numpy()
|
|
504
|
+
|
|
505
|
+
data = self._transform_fn(data)
|
|
506
|
+
|
|
507
|
+
attrs = {
|
|
508
|
+
k: v for k, v in row.items() if k not in ("obs_id", "obs_subject_id")
|
|
509
|
+
}
|
|
510
|
+
coll_writer.write_volume(
|
|
511
|
+
data=data,
|
|
512
|
+
obs_id=obs_id,
|
|
513
|
+
obs_subject_id=row["obs_subject_id"],
|
|
514
|
+
**attrs,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
return RadiObject(uri, ctx=ctx)
|
|
518
|
+
|
|
519
|
+
def __len__(self) -> int:
|
|
520
|
+
"""Number of subjects matching the query."""
|
|
521
|
+
return len(self._resolve_final_subject_mask())
|
|
522
|
+
|
|
523
|
+
def __repr__(self) -> str:
|
|
524
|
+
"""Concise representation of the Query."""
|
|
525
|
+
count = self.count()
|
|
526
|
+
collections = ", ".join(count.n_volumes.keys())
|
|
527
|
+
return (
|
|
528
|
+
f"Query({count.n_subjects} subjects, "
|
|
529
|
+
f"{sum(count.n_volumes.values())} volumes across [{collections}])"
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
class CollectionQuery:
|
|
534
|
+
"""Lazy filter builder for VolumeCollection with explicit materialization.
|
|
535
|
+
|
|
536
|
+
Example:
|
|
537
|
+
high_res = (
|
|
538
|
+
radi.T1w.lazy()
|
|
539
|
+
.filter("voxel_spacing == '1.0x1.0x1.0'")
|
|
540
|
+
.head(100)
|
|
541
|
+
.map(normalize)
|
|
542
|
+
.materialize("./output")
|
|
543
|
+
)
|
|
544
|
+
"""
|
|
545
|
+
|
|
546
|
+
def __init__(
|
|
547
|
+
self,
|
|
548
|
+
source: VolumeCollection,
|
|
549
|
+
*,
|
|
550
|
+
volume_ids: frozenset[str] | None = None,
|
|
551
|
+
volume_query: str | None = None,
|
|
552
|
+
subject_ids: frozenset[str] | None = None,
|
|
553
|
+
transform_fn: TransformFn | None = None,
|
|
554
|
+
):
|
|
555
|
+
self._source = source
|
|
556
|
+
self._volume_ids = volume_ids
|
|
557
|
+
self._volume_query = volume_query
|
|
558
|
+
self._subject_ids = subject_ids
|
|
559
|
+
self._transform_fn = transform_fn
|
|
560
|
+
|
|
561
|
+
def _copy(self, **kwargs) -> CollectionQuery:
|
|
562
|
+
"""Create a copy with modified fields."""
|
|
563
|
+
return CollectionQuery(
|
|
564
|
+
self._source,
|
|
565
|
+
volume_ids=kwargs.get("volume_ids", self._volume_ids),
|
|
566
|
+
volume_query=kwargs.get("volume_query", self._volume_query),
|
|
567
|
+
subject_ids=kwargs.get("subject_ids", self._subject_ids),
|
|
568
|
+
transform_fn=kwargs.get("transform_fn", self._transform_fn),
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# =========================================================================
|
|
572
|
+
# FILTER METHODS
|
|
573
|
+
# =========================================================================
|
|
574
|
+
|
|
575
|
+
def filter(self, expr: str) -> CollectionQuery:
|
|
576
|
+
"""Filter volumes using TileDB QueryCondition on obs."""
|
|
577
|
+
new_query = self._volume_query
|
|
578
|
+
if new_query is None:
|
|
579
|
+
new_query = expr
|
|
580
|
+
else:
|
|
581
|
+
new_query = f"({new_query}) and ({expr})"
|
|
582
|
+
return self._copy(volume_query=new_query)
|
|
583
|
+
|
|
584
|
+
def filter_subjects(self, ids: Sequence[str]) -> CollectionQuery:
|
|
585
|
+
"""Filter to volumes belonging to specific subject IDs."""
|
|
586
|
+
new_ids = frozenset(ids)
|
|
587
|
+
if self._subject_ids is not None:
|
|
588
|
+
new_ids = self._subject_ids & new_ids
|
|
589
|
+
return self._copy(subject_ids=new_ids)
|
|
590
|
+
|
|
591
|
+
def iloc(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> CollectionQuery:
|
|
592
|
+
"""Filter volumes by integer position(s)."""
|
|
593
|
+
n = len(self._source)
|
|
594
|
+
if isinstance(key, int):
|
|
595
|
+
idx = key if key >= 0 else n + key
|
|
596
|
+
obs_id = self._source._index.get_key(idx)
|
|
597
|
+
return self._copy(volume_ids=frozenset([obs_id]))
|
|
598
|
+
elif isinstance(key, slice):
|
|
599
|
+
indices = list(range(*key.indices(n)))
|
|
600
|
+
elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
|
|
601
|
+
indices = list(np.where(key)[0])
|
|
602
|
+
elif isinstance(key, list):
|
|
603
|
+
indices = [i if i >= 0 else n + i for i in key]
|
|
604
|
+
else:
|
|
605
|
+
raise TypeError("iloc key must be int, slice, list[int], or bool array")
|
|
606
|
+
|
|
607
|
+
obs_ids = frozenset(self._source._index.get_key(i) for i in indices)
|
|
608
|
+
new_ids = obs_ids
|
|
609
|
+
if self._volume_ids is not None:
|
|
610
|
+
new_ids = self._volume_ids & obs_ids
|
|
611
|
+
return self._copy(volume_ids=new_ids)
|
|
612
|
+
|
|
613
|
+
def loc(self, key: str | Sequence[str]) -> CollectionQuery:
|
|
614
|
+
"""Filter volumes by obs_id(s)."""
|
|
615
|
+
if isinstance(key, str):
|
|
616
|
+
ids = frozenset([key])
|
|
617
|
+
else:
|
|
618
|
+
ids = frozenset(key)
|
|
619
|
+
new_ids = ids
|
|
620
|
+
if self._volume_ids is not None:
|
|
621
|
+
new_ids = self._volume_ids & ids
|
|
622
|
+
return self._copy(volume_ids=new_ids)
|
|
623
|
+
|
|
624
|
+
def head(self, n: int = 5) -> CollectionQuery:
|
|
625
|
+
"""Filter to first n volumes."""
|
|
626
|
+
return self.iloc(slice(0, n))
|
|
627
|
+
|
|
628
|
+
def tail(self, n: int = 5) -> CollectionQuery:
|
|
629
|
+
"""Filter to last n volumes."""
|
|
630
|
+
total = len(self._source)
|
|
631
|
+
return self.iloc(slice(max(0, total - n), total))
|
|
632
|
+
|
|
633
|
+
def sample(self, n: int = 5, seed: int | None = None) -> CollectionQuery:
|
|
634
|
+
"""Filter to n randomly sampled volumes."""
|
|
635
|
+
rng = np.random.default_rng(seed)
|
|
636
|
+
resolved = list(self._resolve_volume_mask())
|
|
637
|
+
n = min(n, len(resolved))
|
|
638
|
+
sampled = rng.choice(resolved, size=n, replace=False)
|
|
639
|
+
return self._copy(volume_ids=frozenset(sampled))
|
|
640
|
+
|
|
641
|
+
# =========================================================================
|
|
642
|
+
# TRANSFORM (applied during materialization)
|
|
643
|
+
# =========================================================================
|
|
644
|
+
|
|
645
|
+
def map(self, fn: TransformFn) -> CollectionQuery:
|
|
646
|
+
"""Apply transform to each volume during materialization.
|
|
647
|
+
|
|
648
|
+
Multiple map() calls compose: query.map(f1).map(f2) applies f1 then f2.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
fn: Function (X, Y, Z) -> (X', Y', Z'). Can change shape.
|
|
652
|
+
"""
|
|
653
|
+
if self._transform_fn is not None:
|
|
654
|
+
# Compose transforms: apply previous transform, then new one
|
|
655
|
+
prev_fn = self._transform_fn
|
|
656
|
+
|
|
657
|
+
def composed_fn(v: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
|
|
658
|
+
return fn(prev_fn(v))
|
|
659
|
+
|
|
660
|
+
return self._copy(transform_fn=composed_fn)
|
|
661
|
+
return self._copy(transform_fn=fn)
|
|
662
|
+
|
|
663
|
+
# =========================================================================
|
|
664
|
+
# MASK RESOLUTION
|
|
665
|
+
# =========================================================================
|
|
666
|
+
|
|
667
|
+
def _resolve_volume_mask(self) -> frozenset[str]:
|
|
668
|
+
"""Resolve all filters to a set of obs_ids using TileDB-native queries."""
|
|
669
|
+
effective_ctx = self._source.obs._effective_ctx()
|
|
670
|
+
|
|
671
|
+
with tiledb.open(self._source.obs.uri, "r", ctx=effective_ctx) as arr:
|
|
672
|
+
# Build query based on filters
|
|
673
|
+
if self._subject_ids is not None and self._volume_query is not None:
|
|
674
|
+
# Both subject and attribute filters - use dimension slicing + QueryCondition
|
|
675
|
+
query = arr.query(cond=self._volume_query, dims=["obs_id"])
|
|
676
|
+
result = query.multi_index[list(self._subject_ids), :]
|
|
677
|
+
elif self._subject_ids is not None:
|
|
678
|
+
# Subject filter only - use dimension slicing
|
|
679
|
+
result = arr.query(dims=["obs_id"]).multi_index[list(self._subject_ids), :]
|
|
680
|
+
elif self._volume_query is not None:
|
|
681
|
+
# Attribute filter only - use QueryCondition
|
|
682
|
+
result = arr.query(cond=self._volume_query, dims=["obs_id"])[:]
|
|
683
|
+
else:
|
|
684
|
+
# No filters - return all obs_ids
|
|
685
|
+
result = arr.query(dims=["obs_id"], attrs=[])[:]
|
|
686
|
+
|
|
687
|
+
obs_ids = result["obs_id"]
|
|
688
|
+
all_ids = frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in obs_ids)
|
|
689
|
+
|
|
690
|
+
# Apply explicit volume_ids filter (intersection with pre-specified IDs)
|
|
691
|
+
if self._volume_ids is not None:
|
|
692
|
+
all_ids &= self._volume_ids
|
|
693
|
+
|
|
694
|
+
return all_ids
|
|
695
|
+
|
|
696
|
+
# =========================================================================
|
|
697
|
+
# MATERIALIZATION
|
|
698
|
+
# =========================================================================
|
|
699
|
+
|
|
700
|
+
def count(self) -> int:
|
|
701
|
+
"""Count volumes matching the query."""
|
|
702
|
+
return len(self._resolve_volume_mask())
|
|
703
|
+
|
|
704
|
+
def to_obs(self) -> pd.DataFrame:
|
|
705
|
+
"""Return filtered obs DataFrame."""
|
|
706
|
+
volume_mask = self._resolve_volume_mask()
|
|
707
|
+
obs_df = self._source.obs.read()
|
|
708
|
+
return obs_df[obs_df["obs_id"].isin(volume_mask)].reset_index(drop=True)
|
|
709
|
+
|
|
710
|
+
def iter_volumes(self) -> Iterator[Volume]:
|
|
711
|
+
"""Iterate over volumes matching the query."""
|
|
712
|
+
volume_mask = self._resolve_volume_mask()
|
|
713
|
+
for obs_id in sorted(volume_mask):
|
|
714
|
+
yield self._source.loc[obs_id]
|
|
715
|
+
|
|
716
|
+
def to_numpy_stack(self) -> npt.NDArray[np.floating]:
|
|
717
|
+
"""Load all matching volumes as stacked numpy array (N, X, Y, Z)."""
|
|
718
|
+
volume_mask = self._resolve_volume_mask()
|
|
719
|
+
if not volume_mask:
|
|
720
|
+
raise ValueError("No volumes match the query")
|
|
721
|
+
|
|
722
|
+
arrays = [self._source.loc[obs_id].to_numpy() for obs_id in sorted(volume_mask)]
|
|
723
|
+
return np.stack(arrays, axis=0)
|
|
724
|
+
|
|
725
|
+
def materialize(
|
|
726
|
+
self,
|
|
727
|
+
uri: str,
|
|
728
|
+
name: str | None = None,
|
|
729
|
+
ctx: tiledb.Ctx | None = None,
|
|
730
|
+
) -> VolumeCollection:
|
|
731
|
+
"""Materialize query results as a new VolumeCollection.
|
|
732
|
+
|
|
733
|
+
If a transform was set via map(), it is applied to each volume
|
|
734
|
+
during materialization. If the transform changes volume shapes,
|
|
735
|
+
the output collection becomes heterogeneous (shape=None).
|
|
736
|
+
"""
|
|
737
|
+
from radiobject.streaming import StreamingWriter
|
|
738
|
+
|
|
739
|
+
volume_mask = self._resolve_volume_mask()
|
|
740
|
+
if not volume_mask:
|
|
741
|
+
raise ValueError("No volumes match the query")
|
|
742
|
+
|
|
743
|
+
obs_df = self.to_obs()
|
|
744
|
+
collection_name = name or self._source.name
|
|
745
|
+
|
|
746
|
+
obs_schema: dict[str, np.dtype] = {}
|
|
747
|
+
for col in self._source.obs.columns:
|
|
748
|
+
if col in ("obs_id", "obs_subject_id"):
|
|
749
|
+
continue
|
|
750
|
+
obs_schema[col] = self._source.obs.dtypes[col]
|
|
751
|
+
|
|
752
|
+
output_shape = None if self._transform_fn is not None else self._source.shape
|
|
753
|
+
|
|
754
|
+
with StreamingWriter(
|
|
755
|
+
uri=uri,
|
|
756
|
+
shape=output_shape,
|
|
757
|
+
obs_schema=obs_schema,
|
|
758
|
+
name=collection_name,
|
|
759
|
+
ctx=ctx,
|
|
760
|
+
) as writer:
|
|
761
|
+
for obs_id in sorted(volume_mask):
|
|
762
|
+
vol = self._source.loc[obs_id]
|
|
763
|
+
data = vol.to_numpy()
|
|
764
|
+
|
|
765
|
+
if self._transform_fn is not None:
|
|
766
|
+
data = self._transform_fn(data)
|
|
767
|
+
|
|
768
|
+
obs_row = obs_df[obs_df["obs_id"] == obs_id].iloc[0]
|
|
769
|
+
attrs = {k: v for k, v in obs_row.items() if k not in ("obs_id", "obs_subject_id")}
|
|
770
|
+
writer.write_volume(
|
|
771
|
+
data=data,
|
|
772
|
+
obs_id=obs_id,
|
|
773
|
+
obs_subject_id=obs_row["obs_subject_id"],
|
|
774
|
+
**attrs,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
return self._source.__class__(uri, ctx=ctx)
|
|
778
|
+
|
|
779
|
+
def __len__(self) -> int:
|
|
780
|
+
"""Number of volumes matching the query."""
|
|
781
|
+
return len(self._resolve_volume_mask())
|
|
782
|
+
|
|
783
|
+
def __repr__(self) -> str:
|
|
784
|
+
"""Concise representation of the CollectionQuery."""
|
|
785
|
+
n = len(self)
|
|
786
|
+
name = self._source.name or "unnamed"
|
|
787
|
+
shape = "x".join(str(d) for d in self._source.shape)
|
|
788
|
+
return f"CollectionQuery('{name}', {n} volumes, shape={shape})"
|