radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
radiobject/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""RadiObject - TileDB-backed data structure for radiology data at scale."""
|
|
2
|
+
|
|
3
|
+
from radiobject._types import AttrValue, LabelSource, TransformFn
|
|
4
|
+
from radiobject.ctx import ReadConfig, WriteConfig, configure, ctx, get_config
|
|
5
|
+
from radiobject.dataframe import Dataframe
|
|
6
|
+
from radiobject.radi_object import RadiObject
|
|
7
|
+
from radiobject.volume import Volume
|
|
8
|
+
from radiobject.volume_collection import VolumeCollection
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.0"
|
|
11
|
+
__all__ = [
|
|
12
|
+
"RadiObject",
|
|
13
|
+
"Volume",
|
|
14
|
+
"VolumeCollection",
|
|
15
|
+
"Dataframe",
|
|
16
|
+
"ctx",
|
|
17
|
+
"configure",
|
|
18
|
+
"get_config",
|
|
19
|
+
"WriteConfig",
|
|
20
|
+
"ReadConfig",
|
|
21
|
+
"TransformFn",
|
|
22
|
+
"AttrValue",
|
|
23
|
+
"LabelSource",
|
|
24
|
+
]
|
radiobject/_types.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Shared type aliases for the radiobject package."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import numpy.typing as npt
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
# Volume transform function: (X, Y, Z) -> (X', Y', Z')
|
|
13
|
+
TransformFn = Callable[[npt.NDArray[np.floating]], npt.NDArray[np.floating]]
|
|
14
|
+
|
|
15
|
+
# Scalar values storable in TileDB obs attributes
|
|
16
|
+
AttrValue = int | float | bool | str
|
|
17
|
+
|
|
18
|
+
# Flexible label specification for ML datasets
|
|
19
|
+
LabelSource = str | pd.DataFrame | dict[str, Any] | Callable[[str], Any] | None
|
radiobject/ctx.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""TileDB context configuration for radiology data."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Self
|
|
7
|
+
|
|
8
|
+
import boto3
|
|
9
|
+
import tiledb
|
|
10
|
+
from pydantic import BaseModel, Field, model_validator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SliceOrientation(str, Enum):
|
|
14
|
+
"""Preferred slicing orientation for tile optimization."""
|
|
15
|
+
|
|
16
|
+
AXIAL = "axial" # X-Y slices, vary Z (most common for neuro)
|
|
17
|
+
SAGITTAL = "sagittal" # Y-Z slices, vary X
|
|
18
|
+
CORONAL = "coronal" # X-Z slices, vary Y
|
|
19
|
+
ISOTROPIC = "isotropic" # Balanced 64³ chunks for 3D ROI
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Compressor(str, Enum):
|
|
23
|
+
"""Compression algorithms suited for radiology data."""
|
|
24
|
+
|
|
25
|
+
ZSTD = "zstd" # Good balance of speed and ratio
|
|
26
|
+
LZ4 = "lz4" # Fast, lower ratio
|
|
27
|
+
GZIP = "gzip" # High ratio, slower
|
|
28
|
+
NONE = "none"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TileConfig(BaseModel):
|
|
32
|
+
"""Tile dimensions for chunked storage."""
|
|
33
|
+
|
|
34
|
+
orientation: SliceOrientation = Field(
|
|
35
|
+
default=SliceOrientation.AXIAL,
|
|
36
|
+
description="Primary slicing orientation for tile optimization",
|
|
37
|
+
)
|
|
38
|
+
x: int | None = Field(default=None, ge=1, description="Tile extent in X (None = auto)")
|
|
39
|
+
y: int | None = Field(default=None, ge=1, description="Tile extent in Y (None = auto)")
|
|
40
|
+
z: int | None = Field(default=None, ge=1, description="Tile extent in Z (None = auto)")
|
|
41
|
+
t: int = Field(default=1, ge=1, description="Tile extent in T dimension")
|
|
42
|
+
|
|
43
|
+
def extents_for_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]:
|
|
44
|
+
"""Compute optimal tile extents based on orientation and volume shape."""
|
|
45
|
+
sx, sy, sz = shape[0], shape[1], shape[2]
|
|
46
|
+
|
|
47
|
+
match self.orientation:
|
|
48
|
+
case SliceOrientation.AXIAL:
|
|
49
|
+
extents = (self.x or sx, self.y or sy, self.z or 1)
|
|
50
|
+
case SliceOrientation.SAGITTAL:
|
|
51
|
+
extents = (self.x or 1, self.y or sy, self.z or sz)
|
|
52
|
+
case SliceOrientation.CORONAL:
|
|
53
|
+
extents = (self.x or sx, self.y or 1, self.z or sz)
|
|
54
|
+
case SliceOrientation.ISOTROPIC:
|
|
55
|
+
extents = (self.x or 64, self.y or 64, self.z or 64)
|
|
56
|
+
|
|
57
|
+
if len(shape) == 4:
|
|
58
|
+
extents = (*extents, self.t)
|
|
59
|
+
return extents
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CompressionConfig(BaseModel):
|
|
63
|
+
"""Compression settings for volume data."""
|
|
64
|
+
|
|
65
|
+
algorithm: Compressor = Field(
|
|
66
|
+
default=Compressor.ZSTD,
|
|
67
|
+
description="Compression algorithm",
|
|
68
|
+
)
|
|
69
|
+
level: int = Field(
|
|
70
|
+
default=3,
|
|
71
|
+
ge=-1,
|
|
72
|
+
le=22,
|
|
73
|
+
description="Compression level (algorithm-dependent)",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def as_filter(self) -> tiledb.Filter | None:
|
|
77
|
+
match self.algorithm:
|
|
78
|
+
case Compressor.ZSTD:
|
|
79
|
+
return tiledb.ZstdFilter(level=self.level)
|
|
80
|
+
case Compressor.LZ4:
|
|
81
|
+
return tiledb.LZ4Filter(level=self.level)
|
|
82
|
+
case Compressor.GZIP:
|
|
83
|
+
return tiledb.GzipFilter(level=self.level)
|
|
84
|
+
case Compressor.NONE:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class WriteConfig(BaseModel):
|
|
89
|
+
"""Settings applied when creating new TileDB arrays (immutable after creation)."""
|
|
90
|
+
|
|
91
|
+
tile: TileConfig = Field(default_factory=TileConfig)
|
|
92
|
+
compression: CompressionConfig = Field(default_factory=CompressionConfig)
|
|
93
|
+
orientation: "OrientationConfig" = Field(default_factory=lambda: OrientationConfig())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ReadConfig(BaseModel):
|
|
97
|
+
"""Settings for reading TileDB arrays."""
|
|
98
|
+
|
|
99
|
+
memory_budget_mb: int = Field(
|
|
100
|
+
default=1024,
|
|
101
|
+
ge=64,
|
|
102
|
+
description="Memory budget for TileDB operations (MB)",
|
|
103
|
+
)
|
|
104
|
+
concurrency: int = Field(
|
|
105
|
+
default=4,
|
|
106
|
+
ge=1,
|
|
107
|
+
le=64,
|
|
108
|
+
description="Number of TileDB internal I/O threads",
|
|
109
|
+
)
|
|
110
|
+
max_workers: int = Field(
|
|
111
|
+
default=4,
|
|
112
|
+
ge=1,
|
|
113
|
+
le=32,
|
|
114
|
+
description="Max parallel workers for volume I/O operations",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class IOConfig(BaseModel):
|
|
119
|
+
"""I/O and memory settings (deprecated, use ReadConfig)."""
|
|
120
|
+
|
|
121
|
+
memory_budget_mb: int = Field(
|
|
122
|
+
default=1024,
|
|
123
|
+
ge=64,
|
|
124
|
+
description="Memory budget for TileDB operations (MB)",
|
|
125
|
+
)
|
|
126
|
+
concurrency: int = Field(
|
|
127
|
+
default=4,
|
|
128
|
+
ge=1,
|
|
129
|
+
le=64,
|
|
130
|
+
description="Number of TileDB internal I/O threads",
|
|
131
|
+
)
|
|
132
|
+
max_workers: int = Field(
|
|
133
|
+
default=4,
|
|
134
|
+
ge=1,
|
|
135
|
+
le=32,
|
|
136
|
+
description="Max parallel workers for volume I/O operations",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class S3Config(BaseModel):
|
|
141
|
+
"""S3/cloud storage settings."""
|
|
142
|
+
|
|
143
|
+
region: str = Field(default="us-east-1")
|
|
144
|
+
endpoint: str | None = Field(default=None, description="Custom S3 endpoint")
|
|
145
|
+
use_virtual_addressing: bool = Field(default=True)
|
|
146
|
+
max_parallel_ops: int = Field(default=8, ge=1)
|
|
147
|
+
multipart_part_size_mb: int = Field(default=50, ge=5)
|
|
148
|
+
include_credentials: bool = Field(
|
|
149
|
+
default=True, description="Include AWS credentials from boto3 session"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class OrientationConfig(BaseModel):
|
|
154
|
+
"""Orientation detection and standardization settings."""
|
|
155
|
+
|
|
156
|
+
auto_detect: bool = Field(
|
|
157
|
+
default=True,
|
|
158
|
+
description="Automatically detect orientation from file headers",
|
|
159
|
+
)
|
|
160
|
+
canonical_target: str = Field(
|
|
161
|
+
default="RAS",
|
|
162
|
+
description="Target canonical orientation (RAS, LAS, or LPS)",
|
|
163
|
+
)
|
|
164
|
+
reorient_on_load: bool = Field(
|
|
165
|
+
default=False,
|
|
166
|
+
description="Reorient to canonical when loading (preserves original by default)",
|
|
167
|
+
)
|
|
168
|
+
store_original_affine: bool = Field(
|
|
169
|
+
default=True,
|
|
170
|
+
description="Store original affine in metadata when reorienting",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@model_validator(mode="after")
|
|
174
|
+
def validate_canonical_target(self) -> Self:
|
|
175
|
+
"""Ensure canonical target is valid."""
|
|
176
|
+
valid_targets = {"RAS", "LAS", "LPS"}
|
|
177
|
+
if self.canonical_target not in valid_targets:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"canonical_target must be one of {valid_targets}, got {self.canonical_target}"
|
|
180
|
+
)
|
|
181
|
+
return self
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class RadiObjectConfig(BaseModel):
|
|
185
|
+
"""Configuration for RadiObject TileDB context."""
|
|
186
|
+
|
|
187
|
+
write: WriteConfig = Field(default_factory=WriteConfig)
|
|
188
|
+
read: ReadConfig = Field(default_factory=ReadConfig)
|
|
189
|
+
s3: S3Config = Field(default_factory=S3Config)
|
|
190
|
+
|
|
191
|
+
# Deprecated flat access (for backwards compatibility)
|
|
192
|
+
_tile: TileConfig | None = None
|
|
193
|
+
_compression: CompressionConfig | None = None
|
|
194
|
+
_io: IOConfig | None = None
|
|
195
|
+
_orientation: OrientationConfig | None = None
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def tile(self) -> TileConfig:
|
|
199
|
+
"""Tile configuration (deprecated, use write.tile)."""
|
|
200
|
+
return self.write.tile
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def compression(self) -> CompressionConfig:
|
|
204
|
+
"""Compression configuration (deprecated, use write.compression)."""
|
|
205
|
+
return self.write.compression
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def orientation(self) -> OrientationConfig:
|
|
209
|
+
"""Orientation configuration (deprecated, use write.orientation)."""
|
|
210
|
+
return self.write.orientation
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def io(self) -> ReadConfig:
|
|
214
|
+
"""I/O configuration (deprecated, use read)."""
|
|
215
|
+
return self.read
|
|
216
|
+
|
|
217
|
+
@model_validator(mode="after")
|
|
218
|
+
def validate_compression_level(self) -> Self:
|
|
219
|
+
"""Ensure compression level is valid for the algorithm."""
|
|
220
|
+
max_levels = {
|
|
221
|
+
Compressor.ZSTD: 22,
|
|
222
|
+
Compressor.LZ4: 16,
|
|
223
|
+
Compressor.GZIP: 9,
|
|
224
|
+
Compressor.NONE: 0,
|
|
225
|
+
}
|
|
226
|
+
max_level = max_levels[self.write.compression.algorithm]
|
|
227
|
+
if self.write.compression.level > max_level:
|
|
228
|
+
self.write.compression.level = max_level
|
|
229
|
+
return self
|
|
230
|
+
|
|
231
|
+
def to_tiledb_config(self, include_s3_credentials: bool = False) -> tiledb.Config:
|
|
232
|
+
"""Convert to TileDB Config object.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
include_s3_credentials: If True, fetch AWS credentials from boto3.
|
|
236
|
+
This is off by default to avoid expensive/failing credential
|
|
237
|
+
lookups for local-only operations.
|
|
238
|
+
"""
|
|
239
|
+
cfg = tiledb.Config()
|
|
240
|
+
|
|
241
|
+
# Memory settings from read config
|
|
242
|
+
cfg["sm.memory_budget"] = str(self.read.memory_budget_mb * 1024 * 1024)
|
|
243
|
+
cfg["sm.compute_concurrency_level"] = str(self.read.concurrency)
|
|
244
|
+
cfg["sm.io_concurrency_level"] = str(self.read.concurrency)
|
|
245
|
+
|
|
246
|
+
# S3 settings (configuration only, no credential lookup)
|
|
247
|
+
cfg["vfs.s3.region"] = self.s3.region
|
|
248
|
+
cfg["vfs.s3.use_virtual_addressing"] = "true" if self.s3.use_virtual_addressing else "false"
|
|
249
|
+
cfg["vfs.s3.max_parallel_ops"] = str(self.s3.max_parallel_ops)
|
|
250
|
+
cfg["vfs.s3.multipart_part_size"] = str(self.s3.multipart_part_size_mb * 1024 * 1024)
|
|
251
|
+
if self.s3.endpoint:
|
|
252
|
+
cfg["vfs.s3.endpoint_override"] = self.s3.endpoint
|
|
253
|
+
|
|
254
|
+
# Fetch AWS credentials if configured or explicitly requested
|
|
255
|
+
if include_s3_credentials or self.s3.include_credentials:
|
|
256
|
+
self._add_s3_credentials(cfg)
|
|
257
|
+
|
|
258
|
+
return cfg
|
|
259
|
+
|
|
260
|
+
def _add_s3_credentials(self, cfg: tiledb.Config) -> None:
|
|
261
|
+
"""Add AWS credentials to config from boto3 session."""
|
|
262
|
+
try:
|
|
263
|
+
session = boto3.Session()
|
|
264
|
+
credentials = session.get_credentials()
|
|
265
|
+
if credentials:
|
|
266
|
+
frozen = credentials.get_frozen_credentials()
|
|
267
|
+
cfg["vfs.s3.aws_access_key_id"] = frozen.access_key
|
|
268
|
+
cfg["vfs.s3.aws_secret_access_key"] = frozen.secret_key
|
|
269
|
+
if frozen.token:
|
|
270
|
+
cfg["vfs.s3.aws_session_token"] = frozen.token
|
|
271
|
+
except Exception:
|
|
272
|
+
pass # AWS credentials unavailable; S3 operations will fail
|
|
273
|
+
|
|
274
|
+
def to_tiledb_ctx(self, include_s3_credentials: bool = False) -> tiledb.Ctx:
|
|
275
|
+
"""Convert to TileDB Ctx object.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
include_s3_credentials: If True, fetch AWS credentials from boto3.
|
|
279
|
+
"""
|
|
280
|
+
return tiledb.Ctx(self.to_tiledb_config(include_s3_credentials))
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
# Global mutable configuration
|
|
284
|
+
_config: RadiObjectConfig = RadiObjectConfig()
|
|
285
|
+
_ctx: tiledb.Ctx | None = None
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def get_config() -> RadiObjectConfig:
|
|
289
|
+
"""Get the current global configuration."""
|
|
290
|
+
return _config
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def ctx() -> tiledb.Ctx:
|
|
294
|
+
"""Get the global TileDB context (lazily built from config)."""
|
|
295
|
+
global _ctx
|
|
296
|
+
if _ctx is None:
|
|
297
|
+
_ctx = _config.to_tiledb_ctx()
|
|
298
|
+
return _ctx
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def configure(
|
|
302
|
+
*,
|
|
303
|
+
write: WriteConfig | None = None,
|
|
304
|
+
read: ReadConfig | None = None,
|
|
305
|
+
s3: S3Config | None = None,
|
|
306
|
+
# Deprecated flat arguments for backwards compatibility
|
|
307
|
+
tile: TileConfig | None = None,
|
|
308
|
+
compression: CompressionConfig | None = None,
|
|
309
|
+
io: IOConfig | ReadConfig | None = None,
|
|
310
|
+
orientation: OrientationConfig | None = None,
|
|
311
|
+
) -> None:
|
|
312
|
+
"""Update global configuration.
|
|
313
|
+
|
|
314
|
+
Example (new API):
|
|
315
|
+
configure(write=WriteConfig(tile=TileConfig(orientation=SliceOrientation.AXIAL)))
|
|
316
|
+
configure(read=ReadConfig(memory_budget_mb=2048))
|
|
317
|
+
|
|
318
|
+
Example (deprecated flat API, still supported):
|
|
319
|
+
configure(tile=TileConfig(x=128, y=128, z=32))
|
|
320
|
+
configure(compression=CompressionConfig(algorithm=Compressor.LZ4))
|
|
321
|
+
"""
|
|
322
|
+
global _config, _ctx
|
|
323
|
+
|
|
324
|
+
updates: dict = {}
|
|
325
|
+
|
|
326
|
+
# Handle new nested API
|
|
327
|
+
if write is not None:
|
|
328
|
+
updates["write"] = write
|
|
329
|
+
if read is not None:
|
|
330
|
+
updates["read"] = read
|
|
331
|
+
if s3 is not None:
|
|
332
|
+
updates["s3"] = s3
|
|
333
|
+
|
|
334
|
+
# Handle deprecated flat API by mapping to nested structure
|
|
335
|
+
if tile is not None or compression is not None or orientation is not None:
|
|
336
|
+
write_updates = {}
|
|
337
|
+
if tile is not None:
|
|
338
|
+
write_updates["tile"] = tile
|
|
339
|
+
if compression is not None:
|
|
340
|
+
write_updates["compression"] = compression
|
|
341
|
+
if orientation is not None:
|
|
342
|
+
write_updates["orientation"] = orientation
|
|
343
|
+
if write_updates:
|
|
344
|
+
current_write = _config.write.model_copy(update=write_updates)
|
|
345
|
+
updates["write"] = current_write
|
|
346
|
+
|
|
347
|
+
if io is not None:
|
|
348
|
+
# Map deprecated io to read config
|
|
349
|
+
if isinstance(io, ReadConfig):
|
|
350
|
+
updates["read"] = io
|
|
351
|
+
else:
|
|
352
|
+
updates["read"] = ReadConfig(
|
|
353
|
+
memory_budget_mb=io.memory_budget_mb,
|
|
354
|
+
concurrency=io.concurrency,
|
|
355
|
+
max_workers=io.max_workers,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
_config = _config.model_copy(update=updates)
|
|
359
|
+
_ctx = None
|
radiobject/dataframe.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Dataframe - a 2D heterogeneous array backed by TileDB."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import cached_property
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import tiledb
|
|
10
|
+
|
|
11
|
+
from radiobject.ctx import ctx as global_ctx
|
|
12
|
+
from radiobject.ctx import get_config
|
|
13
|
+
|
|
14
|
+
# Mandatory index columns for all Dataframes
|
|
15
|
+
INDEX_COLUMNS = ("obs_subject_id", "obs_id")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Dataframe:
|
|
19
|
+
"""TileDB-backed sparse dataframe for observation metadata.
|
|
20
|
+
|
|
21
|
+
Used internally for obs_meta (subject-level) and obs (volume-level) storage.
|
|
22
|
+
Indexed by (obs_subject_id, obs_id) with user-defined attribute columns.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
df = dataframe.read(columns=["age"], value_filter="age > 40")
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, uri: str, ctx: tiledb.Ctx | None = None):
|
|
29
|
+
self.uri: str = uri
|
|
30
|
+
self._ctx: tiledb.Ctx | None = ctx
|
|
31
|
+
|
|
32
|
+
def _effective_ctx(self) -> tiledb.Ctx:
|
|
33
|
+
return self._ctx if self._ctx else global_ctx()
|
|
34
|
+
|
|
35
|
+
@cached_property
|
|
36
|
+
def _schema(self) -> tiledb.ArraySchema:
|
|
37
|
+
"""Cached schema loaded once from disk."""
|
|
38
|
+
return tiledb.ArraySchema.load(self.uri, ctx=self._effective_ctx())
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def shape(self) -> tuple[int, int]:
|
|
42
|
+
"""(n_rows, n_columns) dimensions."""
|
|
43
|
+
with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
|
|
44
|
+
n_rows = arr.nonempty_domain()[0][1] if arr.nonempty_domain()[0] else 0
|
|
45
|
+
if isinstance(n_rows, str):
|
|
46
|
+
n_rows = len(arr.query(attrs=[])[:][INDEX_COLUMNS[0]])
|
|
47
|
+
n_cols = self._schema.nattr
|
|
48
|
+
return (n_rows, n_cols)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def index_columns(self) -> tuple[str, str]:
|
|
52
|
+
"""Index column names (dimension names)."""
|
|
53
|
+
return INDEX_COLUMNS
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def columns(self) -> list[str]:
|
|
57
|
+
"""Attribute column names (excluding index columns)."""
|
|
58
|
+
return [self._schema.attr(i).name for i in range(self._schema.nattr)]
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def all_columns(self) -> list[str]:
|
|
62
|
+
"""All column names including index columns."""
|
|
63
|
+
return list(INDEX_COLUMNS) + self.columns
|
|
64
|
+
|
|
65
|
+
@cached_property
|
|
66
|
+
def dtypes(self) -> dict[str, np.dtype]:
|
|
67
|
+
"""Column data types (attributes only)."""
|
|
68
|
+
schema = self._schema
|
|
69
|
+
return {schema.attr(i).name: schema.attr(i).dtype for i in range(schema.nattr)}
|
|
70
|
+
|
|
71
|
+
def __len__(self) -> int:
|
|
72
|
+
with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
|
|
73
|
+
result = arr.query(attrs=[])[:][INDEX_COLUMNS[0]]
|
|
74
|
+
return len(result)
|
|
75
|
+
|
|
76
|
+
def __repr__(self) -> str:
|
|
77
|
+
return f"Dataframe(uri={self.uri!r}, shape={self.shape}, columns={self.columns})"
|
|
78
|
+
|
|
79
|
+
def read(
|
|
80
|
+
self,
|
|
81
|
+
columns: list[str] | None = None,
|
|
82
|
+
value_filter: str | None = None,
|
|
83
|
+
include_index: bool = True,
|
|
84
|
+
) -> pd.DataFrame:
|
|
85
|
+
"""Read data with optional column selection and value filtering."""
|
|
86
|
+
# Filter out index columns from requested columns (they're dimensions, not attributes)
|
|
87
|
+
if columns is not None:
|
|
88
|
+
attrs = [c for c in columns if c not in INDEX_COLUMNS]
|
|
89
|
+
else:
|
|
90
|
+
attrs = self.columns
|
|
91
|
+
with tiledb.open(self.uri, "r", ctx=self._effective_ctx()) as arr:
|
|
92
|
+
if value_filter is not None:
|
|
93
|
+
result = arr.query(cond=value_filter, attrs=attrs)[:]
|
|
94
|
+
else:
|
|
95
|
+
result = arr.query(attrs=attrs)[:]
|
|
96
|
+
data = {col: result[col] for col in attrs}
|
|
97
|
+
if include_index:
|
|
98
|
+
for idx_col in INDEX_COLUMNS:
|
|
99
|
+
# Convert bytes to strings for index columns
|
|
100
|
+
raw = result[idx_col]
|
|
101
|
+
data[idx_col] = np.array(
|
|
102
|
+
[v.decode() if isinstance(v, bytes) else str(v) for v in raw]
|
|
103
|
+
)
|
|
104
|
+
df = pd.DataFrame(data)
|
|
105
|
+
if include_index:
|
|
106
|
+
col_order = list(INDEX_COLUMNS) + attrs
|
|
107
|
+
df = df[col_order]
|
|
108
|
+
return df
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _validate_schema(schema: dict[str, np.dtype]) -> None:
|
|
112
|
+
"""Validate column names and types (for non-index attributes)."""
|
|
113
|
+
for name, dtype in schema.items():
|
|
114
|
+
if "\x00" in name:
|
|
115
|
+
raise ValueError(f"Column name contains null byte: {name!r}")
|
|
116
|
+
if name in INDEX_COLUMNS:
|
|
117
|
+
raise ValueError(f"Column name conflicts with index column: {name!r}")
|
|
118
|
+
if not isinstance(dtype, np.dtype):
|
|
119
|
+
try:
|
|
120
|
+
np.dtype(dtype)
|
|
121
|
+
except TypeError as e:
|
|
122
|
+
raise TypeError(f"Invalid dtype for column {name!r}: {dtype}") from e
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def create(
|
|
126
|
+
cls,
|
|
127
|
+
uri: str,
|
|
128
|
+
schema: dict[str, np.dtype],
|
|
129
|
+
ctx: tiledb.Ctx | None = None,
|
|
130
|
+
) -> Dataframe:
|
|
131
|
+
"""Create an empty sparse Dataframe indexed by obs_subject_id and obs_id."""
|
|
132
|
+
cls._validate_schema(schema)
|
|
133
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
134
|
+
|
|
135
|
+
dims = [
|
|
136
|
+
tiledb.Dim(name=INDEX_COLUMNS[0], dtype="ascii", ctx=effective_ctx),
|
|
137
|
+
tiledb.Dim(name=INDEX_COLUMNS[1], dtype="ascii", ctx=effective_ctx),
|
|
138
|
+
]
|
|
139
|
+
domain = tiledb.Domain(*dims, ctx=effective_ctx)
|
|
140
|
+
|
|
141
|
+
config = get_config()
|
|
142
|
+
compression_filter = config.compression.as_filter()
|
|
143
|
+
compression = (
|
|
144
|
+
tiledb.FilterList([compression_filter]) if compression_filter else tiledb.FilterList()
|
|
145
|
+
)
|
|
146
|
+
attrs = [
|
|
147
|
+
tiledb.Attr(name=name, dtype=dtype, filters=compression, ctx=effective_ctx)
|
|
148
|
+
for name, dtype in schema.items()
|
|
149
|
+
]
|
|
150
|
+
|
|
151
|
+
array_schema = tiledb.ArraySchema(
|
|
152
|
+
domain=domain,
|
|
153
|
+
attrs=attrs,
|
|
154
|
+
sparse=True,
|
|
155
|
+
ctx=effective_ctx,
|
|
156
|
+
)
|
|
157
|
+
tiledb.Array.create(uri, array_schema, ctx=effective_ctx)
|
|
158
|
+
|
|
159
|
+
return cls(uri, ctx=ctx)
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
def from_pandas(
|
|
163
|
+
cls,
|
|
164
|
+
uri: str,
|
|
165
|
+
df: pd.DataFrame,
|
|
166
|
+
ctx: tiledb.Ctx | None = None,
|
|
167
|
+
) -> Dataframe:
|
|
168
|
+
"""Create a new Dataframe from a pandas DataFrame with obs_subject_id and obs_id columns."""
|
|
169
|
+
for idx_col in INDEX_COLUMNS:
|
|
170
|
+
if idx_col not in df.columns:
|
|
171
|
+
raise ValueError(f"DataFrame must contain index column: {idx_col!r}")
|
|
172
|
+
|
|
173
|
+
attr_cols = [col for col in df.columns if col not in INDEX_COLUMNS]
|
|
174
|
+
schema = {col: df[col].to_numpy().dtype for col in attr_cols}
|
|
175
|
+
dataframe = cls.create(uri, schema=schema, ctx=ctx)
|
|
176
|
+
|
|
177
|
+
effective_ctx = ctx if ctx else global_ctx()
|
|
178
|
+
with tiledb.open(uri, mode="w", ctx=effective_ctx) as arr:
|
|
179
|
+
coords = {
|
|
180
|
+
INDEX_COLUMNS[0]: df[INDEX_COLUMNS[0]].astype(str).to_numpy(),
|
|
181
|
+
INDEX_COLUMNS[1]: df[INDEX_COLUMNS[1]].astype(str).to_numpy(),
|
|
182
|
+
}
|
|
183
|
+
data = {col: df[col].to_numpy() for col in attr_cols}
|
|
184
|
+
arr[coords[INDEX_COLUMNS[0]], coords[INDEX_COLUMNS[1]]] = data
|
|
185
|
+
|
|
186
|
+
return dataframe
|