radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1665 @@
1
+ """RadiObject - top-level container for multi-collection radiology data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+ from collections.abc import Iterator
7
+ from functools import cached_property
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Sequence, overload
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ import pandas as pd
14
+ import tiledb
15
+
16
+ from radiobject.ctx import ctx as global_ctx
17
+ from radiobject.dataframe import Dataframe
18
+ from radiobject.imaging_metadata import (
19
+ extract_dicom_metadata,
20
+ extract_nifti_metadata,
21
+ infer_series_type,
22
+ )
23
+ from radiobject.indexing import Index
24
+ from radiobject.parallel import WriteResult, create_worker_ctx
25
+ from radiobject.volume import Volume
26
+ from radiobject.volume_collection import (
27
+ VolumeCollection,
28
+ _normalize_index,
29
+ _write_volumes_parallel,
30
+ )
31
+
32
+ if TYPE_CHECKING:
33
+ from radiobject.query import Query
34
+
35
+
36
+ class _SubjectILocIndexer:
37
+ """Integer-location based indexer for RadiObject subjects."""
38
+
39
+ def __init__(self, radi_object: RadiObject):
40
+ self._radi_object = radi_object
41
+
42
+ @overload
43
+ def __getitem__(self, key: int) -> RadiObject: ...
44
+ @overload
45
+ def __getitem__(self, key: slice) -> RadiObject: ...
46
+ @overload
47
+ def __getitem__(self, key: list[int]) -> RadiObject: ...
48
+ @overload
49
+ def __getitem__(self, key: npt.NDArray[np.bool_]) -> RadiObject: ...
50
+
51
+ def __getitem__(self, key: int | slice | list[int] | npt.NDArray[np.bool_]) -> RadiObject:
52
+ """Returns a RadiObject view filtered to selected subject indices."""
53
+ n = len(self._radi_object)
54
+ if isinstance(key, int):
55
+ idx = _normalize_index(key, n)
56
+ return self._radi_object._filter_by_indices([idx])
57
+ elif isinstance(key, slice):
58
+ indices = list(range(*key.indices(n)))
59
+ return self._radi_object._filter_by_indices(indices)
60
+ elif isinstance(key, np.ndarray) and key.dtype == np.bool_:
61
+ if len(key) != n:
62
+ raise ValueError(f"Boolean mask length {len(key)} != subject count {n}")
63
+ indices = list(np.where(key)[0])
64
+ return self._radi_object._filter_by_indices(indices)
65
+ elif isinstance(key, list):
66
+ indices = [_normalize_index(i, n) for i in key]
67
+ return self._radi_object._filter_by_indices(indices)
68
+ raise TypeError(
69
+ f"iloc indices must be int, slice, list[int], or boolean array, got {type(key)}"
70
+ )
71
+
72
+
73
+ class _SubjectLocIndexer:
74
+ """Label-based indexer for RadiObject subjects."""
75
+
76
+ def __init__(self, radi_object: RadiObject):
77
+ self._radi_object = radi_object
78
+
79
+ @overload
80
+ def __getitem__(self, key: str) -> RadiObject: ...
81
+ @overload
82
+ def __getitem__(self, key: list[str]) -> RadiObject: ...
83
+
84
+ def __getitem__(self, key: str | list[str]) -> RadiObject:
85
+ """Returns a RadiObject view filtered to selected obs_subject_ids."""
86
+ if isinstance(key, str):
87
+ return self._radi_object._filter_by_subject_ids([key])
88
+ elif isinstance(key, list):
89
+ return self._radi_object._filter_by_subject_ids(key)
90
+ raise TypeError(f"loc indices must be str or list[str], got {type(key)}")
91
+
92
+
93
+ class RadiObject:
94
+ """Top-level container for multi-collection radiology data with subject metadata.
95
+
96
+ RadiObject can be either "attached" (backed by storage at a URI) or a "view"
97
+ (filtered subset referencing a source RadiObject). Views are created by
98
+ filtering operations and read data from their source with filters applied.
99
+
100
+ Attached (has URI):
101
+ radi = RadiObject("s3://bucket/dataset")
102
+ radi.is_view # False
103
+ radi.uri # "s3://bucket/dataset"
104
+
105
+ View (filtered, no URI):
106
+ subset = radi.filter("age > 40")
107
+ subset.is_view # True
108
+ subset.uri # None
109
+ subset._root # Original RadiObject
110
+
111
+ Views are immutable. To persist a view, use materialize(uri).
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ uri: str | None,
117
+ ctx: tiledb.Ctx | None = None,
118
+ *,
119
+ # View state (internal use only)
120
+ _source: RadiObject | None = None,
121
+ _subject_ids: frozenset[str] | None = None,
122
+ _collection_names: frozenset[str] | None = None,
123
+ ):
124
+ self._uri: str | None = uri
125
+ self._ctx: tiledb.Ctx | None = ctx
126
+ # View state
127
+ self._source: RadiObject | None = _source
128
+ self._subject_ids: frozenset[str] | None = _subject_ids
129
+ self._collection_names_filter: frozenset[str] | None = _collection_names
130
+
131
+ @property
132
+ def uri(self) -> str | None:
133
+ """URI of this RadiObject, or None if this is a view."""
134
+ return self._uri
135
+
136
+ @property
137
+ def is_view(self) -> bool:
138
+ """True if this RadiObject is a filtered view of another."""
139
+ return self._source is not None
140
+
141
+ @property
142
+ def _root(self) -> RadiObject:
143
+ """The original attached RadiObject (follows source chain)."""
144
+ if self._source is None:
145
+ return self
146
+ return self._source._root
147
+
148
+ def _effective_ctx(self) -> tiledb.Ctx:
149
+ if self._source is not None:
150
+ return self._source._effective_ctx()
151
+ return self._ctx if self._ctx else global_ctx()
152
+
153
+ def _effective_uri(self) -> str:
154
+ """Get the storage URI (from root if this is a view)."""
155
+ if self._source is not None:
156
+ return self._source._effective_uri()
157
+ if self._uri is None:
158
+ raise ValueError("RadiObject has no URI")
159
+ return self._uri
160
+
161
+ # ===== View Factory =====
162
+
163
+ def _create_view(
164
+ self,
165
+ subject_ids: frozenset[str] | None = None,
166
+ collection_names: frozenset[str] | None = None,
167
+ ) -> RadiObject:
168
+ """Create a view with specified filters, intersecting with current filters."""
169
+ # Intersect subject_ids with current filter
170
+ if self._subject_ids is not None and subject_ids is not None:
171
+ subject_ids = self._subject_ids & subject_ids
172
+ elif self._subject_ids is not None:
173
+ subject_ids = self._subject_ids
174
+ # subject_ids stays as passed if self._subject_ids is None
175
+
176
+ # Intersect collection_names with current filter
177
+ if self._collection_names_filter is not None and collection_names is not None:
178
+ collection_names = self._collection_names_filter & collection_names
179
+ elif self._collection_names_filter is not None:
180
+ collection_names = self._collection_names_filter
181
+ # collection_names stays as passed if self._collection_names_filter is None
182
+
183
+ return RadiObject(
184
+ uri=None,
185
+ ctx=self._ctx,
186
+ _source=self._root, # Always point to root to avoid deep chains
187
+ _subject_ids=subject_ids,
188
+ _collection_names=collection_names,
189
+ )
190
+
191
+ # ===== Subject Indexing =====
192
+
193
+ @cached_property
194
+ def iloc(self) -> _SubjectILocIndexer:
195
+ """Integer-location based indexing for selecting subjects by position."""
196
+ return _SubjectILocIndexer(self)
197
+
198
+ @cached_property
199
+ def loc(self) -> _SubjectLocIndexer:
200
+ """Label-based indexing for selecting subjects by obs_subject_id."""
201
+ return _SubjectLocIndexer(self)
202
+
203
+ # ===== ObsMeta (Subject Metadata) =====
204
+
205
+ @property
206
+ def obs_meta(self) -> pd.DataFrame | Dataframe:
207
+ """Subject-level observational metadata.
208
+
209
+ Returns Dataframe for attached RadiObject, pd.DataFrame for views.
210
+ """
211
+ if self.is_view:
212
+ # Return filtered DataFrame
213
+ full_obs_meta = self._root.obs_meta.read()
214
+ if self._subject_ids is not None:
215
+ return full_obs_meta[
216
+ full_obs_meta["obs_subject_id"].isin(self._subject_ids)
217
+ ].reset_index(drop=True)
218
+ return full_obs_meta
219
+ obs_meta_uri = f"{self._effective_uri()}/obs_meta"
220
+ return Dataframe(uri=obs_meta_uri, ctx=self._ctx)
221
+
222
+ @cached_property
223
+ def _index(self) -> Index:
224
+ """Cached bidirectional index for obs_subject_id lookups."""
225
+ if self.is_view:
226
+ # Build index from filtered subject_ids
227
+ if self._subject_ids is not None:
228
+ # Preserve order from root
229
+ root_ids = self._root.obs_subject_ids
230
+ filtered = [sid for sid in root_ids if sid in self._subject_ids]
231
+ return Index.build(filtered)
232
+ return self._root._index
233
+ n = self._metadata.get("subject_count", 0)
234
+ if n == 0:
235
+ return Index.build([])
236
+ # Only load the index column for efficiency
237
+ data = self.obs_meta.read(columns=["obs_subject_id"])
238
+ return Index.build(list(data["obs_subject_id"]))
239
+
240
+ @property
241
+ def index(self) -> Index:
242
+ """Subject index for bidirectional ID/position lookups."""
243
+ return self._index
244
+
245
+ @property
246
+ def obs_subject_ids(self) -> list[str]:
247
+ """All obs_subject_id values in index order."""
248
+ return list(self._index.keys)
249
+
250
+ def get_obs_row_by_obs_subject_id(self, obs_subject_id: str) -> pd.DataFrame:
251
+ """Get obs_meta row by obs_subject_id string identifier."""
252
+ if self.is_view:
253
+ obs_meta_df = self.obs_meta
254
+ return obs_meta_df[obs_meta_df["obs_subject_id"] == obs_subject_id].reset_index(
255
+ drop=True
256
+ )
257
+ df = self.obs_meta.read()
258
+ filtered = df[df["obs_subject_id"] == obs_subject_id].reset_index(drop=True)
259
+ return filtered
260
+
261
+ # ===== Volume Access Across Collections =====
262
+
263
+ @cached_property
264
+ def all_obs_ids(self) -> list[str]:
265
+ """All obs_ids across all collections (for uniqueness checks)."""
266
+ obs_ids = []
267
+ for name in self.collection_names:
268
+ obs_ids.extend(self.collection(name).obs_ids)
269
+ return obs_ids
270
+
271
+ def get_volume(self, obs_id: str) -> Volume:
272
+ """Get a volume by obs_id from any collection."""
273
+ for name in self.collection_names:
274
+ coll = self.collection(name)
275
+ if obs_id in coll.index:
276
+ return coll.loc[obs_id]
277
+ raise KeyError(f"obs_id '{obs_id}' not found in any collection")
278
+
279
+ # ===== VolumeCollections =====
280
+
281
+ @cached_property
282
+ def _metadata(self) -> dict:
283
+ """Cached group metadata."""
284
+ uri = self._effective_uri()
285
+ with tiledb.Group(uri, "r", ctx=self._effective_ctx()) as grp:
286
+ return dict(grp.meta)
287
+
288
+ @cached_property
289
+ def collection_names(self) -> tuple[str, ...]:
290
+ """Names of all VolumeCollections."""
291
+ if self.is_view and self._collection_names_filter is not None:
292
+ # Return filtered collection names (preserving root order)
293
+ root_names = self._root.collection_names
294
+ return tuple(name for name in root_names if name in self._collection_names_filter)
295
+ uri = self._effective_uri()
296
+ collections_uri = f"{uri}/collections"
297
+ with tiledb.Group(collections_uri, "r", ctx=self._effective_ctx()) as grp:
298
+ return tuple(obj.name for obj in grp)
299
+
300
+ def collection(self, name: str) -> VolumeCollection:
301
+ """Get a VolumeCollection by name."""
302
+ if name not in self.collection_names:
303
+ raise KeyError(f"Collection '{name}' not found. Available: {self.collection_names}")
304
+ uri = self._effective_uri()
305
+ collection_uri = f"{uri}/collections/{name}"
306
+ return VolumeCollection(collection_uri, ctx=self._ctx)
307
+
308
+ def __getattr__(self, name: str) -> VolumeCollection:
309
+ """Attribute access to collections (e.g., radi.T1w)."""
310
+ if name.startswith("_"):
311
+ raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
312
+ try:
313
+ return self.collection(name)
314
+ except KeyError:
315
+ raise AttributeError(f"'{type(self).__name__}' has no collection '{name}'")
316
+
317
+ def rename_collection(self, old_name: str, new_name: str) -> None:
318
+ """Rename a collection."""
319
+ self._check_not_view("rename_collection")
320
+ if old_name not in self.collection_names:
321
+ raise KeyError(f"Collection '{old_name}' not found")
322
+ if new_name in self.collection_names:
323
+ raise ValueError(f"Collection '{new_name}' already exists")
324
+
325
+ effective_ctx = self._effective_ctx()
326
+ uri = self._effective_uri()
327
+ collections_uri = f"{uri}/collections"
328
+ old_uri = f"{collections_uri}/{old_name}"
329
+
330
+ with tiledb.Group(old_uri, "w", ctx=effective_ctx) as grp:
331
+ grp.meta["name"] = new_name
332
+
333
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
334
+ grp.remove(old_name)
335
+ grp.add(old_uri, name=new_name)
336
+
337
+ if "collection_names" in self.__dict__:
338
+ del self.__dict__["collection_names"]
339
+
340
+ # ===== Length / Iteration =====
341
+
342
+ def __len__(self) -> int:
343
+ """Number of subjects."""
344
+ if self.is_view:
345
+ return len(self._index)
346
+ return int(self._metadata.get("subject_count", 0))
347
+
348
+ @property
349
+ def n_collections(self) -> int:
350
+ """Number of VolumeCollections."""
351
+ return len(self.collection_names)
352
+
353
+ def __iter__(self) -> Iterator[str]:
354
+ """Iterate over collection names."""
355
+ return iter(self.collection_names)
356
+
357
+ @overload
358
+ def __getitem__(self, key: str) -> RadiObject: ...
359
+ @overload
360
+ def __getitem__(self, key: list[str]) -> RadiObject: ...
361
+
362
+ def __getitem__(self, key: str | list[str]) -> RadiObject:
363
+ """Bracket indexing for subjects by obs_subject_id.
364
+
365
+ Alias for .loc[] - allows `radi["BraTS001"]` as shorthand for `radi.loc["BraTS001"]`.
366
+ """
367
+ return self.loc[key]
368
+
369
+ def __repr__(self) -> str:
370
+ """Concise representation of the RadiObject."""
371
+ collections = ", ".join(self.collection_names) if self.collection_names else "none"
372
+ view_indicator = " (view)" if self.is_view else ""
373
+ return (
374
+ f"RadiObject({len(self)} subjects, {self.n_collections} collections: "
375
+ f"[{collections}]){view_indicator}"
376
+ )
377
+
378
+ def describe(self) -> str:
379
+ """Return a summary: subjects, collections, shapes, and label distributions."""
380
+ lines = [
381
+ "RadiObject Summary",
382
+ "==================",
383
+ f"URI: {self.uri or '(view)'}",
384
+ f"Subjects: {len(self)}",
385
+ f"Collections: {self.n_collections}",
386
+ "",
387
+ "Collections:",
388
+ ]
389
+
390
+ for name in self.collection_names:
391
+ coll = self.collection(name)
392
+ shape = coll.shape
393
+ shape_str = "x".join(str(d) for d in shape) if shape else "heterogeneous"
394
+ uniform_str = "" if coll.is_uniform else " (mixed shapes)"
395
+ lines.append(f" - {name}: {len(coll)} volumes, shape={shape_str}{uniform_str}")
396
+
397
+ # Find label columns
398
+ obs_meta_df = self.obs_meta if self.is_view else self.obs_meta.read()
399
+ label_cols = []
400
+ for col in obs_meta_df.columns:
401
+ if col in ("obs_subject_id", "obs_id"):
402
+ continue
403
+ dtype = obs_meta_df[col].dtype
404
+ if dtype in (np.int64, np.int32, np.float64, np.float32, object):
405
+ n_unique = obs_meta_df[col].nunique()
406
+ if n_unique <= 10:
407
+ label_cols.append(col)
408
+
409
+ if label_cols:
410
+ lines.append("")
411
+ lines.append("Label Columns:")
412
+ for col in label_cols:
413
+ value_counts = obs_meta_df[col].value_counts().to_dict()
414
+ lines.append(f" - {col}: {value_counts}")
415
+
416
+ return "\n".join(lines)
417
+
418
+ # ===== Lazy Mode (returns Query for transform pipelines) =====
419
+
420
+ def lazy(self) -> Query:
421
+ """Enter lazy mode for transform pipelines.
422
+
423
+ Returns a Query that accumulates transforms without executing them.
424
+ Use this when you need to apply transforms via .map().
425
+
426
+ Example:
427
+ normalized = (
428
+ radi.CT
429
+ .lazy()
430
+ .filter("quality == 'good'")
431
+ .map(normalize_intensity)
432
+ .materialize("./normalized")
433
+ )
434
+ """
435
+ from radiobject.query import Query
436
+
437
+ return Query(
438
+ self._root,
439
+ subject_ids=self._subject_ids,
440
+ output_collections=self._collection_names_filter,
441
+ )
442
+
443
+ # ===== Immutability Check =====
444
+
445
+ def _check_not_view(self, operation: str) -> None:
446
+ """Raise if attempting to modify a view."""
447
+ if self.is_view:
448
+ raise ValueError(
449
+ f"Cannot {operation} on a view. Call materialize(uri) first to create "
450
+ "an attached RadiObject."
451
+ )
452
+
453
+ # ===== Filtering (returns RadiObject view) =====
454
+
455
+ def _filter_by_indices(self, indices: list[int]) -> RadiObject:
456
+ """Create a view filtered to specific subject indices."""
457
+ subject_ids = frozenset(self._index.get_key(i) for i in indices)
458
+ return self._create_view(subject_ids=subject_ids)
459
+
460
+ def _filter_by_subject_ids(self, obs_subject_ids: list[str]) -> RadiObject:
461
+ """Create a view filtered to specific obs_subject_ids."""
462
+ current_ids = set(self._index.keys)
463
+ for sid in obs_subject_ids:
464
+ if sid not in current_ids:
465
+ raise KeyError(f"obs_subject_id '{sid}' not found")
466
+ return self._create_view(subject_ids=frozenset(obs_subject_ids))
467
+
468
+ def select_collections(self, names: list[str]) -> RadiObject:
469
+ """Create a view with only specified collections."""
470
+ current_names = set(self.collection_names)
471
+ for name in names:
472
+ if name not in current_names:
473
+ raise KeyError(f"Collection '{name}' not found")
474
+ return self._create_view(collection_names=frozenset(names))
475
+
476
+ def filter(self, expr: str) -> RadiObject:
477
+ """Filter subjects using a query expression on obs_meta.
478
+
479
+ Args:
480
+ expr: TileDB QueryCondition string (e.g., "tumor_grade == 'HGG' and age > 40")
481
+
482
+ Returns:
483
+ RadiObject view filtered to matching subjects
484
+ """
485
+ if self.is_view:
486
+ # Filter from the obs_meta DataFrame
487
+ obs_meta_df = self.obs_meta
488
+ # Use pandas query for view filtering
489
+ filtered = obs_meta_df.query(expr)
490
+ subject_ids = frozenset(filtered["obs_subject_id"])
491
+ else:
492
+ # Use TileDB QueryCondition for attached RadiObject
493
+ filtered = self.obs_meta.read(value_filter=expr)
494
+ subject_ids = frozenset(filtered["obs_subject_id"])
495
+ return self._create_view(subject_ids=subject_ids)
496
+
497
+ def head(self, n: int = 5) -> RadiObject:
498
+ """Return view of first n subjects."""
499
+ n = min(n, len(self))
500
+ return self._filter_by_indices(list(range(n)))
501
+
502
+ def tail(self, n: int = 5) -> RadiObject:
503
+ """Return view of last n subjects."""
504
+ total = len(self)
505
+ n = min(n, total)
506
+ return self._filter_by_indices(list(range(total - n, total)))
507
+
508
+ def sample(self, n: int = 5, seed: int | None = None) -> RadiObject:
509
+ """Return view of n randomly sampled subjects."""
510
+ rng = np.random.default_rng(seed)
511
+ total = len(self)
512
+ n = min(n, total)
513
+ indices = list(rng.choice(total, size=n, replace=False))
514
+ return self._filter_by_indices(sorted(indices))
515
+
516
+ # ===== Materialization =====
517
+
518
+ def materialize(
519
+ self,
520
+ uri: str,
521
+ streaming: bool = True,
522
+ ctx: tiledb.Ctx | None = None,
523
+ ) -> RadiObject:
524
+ """Write this RadiObject (or view) to storage.
525
+
526
+ For attached RadiObjects, this copies the entire dataset.
527
+ For views, this writes only the filtered subset.
528
+
529
+ Args:
530
+ uri: Target URI for the new RadiObject
531
+ streaming: Use streaming writer for memory efficiency (default: True)
532
+ ctx: TileDB context
533
+
534
+ Returns:
535
+ New attached RadiObject at the target URI
536
+ """
537
+ # Get filtered obs_meta
538
+ if self.is_view:
539
+ filtered_obs_meta = self.obs_meta # Already filtered DataFrame
540
+ else:
541
+ filtered_obs_meta = self.obs_meta.read()
542
+
543
+ # Build obs_meta schema
544
+ obs_meta_schema: dict[str, np.dtype] = {}
545
+ for col in filtered_obs_meta.columns:
546
+ if col in ("obs_subject_id", "obs_id"):
547
+ continue
548
+ dtype = filtered_obs_meta[col].to_numpy().dtype
549
+ if dtype == np.dtype("O"):
550
+ dtype = np.dtype("U64")
551
+ obs_meta_schema[col] = dtype
552
+
553
+ if streaming:
554
+ return self._materialize_streaming(uri, filtered_obs_meta, obs_meta_schema, ctx)
555
+ return self._materialize_batch(uri, filtered_obs_meta, obs_meta_schema, ctx)
556
+
557
+ def _materialize_streaming(
558
+ self,
559
+ uri: str,
560
+ obs_meta_df: pd.DataFrame,
561
+ obs_meta_schema: dict[str, np.dtype],
562
+ ctx: tiledb.Ctx | None,
563
+ ) -> RadiObject:
564
+ """Materialize view to storage using streaming writer."""
565
+ from radiobject.streaming import RadiObjectWriter
566
+
567
+ subject_ids = set(obs_meta_df["obs_subject_id"])
568
+
569
+ with RadiObjectWriter(uri, obs_meta_schema=obs_meta_schema, ctx=ctx) as writer:
570
+ writer.write_obs_meta(obs_meta_df)
571
+
572
+ for coll_name in self.collection_names:
573
+ src_collection = self.collection(coll_name)
574
+ obs_df = src_collection.obs.read()
575
+ filtered_obs = obs_df[obs_df["obs_subject_id"].isin(subject_ids)]
576
+
577
+ if len(filtered_obs) == 0:
578
+ continue
579
+
580
+ # Extract obs schema
581
+ obs_schema: dict[str, np.dtype] = {}
582
+ for col in src_collection.obs.columns:
583
+ if col in ("obs_id", "obs_subject_id"):
584
+ continue
585
+ obs_schema[col] = src_collection.obs.dtypes[col]
586
+
587
+ with writer.add_collection(
588
+ coll_name, src_collection.shape, obs_schema
589
+ ) as coll_writer:
590
+ for _, row in filtered_obs.iterrows():
591
+ obs_id = row["obs_id"]
592
+ vol = src_collection.loc[obs_id]
593
+ attrs = {
594
+ k: v for k, v in row.items() if k not in ("obs_id", "obs_subject_id")
595
+ }
596
+ coll_writer.write_volume(
597
+ data=vol.to_numpy(),
598
+ obs_id=obs_id,
599
+ obs_subject_id=row["obs_subject_id"],
600
+ **attrs,
601
+ )
602
+
603
+ return RadiObject(uri, ctx=ctx)
604
+
605
+ def _materialize_batch(
606
+ self,
607
+ uri: str,
608
+ obs_meta_df: pd.DataFrame,
609
+ obs_meta_schema: dict[str, np.dtype],
610
+ ctx: tiledb.Ctx | None,
611
+ ) -> RadiObject:
612
+ """Materialize view to storage using batch writer."""
613
+ effective_ctx = ctx if ctx else self._effective_ctx()
614
+ subject_ids = list(obs_meta_df["obs_subject_id"])
615
+
616
+ RadiObject._create(
617
+ uri,
618
+ obs_meta_schema=obs_meta_schema,
619
+ n_subjects=len(subject_ids),
620
+ ctx=ctx,
621
+ )
622
+
623
+ # Write obs_meta
624
+ obs_meta_uri = f"{uri}/obs_meta"
625
+ obs_subject_ids_arr = obs_meta_df["obs_subject_id"].astype(str).to_numpy()
626
+ obs_ids_arr = (
627
+ obs_meta_df["obs_id"].astype(str).to_numpy()
628
+ if "obs_id" in obs_meta_df.columns
629
+ else obs_subject_ids_arr
630
+ )
631
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
632
+ attr_data = {
633
+ col: obs_meta_df[col].to_numpy()
634
+ for col in obs_meta_df.columns
635
+ if col not in ("obs_subject_id", "obs_id")
636
+ }
637
+ arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
638
+
639
+ # Copy collections
640
+ collections_uri = f"{uri}/collections"
641
+ for coll_name in self.collection_names:
642
+ src_collection = self.collection(coll_name)
643
+ new_vc_uri = f"{collections_uri}/{coll_name}"
644
+
645
+ _copy_filtered_volume_collection(
646
+ src_collection,
647
+ new_vc_uri,
648
+ obs_subject_ids=subject_ids,
649
+ name=coll_name,
650
+ ctx=ctx,
651
+ )
652
+
653
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
654
+ grp.add(new_vc_uri, name=coll_name)
655
+
656
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
657
+ grp.meta["n_collections"] = len(self.collection_names)
658
+ grp.meta["subject_count"] = len(subject_ids)
659
+
660
+ return RadiObject(uri, ctx=ctx)
661
+
662
+ def copy(self) -> RadiObject:
663
+ """Create an independent in-memory copy, detached from the view chain.
664
+
665
+ Useful when you want to break the reference to the source RadiObject.
666
+ Note: This does NOT persist data. Call materialize(uri) to write to storage.
667
+ """
668
+ if not self.is_view:
669
+ # For attached RadiObject, just return self (already independent)
670
+ return self
671
+ # Create a new view with the same filters but mark it as "detached"
672
+ # In practice, since we always point to _root, this is already independent
673
+ return RadiObject(
674
+ uri=None,
675
+ ctx=self._ctx,
676
+ _source=self._root,
677
+ _subject_ids=self._subject_ids,
678
+ _collection_names=self._collection_names_filter,
679
+ )
680
+
681
+ # ===== Append Operations (Mutations) =====
682
+
683
+ def append(
684
+ self,
685
+ niftis: Sequence[tuple[str | Path, str]] | None = None,
686
+ dicom_dirs: Sequence[tuple[str | Path, str]] | None = None,
687
+ obs_meta: pd.DataFrame | None = None,
688
+ reorient: bool | None = None,
689
+ progress: bool = False,
690
+ ) -> None:
691
+ """Append new subjects and their volumes atomically."""
692
+ self._check_not_view("append")
693
+
694
+ if niftis is None and dicom_dirs is None:
695
+ raise ValueError("Must provide either niftis or dicom_dirs")
696
+ if niftis is not None and dicom_dirs is not None:
697
+ raise ValueError("Cannot provide both niftis and dicom_dirs")
698
+
699
+ effective_ctx = self._effective_ctx()
700
+ uri = self._effective_uri()
701
+
702
+ # Collect all subject IDs from input
703
+ if niftis is not None:
704
+ input_subject_ids = {sid for _, sid in niftis}
705
+ else:
706
+ input_subject_ids = {sid for _, sid in dicom_dirs}
707
+
708
+ existing_subject_ids = set(self.obs_subject_ids)
709
+ new_subject_ids = input_subject_ids - existing_subject_ids
710
+
711
+ # Validate obs_meta
712
+ if new_subject_ids:
713
+ if obs_meta is None:
714
+ raise ValueError(
715
+ f"obs_meta required for new subjects: {sorted(new_subject_ids)[:5]}"
716
+ )
717
+ if "obs_subject_id" not in obs_meta.columns:
718
+ raise ValueError("obs_meta must contain 'obs_subject_id' column")
719
+ obs_meta_ids = set(obs_meta["obs_subject_id"])
720
+ missing = new_subject_ids - obs_meta_ids
721
+ if missing:
722
+ raise ValueError(f"obs_meta missing entries for: {sorted(missing)[:5]}")
723
+ obs_meta = obs_meta[obs_meta["obs_subject_id"].isin(new_subject_ids)]
724
+
725
+ # Append obs_meta for new subjects
726
+ if obs_meta is not None and len(obs_meta) > 0:
727
+ obs_meta_uri = f"{uri}/obs_meta"
728
+ obs_subject_ids_arr = obs_meta["obs_subject_id"].astype(str).to_numpy()
729
+ obs_ids_arr = (
730
+ obs_meta["obs_id"].astype(str).to_numpy()
731
+ if "obs_id" in obs_meta.columns
732
+ else obs_subject_ids_arr
733
+ )
734
+ existing_columns = set(self._root.obs_meta.columns)
735
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
736
+ attr_data = {
737
+ col: obs_meta[col].to_numpy()
738
+ for col in obs_meta.columns
739
+ if col not in ("obs_subject_id", "obs_id") and col in existing_columns
740
+ }
741
+ arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
742
+
743
+ new_count = len(self) + len(obs_meta)
744
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
745
+ grp.meta["subject_count"] = new_count
746
+
747
+ # Process and group input files
748
+ if niftis is not None:
749
+ self._append_niftis(niftis, reorient, effective_ctx, progress)
750
+ else:
751
+ self._append_dicoms(dicom_dirs, reorient, effective_ctx, progress)
752
+
753
+ # Invalidate cached properties
754
+ for prop in ("_index", "_metadata", "collection_names"):
755
+ if prop in self.__dict__:
756
+ del self.__dict__[prop]
757
+
758
+ def _append_niftis(
759
+ self,
760
+ niftis: Sequence[tuple[str | Path, str]],
761
+ reorient: bool | None,
762
+ effective_ctx: tiledb.Ctx,
763
+ progress: bool = False,
764
+ ) -> None:
765
+ """Internal: append NIfTI files to existing collections or create new ones."""
766
+ uri = self._effective_uri()
767
+ file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
768
+ for nifti_path, obs_subject_id in niftis:
769
+ path = Path(nifti_path)
770
+ if not path.exists():
771
+ raise FileNotFoundError(f"NIfTI file not found: {path}")
772
+ metadata = extract_nifti_metadata(path)
773
+ series_type = infer_series_type(path)
774
+ file_info.append((path, obs_subject_id, metadata.dimensions, series_type))
775
+
776
+ groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
777
+ for path, subject_id, shape, series_type in file_info:
778
+ groups[(shape, series_type)].append((path, subject_id))
779
+
780
+ collections_uri = f"{uri}/collections"
781
+ existing_collections = set(self.collection_names)
782
+
783
+ groups_iter = groups.items()
784
+ if progress:
785
+ from tqdm.auto import tqdm
786
+
787
+ groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
788
+
789
+ for (shape, series_type), items in groups_iter:
790
+ collection_name = series_type
791
+ if collection_name in existing_collections:
792
+ vc = self.collection(collection_name)
793
+ if vc.shape != shape:
794
+ collection_name = f"{series_type}_{shape[0]}x{shape[1]}x{shape[2]}"
795
+
796
+ if collection_name in existing_collections:
797
+ vc = self.collection(collection_name)
798
+ vc.append(niftis=items, reorient=reorient, progress=progress)
799
+ else:
800
+ vc_uri = f"{collections_uri}/{collection_name}"
801
+ VolumeCollection.from_niftis(
802
+ uri=vc_uri,
803
+ niftis=items,
804
+ reorient=reorient,
805
+ validate_dimensions=True,
806
+ name=collection_name,
807
+ ctx=self._ctx,
808
+ progress=progress,
809
+ )
810
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
811
+ grp.add(vc_uri, name=collection_name)
812
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
813
+ grp.meta["n_collections"] = self.n_collections + 1
814
+ existing_collections.add(collection_name)
815
+
816
+ def _append_dicoms(
817
+ self,
818
+ dicom_dirs: Sequence[tuple[str | Path, str]],
819
+ reorient: bool | None,
820
+ effective_ctx: tiledb.Ctx,
821
+ progress: bool = False,
822
+ ) -> None:
823
+ """Internal: append DICOM series to existing collections or create new ones."""
824
+ uri = self._effective_uri()
825
+ file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
826
+ for dicom_dir, obs_subject_id in dicom_dirs:
827
+ path = Path(dicom_dir)
828
+ if not path.exists():
829
+ raise FileNotFoundError(f"DICOM directory not found: {path}")
830
+ metadata = extract_dicom_metadata(path)
831
+ dims = metadata.dimensions
832
+ shape = (dims[1], dims[0], dims[2])
833
+ file_info.append((path, obs_subject_id, shape, metadata.modality))
834
+
835
+ groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
836
+ for path, subject_id, shape, modality in file_info:
837
+ groups[(shape, modality)].append((path, subject_id))
838
+
839
+ collections_uri = f"{uri}/collections"
840
+ existing_collections = set(self.collection_names)
841
+
842
+ groups_iter = groups.items()
843
+ if progress:
844
+ from tqdm.auto import tqdm
845
+
846
+ groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
847
+
848
+ for (shape, modality), items in groups_iter:
849
+ collection_name = modality
850
+ if collection_name in existing_collections:
851
+ vc = self.collection(collection_name)
852
+ if vc.shape != shape:
853
+ collection_name = f"{modality}_{shape[0]}x{shape[1]}x{shape[2]}"
854
+
855
+ if collection_name in existing_collections:
856
+ vc = self.collection(collection_name)
857
+ vc.append(dicom_dirs=items, reorient=reorient, progress=progress)
858
+ else:
859
+ vc_uri = f"{collections_uri}/{collection_name}"
860
+ VolumeCollection.from_dicoms(
861
+ uri=vc_uri,
862
+ dicom_dirs=items,
863
+ reorient=reorient,
864
+ validate_dimensions=True,
865
+ name=collection_name,
866
+ ctx=self._ctx,
867
+ progress=progress,
868
+ )
869
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
870
+ grp.add(vc_uri, name=collection_name)
871
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
872
+ grp.meta["n_collections"] = self.n_collections + 1
873
+ existing_collections.add(collection_name)
874
+
875
+ # ===== Validation =====
876
+
877
+ def validate(self) -> None:
878
+ """Validate internal consistency of the RadiObject and all collections."""
879
+ self._check_not_view("validate")
880
+ obs_meta_data = self.obs_meta.read()
881
+ actual_subject_count = len(obs_meta_data)
882
+ stored_subject_count = self._metadata.get("subject_count", 0)
883
+ if actual_subject_count != stored_subject_count:
884
+ raise ValueError(
885
+ f"subject_count mismatch: metadata={stored_subject_count}, actual={actual_subject_count}"
886
+ )
887
+
888
+ actual_n_collections = len(self.collection_names)
889
+ stored_n_collections = self._metadata.get("n_collections", 0)
890
+ if actual_n_collections != stored_n_collections:
891
+ raise ValueError(
892
+ f"n_collections mismatch: metadata={stored_n_collections}, actual={actual_n_collections}"
893
+ )
894
+
895
+ for name in self.collection_names:
896
+ self.collection(name).validate()
897
+
898
+ obs_meta_subject_ids = set(obs_meta_data["obs_subject_id"])
899
+ for name in self.collection_names:
900
+ vc = self.collection(name)
901
+ vc_obs = vc.obs.read()
902
+ vc_subject_ids = set(vc_obs["obs_subject_id"])
903
+ orphan_subjects = vc_subject_ids - obs_meta_subject_ids
904
+ if orphan_subjects:
905
+ raise ValueError(
906
+ f"Collection '{name}' has obs_subject_ids not in obs_meta: "
907
+ f"{sorted(orphan_subjects)[:5]}"
908
+ )
909
+
910
+ seen_obs_ids: dict[str, str] = {}
911
+ for name in self.collection_names:
912
+ vc = self.collection(name)
913
+ for obs_id in vc.obs_ids:
914
+ if obs_id in seen_obs_ids:
915
+ raise ValueError(
916
+ f"obs_id '{obs_id}' is duplicated across collections: "
917
+ f"'{seen_obs_ids[obs_id]}' and '{name}'"
918
+ )
919
+ seen_obs_ids[obs_id] = name
920
+
921
+ # ===== Factory Methods =====
922
+
923
+ @classmethod
924
+ def _create(
925
+ cls,
926
+ uri: str,
927
+ obs_meta_schema: dict[str, np.dtype] | None = None,
928
+ n_subjects: int = 0,
929
+ ctx: tiledb.Ctx | None = None,
930
+ ) -> RadiObject:
931
+ """Internal: create an empty RadiObject with optional obs_meta schema."""
932
+ effective_ctx = ctx if ctx else global_ctx()
933
+
934
+ tiledb.Group.create(uri, ctx=effective_ctx)
935
+
936
+ obs_meta_uri = f"{uri}/obs_meta"
937
+ Dataframe.create(obs_meta_uri, schema=obs_meta_schema or {}, ctx=ctx)
938
+
939
+ collections_uri = f"{uri}/collections"
940
+ tiledb.Group.create(collections_uri, ctx=effective_ctx)
941
+
942
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
943
+ grp.meta["subject_count"] = n_subjects
944
+ grp.meta["n_collections"] = 0
945
+ grp.add(obs_meta_uri, name="obs_meta")
946
+ grp.add(collections_uri, name="collections")
947
+
948
+ return cls(uri, ctx=ctx)
949
+
950
+ @classmethod
951
+ def _from_volume_collections(
952
+ cls,
953
+ uri: str,
954
+ collections: dict[str, VolumeCollection],
955
+ obs_meta: pd.DataFrame | None = None,
956
+ ctx: tiledb.Ctx | None = None,
957
+ ) -> RadiObject:
958
+ """Internal: create RadiObject from existing VolumeCollections."""
959
+ if not collections:
960
+ raise ValueError("At least one VolumeCollection is required")
961
+
962
+ effective_ctx = ctx if ctx else global_ctx()
963
+
964
+ n_subjects = len(obs_meta) if obs_meta is not None else 0
965
+
966
+ obs_meta_schema = None
967
+ if obs_meta is not None:
968
+ obs_meta_schema = {}
969
+ for col in obs_meta.columns:
970
+ if col in ("obs_subject_id", "obs_id"):
971
+ continue
972
+ dtype = obs_meta[col].to_numpy().dtype
973
+ if dtype == np.dtype("O"):
974
+ dtype = np.dtype("U64")
975
+ obs_meta_schema[col] = dtype
976
+
977
+ cls._create(uri, obs_meta_schema=obs_meta_schema, n_subjects=n_subjects, ctx=ctx)
978
+
979
+ if obs_meta is not None and len(obs_meta) > 0:
980
+ obs_meta_uri = f"{uri}/obs_meta"
981
+ obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
982
+ obs_ids = (
983
+ obs_meta["obs_id"].astype(str).to_numpy()
984
+ if "obs_id" in obs_meta.columns
985
+ else obs_subject_ids
986
+ )
987
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
988
+ attr_data = {}
989
+ for col in obs_meta.columns:
990
+ if col not in ("obs_subject_id", "obs_id"):
991
+ attr_data[col] = obs_meta[col].to_numpy()
992
+ arr[obs_subject_ids, obs_ids] = attr_data
993
+
994
+ collections_uri = f"{uri}/collections"
995
+ for coll_name, vc in collections.items():
996
+ new_vc_uri = f"{collections_uri}/{coll_name}"
997
+ _copy_volume_collection(vc, new_vc_uri, name=coll_name, ctx=ctx)
998
+
999
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
1000
+ grp.add(new_vc_uri, name=coll_name)
1001
+
1002
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
1003
+ grp.meta["n_collections"] = len(collections)
1004
+ grp.meta["subject_count"] = n_subjects
1005
+
1006
+ radi_result = cls(uri, ctx=ctx)
1007
+ return radi_result
1008
+
1009
+ @classmethod
1010
+ def from_collections(
1011
+ cls,
1012
+ uri: str,
1013
+ collections: dict[str, VolumeCollection | str],
1014
+ obs_meta: pd.DataFrame | None = None,
1015
+ ctx: tiledb.Ctx | None = None,
1016
+ ) -> RadiObject:
1017
+ """Create RadiObject from existing VolumeCollections.
1018
+
1019
+ Links collections without copying when they're already at expected URIs
1020
+ ({uri}/collections/{name}). Copies collections that are elsewhere.
1021
+
1022
+ Args:
1023
+ uri: Target URI for RadiObject
1024
+ collections: Dict mapping collection names to VolumeCollection objects or URIs
1025
+ obs_meta: Optional subject-level metadata. If None, derived from collections.
1026
+ ctx: TileDB context
1027
+
1028
+ Example:
1029
+ # Collections already at expected locations (no copy)
1030
+ ct_vc = radi.CT.lazy().map(transform).materialize(uri=f"{URI}/collections/CT")
1031
+ seg_vc = radi.seg.lazy().map(transform).materialize(uri=f"{URI}/collections/seg")
1032
+ radi = RadiObject.from_collections(
1033
+ uri=URI,
1034
+ collections={"CT": ct_vc, "seg": seg_vc},
1035
+ )
1036
+
1037
+ # Collections from elsewhere (will be copied)
1038
+ radi = RadiObject.from_collections(
1039
+ uri="./new_dataset",
1040
+ collections={"T1w": existing_t1w_collection},
1041
+ )
1042
+ """
1043
+ if not collections:
1044
+ raise ValueError("At least one collection is required")
1045
+
1046
+ effective_ctx = ctx if ctx else global_ctx()
1047
+ collections_uri = f"{uri}/collections"
1048
+
1049
+ # Resolve string URIs to VolumeCollection objects
1050
+ resolved: dict[str, VolumeCollection] = {}
1051
+ for name, vc_or_uri in collections.items():
1052
+ if isinstance(vc_or_uri, str):
1053
+ resolved[name] = VolumeCollection(vc_or_uri, ctx=ctx)
1054
+ else:
1055
+ resolved[name] = vc_or_uri
1056
+
1057
+ # Determine which collections need copying vs linking
1058
+ in_place: dict[str, VolumeCollection] = {}
1059
+ to_copy: dict[str, VolumeCollection] = {}
1060
+
1061
+ for name, vc in resolved.items():
1062
+ expected_uri = f"{collections_uri}/{name}"
1063
+ if vc.uri == expected_uri:
1064
+ in_place[name] = vc
1065
+ else:
1066
+ to_copy[name] = vc
1067
+
1068
+ # Check if collections group already exists (from materialize)
1069
+ vfs = tiledb.VFS(ctx=effective_ctx)
1070
+ collections_group_exists = vfs.is_dir(collections_uri)
1071
+
1072
+ # Create root group
1073
+ tiledb.Group.create(uri, ctx=effective_ctx)
1074
+
1075
+ # Create or use existing collections group
1076
+ if not collections_group_exists:
1077
+ tiledb.Group.create(collections_uri, ctx=effective_ctx)
1078
+
1079
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
1080
+ grp.add(collections_uri, name="collections")
1081
+
1082
+ # Link in-place collections (no copy needed)
1083
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
1084
+ for name, vc in in_place.items():
1085
+ grp.add(vc.uri, name=name)
1086
+
1087
+ # Copy external collections
1088
+ for name, vc in to_copy.items():
1089
+ new_uri = f"{collections_uri}/{name}"
1090
+ _copy_volume_collection(vc, new_uri, name=name, ctx=ctx)
1091
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
1092
+ grp.add(new_uri, name=name)
1093
+
1094
+ # Derive obs_meta if not provided
1095
+ if obs_meta is None:
1096
+ all_subject_ids: set[str] = set()
1097
+ for vc in resolved.values():
1098
+ obs_df = vc.obs.read()
1099
+ all_subject_ids.update(obs_df["obs_subject_id"].tolist())
1100
+ sorted_ids = sorted(all_subject_ids)
1101
+ obs_meta = pd.DataFrame(
1102
+ {
1103
+ "obs_subject_id": sorted_ids,
1104
+ "obs_id": sorted_ids,
1105
+ }
1106
+ )
1107
+
1108
+ # Build obs_meta schema
1109
+ n_subjects = len(obs_meta)
1110
+ obs_meta_schema: dict[str, np.dtype] = {}
1111
+ for col in obs_meta.columns:
1112
+ if col in ("obs_subject_id", "obs_id"):
1113
+ continue
1114
+ dtype = obs_meta[col].to_numpy().dtype
1115
+ if dtype == np.dtype("O"):
1116
+ dtype = np.dtype("U64")
1117
+ obs_meta_schema[col] = dtype
1118
+
1119
+ # Create obs_meta
1120
+ obs_meta_uri = f"{uri}/obs_meta"
1121
+ Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
1122
+
1123
+ if len(obs_meta) > 0:
1124
+ obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
1125
+ obs_ids = (
1126
+ obs_meta["obs_id"].astype(str).to_numpy()
1127
+ if "obs_id" in obs_meta.columns
1128
+ else obs_subject_ids
1129
+ )
1130
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
1131
+ attr_data = {
1132
+ col: obs_meta[col].to_numpy()
1133
+ for col in obs_meta.columns
1134
+ if col not in ("obs_subject_id", "obs_id")
1135
+ }
1136
+ arr[obs_subject_ids, obs_ids] = attr_data
1137
+
1138
+ # Link obs_meta to root and set metadata
1139
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
1140
+ grp.add(obs_meta_uri, name="obs_meta")
1141
+ grp.meta["n_collections"] = len(resolved)
1142
+ grp.meta["subject_count"] = n_subjects
1143
+
1144
+ return cls(uri, ctx=ctx)
1145
+
1146
+ @classmethod
1147
+ def from_niftis(
1148
+ cls,
1149
+ uri: str,
1150
+ niftis: Sequence[tuple[str | Path, str]] | None = None,
1151
+ image_dir: str | Path | None = None,
1152
+ collection_name: str | None = None,
1153
+ images: dict[str, str | Path | Sequence[tuple[str | Path, str]]] | None = None,
1154
+ validate_alignment: bool = False,
1155
+ obs_meta: pd.DataFrame | None = None,
1156
+ reorient: bool | None = None,
1157
+ ctx: tiledb.Ctx | None = None,
1158
+ progress: bool = False,
1159
+ ) -> RadiObject:
1160
+ """Create RadiObject from NIfTI files with raw data storage.
1161
+
1162
+ Ingestion stores volumes in their original dimensions without any
1163
+ preprocessing. Use `collection.lazy().map()` for post-hoc transformations.
1164
+
1165
+ Three input modes:
1166
+ 1. images: Dict mapping collection names to paths/globs/lists (recommended)
1167
+ 2. niftis: List of (path, subject_id) tuples (legacy)
1168
+ 3. image_dir: Directory-based discovery (legacy)
1169
+
1170
+ Collection organization:
1171
+ - With images dict: each key becomes a collection
1172
+ - With collection_name: all volumes go to that single collection
1173
+ - Otherwise: auto-group by inferred modality (T1w, FLAIR, CT, etc.)
1174
+
1175
+ Args:
1176
+ uri: Target URI for RadiObject
1177
+ images: Dict mapping collection names to NIfTI sources. Sources can be:
1178
+ - Glob pattern: "./imagesTr/*.nii.gz"
1179
+ - Directory path: "./imagesTr"
1180
+ - Pre-resolved list: [(path, subject_id), ...]
1181
+ niftis: List of (nifti_path, obs_subject_id) tuples (legacy)
1182
+ image_dir: Directory containing image NIfTIs (legacy, mutually exclusive with niftis)
1183
+ collection_name: Explicit name for collection (legacy, all volumes go here)
1184
+ validate_alignment: If True, verify all collections have same subject IDs
1185
+ obs_meta: Subject-level metadata. Must contain obs_subject_id column.
1186
+ reorient: Reorient to canonical orientation (None uses config default)
1187
+ ctx: TileDB context
1188
+ progress: Show tqdm progress bar
1189
+
1190
+ Example (images dict with globs):
1191
+ radi = RadiObject.from_niftis(
1192
+ uri="./dataset",
1193
+ images={
1194
+ "CT": "./imagesTr/*.nii.gz",
1195
+ "seg": "./labelsTr/*.nii.gz",
1196
+ },
1197
+ )
1198
+
1199
+ Example (images dict with directories):
1200
+ radi = RadiObject.from_niftis(
1201
+ uri="./dataset",
1202
+ images={"CT": "./imagesTr", "seg": "./labelsTr"},
1203
+ )
1204
+
1205
+ Example (legacy explicit collection name):
1206
+ radi = RadiObject.from_niftis(
1207
+ uri="s3://bucket/raw",
1208
+ image_dir="./imagesTr",
1209
+ collection_name="lung_ct",
1210
+ )
1211
+
1212
+ Example (legacy auto-group by modality):
1213
+ radi = RadiObject.from_niftis(
1214
+ uri="s3://bucket/raw",
1215
+ niftis=[
1216
+ ("sub01_T1w.nii.gz", "sub-01"),
1217
+ ("sub01_FLAIR.nii.gz", "sub-01"),
1218
+ ],
1219
+ )
1220
+ # Result: radi.T1w, radi.FLAIR collections
1221
+ """
1222
+ from radiobject.ingest import resolve_nifti_source
1223
+
1224
+ # --- NORMALIZE ALL INPUTS TO images DICT ---
1225
+
1226
+ if images is not None:
1227
+ if niftis is not None or image_dir is not None or collection_name is not None:
1228
+ raise ValueError("Cannot use 'images' with legacy parameters")
1229
+ if not images:
1230
+ raise ValueError("images dict cannot be empty")
1231
+
1232
+ elif image_dir is not None:
1233
+ if niftis is not None:
1234
+ raise ValueError("Cannot specify both 'niftis' and 'image_dir'")
1235
+ from radiobject.ingest import discover_nifti_pairs
1236
+
1237
+ sources = discover_nifti_pairs(image_dir)
1238
+ niftis = [(s.image_path, s.subject_id) for s in sources]
1239
+
1240
+ if collection_name:
1241
+ images = {collection_name: niftis}
1242
+ else:
1243
+ modality_groups: dict[str, list[tuple[str | Path, str]]] = defaultdict(list)
1244
+ for path, sid in niftis:
1245
+ series_type = infer_series_type(Path(path))
1246
+ modality_groups[series_type].append((path, sid))
1247
+ images = dict(modality_groups)
1248
+
1249
+ elif niftis is not None:
1250
+ if collection_name:
1251
+ images = {collection_name: niftis}
1252
+ else:
1253
+ modality_groups: dict[str, list[tuple[str | Path, str]]] = defaultdict(list)
1254
+ for path, sid in niftis:
1255
+ series_type = infer_series_type(Path(path))
1256
+ modality_groups[series_type].append((path, sid))
1257
+ images = dict(modality_groups)
1258
+ else:
1259
+ raise ValueError("Must specify 'images', 'niftis', or 'image_dir'")
1260
+
1261
+ # --- SINGLE CODE PATH: Resolve images dict ---
1262
+
1263
+ groups: dict[str, list[tuple[Path, str]]] = {}
1264
+ for coll_name, source in images.items():
1265
+ groups[coll_name] = resolve_nifti_source(source)
1266
+
1267
+ # Optional alignment validation
1268
+ if validate_alignment and len(groups) > 1:
1269
+ subject_sets = {
1270
+ name: {sid for _, sid in nifti_list} for name, nifti_list in groups.items()
1271
+ }
1272
+ first_name, first_set = next(iter(subject_sets.items()))
1273
+ for name, sid_set in subject_sets.items():
1274
+ if sid_set != first_set:
1275
+ missing_in_first = sid_set - first_set
1276
+ missing_in_other = first_set - sid_set
1277
+ raise ValueError(
1278
+ f"Subject ID mismatch between '{first_name}' and '{name}': "
1279
+ f"missing in '{first_name}': {sorted(missing_in_first)[:3]}, "
1280
+ f"missing in '{name}': {sorted(missing_in_other)[:3]}"
1281
+ )
1282
+
1283
+ # Validate all files exist
1284
+ for coll_name, nifti_list in groups.items():
1285
+ for path, _ in nifti_list:
1286
+ if not path.exists():
1287
+ raise FileNotFoundError(f"NIfTI file not found: {path}")
1288
+
1289
+ # Collect all subject IDs
1290
+ all_subject_ids: set[str] = set()
1291
+ for nifti_list in groups.values():
1292
+ all_subject_ids.update(sid for _, sid in nifti_list)
1293
+
1294
+ # Validate FK constraint if obs_meta provided
1295
+ if obs_meta is not None:
1296
+ if "obs_subject_id" not in obs_meta.columns:
1297
+ raise ValueError("obs_meta must contain 'obs_subject_id' column")
1298
+ obs_meta_subject_ids = set(obs_meta["obs_subject_id"])
1299
+ missing = all_subject_ids - obs_meta_subject_ids
1300
+ if missing:
1301
+ raise ValueError(
1302
+ f"obs_subject_ids in niftis not found in obs_meta: {sorted(missing)[:5]}"
1303
+ )
1304
+ else:
1305
+ sorted_ids = sorted(all_subject_ids)
1306
+ obs_meta = pd.DataFrame(
1307
+ {
1308
+ "obs_subject_id": sorted_ids,
1309
+ "obs_id": sorted_ids,
1310
+ }
1311
+ )
1312
+
1313
+ if not groups or all(len(nifti_list) == 0 for nifti_list in groups.values()):
1314
+ raise ValueError("No NIfTI files found")
1315
+
1316
+ effective_ctx = ctx if ctx else global_ctx()
1317
+
1318
+ tiledb.Group.create(uri, ctx=effective_ctx)
1319
+ collections_uri = f"{uri}/collections"
1320
+ tiledb.Group.create(collections_uri, ctx=effective_ctx)
1321
+
1322
+ collections: dict[str, VolumeCollection] = {}
1323
+
1324
+ groups_iter = list(groups.items())
1325
+ if progress:
1326
+ from tqdm.auto import tqdm
1327
+
1328
+ groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
1329
+
1330
+ for coll_name, items in groups_iter:
1331
+ vc_uri = f"{collections_uri}/{coll_name}"
1332
+ nifti_list = [(path, subject_id) for path, subject_id in items]
1333
+
1334
+ vc = VolumeCollection.from_niftis(
1335
+ uri=vc_uri,
1336
+ niftis=nifti_list,
1337
+ reorient=reorient,
1338
+ validate_dimensions=False,
1339
+ name=coll_name,
1340
+ ctx=ctx,
1341
+ progress=progress,
1342
+ )
1343
+ collections[coll_name] = vc
1344
+
1345
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
1346
+ grp.add(vc_uri, name=coll_name)
1347
+
1348
+ n_subjects = len(obs_meta)
1349
+ obs_meta_schema: dict[str, np.dtype] = {}
1350
+ for col in obs_meta.columns:
1351
+ if col in ("obs_subject_id", "obs_id"):
1352
+ continue
1353
+ dtype = obs_meta[col].to_numpy().dtype
1354
+ if dtype == np.dtype("O"):
1355
+ dtype = np.dtype("U64")
1356
+ obs_meta_schema[col] = dtype
1357
+
1358
+ obs_meta_uri = f"{uri}/obs_meta"
1359
+ Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
1360
+
1361
+ if len(obs_meta) > 0:
1362
+ obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
1363
+ obs_ids = (
1364
+ obs_meta["obs_id"].astype(str).to_numpy()
1365
+ if "obs_id" in obs_meta.columns
1366
+ else obs_subject_ids
1367
+ )
1368
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
1369
+ attr_data = {}
1370
+ for col in obs_meta.columns:
1371
+ if col not in ("obs_subject_id", "obs_id"):
1372
+ attr_data[col] = obs_meta[col].to_numpy()
1373
+ arr[obs_subject_ids, obs_ids] = attr_data
1374
+
1375
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
1376
+ grp.meta["n_collections"] = len(collections)
1377
+ grp.meta["subject_count"] = n_subjects
1378
+ grp.add(obs_meta_uri, name="obs_meta")
1379
+ grp.add(collections_uri, name="collections")
1380
+
1381
+ return cls(uri, ctx=ctx)
1382
+
1383
+ @classmethod
1384
+ def from_dicoms(
1385
+ cls,
1386
+ uri: str,
1387
+ dicom_dirs: Sequence[tuple[str | Path, str]],
1388
+ obs_meta: pd.DataFrame | None = None,
1389
+ reorient: bool | None = None,
1390
+ ctx: tiledb.Ctx | None = None,
1391
+ progress: bool = False,
1392
+ ) -> RadiObject:
1393
+ """Create RadiObject from DICOM series with automatic grouping.
1394
+
1395
+ Files are automatically grouped into VolumeCollections by:
1396
+ 1. Dimensions (rows, columns, n_slices)
1397
+ 2. Modality tag (CT, MR) + SeriesDescription
1398
+
1399
+ Args:
1400
+ uri: Target URI for RadiObject
1401
+ dicom_dirs: List of (dicom_dir, obs_subject_id) tuples
1402
+ obs_meta: Subject-level metadata (user-provided). Must contain obs_subject_id column.
1403
+ reorient: Reorient to canonical orientation (None uses config default)
1404
+ ctx: TileDB context
1405
+ progress: Show tqdm progress bar during volume writes
1406
+
1407
+ Example:
1408
+ radi = RadiObject.from_dicoms(
1409
+ uri="/storage/ct_study",
1410
+ dicom_dirs=[
1411
+ ("/dicom/sub01/CT_HEAD", "sub-01"),
1412
+ ("/dicom/sub01/CT_CHEST", "sub-01"),
1413
+ ("/dicom/sub02/CT_HEAD", "sub-02"),
1414
+ ],
1415
+ obs_meta=obs_meta_df,
1416
+ )
1417
+ """
1418
+ if not dicom_dirs:
1419
+ raise ValueError("At least one DICOM directory is required")
1420
+
1421
+ all_subject_ids = {sid for _, sid in dicom_dirs}
1422
+
1423
+ if obs_meta is not None:
1424
+ if "obs_subject_id" not in obs_meta.columns:
1425
+ raise ValueError("obs_meta must contain 'obs_subject_id' column")
1426
+ obs_meta_subject_ids = set(obs_meta["obs_subject_id"])
1427
+ missing = all_subject_ids - obs_meta_subject_ids
1428
+ if missing:
1429
+ raise ValueError(
1430
+ f"obs_subject_ids in dicom_dirs not found in obs_meta: {sorted(missing)[:5]}"
1431
+ )
1432
+ else:
1433
+ sorted_ids = sorted(all_subject_ids)
1434
+ obs_meta = pd.DataFrame(
1435
+ {
1436
+ "obs_subject_id": sorted_ids,
1437
+ "obs_id": sorted_ids,
1438
+ }
1439
+ )
1440
+
1441
+ file_info: list[tuple[Path, str, tuple[int, int, int], str]] = []
1442
+ for dicom_dir, obs_subject_id in dicom_dirs:
1443
+ path = Path(dicom_dir)
1444
+ if not path.exists():
1445
+ raise FileNotFoundError(f"DICOM directory not found: {path}")
1446
+
1447
+ metadata = extract_dicom_metadata(path)
1448
+ dims = metadata.dimensions
1449
+ shape = (dims[1], dims[0], dims[2])
1450
+ group_key = metadata.modality
1451
+ file_info.append((path, obs_subject_id, shape, group_key))
1452
+
1453
+ groups: dict[tuple[tuple[int, int, int], str], list[tuple[Path, str]]] = defaultdict(list)
1454
+ for path, subject_id, shape, group_key in file_info:
1455
+ key = (shape, group_key)
1456
+ groups[key].append((path, subject_id))
1457
+
1458
+ effective_ctx = ctx if ctx else global_ctx()
1459
+
1460
+ tiledb.Group.create(uri, ctx=effective_ctx)
1461
+ collections_uri = f"{uri}/collections"
1462
+ tiledb.Group.create(collections_uri, ctx=effective_ctx)
1463
+
1464
+ collections: dict[str, VolumeCollection] = {}
1465
+ used_names: set[str] = set()
1466
+
1467
+ groups_iter = groups.items()
1468
+ if progress:
1469
+ from tqdm.auto import tqdm
1470
+
1471
+ groups_iter = tqdm(groups_iter, desc="Collections", unit="coll")
1472
+
1473
+ for (shape, modality), items in groups_iter:
1474
+ coll_name = modality
1475
+ if coll_name in used_names:
1476
+ coll_name = f"{modality}_{shape[0]}x{shape[1]}x{shape[2]}"
1477
+ used_names.add(coll_name)
1478
+
1479
+ vc_uri = f"{collections_uri}/{coll_name}"
1480
+ dicom_list = [(path, subject_id) for path, subject_id in items]
1481
+
1482
+ vc = VolumeCollection.from_dicoms(
1483
+ uri=vc_uri,
1484
+ dicom_dirs=dicom_list,
1485
+ reorient=reorient,
1486
+ validate_dimensions=True,
1487
+ name=coll_name,
1488
+ ctx=ctx,
1489
+ progress=progress,
1490
+ )
1491
+ collections[coll_name] = vc
1492
+
1493
+ with tiledb.Group(collections_uri, "w", ctx=effective_ctx) as grp:
1494
+ grp.add(vc_uri, name=coll_name)
1495
+
1496
+ n_subjects = len(obs_meta)
1497
+ obs_meta_schema: dict[str, np.dtype] = {}
1498
+ for col in obs_meta.columns:
1499
+ if col in ("obs_subject_id", "obs_id"):
1500
+ continue
1501
+ dtype = obs_meta[col].to_numpy().dtype
1502
+ if dtype == np.dtype("O"):
1503
+ dtype = np.dtype("U64")
1504
+ obs_meta_schema[col] = dtype
1505
+
1506
+ obs_meta_uri = f"{uri}/obs_meta"
1507
+ Dataframe.create(obs_meta_uri, schema=obs_meta_schema, ctx=ctx)
1508
+
1509
+ if len(obs_meta) > 0:
1510
+ obs_subject_ids = obs_meta["obs_subject_id"].astype(str).to_numpy()
1511
+ obs_ids = (
1512
+ obs_meta["obs_id"].astype(str).to_numpy()
1513
+ if "obs_id" in obs_meta.columns
1514
+ else obs_subject_ids
1515
+ )
1516
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
1517
+ attr_data = {}
1518
+ for col in obs_meta.columns:
1519
+ if col not in ("obs_subject_id", "obs_id"):
1520
+ attr_data[col] = obs_meta[col].to_numpy()
1521
+ arr[obs_subject_ids, obs_ids] = attr_data
1522
+
1523
+ with tiledb.Group(uri, "w", ctx=effective_ctx) as grp:
1524
+ grp.meta["n_collections"] = len(collections)
1525
+ grp.meta["subject_count"] = n_subjects
1526
+ grp.add(obs_meta_uri, name="obs_meta")
1527
+ grp.add(collections_uri, name="collections")
1528
+
1529
+ return cls(uri, ctx=ctx)
1530
+
1531
+
1532
+ # ===== Helper Functions =====
1533
+
1534
+
1535
+ def _extract_obs_schema(obs: Dataframe) -> dict[str, np.dtype]:
1536
+ """Extract schema from an obs Dataframe (excluding obs_id and obs_subject_id)."""
1537
+ schema = {}
1538
+ for col in obs.columns:
1539
+ if col in ("obs_id", "obs_subject_id"):
1540
+ continue
1541
+ schema[col] = obs.dtypes[col]
1542
+ return schema
1543
+
1544
+
1545
+ def _copy_volume_collection(
1546
+ src: VolumeCollection,
1547
+ dst_uri: str,
1548
+ name: str | None = None,
1549
+ ctx: tiledb.Ctx | None = None,
1550
+ ) -> None:
1551
+ """Copy a VolumeCollection to a new URI."""
1552
+ effective_ctx = ctx if ctx else global_ctx()
1553
+
1554
+ collection_name = name if name is not None else src.name
1555
+
1556
+ VolumeCollection._create(
1557
+ dst_uri,
1558
+ shape=src.shape,
1559
+ obs_schema=_extract_obs_schema(src.obs),
1560
+ n_volumes=len(src),
1561
+ name=collection_name,
1562
+ ctx=ctx,
1563
+ )
1564
+
1565
+ obs_df = src.obs.read()
1566
+ obs_uri = f"{dst_uri}/obs"
1567
+ obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
1568
+ obs_ids = obs_df["obs_id"].astype(str).to_numpy()
1569
+ with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
1570
+ attr_data = {
1571
+ col: obs_df[col].to_numpy()
1572
+ for col in obs_df.columns
1573
+ if col not in ("obs_subject_id", "obs_id")
1574
+ }
1575
+ arr[obs_subject_ids, obs_ids] = attr_data
1576
+
1577
+ def write_volume(args: tuple[int, str, Volume]) -> WriteResult:
1578
+ idx, obs_id, vol = args
1579
+ worker_ctx = create_worker_ctx(ctx)
1580
+ new_vol_uri = f"{dst_uri}/volumes/{idx}"
1581
+ try:
1582
+ data = vol.to_numpy()
1583
+ new_vol = Volume.from_numpy(new_vol_uri, data, ctx=worker_ctx)
1584
+ new_vol.set_obs_id(obs_id)
1585
+ return WriteResult(idx, new_vol_uri, obs_id, success=True)
1586
+ except Exception as e:
1587
+ return WriteResult(idx, new_vol_uri, obs_id, success=False, error=e)
1588
+
1589
+ write_args = [(idx, obs_id, src.iloc[idx]) for idx, obs_id in enumerate(src.obs_ids)]
1590
+ results = _write_volumes_parallel(
1591
+ write_volume, write_args, progress=False, desc="Copying volumes"
1592
+ )
1593
+
1594
+ with tiledb.Group(f"{dst_uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
1595
+ for result in results:
1596
+ vol_grp.add(result.uri, name=str(result.index))
1597
+
1598
+
1599
+ def _copy_filtered_volume_collection(
1600
+ src: VolumeCollection,
1601
+ dst_uri: str,
1602
+ obs_subject_ids: list[str],
1603
+ name: str | None = None,
1604
+ ctx: tiledb.Ctx | None = None,
1605
+ ) -> None:
1606
+ """Copy a VolumeCollection, filtering to volumes matching obs_subject_ids."""
1607
+ effective_ctx = ctx if ctx else global_ctx()
1608
+
1609
+ collection_name = name if name is not None else src.name
1610
+
1611
+ obs_df = src.obs.read()
1612
+ subject_id_set = set(obs_subject_ids)
1613
+
1614
+ filtered_obs = obs_df[obs_df["obs_subject_id"].isin(subject_id_set)].reset_index(drop=True)
1615
+
1616
+ if len(filtered_obs) == 0:
1617
+ raise ValueError("No volumes match the specified obs_subject_ids")
1618
+
1619
+ VolumeCollection._create(
1620
+ dst_uri,
1621
+ shape=src.shape,
1622
+ obs_schema=_extract_obs_schema(src.obs),
1623
+ n_volumes=len(filtered_obs),
1624
+ name=collection_name,
1625
+ ctx=ctx,
1626
+ )
1627
+
1628
+ obs_uri = f"{dst_uri}/obs"
1629
+ obs_subject_ids_arr = filtered_obs["obs_subject_id"].astype(str).to_numpy()
1630
+ obs_ids_arr = filtered_obs["obs_id"].astype(str).to_numpy()
1631
+ with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
1632
+ attr_data = {
1633
+ col: filtered_obs[col].to_numpy()
1634
+ for col in filtered_obs.columns
1635
+ if col not in ("obs_subject_id", "obs_id")
1636
+ }
1637
+ arr[obs_subject_ids_arr, obs_ids_arr] = attr_data
1638
+
1639
+ selected_obs_ids = set(filtered_obs["obs_id"])
1640
+ selected_indices = [i for i, oid in enumerate(src.obs_ids) if oid in selected_obs_ids]
1641
+
1642
+ def write_volume(args: tuple[int, int, str]) -> WriteResult:
1643
+ new_idx, orig_idx, obs_id = args
1644
+ worker_ctx = create_worker_ctx(ctx)
1645
+ new_vol_uri = f"{dst_uri}/volumes/{new_idx}"
1646
+ try:
1647
+ vol = src.iloc[orig_idx]
1648
+ data = vol.to_numpy()
1649
+ new_vol = Volume.from_numpy(new_vol_uri, data, ctx=worker_ctx)
1650
+ new_vol.set_obs_id(obs_id)
1651
+ return WriteResult(new_idx, new_vol_uri, obs_id, success=True)
1652
+ except Exception as e:
1653
+ return WriteResult(new_idx, new_vol_uri, obs_id, success=False, error=e)
1654
+
1655
+ write_args = [
1656
+ (new_idx, orig_idx, src.obs_ids[orig_idx])
1657
+ for new_idx, orig_idx in enumerate(selected_indices)
1658
+ ]
1659
+ results = _write_volumes_parallel(
1660
+ write_volume, write_args, progress=False, desc="Filtering volumes"
1661
+ )
1662
+
1663
+ with tiledb.Group(f"{dst_uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
1664
+ for result in results:
1665
+ vol_grp.add(result.uri, name=str(result.index))