radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,389 @@
1
+ """Streaming writers for memory-efficient RadiObject and VolumeCollection creation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import pandas as pd
10
+ import tiledb
11
+
12
+ from radiobject._types import AttrValue
13
+ from radiobject.ctx import ctx as global_ctx
14
+ from radiobject.dataframe import Dataframe
15
+ from radiobject.volume import Volume
16
+
17
+ if TYPE_CHECKING:
18
+ from radiobject.radi_object import RadiObject
19
+
20
+
21
+ class StreamingWriter:
22
+ """Write volumes incrementally without full memory load.
23
+
24
+ Context manager that creates a VolumeCollection and writes volumes
25
+ one at a time, keeping memory usage bounded.
26
+
27
+ Example:
28
+ # Uniform shape collection
29
+ with StreamingWriter(uri, shape=(256, 256, 128)) as writer:
30
+ for nifti_path, subject_id in niftis:
31
+ data = load_nifti(nifti_path)
32
+ writer.write_volume(data, obs_id=f"{subject_id}_T1w", obs_subject_id=subject_id)
33
+
34
+ # Heterogeneous shape collection (raw ingestion)
35
+ with StreamingWriter(uri) as writer:
36
+ for nifti_path, subject_id in niftis:
37
+ data = load_nifti(nifti_path)
38
+ writer.write_volume(data, obs_id=f"{subject_id}_T1w", obs_subject_id=subject_id)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ uri: str,
44
+ shape: tuple[int, int, int] | None = None,
45
+ obs_schema: dict[str, np.dtype] | None = None,
46
+ name: str | None = None,
47
+ ctx: tiledb.Ctx | None = None,
48
+ ):
49
+ self.uri = uri
50
+ self.shape = shape # None = heterogeneous shapes allowed
51
+ self.obs_schema = obs_schema or {}
52
+ self.name = name
53
+ self._ctx = ctx
54
+ self._volume_count = 0
55
+ self._obs_rows: list[dict[str, AttrValue]] = []
56
+ self._initialized = False
57
+ self._finalized = False
58
+
59
+ def _effective_ctx(self) -> tiledb.Ctx:
60
+ return self._ctx if self._ctx else global_ctx()
61
+
62
+ def __enter__(self) -> StreamingWriter:
63
+ """Initialize the VolumeCollection structure."""
64
+ if self._initialized:
65
+ raise RuntimeError("StreamingWriter already initialized")
66
+
67
+ effective_ctx = self._effective_ctx()
68
+
69
+ # Create group structure
70
+ tiledb.Group.create(self.uri, ctx=effective_ctx)
71
+
72
+ volumes_uri = f"{self.uri}/volumes"
73
+ tiledb.Group.create(volumes_uri, ctx=effective_ctx)
74
+
75
+ obs_uri = f"{self.uri}/obs"
76
+ Dataframe.create(obs_uri, schema=self.obs_schema, ctx=self._ctx)
77
+
78
+ # Set initial metadata (n_volumes=0, will update on finalize)
79
+ with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
80
+ if self.shape is not None:
81
+ grp.meta["x_dim"] = self.shape[0]
82
+ grp.meta["y_dim"] = self.shape[1]
83
+ grp.meta["z_dim"] = self.shape[2]
84
+ grp.meta["n_volumes"] = 0
85
+ if self.name is not None:
86
+ grp.meta["name"] = self.name
87
+ grp.add(volumes_uri, name="volumes")
88
+ grp.add(obs_uri, name="obs")
89
+
90
+ self._initialized = True
91
+ return self
92
+
93
+ def write_volume(
94
+ self,
95
+ data: npt.NDArray[np.floating],
96
+ obs_id: str,
97
+ obs_subject_id: str,
98
+ **attrs: AttrValue,
99
+ ) -> None:
100
+ """Write a single volume to the collection.
101
+
102
+ Args:
103
+ data: Volume data array (must match shape if uniform collection)
104
+ obs_id: Unique identifier for this volume
105
+ obs_subject_id: Subject identifier (foreign key)
106
+ **attrs: Additional obs attributes matching obs_schema
107
+ """
108
+ if not self._initialized:
109
+ raise RuntimeError("StreamingWriter not initialized. Use as context manager.")
110
+ if self._finalized:
111
+ raise RuntimeError("StreamingWriter already finalized")
112
+
113
+ # Only validate shape if collection requires uniform dimensions
114
+ if self.shape is not None and data.shape[:3] != self.shape:
115
+ raise ValueError(
116
+ f"Volume shape {data.shape[:3]} doesn't match collection shape {self.shape}"
117
+ )
118
+
119
+ effective_ctx = self._effective_ctx()
120
+ idx = self._volume_count
121
+ volume_uri = f"{self.uri}/volumes/{idx}"
122
+
123
+ # Write volume array
124
+ vol = Volume.from_numpy(volume_uri, data, ctx=self._ctx)
125
+ vol.set_obs_id(obs_id)
126
+
127
+ # Register with volumes group
128
+ with tiledb.Group(f"{self.uri}/volumes", "w", ctx=effective_ctx) as vol_grp:
129
+ vol_grp.add(volume_uri, name=str(idx))
130
+
131
+ # Collect obs row data
132
+ row = {"obs_id": obs_id, "obs_subject_id": obs_subject_id, **attrs}
133
+ self._obs_rows.append(row)
134
+
135
+ self._volume_count += 1
136
+
137
+ def write_batch(
138
+ self, volumes: list[tuple[npt.NDArray[np.floating], str, str, dict[str, AttrValue]]]
139
+ ) -> None:
140
+ """Write multiple volumes at once.
141
+
142
+ Args:
143
+ volumes: List of (data, obs_id, obs_subject_id, attrs) tuples
144
+ """
145
+ for data, obs_id, obs_subject_id, attrs in volumes:
146
+ self.write_volume(data, obs_id, obs_subject_id, **attrs)
147
+
148
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
149
+ """Finalize collection and write obs metadata."""
150
+ if self._finalized:
151
+ return
152
+
153
+ if exc_type is not None:
154
+ # Exception occurred, don't finalize
155
+ return
156
+
157
+ effective_ctx = self._effective_ctx()
158
+
159
+ # Write obs data
160
+ if self._obs_rows:
161
+ obs_uri = f"{self.uri}/obs"
162
+ obs_df = pd.DataFrame(self._obs_rows)
163
+
164
+ obs_subject_ids = obs_df["obs_subject_id"].astype(str).to_numpy()
165
+ obs_ids = obs_df["obs_id"].astype(str).to_numpy()
166
+
167
+ with tiledb.open(obs_uri, "w", ctx=effective_ctx) as arr:
168
+ attr_data = {
169
+ col: obs_df[col].to_numpy()
170
+ for col in obs_df.columns
171
+ if col not in ("obs_subject_id", "obs_id")
172
+ }
173
+ arr[obs_subject_ids, obs_ids] = attr_data
174
+
175
+ # Update final volume count
176
+ with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
177
+ grp.meta["n_volumes"] = self._volume_count
178
+
179
+ self._finalized = True
180
+
181
+ @property
182
+ def n_written(self) -> int:
183
+ """Number of volumes written so far."""
184
+ return self._volume_count
185
+
186
+
187
+ class RadiObjectWriter:
188
+ """Build RadiObject incrementally from multiple collections.
189
+
190
+ Context manager that creates a RadiObject structure and allows
191
+ adding collections one at a time via StreamingWriter instances.
192
+
193
+ Example:
194
+ with RadiObjectWriter(uri) as writer:
195
+ writer.write_obs_meta(obs_meta_df)
196
+
197
+ with writer.add_collection("T1w", shape=(256, 256, 128)) as t1w_writer:
198
+ for path, subj_id in t1w_files:
199
+ t1w_writer.write_volume(load_nifti(path), f"{subj_id}_T1w", subj_id)
200
+
201
+ with writer.add_collection("FLAIR", shape=(256, 256, 128)) as flair_writer:
202
+ for path, subj_id in flair_files:
203
+ flair_writer.write_volume(load_nifti(path), f"{subj_id}_FLAIR", subj_id)
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ uri: str,
209
+ obs_meta_schema: dict[str, np.dtype] | None = None,
210
+ ctx: tiledb.Ctx | None = None,
211
+ ):
212
+ self.uri = uri
213
+ self.obs_meta_schema = obs_meta_schema or {}
214
+ self._ctx = ctx
215
+ self._collection_names: list[str] = []
216
+ self._subject_count = 0
217
+ self._initialized = False
218
+ self._finalized = False
219
+ self._all_obs_ids: set[str] = set() # Track obs_ids across all collections
220
+
221
+ def _effective_ctx(self) -> tiledb.Ctx:
222
+ return self._ctx if self._ctx else global_ctx()
223
+
224
+ def __enter__(self) -> RadiObjectWriter:
225
+ """Initialize the RadiObject structure."""
226
+ if self._initialized:
227
+ raise RuntimeError("RadiObjectWriter already initialized")
228
+
229
+ effective_ctx = self._effective_ctx()
230
+
231
+ # Create main group
232
+ tiledb.Group.create(self.uri, ctx=effective_ctx)
233
+
234
+ # Create obs_meta dataframe
235
+ obs_meta_uri = f"{self.uri}/obs_meta"
236
+ Dataframe.create(obs_meta_uri, schema=self.obs_meta_schema, ctx=self._ctx)
237
+
238
+ # Create collections group
239
+ collections_uri = f"{self.uri}/collections"
240
+ tiledb.Group.create(collections_uri, ctx=effective_ctx)
241
+
242
+ # Set initial metadata
243
+ with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
244
+ grp.meta["subject_count"] = 0
245
+ grp.meta["n_collections"] = 0
246
+ grp.add(obs_meta_uri, name="obs_meta")
247
+ grp.add(collections_uri, name="collections")
248
+
249
+ self._initialized = True
250
+ return self
251
+
252
+ def write_obs_meta(self, df: pd.DataFrame) -> None:
253
+ """Write subject-level metadata.
254
+
255
+ Args:
256
+ df: DataFrame with obs_subject_id column and optional obs_id column
257
+ """
258
+ if not self._initialized:
259
+ raise RuntimeError("RadiObjectWriter not initialized. Use as context manager.")
260
+ if "obs_subject_id" not in df.columns:
261
+ raise ValueError("DataFrame must contain 'obs_subject_id' column")
262
+
263
+ effective_ctx = self._effective_ctx()
264
+ obs_meta_uri = f"{self.uri}/obs_meta"
265
+
266
+ obs_subject_ids = df["obs_subject_id"].astype(str).to_numpy()
267
+ obs_ids = df["obs_id"].astype(str).to_numpy() if "obs_id" in df.columns else obs_subject_ids
268
+
269
+ with tiledb.open(obs_meta_uri, "w", ctx=effective_ctx) as arr:
270
+ attr_data = {
271
+ col: df[col].to_numpy()
272
+ for col in df.columns
273
+ if col not in ("obs_subject_id", "obs_id")
274
+ }
275
+ arr[obs_subject_ids, obs_ids] = attr_data
276
+
277
+ self._subject_count = len(df)
278
+
279
+ def add_collection(
280
+ self,
281
+ name: str,
282
+ shape: tuple[int, int, int] | None = None,
283
+ obs_schema: dict[str, np.dtype] | None = None,
284
+ ) -> StreamingWriter:
285
+ """Add a new collection and return a StreamingWriter for it.
286
+
287
+ Args:
288
+ name: Collection name (e.g., "T1w", "FLAIR")
289
+ shape: Volume dimensions (X, Y, Z). None for heterogeneous shapes.
290
+ obs_schema: Schema for volume-level obs attributes
291
+
292
+ Returns:
293
+ StreamingWriter context manager for writing volumes
294
+ """
295
+ if not self._initialized:
296
+ raise RuntimeError("RadiObjectWriter not initialized. Use as context manager.")
297
+ if self._finalized:
298
+ raise RuntimeError("RadiObjectWriter already finalized")
299
+ if name in self._collection_names:
300
+ raise ValueError(f"Collection '{name}' already added")
301
+
302
+ collection_uri = f"{self.uri}/collections/{name}"
303
+ writer = _CollectionStreamingWriter(
304
+ uri=collection_uri,
305
+ shape=shape,
306
+ obs_schema=obs_schema,
307
+ name=name,
308
+ ctx=self._ctx,
309
+ parent=self,
310
+ )
311
+ return writer
312
+
313
+ def _register_collection(self, name: str, uri: str) -> None:
314
+ """Internal: register a completed collection."""
315
+ effective_ctx = self._effective_ctx()
316
+
317
+ with tiledb.Group(f"{self.uri}/collections", "w", ctx=effective_ctx) as grp:
318
+ grp.add(uri, name=name)
319
+
320
+ self._collection_names.append(name)
321
+
322
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
323
+ """Finalize the RadiObject by updating metadata."""
324
+ if self._finalized:
325
+ return
326
+
327
+ if exc_type is not None:
328
+ return
329
+
330
+ effective_ctx = self._effective_ctx()
331
+
332
+ with tiledb.Group(self.uri, "w", ctx=effective_ctx) as grp:
333
+ grp.meta["subject_count"] = self._subject_count
334
+ grp.meta["n_collections"] = len(self._collection_names)
335
+
336
+ self._finalized = True
337
+
338
+ def finalize(self) -> RadiObject:
339
+ """Finalize and return the created RadiObject."""
340
+ from radiobject.radi_object import RadiObject
341
+
342
+ if not self._finalized:
343
+ self.__exit__(None, None, None)
344
+
345
+ return RadiObject(self.uri, ctx=self._ctx)
346
+
347
+ @property
348
+ def collection_names(self) -> list[str]:
349
+ """Names of collections added so far."""
350
+ return list(self._collection_names)
351
+
352
+
353
+ class _CollectionStreamingWriter(StreamingWriter):
354
+ """StreamingWriter that registers with parent RadiObjectWriter on completion."""
355
+
356
+ def __init__(
357
+ self,
358
+ uri: str,
359
+ shape: tuple[int, int, int] | None,
360
+ obs_schema: dict[str, np.dtype] | None,
361
+ name: str,
362
+ ctx: tiledb.Ctx | None,
363
+ parent: RadiObjectWriter,
364
+ ):
365
+ super().__init__(uri, shape, obs_schema, name, ctx)
366
+ self._parent = parent
367
+
368
+ def write_volume(
369
+ self,
370
+ data: npt.NDArray[np.floating],
371
+ obs_id: str,
372
+ obs_subject_id: str,
373
+ **attrs: AttrValue,
374
+ ) -> None:
375
+ """Write a volume, checking obs_id uniqueness across all collections."""
376
+ if obs_id in self._parent._all_obs_ids:
377
+ raise ValueError(
378
+ f"obs_id '{obs_id}' already exists in RadiObject. "
379
+ f"obs_id must be unique across all collections."
380
+ )
381
+ self._parent._all_obs_ids.add(obs_id)
382
+ super().write_volume(data, obs_id, obs_subject_id, **attrs)
383
+
384
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
385
+ """Finalize and register with parent."""
386
+ super().__exit__(exc_type, exc_val, exc_tb)
387
+
388
+ if exc_type is None and self._finalized:
389
+ self._parent._register_collection(self.name, self.uri)
radiobject/utils.py ADDED
@@ -0,0 +1,17 @@
1
+ """Shared utilities for RadiObject."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+
9
+
10
+ def affine_to_list(affine: np.ndarray) -> list[list[float]]:
11
+ """Convert numpy affine matrix to nested list for JSON serialization."""
12
+ return [[float(v) for v in row] for row in affine]
13
+
14
+
15
+ def affine_to_json(affine: np.ndarray) -> str:
16
+ """Serialize 4x4 affine matrix to JSON string."""
17
+ return json.dumps(affine_to_list(affine))