radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
radiobject/query.py ADDED
@@ -0,0 +1,788 @@
1
+ """Lazy query builder pattern for RadiObject and VolumeCollection filtering.
2
+
3
+ Query Design:
4
+ =============
5
+ Queries work by building up filter conditions and resolving them to masks (sets of IDs).
6
+ The flow is:
7
+
8
+ 1. Start with RadiObject.lazy() or VolumeCollection.lazy()
9
+ 2. Add filter conditions (obs_meta filters, collection filters)
10
+ 3. Add transforms via .map() for compute-intensive operations
11
+ 4. Resolve to masks:
12
+ - Subject mask: set of obs_subject_ids matching criteria
13
+ - Volume mask: set of obs_ids within each collection matching criteria
14
+ 5. Apply masks via materialization (iter_volumes, materialize, etc.)
15
+
16
+ Key Concepts:
17
+ - obs_meta: Subject-level metadata (indexed by obs_subject_id)
18
+ - obs: Volume-level metadata per collection (indexed by obs_id, contains obs_subject_id FK)
19
+ - A subject matches if it passes obs_meta filter AND has at least one volume passing collection filters
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Iterator, Sequence
26
+
27
+ import numpy as np
28
+ import numpy.typing as npt
29
+ import pandas as pd
30
+ import tiledb
31
+
32
+ from radiobject._types import TransformFn
33
+ from radiobject.volume import Volume
34
+
35
+ if TYPE_CHECKING:
36
+ from radiobject.radi_object import RadiObject
37
+ from radiobject.volume_collection import VolumeCollection
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class VolumeBatch:
42
+ """Batch of volumes for ML training with stacked numpy arrays."""
43
+
44
+ volumes: dict[str, npt.NDArray[np.floating]] # collection_name -> (N, X, Y, Z)
45
+ subject_ids: tuple[str, ...]
46
+ obs_ids: dict[str, tuple[str, ...]] # collection_name -> obs_ids
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class QueryCount:
51
+ """Count results for a query."""
52
+
53
+ n_subjects: int
54
+ n_volumes: dict[str, int] # collection_name -> volume count
55
+
56
+
57
+ class Query:
58
+ """Lazy filter builder for RadiObject with explicit materialization.
59
+
60
+ Filters accumulate without data access. Call iter_volumes(), materialize(),
61
+ or count() to materialize results.
62
+
63
+ Example:
64
+ result = (
65
+ radi.lazy()
66
+ .filter("age > 40")
67
+ .filter_collection("T1w", "voxel_spacing == '1.0x1.0x1.0'")
68
+ .map(normalize_intensity)
69
+ .materialize("s3://bucket/subset")
70
+ )
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ source: RadiObject,
76
+ *,
77
+ # Subject-level filters (on obs_meta)
78
+ subject_ids: frozenset[str] | None = None,
79
+ subject_query: str | None = None,
80
+ # Collection-level filters (on each collection's obs)
81
+ collection_filters: dict[str, str] | None = None, # collection_name -> query expr
82
+ # Output scope
83
+ output_collections: frozenset[str] | None = None,
84
+ # Transform function applied during materialization
85
+ transform_fn: TransformFn | None = None,
86
+ ):
87
+ self._source = source
88
+ self._subject_ids = subject_ids
89
+ self._subject_query = subject_query
90
+ self._collection_filters = collection_filters or {}
91
+ self._output_collections = output_collections
92
+ self._transform_fn = transform_fn
93
+
94
+ def _copy(self, **kwargs) -> Query:
95
+ """Create a copy with modified fields."""
96
+ return Query(
97
+ self._source,
98
+ subject_ids=kwargs.get("subject_ids", self._subject_ids),
99
+ subject_query=kwargs.get("subject_query", self._subject_query),
100
+ collection_filters=kwargs.get("collection_filters", self._collection_filters),
101
+ output_collections=kwargs.get("output_collections", self._output_collections),
102
+ transform_fn=kwargs.get("transform_fn", self._transform_fn),
103
+ )
104
+
105
+ # =========================================================================
106
+ # SUBJECT-LEVEL FILTERS (operate on obs_meta)
107
+ # =========================================================================
108
+
109
+ def filter(self, expr: str) -> Query:
110
+ """Filter subjects using TileDB QueryCondition on obs_meta.
111
+
112
+ Args:
113
+ expr: TileDB query expression (e.g., "age > 40 and diagnosis == 'tumor'")
114
+
115
+ Returns:
116
+ New Query with filter applied
117
+ """
118
+ new_query = self._subject_query
119
+ if new_query is None:
120
+ new_query = expr
121
+ else:
122
+ new_query = f"({new_query}) and ({expr})"
123
+ return self._copy(subject_query=new_query)
124
+
125
+ def filter_subjects(self, ids: Sequence[str]) -> Query:
126
+ """Filter to specific subject IDs.
127
+
128
+ Args:
129
+ ids: List of obs_subject_ids to include
130
+
131
+ Returns:
132
+ New Query with filter applied
133
+ """
134
+ new_ids = frozenset(ids)
135
+ if self._subject_ids is not None:
136
+ new_ids = self._subject_ids & new_ids
137
+ return self._copy(subject_ids=new_ids)
138
+
139
+ def iloc(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> Query:
140
+ """Filter subjects by integer position(s)."""
141
+ n = len(self._source)
142
+ if isinstance(key, int):
143
+ idx = key if key >= 0 else n + key
144
+ indices = [idx]
145
+ elif isinstance(key, slice):
146
+ indices = list(range(*key.indices(n)))
147
+ elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
148
+ indices = list(np.where(key)[0])
149
+ elif isinstance(key, list):
150
+ indices = [i if i >= 0 else n + i for i in key]
151
+ else:
152
+ raise TypeError("iloc key must be int, slice, list[int], or bool array")
153
+
154
+ ids = [self._source._index.get_key(i) for i in indices]
155
+ return self.filter_subjects(ids)
156
+
157
+ def loc(self, key: str | Sequence[str]) -> Query:
158
+ """Filter subjects by obs_subject_id(s)."""
159
+ if isinstance(key, str):
160
+ return self.filter_subjects([key])
161
+ return self.filter_subjects(key)
162
+
163
+ def head(self, n: int = 5) -> Query:
164
+ """Filter to first n subjects."""
165
+ return self.iloc(slice(0, n))
166
+
167
+ def tail(self, n: int = 5) -> Query:
168
+ """Filter to last n subjects."""
169
+ total = len(self._source)
170
+ return self.iloc(slice(max(0, total - n), total))
171
+
172
+ def sample(self, n: int = 5, seed: int | None = None) -> Query:
173
+ """Filter to n randomly sampled subjects."""
174
+ rng = np.random.default_rng(seed)
175
+ resolved = self._resolve_subject_mask()
176
+ subject_list = list(resolved)
177
+ n = min(n, len(subject_list))
178
+ sampled = rng.choice(subject_list, size=n, replace=False)
179
+ return self.filter_subjects(sampled)
180
+
181
+ # =========================================================================
182
+ # COLLECTION-LEVEL FILTERS (operate on collection obs)
183
+ # =========================================================================
184
+
185
+ def filter_collection(self, collection_name: str, expr: str) -> Query:
186
+ """Filter volumes in a specific collection using TileDB QueryCondition on obs.
187
+
188
+ Only subjects that have at least one volume matching the filter will be included.
189
+
190
+ Args:
191
+ collection_name: Name of the collection to filter
192
+ expr: TileDB query expression on the collection's obs dataframe
193
+
194
+ Returns:
195
+ New Query with filter applied
196
+
197
+ Example:
198
+ # Only include subjects whose T1w scans have 1mm resolution
199
+ query.filter_collection("T1w", "voxel_spacing == '1.0x1.0x1.0'")
200
+ """
201
+ new_filters = dict(self._collection_filters)
202
+ if collection_name in new_filters:
203
+ new_filters[collection_name] = f"({new_filters[collection_name]}) and ({expr})"
204
+ else:
205
+ new_filters[collection_name] = expr
206
+ return self._copy(collection_filters=new_filters)
207
+
208
+ def select_collections(self, names: Sequence[str]) -> Query:
209
+ """Limit output to specific collections.
210
+
211
+ This doesn't affect filtering - it only limits which collections
212
+ appear in the output.
213
+
214
+ Args:
215
+ names: Collection names to include in output
216
+
217
+ Returns:
218
+ New Query with output scope set
219
+ """
220
+ new_collections = frozenset(names)
221
+ if self._output_collections is not None:
222
+ new_collections = self._output_collections & new_collections
223
+ return self._copy(output_collections=new_collections)
224
+
225
+ # =========================================================================
226
+ # TRANSFORM (applied during materialization)
227
+ # =========================================================================
228
+
229
+ def map(self, fn: TransformFn) -> Query:
230
+ """Apply transform to each volume during materialization.
231
+
232
+ Multiple map() calls compose: query.map(f1).map(f2) applies f1 then f2.
233
+
234
+ Args:
235
+ fn: Function (X, Y, Z) -> (X', Y', Z'). Can change shape.
236
+ """
237
+ if self._transform_fn is not None:
238
+ # Compose transforms: apply previous transform, then new one
239
+ prev_fn = self._transform_fn
240
+
241
+ def composed_fn(v: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
242
+ return fn(prev_fn(v))
243
+
244
+ return self._copy(transform_fn=composed_fn)
245
+ return self._copy(transform_fn=fn)
246
+
247
+ # =========================================================================
248
+ # MASK RESOLUTION
249
+ # =========================================================================
250
+
251
+ def _resolve_subject_mask(self) -> frozenset[str]:
252
+ """Resolve subject-level filters to a set of obs_subject_ids.
253
+
254
+ This applies:
255
+ 1. Explicit subject_ids filter
256
+ 2. Query expression on obs_meta
257
+ """
258
+ all_ids = set(self._source.obs_subject_ids)
259
+
260
+ # Apply explicit subject_ids filter
261
+ if self._subject_ids is not None:
262
+ all_ids &= self._subject_ids
263
+
264
+ # Apply query expression on obs_meta
265
+ if self._subject_query is not None:
266
+ filtered = self._source.obs_meta.read(value_filter=self._subject_query)
267
+ query_ids = set(filtered["obs_subject_id"])
268
+ all_ids &= query_ids
269
+
270
+ return frozenset(all_ids)
271
+
272
+ def _resolve_volume_mask(
273
+ self, collection_name: str, subject_mask: frozenset[str]
274
+ ) -> frozenset[str]:
275
+ """Resolve volume-level filters to a set of obs_ids for a collection.
276
+
277
+ Uses TileDB dimension slicing for efficient subject filtering,
278
+ combined with QueryCondition for attribute filters.
279
+ """
280
+ vc = self._source.collection(collection_name)
281
+ effective_ctx = vc.obs._effective_ctx()
282
+
283
+ # Build query with dimension slicing for subject_mask
284
+ with tiledb.open(vc.obs.uri, "r", ctx=effective_ctx) as arr:
285
+ # Use multi-index for efficient dimension-based filtering
286
+ subject_list = list(subject_mask) if subject_mask else None
287
+
288
+ if subject_list:
289
+ # Query only rows matching subject_mask using dimension slicing
290
+ query = arr.query(dims=["obs_subject_id", "obs_id"])
291
+ if collection_name in self._collection_filters:
292
+ query = query.cond(self._collection_filters[collection_name])
293
+ result = query.multi_index[subject_list, :]
294
+ else:
295
+ # No subject filter - apply attribute filter only
296
+ if collection_name in self._collection_filters:
297
+ result = arr.query(cond=self._collection_filters[collection_name])[:]
298
+ else:
299
+ result = arr.query(attrs=[])[:]["obs_id"]
300
+ return frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in result)
301
+
302
+ obs_ids = result["obs_id"]
303
+ return frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in obs_ids)
304
+
305
+ def _resolve_final_subject_mask(self) -> frozenset[str]:
306
+ """Resolve to final subject mask after applying all filters.
307
+
308
+ Uses TileDB dimension queries to efficiently find subjects with matching volumes.
309
+ """
310
+ subject_mask = self._resolve_subject_mask()
311
+
312
+ # If there are collection filters, further filter subjects
313
+ # to only those with matching volumes
314
+ if self._collection_filters:
315
+ subjects_with_matching_volumes = set()
316
+
317
+ for coll_name in self._collection_filters:
318
+ if coll_name not in self._source.collection_names:
319
+ continue
320
+
321
+ vc = self._source.collection(coll_name)
322
+ effective_ctx = vc.obs._effective_ctx()
323
+ expr = self._collection_filters[coll_name]
324
+
325
+ # Query with attribute filter, only request obs_subject_id dimension
326
+ with tiledb.open(vc.obs.uri, "r", ctx=effective_ctx) as arr:
327
+ result = arr.query(cond=expr, dims=["obs_subject_id"], attrs=[])[:]
328
+ subject_ids = result["obs_subject_id"]
329
+ for sid in subject_ids:
330
+ s = sid.decode() if isinstance(sid, bytes) else str(sid)
331
+ if s in subject_mask:
332
+ subjects_with_matching_volumes.add(s)
333
+
334
+ subject_mask = subject_mask & frozenset(subjects_with_matching_volumes)
335
+
336
+ return subject_mask
337
+
338
+ def _resolve_output_collections(self) -> tuple[str, ...]:
339
+ """Resolve which collections to include in output."""
340
+ if self._output_collections is not None:
341
+ return tuple(
342
+ name for name in self._source.collection_names if name in self._output_collections
343
+ )
344
+ return self._source.collection_names
345
+
346
+ # =========================================================================
347
+ # MATERIALIZATION (triggers data access)
348
+ # =========================================================================
349
+
350
+ def count(self) -> QueryCount:
351
+ """Count subjects and volumes matching the query without loading volume data."""
352
+ subject_mask = self._resolve_final_subject_mask()
353
+ output_collections = self._resolve_output_collections()
354
+
355
+ volume_counts: dict[str, int] = {}
356
+ for name in output_collections:
357
+ volume_mask = self._resolve_volume_mask(name, subject_mask)
358
+ volume_counts[name] = len(volume_mask)
359
+
360
+ return QueryCount(n_subjects=len(subject_mask), n_volumes=volume_counts)
361
+
362
+ def to_obs_meta(self) -> pd.DataFrame:
363
+ """Return filtered obs_meta DataFrame."""
364
+ subject_mask = self._resolve_final_subject_mask()
365
+ obs_meta = self._source.obs_meta.read()
366
+ return obs_meta[obs_meta["obs_subject_id"].isin(subject_mask)].reset_index(drop=True)
367
+
368
+ def iter_volumes(self, collection_name: str | None = None) -> Iterator[Volume]:
369
+ """Iterate over volumes matching the query.
370
+
371
+ Args:
372
+ collection_name: Specific collection to iterate. If None, iterates
373
+ over all output collections.
374
+ """
375
+ subject_mask = self._resolve_final_subject_mask()
376
+ output_collections = self._resolve_output_collections()
377
+
378
+ if collection_name is not None:
379
+ if collection_name not in output_collections:
380
+ raise ValueError(f"Collection '{collection_name}' not in query scope")
381
+ output_collections = (collection_name,)
382
+
383
+ for coll_name in output_collections:
384
+ vc = self._source.collection(coll_name)
385
+ volume_mask = self._resolve_volume_mask(coll_name, subject_mask)
386
+
387
+ for obs_id in volume_mask:
388
+ yield vc.loc[obs_id]
389
+
390
+ def iter_batches(self, batch_size: int = 32) -> Iterator[VolumeBatch]:
391
+ """Iterate over batched volumes for ML training.
392
+
393
+ Yields VolumeBatch with stacked numpy arrays for each collection.
394
+ Batches are grouped by subject.
395
+ """
396
+ subject_mask = self._resolve_final_subject_mask()
397
+ subject_list = sorted(subject_mask)
398
+ output_collections = self._resolve_output_collections()
399
+
400
+ for i in range(0, len(subject_list), batch_size):
401
+ batch_subjects = subject_list[i : i + batch_size]
402
+ batch_subject_set = frozenset(batch_subjects)
403
+
404
+ volumes: dict[str, npt.NDArray[np.floating]] = {}
405
+ obs_ids: dict[str, tuple[str, ...]] = {}
406
+
407
+ for coll_name in output_collections:
408
+ volume_mask = self._resolve_volume_mask(coll_name, batch_subject_set)
409
+ vc = self._source.collection(coll_name)
410
+
411
+ batch_arrays = []
412
+ batch_obs_ids = []
413
+ for obs_id in sorted(volume_mask):
414
+ vol = vc.loc[obs_id]
415
+ batch_arrays.append(vol.to_numpy())
416
+ batch_obs_ids.append(obs_id)
417
+
418
+ if batch_arrays:
419
+ volumes[coll_name] = np.stack(batch_arrays, axis=0)
420
+ obs_ids[coll_name] = tuple(batch_obs_ids)
421
+
422
+ yield VolumeBatch(
423
+ volumes=volumes,
424
+ subject_ids=tuple(batch_subjects),
425
+ obs_ids=obs_ids,
426
+ )
427
+
428
+ def materialize(
429
+ self,
430
+ uri: str,
431
+ streaming: bool = True,
432
+ ctx: tiledb.Ctx | None = None,
433
+ ) -> RadiObject:
434
+ """Materialize query results as a new RadiObject.
435
+
436
+ If a transform was set via map(), it is applied to each volume
437
+ during materialization. If the transform changes volume shapes,
438
+ the output collection becomes heterogeneous.
439
+
440
+ Args:
441
+ uri: Target URI for the new RadiObject
442
+ streaming: Use streaming writer for memory efficiency (default: True)
443
+ ctx: TileDB context
444
+ """
445
+ from radiobject.radi_object import RadiObject
446
+ from radiobject.streaming import RadiObjectWriter
447
+
448
+ subject_mask = self._resolve_final_subject_mask()
449
+ output_collections = self._resolve_output_collections()
450
+
451
+ # If no transform, create a RadiObject view and materialize it
452
+ if self._transform_fn is None:
453
+ view = RadiObject(
454
+ uri=None,
455
+ ctx=ctx,
456
+ _source=self._source,
457
+ _subject_ids=subject_mask,
458
+ _collection_names=frozenset(output_collections),
459
+ )
460
+ return view.materialize(uri, streaming=streaming, ctx=ctx)
461
+
462
+ # With transform: use streaming writer and apply transform to each volume
463
+ obs_meta_df = self._source.obs_meta.read()
464
+ filtered_obs_meta = obs_meta_df[
465
+ obs_meta_df["obs_subject_id"].isin(subject_mask)
466
+ ].reset_index(drop=True)
467
+
468
+ obs_meta_schema: dict[str, np.dtype] = {}
469
+ for col in filtered_obs_meta.columns:
470
+ if col in ("obs_subject_id", "obs_id"):
471
+ continue
472
+ dtype = filtered_obs_meta[col].to_numpy().dtype
473
+ if dtype == np.dtype("O"):
474
+ dtype = np.dtype("U64")
475
+ obs_meta_schema[col] = dtype
476
+
477
+ with RadiObjectWriter(uri, obs_meta_schema=obs_meta_schema, ctx=ctx) as writer:
478
+ writer.write_obs_meta(filtered_obs_meta)
479
+
480
+ for coll_name in output_collections:
481
+ src_collection = self._source.collection(coll_name)
482
+ volume_mask = self._resolve_volume_mask(coll_name, subject_mask)
483
+
484
+ if not volume_mask:
485
+ continue
486
+
487
+ obs_df = src_collection.obs.read()
488
+ filtered_obs = obs_df[obs_df["obs_id"].isin(volume_mask)]
489
+
490
+ obs_schema: dict[str, np.dtype] = {}
491
+ for col in src_collection.obs.columns:
492
+ if col in ("obs_id", "obs_subject_id"):
493
+ continue
494
+ obs_schema[col] = src_collection.obs.dtypes[col]
495
+
496
+ # Transform may change shape, so output is heterogeneous (shape=None)
497
+ with writer.add_collection(
498
+ coll_name, shape=None, obs_schema=obs_schema
499
+ ) as coll_writer:
500
+ for _, row in filtered_obs.iterrows():
501
+ obs_id = row["obs_id"]
502
+ vol = src_collection.loc[obs_id]
503
+ data = vol.to_numpy()
504
+
505
+ data = self._transform_fn(data)
506
+
507
+ attrs = {
508
+ k: v for k, v in row.items() if k not in ("obs_id", "obs_subject_id")
509
+ }
510
+ coll_writer.write_volume(
511
+ data=data,
512
+ obs_id=obs_id,
513
+ obs_subject_id=row["obs_subject_id"],
514
+ **attrs,
515
+ )
516
+
517
+ return RadiObject(uri, ctx=ctx)
518
+
519
+ def __len__(self) -> int:
520
+ """Number of subjects matching the query."""
521
+ return len(self._resolve_final_subject_mask())
522
+
523
+ def __repr__(self) -> str:
524
+ """Concise representation of the Query."""
525
+ count = self.count()
526
+ collections = ", ".join(count.n_volumes.keys())
527
+ return (
528
+ f"Query({count.n_subjects} subjects, "
529
+ f"{sum(count.n_volumes.values())} volumes across [{collections}])"
530
+ )
531
+
532
+
533
+ class CollectionQuery:
534
+ """Lazy filter builder for VolumeCollection with explicit materialization.
535
+
536
+ Example:
537
+ high_res = (
538
+ radi.T1w.lazy()
539
+ .filter("voxel_spacing == '1.0x1.0x1.0'")
540
+ .head(100)
541
+ .map(normalize)
542
+ .materialize("./output")
543
+ )
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ source: VolumeCollection,
549
+ *,
550
+ volume_ids: frozenset[str] | None = None,
551
+ volume_query: str | None = None,
552
+ subject_ids: frozenset[str] | None = None,
553
+ transform_fn: TransformFn | None = None,
554
+ ):
555
+ self._source = source
556
+ self._volume_ids = volume_ids
557
+ self._volume_query = volume_query
558
+ self._subject_ids = subject_ids
559
+ self._transform_fn = transform_fn
560
+
561
+ def _copy(self, **kwargs) -> CollectionQuery:
562
+ """Create a copy with modified fields."""
563
+ return CollectionQuery(
564
+ self._source,
565
+ volume_ids=kwargs.get("volume_ids", self._volume_ids),
566
+ volume_query=kwargs.get("volume_query", self._volume_query),
567
+ subject_ids=kwargs.get("subject_ids", self._subject_ids),
568
+ transform_fn=kwargs.get("transform_fn", self._transform_fn),
569
+ )
570
+
571
+ # =========================================================================
572
+ # FILTER METHODS
573
+ # =========================================================================
574
+
575
+ def filter(self, expr: str) -> CollectionQuery:
576
+ """Filter volumes using TileDB QueryCondition on obs."""
577
+ new_query = self._volume_query
578
+ if new_query is None:
579
+ new_query = expr
580
+ else:
581
+ new_query = f"({new_query}) and ({expr})"
582
+ return self._copy(volume_query=new_query)
583
+
584
+ def filter_subjects(self, ids: Sequence[str]) -> CollectionQuery:
585
+ """Filter to volumes belonging to specific subject IDs."""
586
+ new_ids = frozenset(ids)
587
+ if self._subject_ids is not None:
588
+ new_ids = self._subject_ids & new_ids
589
+ return self._copy(subject_ids=new_ids)
590
+
591
+ def iloc(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> CollectionQuery:
592
+ """Filter volumes by integer position(s)."""
593
+ n = len(self._source)
594
+ if isinstance(key, int):
595
+ idx = key if key >= 0 else n + key
596
+ obs_id = self._source._index.get_key(idx)
597
+ return self._copy(volume_ids=frozenset([obs_id]))
598
+ elif isinstance(key, slice):
599
+ indices = list(range(*key.indices(n)))
600
+ elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
601
+ indices = list(np.where(key)[0])
602
+ elif isinstance(key, list):
603
+ indices = [i if i >= 0 else n + i for i in key]
604
+ else:
605
+ raise TypeError("iloc key must be int, slice, list[int], or bool array")
606
+
607
+ obs_ids = frozenset(self._source._index.get_key(i) for i in indices)
608
+ new_ids = obs_ids
609
+ if self._volume_ids is not None:
610
+ new_ids = self._volume_ids & obs_ids
611
+ return self._copy(volume_ids=new_ids)
612
+
613
+ def loc(self, key: str | Sequence[str]) -> CollectionQuery:
614
+ """Filter volumes by obs_id(s)."""
615
+ if isinstance(key, str):
616
+ ids = frozenset([key])
617
+ else:
618
+ ids = frozenset(key)
619
+ new_ids = ids
620
+ if self._volume_ids is not None:
621
+ new_ids = self._volume_ids & ids
622
+ return self._copy(volume_ids=new_ids)
623
+
624
+ def head(self, n: int = 5) -> CollectionQuery:
625
+ """Filter to first n volumes."""
626
+ return self.iloc(slice(0, n))
627
+
628
+ def tail(self, n: int = 5) -> CollectionQuery:
629
+ """Filter to last n volumes."""
630
+ total = len(self._source)
631
+ return self.iloc(slice(max(0, total - n), total))
632
+
633
+ def sample(self, n: int = 5, seed: int | None = None) -> CollectionQuery:
634
+ """Filter to n randomly sampled volumes."""
635
+ rng = np.random.default_rng(seed)
636
+ resolved = list(self._resolve_volume_mask())
637
+ n = min(n, len(resolved))
638
+ sampled = rng.choice(resolved, size=n, replace=False)
639
+ return self._copy(volume_ids=frozenset(sampled))
640
+
641
+ # =========================================================================
642
+ # TRANSFORM (applied during materialization)
643
+ # =========================================================================
644
+
645
+ def map(self, fn: TransformFn) -> CollectionQuery:
646
+ """Apply transform to each volume during materialization.
647
+
648
+ Multiple map() calls compose: query.map(f1).map(f2) applies f1 then f2.
649
+
650
+ Args:
651
+ fn: Function (X, Y, Z) -> (X', Y', Z'). Can change shape.
652
+ """
653
+ if self._transform_fn is not None:
654
+ # Compose transforms: apply previous transform, then new one
655
+ prev_fn = self._transform_fn
656
+
657
+ def composed_fn(v: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
658
+ return fn(prev_fn(v))
659
+
660
+ return self._copy(transform_fn=composed_fn)
661
+ return self._copy(transform_fn=fn)
662
+
663
+ # =========================================================================
664
+ # MASK RESOLUTION
665
+ # =========================================================================
666
+
667
+ def _resolve_volume_mask(self) -> frozenset[str]:
668
+ """Resolve all filters to a set of obs_ids using TileDB-native queries."""
669
+ effective_ctx = self._source.obs._effective_ctx()
670
+
671
+ with tiledb.open(self._source.obs.uri, "r", ctx=effective_ctx) as arr:
672
+ # Build query based on filters
673
+ if self._subject_ids is not None and self._volume_query is not None:
674
+ # Both subject and attribute filters - use dimension slicing + QueryCondition
675
+ query = arr.query(cond=self._volume_query, dims=["obs_id"])
676
+ result = query.multi_index[list(self._subject_ids), :]
677
+ elif self._subject_ids is not None:
678
+ # Subject filter only - use dimension slicing
679
+ result = arr.query(dims=["obs_id"]).multi_index[list(self._subject_ids), :]
680
+ elif self._volume_query is not None:
681
+ # Attribute filter only - use QueryCondition
682
+ result = arr.query(cond=self._volume_query, dims=["obs_id"])[:]
683
+ else:
684
+ # No filters - return all obs_ids
685
+ result = arr.query(dims=["obs_id"], attrs=[])[:]
686
+
687
+ obs_ids = result["obs_id"]
688
+ all_ids = frozenset(v.decode() if isinstance(v, bytes) else str(v) for v in obs_ids)
689
+
690
+ # Apply explicit volume_ids filter (intersection with pre-specified IDs)
691
+ if self._volume_ids is not None:
692
+ all_ids &= self._volume_ids
693
+
694
+ return all_ids
695
+
696
+ # =========================================================================
697
+ # MATERIALIZATION
698
+ # =========================================================================
699
+
700
+ def count(self) -> int:
701
+ """Count volumes matching the query."""
702
+ return len(self._resolve_volume_mask())
703
+
704
+ def to_obs(self) -> pd.DataFrame:
705
+ """Return filtered obs DataFrame."""
706
+ volume_mask = self._resolve_volume_mask()
707
+ obs_df = self._source.obs.read()
708
+ return obs_df[obs_df["obs_id"].isin(volume_mask)].reset_index(drop=True)
709
+
710
+ def iter_volumes(self) -> Iterator[Volume]:
711
+ """Iterate over volumes matching the query."""
712
+ volume_mask = self._resolve_volume_mask()
713
+ for obs_id in sorted(volume_mask):
714
+ yield self._source.loc[obs_id]
715
+
716
+ def to_numpy_stack(self) -> npt.NDArray[np.floating]:
717
+ """Load all matching volumes as stacked numpy array (N, X, Y, Z)."""
718
+ volume_mask = self._resolve_volume_mask()
719
+ if not volume_mask:
720
+ raise ValueError("No volumes match the query")
721
+
722
+ arrays = [self._source.loc[obs_id].to_numpy() for obs_id in sorted(volume_mask)]
723
+ return np.stack(arrays, axis=0)
724
+
725
+ def materialize(
726
+ self,
727
+ uri: str,
728
+ name: str | None = None,
729
+ ctx: tiledb.Ctx | None = None,
730
+ ) -> VolumeCollection:
731
+ """Materialize query results as a new VolumeCollection.
732
+
733
+ If a transform was set via map(), it is applied to each volume
734
+ during materialization. If the transform changes volume shapes,
735
+ the output collection becomes heterogeneous (shape=None).
736
+ """
737
+ from radiobject.streaming import StreamingWriter
738
+
739
+ volume_mask = self._resolve_volume_mask()
740
+ if not volume_mask:
741
+ raise ValueError("No volumes match the query")
742
+
743
+ obs_df = self.to_obs()
744
+ collection_name = name or self._source.name
745
+
746
+ obs_schema: dict[str, np.dtype] = {}
747
+ for col in self._source.obs.columns:
748
+ if col in ("obs_id", "obs_subject_id"):
749
+ continue
750
+ obs_schema[col] = self._source.obs.dtypes[col]
751
+
752
+ output_shape = None if self._transform_fn is not None else self._source.shape
753
+
754
+ with StreamingWriter(
755
+ uri=uri,
756
+ shape=output_shape,
757
+ obs_schema=obs_schema,
758
+ name=collection_name,
759
+ ctx=ctx,
760
+ ) as writer:
761
+ for obs_id in sorted(volume_mask):
762
+ vol = self._source.loc[obs_id]
763
+ data = vol.to_numpy()
764
+
765
+ if self._transform_fn is not None:
766
+ data = self._transform_fn(data)
767
+
768
+ obs_row = obs_df[obs_df["obs_id"] == obs_id].iloc[0]
769
+ attrs = {k: v for k, v in obs_row.items() if k not in ("obs_id", "obs_subject_id")}
770
+ writer.write_volume(
771
+ data=data,
772
+ obs_id=obs_id,
773
+ obs_subject_id=obs_row["obs_subject_id"],
774
+ **attrs,
775
+ )
776
+
777
+ return self._source.__class__(uri, ctx=ctx)
778
+
779
+ def __len__(self) -> int:
780
+ """Number of volumes matching the query."""
781
+ return len(self._resolve_volume_mask())
782
+
783
+ def __repr__(self) -> str:
784
+ """Concise representation of the CollectionQuery."""
785
+ n = len(self)
786
+ name = self._source.name or "unnamed"
787
+ shape = "x".join(str(d) for d in self._source.shape)
788
+ return f"CollectionQuery('{name}', {n} volumes, shape={shape})"