radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
radiobject/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ """RadiObject - TileDB-backed data structure for radiology data at scale."""
2
+
3
+ from radiobject._types import AttrValue, LabelSource, TransformFn
4
+ from radiobject.ctx import ReadConfig, WriteConfig, configure, ctx, get_config
5
+ from radiobject.dataframe import Dataframe
6
+ from radiobject.radi_object import RadiObject
7
+ from radiobject.volume import Volume
8
+ from radiobject.volume_collection import VolumeCollection
9
+
10
+ __version__ = "0.1.0"
11
+ __all__ = [
12
+ "RadiObject",
13
+ "Volume",
14
+ "VolumeCollection",
15
+ "Dataframe",
16
+ "ctx",
17
+ "configure",
18
+ "get_config",
19
+ "WriteConfig",
20
+ "ReadConfig",
21
+ "TransformFn",
22
+ "AttrValue",
23
+ "LabelSource",
24
+ ]
radiobject/_types.py ADDED
@@ -0,0 +1,19 @@
1
+ """Shared type aliases for the radiobject package."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+ import pandas as pd
11
+
12
+ # Volume transform function: (X, Y, Z) -> (X', Y', Z')
13
+ TransformFn = Callable[[npt.NDArray[np.floating]], npt.NDArray[np.floating]]
14
+
15
+ # Scalar values storable in TileDB obs attributes
16
+ AttrValue = int | float | bool | str
17
+
18
+ # Flexible label specification for ML datasets
19
+ LabelSource = str | pd.DataFrame | dict[str, Any] | Callable[[str], Any] | None
radiobject/ctx.py ADDED
@@ -0,0 +1,359 @@
1
+ """TileDB context configuration for radiology data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+ from typing import Self
7
+
8
+ import boto3
9
+ import tiledb
10
+ from pydantic import BaseModel, Field, model_validator
11
+
12
+
13
+ class SliceOrientation(str, Enum):
14
+ """Preferred slicing orientation for tile optimization."""
15
+
16
+ AXIAL = "axial" # X-Y slices, vary Z (most common for neuro)
17
+ SAGITTAL = "sagittal" # Y-Z slices, vary X
18
+ CORONAL = "coronal" # X-Z slices, vary Y
19
+ ISOTROPIC = "isotropic" # Balanced 64³ chunks for 3D ROI
20
+
21
+
22
+ class Compressor(str, Enum):
23
+ """Compression algorithms suited for radiology data."""
24
+
25
+ ZSTD = "zstd" # Good balance of speed and ratio
26
+ LZ4 = "lz4" # Fast, lower ratio
27
+ GZIP = "gzip" # High ratio, slower
28
+ NONE = "none"
29
+
30
+
31
+ class TileConfig(BaseModel):
32
+ """Tile dimensions for chunked storage."""
33
+
34
+ orientation: SliceOrientation = Field(
35
+ default=SliceOrientation.AXIAL,
36
+ description="Primary slicing orientation for tile optimization",
37
+ )
38
+ x: int | None = Field(default=None, ge=1, description="Tile extent in X (None = auto)")
39
+ y: int | None = Field(default=None, ge=1, description="Tile extent in Y (None = auto)")
40
+ z: int | None = Field(default=None, ge=1, description="Tile extent in Z (None = auto)")
41
+ t: int = Field(default=1, ge=1, description="Tile extent in T dimension")
42
+
43
+ def extents_for_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
44
+ """Compute optimal tile extents based on orientation and volume shape."""
45
+ sx, sy, sz = shape[0], shape[1], shape[2]
46
+
47
+ match self.orientation:
48
+ case SliceOrientation.AXIAL:
49
+ extents = (self.x or sx, self.y or sy, self.z or 1)
50
+ case SliceOrientation.SAGITTAL:
51
+ extents = (self.x or 1, self.y or sy, self.z or sz)
52
+ case SliceOrientation.CORONAL:
53
+ extents = (self.x or sx, self.y or 1, self.z or sz)
54
+ case SliceOrientation.ISOTROPIC:
55
+ extents = (self.x or 64, self.y or 64, self.z or 64)
56
+
57
+ if len(shape) == 4:
58
+ extents = (*extents, self.t)
59
+ return extents
60
+
61
+
62
+ class CompressionConfig(BaseModel):
63
+ """Compression settings for volume data."""
64
+
65
+ algorithm: Compressor = Field(
66
+ default=Compressor.ZSTD,
67
+ description="Compression algorithm",
68
+ )
69
+ level: int = Field(
70
+ default=3,
71
+ ge=-1,
72
+ le=22,
73
+ description="Compression level (algorithm-dependent)",
74
+ )
75
+
76
+ def as_filter(self) -> tiledb.Filter | None:
77
+ match self.algorithm:
78
+ case Compressor.ZSTD:
79
+ return tiledb.ZstdFilter(level=self.level)
80
+ case Compressor.LZ4:
81
+ return tiledb.LZ4Filter(level=self.level)
82
+ case Compressor.GZIP:
83
+ return tiledb.GzipFilter(level=self.level)
84
+ case Compressor.NONE:
85
+ return None
86
+
87
+
88
+ class WriteConfig(BaseModel):
89
+ """Settings applied when creating new TileDB arrays (immutable after creation)."""
90
+
91
+ tile: TileConfig = Field(default_factory=TileConfig)
92
+ compression: CompressionConfig = Field(default_factory=CompressionConfig)
93
+ orientation: "OrientationConfig" = Field(default_factory=lambda: OrientationConfig())
94
+
95
+
96
+ class ReadConfig(BaseModel):
97
+ """Settings for reading TileDB arrays."""
98
+
99
+ memory_budget_mb: int = Field(
100
+ default=1024,
101
+ ge=64,
102
+ description="Memory budget for TileDB operations (MB)",
103
+ )
104
+ concurrency: int = Field(
105
+ default=4,
106
+ ge=1,
107
+ le=64,
108
+ description="Number of TileDB internal I/O threads",
109
+ )
110
+ max_workers: int = Field(
111
+ default=4,
112
+ ge=1,
113
+ le=32,
114
+ description="Max parallel workers for volume I/O operations",
115
+ )
116
+
117
+
118
+ class IOConfig(BaseModel):
119
+ """I/O and memory settings (deprecated, use ReadConfig)."""
120
+
121
+ memory_budget_mb: int = Field(
122
+ default=1024,
123
+ ge=64,
124
+ description="Memory budget for TileDB operations (MB)",
125
+ )
126
+ concurrency: int = Field(
127
+ default=4,
128
+ ge=1,
129
+ le=64,
130
+ description="Number of TileDB internal I/O threads",
131
+ )
132
+ max_workers: int = Field(
133
+ default=4,
134
+ ge=1,
135
+ le=32,
136
+ description="Max parallel workers for volume I/O operations",
137
+ )
138
+
139
+
140
+ class S3Config(BaseModel):
141
+ """S3/cloud storage settings."""
142
+
143
+ region: str = Field(default="us-east-1")
144
+ endpoint: str | None = Field(default=None, description="Custom S3 endpoint")
145
+ use_virtual_addressing: bool = Field(default=True)
146
+ max_parallel_ops: int = Field(default=8, ge=1)
147
+ multipart_part_size_mb: int = Field(default=50, ge=5)
148
+ include_credentials: bool = Field(
149
+ default=True, description="Include AWS credentials from boto3 session"
150
+ )
151
+
152
+
153
+ class OrientationConfig(BaseModel):
154
+ """Orientation detection and standardization settings."""
155
+
156
+ auto_detect: bool = Field(
157
+ default=True,
158
+ description="Automatically detect orientation from file headers",
159
+ )
160
+ canonical_target: str = Field(
161
+ default="RAS",
162
+ description="Target canonical orientation (RAS, LAS, or LPS)",
163
+ )
164
+ reorient_on_load: bool = Field(
165
+ default=False,
166
+ description="Reorient to canonical when loading (preserves original by default)",
167
+ )
168
+ store_original_affine: bool = Field(
169
+ default=True,
170
+ description="Store original affine in metadata when reorienting",
171
+ )
172
+
173
+ @model_validator(mode="after")
174
+ def validate_canonical_target(self) -> Self:
175
+ """Ensure canonical target is valid."""
176
+ valid_targets = {"RAS", "LAS", "LPS"}
177
+ if self.canonical_target not in valid_targets:
178
+ raise ValueError(
179
+ f"canonical_target must be one of {valid_targets}, got {self.canonical_target}"
180
+ )
181
+ return self
182
+
183
+
184
+ class RadiObjectConfig(BaseModel):
185
+ """Configuration for RadiObject TileDB context."""
186
+
187
+ write: WriteConfig = Field(default_factory=WriteConfig)
188
+ read: ReadConfig = Field(default_factory=ReadConfig)
189
+ s3: S3Config = Field(default_factory=S3Config)
190
+
191
+ # Deprecated flat access (for backwards compatibility)
192
+ _tile: TileConfig | None = None
193
+ _compression: CompressionConfig | None = None
194
+ _io: IOConfig | None = None
195
+ _orientation: OrientationConfig | None = None
196
+
197
+ @property
198
+ def tile(self) -> TileConfig:
199
+ """Tile configuration (deprecated, use write.tile)."""
200
+ return self.write.tile
201
+
202
+ @property
203
+ def compression(self) -> CompressionConfig:
204
+ """Compression configuration (deprecated, use write.compression)."""
205
+ return self.write.compression
206
+
207
+ @property
208
+ def orientation(self) -> OrientationConfig:
209
+ """Orientation configuration (deprecated, use write.orientation)."""
210
+ return self.write.orientation
211
+
212
+ @property
213
+ def io(self) -> ReadConfig:
214
+ """I/O configuration (deprecated, use read)."""
215
+ return self.read
216
+
217
+ @model_validator(mode="after")
218
+ def validate_compression_level(self) -> Self:
219
+ """Ensure compression level is valid for the algorithm."""
220
+ max_levels = {
221
+ Compressor.ZSTD: 22,
222
+ Compressor.LZ4: 16,
223
+ Compressor.GZIP: 9,
224
+ Compressor.NONE: 0,
225
+ }
226
+ max_level = max_levels[self.write.compression.algorithm]
227
+ if self.write.compression.level > max_level:
228
+ self.write.compression.level = max_level
229
+ return self
230
+
231
+ def to_tiledb_config(self, include_s3_credentials: bool = False) -> tiledb.Config:
232
+ """Convert to TileDB Config object.
233
+
234
+ Args:
235
+ include_s3_credentials: If True, fetch AWS credentials from boto3.
236
+ This is off by default to avoid expensive/failing credential
237
+ lookups for local-only operations.
238
+ """
239
+ cfg = tiledb.Config()
240
+
241
+ # Memory settings from read config
242
+ cfg["sm.memory_budget"] = str(self.read.memory_budget_mb * 1024 * 1024)
243
+ cfg["sm.compute_concurrency_level"] = str(self.read.concurrency)
244
+ cfg["sm.io_concurrency_level"] = str(self.read.concurrency)
245
+
246
+ # S3 settings (configuration only, no credential lookup)
247
+ cfg["vfs.s3.region"] = self.s3.region
248
+ cfg["vfs.s3.use_virtual_addressing"] = "true" if self.s3.use_virtual_addressing else "false"
249
+ cfg["vfs.s3.max_parallel_ops"] = str(self.s3.max_parallel_ops)
250
+ cfg["vfs.s3.multipart_part_size"] = str(self.s3.multipart_part_size_mb * 1024 * 1024)
251
+ if self.s3.endpoint:
252
+ cfg["vfs.s3.endpoint_override"] = self.s3.endpoint
253
+
254
+ # Fetch AWS credentials if configured or explicitly requested
255
+ if include_s3_credentials or self.s3.include_credentials:
256
+ self._add_s3_credentials(cfg)
257
+
258
+ return cfg
259
+
260
+ def _add_s3_credentials(self, cfg: tiledb.Config) -> None:
261
+ """Add AWS credentials to config from boto3 session."""
262
+ try:
263
+ session = boto3.Session()
264
+ credentials = session.get_credentials()
265
+ if credentials:
266
+ frozen = credentials.get_frozen_credentials()
267
+ cfg["vfs.s3.aws_access_key_id"] = frozen.access_key
268
+ cfg["vfs.s3.aws_secret_access_key"] = frozen.secret_key
269
+ if frozen.token:
270
+ cfg["vfs.s3.aws_session_token"] = frozen.token
271
+ except Exception:
272
+ pass # AWS credentials unavailable; S3 operations will fail
273
+
274
+ def to_tiledb_ctx(self, include_s3_credentials: bool = False) -> tiledb.Ctx:
275
+ """Convert to TileDB Ctx object.
276
+
277
+ Args:
278
+ include_s3_credentials: If True, fetch AWS credentials from boto3.
279
+ """
280
+ return tiledb.Ctx(self.to_tiledb_config(include_s3_credentials))
281
+
282
+
283
+ # Global mutable configuration
284
+ _config: RadiObjectConfig = RadiObjectConfig()
285
+ _ctx: tiledb.Ctx | None = None
286
+
287
+
288
+ def get_config() -> RadiObjectConfig:
289
+ """Get the current global configuration."""
290
+ return _config
291
+
292
+
293
+ def ctx() -> tiledb.Ctx:
294
+ """Get the global TileDB context (lazily built from config)."""
295
+ global _ctx
296
+ if _ctx is None:
297
+ _ctx = _config.to_tiledb_ctx()
298
+ return _ctx
299
+
300
+
301
+ def configure(
302
+ *,
303
+ write: WriteConfig | None = None,
304
+ read: ReadConfig | None = None,
305
+ s3: S3Config | None = None,
306
+ # Deprecated flat arguments for backwards compatibility
307
+ tile: TileConfig | None = None,
308
+ compression: CompressionConfig | None = None,
309
+ io: IOConfig | ReadConfig | None = None,
310
+ orientation: OrientationConfig | None = None,
311
+ ) -> None:
312
+ """Update global configuration.
313
+
314
+ Example (new API):
315
+ configure(write=WriteConfig(tile=TileConfig(orientation=SliceOrientation.AXIAL)))
316
+ configure(read=ReadConfig(memory_budget_mb=2048))
317
+
318
+ Example (deprecated flat API, still supported):
319
+ configure(tile=TileConfig(x=128, y=128, z=32))
320
+ configure(compression=CompressionConfig(algorithm=Compressor.LZ4))
321
+ """
322
+ global _config, _ctx
323
+
324
+ updates: dict = {}
325
+
326
+ # Handle new nested API
327
+ if write is not None:
328
+ updates["write"] = write
329
+ if read is not None:
330
+ updates["read"] = read
331
+ if s3 is not None:
332
+ updates["s3"] = s3
333
+
334
+ # Handle deprecated flat API by mapping to nested structure
335
+ if tile is not None or compression is not None or orientation is not None:
336
+ write_updates = {}
337
+ if tile is not None:
338
+ write_updates["tile"] = tile
339
+ if compression is not None:
340
+ write_updates["compression"] = compression
341
+ if orientation is not None:
342
+ write_updates["orientation"] = orientation
343
+ if write_updates:
344
+ current_write = _config.write.model_copy(update=write_updates)
345
+ updates["write"] = current_write
346
+
347
+ if io is not None:
348
+ # Map deprecated io to read config
349
+ if isinstance(io, ReadConfig):
350
+ updates["read"] = io
351
+ else:
352
+ updates["read"] = ReadConfig(
353
+ memory_budget_mb=io.memory_budget_mb,
354
+ concurrency=io.concurrency,
355
+ max_workers=io.max_workers,
356
+ )
357
+
358
+ _config = _config.model_copy(update=updates)
359
+ _ctx = None
@@ -0,0 +1,186 @@
1
+ """Dataframe - a 2D heterogeneous array backed by TileDB."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import cached_property
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import tiledb
10
+
11
+ from radiobject.ctx import ctx as global_ctx
12
+ from radiobject.ctx import get_config
13
+
14
+ # Mandatory index columns for all Dataframes
15
+ INDEX_COLUMNS = ("obs_subject_id", "obs_id")
16
+
17
+
18
+ class Dataframe:
19
+ """TileDB-backed sparse dataframe for observation metadata.
20
+
21
+ Used internally for obs_meta (subject-level) and obs (volume-level) storage.
22
+ Indexed by (obs_subject_id, obs_id) with user-defined attribute columns.
23
+
24
+ Example:
25
+ df = dataframe.read(columns=["age"], value_filter="age > 40")
26
+ """
27
+
28
+ def __init__(self, uri: str, ctx: tiledb.Ctx | None = None):
29
+ self.uri: str = uri
30
+ self._ctx: tiledb.Ctx | None = ctx
31
+
32
+ def _effective_ctx(self) -> tiledb.Ctx:
33
+ return self._ctx if self._ctx else global_ctx()
34
+
35
+ @cached_property
36
+ def _schema(self) -> tiledb.ArraySchema:
37
+ """Cached schema loaded once from disk."""
38
+ return tiledb.ArraySchema.load(self.uri, ctx=self._effective_ctx())
39
+
40
+ @property
41
+ def shape(self) -> tuple[int, int]:
42
+ """(n_rows, n_columns) dimensions."""
43
+ with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
44
+ n_rows = arr.nonempty_domain()[0][1] if arr.nonempty_domain()[0] else 0
45
+ if isinstance(n_rows, str):
46
+ n_rows = len(arr.query(attrs=[])[:][INDEX_COLUMNS[0]])
47
+ n_cols = self._schema.nattr
48
+ return (n_rows, n_cols)
49
+
50
+ @property
51
+ def index_columns(self) -> tuple[str, str]:
52
+ """Index column names (dimension names)."""
53
+ return INDEX_COLUMNS
54
+
55
+ @property
56
+ def columns(self) -> list[str]:
57
+ """Attribute column names (excluding index columns)."""
58
+ return [self._schema.attr(i).name for i in range(self._schema.nattr)]
59
+
60
+ @property
61
+ def all_columns(self) -> list[str]:
62
+ """All column names including index columns."""
63
+ return list(INDEX_COLUMNS) + self.columns
64
+
65
+ @cached_property
66
+ def dtypes(self) -> dict[str, np.dtype]:
67
+ """Column data types (attributes only)."""
68
+ schema = self._schema
69
+ return {schema.attr(i).name: schema.attr(i).dtype for i in range(schema.nattr)}
70
+
71
+ def __len__(self) -> int:
72
+ with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
73
+ result = arr.query(attrs=[])[:][INDEX_COLUMNS[0]]
74
+ return len(result)
75
+
76
+ def __repr__(self) -> str:
77
+ return f"Dataframe(uri={self.uri!r}, shape={self.shape}, columns={self.columns})"
78
+
79
+ def read(
80
+ self,
81
+ columns: list[str] | None = None,
82
+ value_filter: str | None = None,
83
+ include_index: bool = True,
84
+ ) -> pd.DataFrame:
85
+ """Read data with optional column selection and value filtering."""
86
+ # Filter out index columns from requested columns (they're dimensions, not attributes)
87
+ if columns is not None:
88
+ attrs = [c for c in columns if c not in INDEX_COLUMNS]
89
+ else:
90
+ attrs = self.columns
91
+ with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
92
+ if value_filter is not None:
93
+ result = arr.query(cond=value_filter, attrs=attrs)[:]
94
+ else:
95
+ result = arr.query(attrs=attrs)[:]
96
+ data = {col: result[col] for col in attrs}
97
+ if include_index:
98
+ for idx_col in INDEX_COLUMNS:
99
+ # Convert bytes to strings for index columns
100
+ raw = result[idx_col]
101
+ data[idx_col] = np.array(
102
+ [v.decode() if isinstance(v, bytes) else str(v) for v in raw]
103
+ )
104
+ df = pd.DataFrame(data)
105
+ if include_index:
106
+ col_order = list(INDEX_COLUMNS) + attrs
107
+ df = df[col_order]
108
+ return df
109
+
110
+ @staticmethod
111
+ def _validate_schema(schema: dict[str, np.dtype]) -> None:
112
+ """Validate column names and types (for non-index attributes)."""
113
+ for name, dtype in schema.items():
114
+ if "\x00" in name:
115
+ raise ValueError(f"Column name contains null byte: {name!r}")
116
+ if name in INDEX_COLUMNS:
117
+ raise ValueError(f"Column name conflicts with index column: {name!r}")
118
+ if not isinstance(dtype, np.dtype):
119
+ try:
120
+ np.dtype(dtype)
121
+ except TypeError as e:
122
+ raise TypeError(f"Invalid dtype for column {name!r}: {dtype}") from e
123
+
124
+ @classmethod
125
+ def create(
126
+ cls,
127
+ uri: str,
128
+ schema: dict[str, np.dtype],
129
+ ctx: tiledb.Ctx | None = None,
130
+ ) -> Dataframe:
131
+ """Create an empty sparse Dataframe indexed by obs_subject_id and obs_id."""
132
+ cls._validate_schema(schema)
133
+ effective_ctx = ctx if ctx else global_ctx()
134
+
135
+ dims = [
136
+ tiledb.Dim(name=INDEX_COLUMNS[0], dtype="ascii", ctx=effective_ctx),
137
+ tiledb.Dim(name=INDEX_COLUMNS[1], dtype="ascii", ctx=effective_ctx),
138
+ ]
139
+ domain = tiledb.Domain(*dims, ctx=effective_ctx)
140
+
141
+ config = get_config()
142
+ compression_filter = config.compression.as_filter()
143
+ compression = (
144
+ tiledb.FilterList([compression_filter]) if compression_filter else tiledb.FilterList()
145
+ )
146
+ attrs = [
147
+ tiledb.Attr(name=name, dtype=dtype, filters=compression, ctx=effective_ctx)
148
+ for name, dtype in schema.items()
149
+ ]
150
+
151
+ array_schema = tiledb.ArraySchema(
152
+ domain=domain,
153
+ attrs=attrs,
154
+ sparse=True,
155
+ ctx=effective_ctx,
156
+ )
157
+ tiledb.Array.create(uri, array_schema, ctx=effective_ctx)
158
+
159
+ return cls(uri, ctx=ctx)
160
+
161
+ @classmethod
162
+ def from_pandas(
163
+ cls,
164
+ uri: str,
165
+ df: pd.DataFrame,
166
+ ctx: tiledb.Ctx | None = None,
167
+ ) -> Dataframe:
168
+ """Create a new Dataframe from a pandas DataFrame with obs_subject_id and obs_id columns."""
169
+ for idx_col in INDEX_COLUMNS:
170
+ if idx_col not in df.columns:
171
+ raise ValueError(f"DataFrame must contain index column: {idx_col!r}")
172
+
173
+ attr_cols = [col for col in df.columns if col not in INDEX_COLUMNS]
174
+ schema = {col: df[col].to_numpy().dtype for col in attr_cols}
175
+ dataframe = cls.create(uri, schema=schema, ctx=ctx)
176
+
177
+ effective_ctx = ctx if ctx else global_ctx()
178
+ with tiledb.open(uri, mode="w", ctx=effective_ctx) as arr:
179
+ coords = {
180
+ INDEX_COLUMNS[0]: df[INDEX_COLUMNS[0]].astype(str).to_numpy(),
181
+ INDEX_COLUMNS[1]: df[INDEX_COLUMNS[1]].astype(str).to_numpy(),
182
+ }
183
+ data = {col: df[col].to_numpy() for col in attr_cols}
184
+ arr[coords[INDEX_COLUMNS[0]], coords[INDEX_COLUMNS[1]]] = data
185
+
186
+ return dataframe