radiobject 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- radiobject/__init__.py +24 -0
- radiobject/_types.py +19 -0
- radiobject/ctx.py +359 -0
- radiobject/dataframe.py +186 -0
- radiobject/imaging_metadata.py +387 -0
- radiobject/indexing.py +45 -0
- radiobject/ingest.py +132 -0
- radiobject/ml/__init__.py +26 -0
- radiobject/ml/cache.py +53 -0
- radiobject/ml/compat/__init__.py +33 -0
- radiobject/ml/compat/torchio.py +99 -0
- radiobject/ml/config.py +42 -0
- radiobject/ml/datasets/__init__.py +12 -0
- radiobject/ml/datasets/collection_dataset.py +198 -0
- radiobject/ml/datasets/multimodal.py +129 -0
- radiobject/ml/datasets/patch_dataset.py +158 -0
- radiobject/ml/datasets/segmentation_dataset.py +219 -0
- radiobject/ml/datasets/volume_dataset.py +233 -0
- radiobject/ml/distributed.py +82 -0
- radiobject/ml/factory.py +249 -0
- radiobject/ml/utils/__init__.py +13 -0
- radiobject/ml/utils/labels.py +106 -0
- radiobject/ml/utils/validation.py +85 -0
- radiobject/ml/utils/worker_init.py +10 -0
- radiobject/orientation.py +270 -0
- radiobject/parallel.py +65 -0
- radiobject/py.typed +0 -0
- radiobject/query.py +788 -0
- radiobject/radi_object.py +1665 -0
- radiobject/streaming.py +389 -0
- radiobject/utils.py +17 -0
- radiobject/volume.py +438 -0
- radiobject/volume_collection.py +1182 -0
- radiobject-0.1.0.dist-info/METADATA +139 -0
- radiobject-0.1.0.dist-info/RECORD +37 -0
- radiobject-0.1.0.dist-info/WHEEL +4 -0
- radiobject-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
"""Imaging metadata extraction for NIfTI and DICOM files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import nibabel as nib
|
|
9
|
+
import numpy as np
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
from radiobject.utils import affine_to_json
|
|
13
|
+
|
|
14
|
+
# BIDS-aligned series type identifiers
|
|
15
|
+
KNOWN_SERIES_TYPES: frozenset[str] = frozenset(
|
|
16
|
+
{
|
|
17
|
+
# Anatomical MRI
|
|
18
|
+
"T1w",
|
|
19
|
+
"T2w",
|
|
20
|
+
"T1rho",
|
|
21
|
+
"T1map",
|
|
22
|
+
"T2map",
|
|
23
|
+
"T2star",
|
|
24
|
+
"FLAIR",
|
|
25
|
+
"FLASH",
|
|
26
|
+
"PD",
|
|
27
|
+
"PDmap",
|
|
28
|
+
"PDT2",
|
|
29
|
+
"inplaneT1",
|
|
30
|
+
"inplaneT2",
|
|
31
|
+
"angio",
|
|
32
|
+
"T1gd",
|
|
33
|
+
# Functional MRI
|
|
34
|
+
"bold",
|
|
35
|
+
"cbv",
|
|
36
|
+
"phase",
|
|
37
|
+
# Diffusion MRI
|
|
38
|
+
"dwi",
|
|
39
|
+
# CT variants
|
|
40
|
+
"CT",
|
|
41
|
+
"HRCT",
|
|
42
|
+
"CTA",
|
|
43
|
+
"CTPA",
|
|
44
|
+
# Field maps
|
|
45
|
+
"phasediff",
|
|
46
|
+
"magnitude",
|
|
47
|
+
"fieldmap",
|
|
48
|
+
"epi",
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Filename pattern to series type mapping (ordered from most specific to least specific)
|
|
53
|
+
_FILENAME_PATTERNS: tuple[tuple[str, str], ...] = (
|
|
54
|
+
# Contrast-enhanced T1 (check before T1)
|
|
55
|
+
("T1GD", "T1gd"),
|
|
56
|
+
("T1CE", "T1gd"),
|
|
57
|
+
("T1C", "T1gd"),
|
|
58
|
+
# Standard patterns
|
|
59
|
+
("T1W", "T1w"),
|
|
60
|
+
("T1", "T1w"),
|
|
61
|
+
("T2W", "T2w"),
|
|
62
|
+
("T2", "T2w"),
|
|
63
|
+
("FLAIR", "FLAIR"),
|
|
64
|
+
("DWI", "dwi"),
|
|
65
|
+
("DTI", "dwi"),
|
|
66
|
+
("BOLD", "bold"),
|
|
67
|
+
("FUNC", "bold"),
|
|
68
|
+
# CT patterns (MSD datasets, common naming)
|
|
69
|
+
("LUNG_", "CT"),
|
|
70
|
+
("LIVER_", "CT"),
|
|
71
|
+
("COLON_", "CT"),
|
|
72
|
+
("PANCREAS_", "CT"),
|
|
73
|
+
("SPLEEN_", "CT"),
|
|
74
|
+
("HEPATIC", "CT"),
|
|
75
|
+
("_CT_", "CT"),
|
|
76
|
+
("_CT.", "CT"),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Spatial unit mapping from NIfTI xyzt_units
|
|
80
|
+
_SPATIAL_UNIT_MAP: dict[int, str] = {
|
|
81
|
+
0: "unknown",
|
|
82
|
+
1: "m",
|
|
83
|
+
2: "mm",
|
|
84
|
+
3: "um",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def infer_series_type(path: Path, header: nib.Nifti1Header | None = None) -> str:
|
|
89
|
+
"""Infer series type from filename patterns and header.
|
|
90
|
+
|
|
91
|
+
Priority:
|
|
92
|
+
1. BIDS-style suffix: sub-01_ses-01_T1w.nii.gz -> "T1w"
|
|
93
|
+
2. Common patterns: T1_MPRAGE.nii.gz -> "T1w"
|
|
94
|
+
3. Header description field
|
|
95
|
+
4. Fallback: "unknown"
|
|
96
|
+
"""
|
|
97
|
+
filename = path.stem
|
|
98
|
+
if filename.endswith(".nii"):
|
|
99
|
+
filename = filename[:-4]
|
|
100
|
+
|
|
101
|
+
# Check for BIDS suffix (last underscore-separated part)
|
|
102
|
+
parts = filename.split("_")
|
|
103
|
+
if parts:
|
|
104
|
+
suffix = parts[-1]
|
|
105
|
+
for known in KNOWN_SERIES_TYPES:
|
|
106
|
+
if known.lower() == suffix.lower():
|
|
107
|
+
return known
|
|
108
|
+
|
|
109
|
+
# Check common patterns in full filename (ordered from most specific to least)
|
|
110
|
+
filename_upper = filename.upper()
|
|
111
|
+
for pattern, series_type in _FILENAME_PATTERNS:
|
|
112
|
+
if pattern in filename_upper:
|
|
113
|
+
return series_type
|
|
114
|
+
|
|
115
|
+
# Check header description if available
|
|
116
|
+
if header is not None:
|
|
117
|
+
descrip_raw = header.get("descrip", b"")
|
|
118
|
+
if isinstance(descrip_raw, bytes):
|
|
119
|
+
descrip = descrip_raw.decode("utf-8", errors="ignore")
|
|
120
|
+
else:
|
|
121
|
+
descrip = str(descrip_raw)
|
|
122
|
+
descrip_lower = descrip.lower()
|
|
123
|
+
for known in KNOWN_SERIES_TYPES:
|
|
124
|
+
if known.lower() in descrip_lower:
|
|
125
|
+
return known
|
|
126
|
+
|
|
127
|
+
return "unknown"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _get_spatial_units(xyzt_units: int) -> str:
|
|
131
|
+
"""Extract spatial units from NIfTI xyzt_units field."""
|
|
132
|
+
spatial_code = xyzt_units & 0x07 # Lower 3 bits
|
|
133
|
+
return _SPATIAL_UNIT_MAP.get(spatial_code, "unknown")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class NiftiMetadata(BaseModel):
|
|
137
|
+
"""NIfTI header metadata for obs DataFrame."""
|
|
138
|
+
|
|
139
|
+
# Voxel spacing (from pixdim[1:4]) as (x, y, z) tuple
|
|
140
|
+
voxel_spacing: tuple[float, float, float]
|
|
141
|
+
|
|
142
|
+
# Original dimensions as (x, y, z) tuple
|
|
143
|
+
dimensions: tuple[int, int, int]
|
|
144
|
+
|
|
145
|
+
# Data type
|
|
146
|
+
datatype: int
|
|
147
|
+
bitpix: int
|
|
148
|
+
|
|
149
|
+
# Scaling
|
|
150
|
+
scl_slope: float
|
|
151
|
+
scl_inter: float
|
|
152
|
+
|
|
153
|
+
# Units
|
|
154
|
+
xyzt_units: int
|
|
155
|
+
spatial_units: str # "mm", "um", "m", or "unknown"
|
|
156
|
+
|
|
157
|
+
# Coordinate system codes
|
|
158
|
+
qform_code: int
|
|
159
|
+
sform_code: int
|
|
160
|
+
|
|
161
|
+
# Orientation
|
|
162
|
+
axcodes: str # e.g., "RAS"
|
|
163
|
+
affine_json: str # 4x4 matrix as JSON
|
|
164
|
+
orientation_source: Literal["nifti_sform", "nifti_qform", "identity"]
|
|
165
|
+
|
|
166
|
+
# Provenance
|
|
167
|
+
source_path: str
|
|
168
|
+
|
|
169
|
+
model_config = {"frozen": True}
|
|
170
|
+
|
|
171
|
+
def to_obs_dict(self, obs_id: str, obs_subject_id: str, series_type: str) -> dict:
|
|
172
|
+
"""Convert to dictionary for obs DataFrame row."""
|
|
173
|
+
data = self.model_dump()
|
|
174
|
+
# Serialize tuples as strings for TileDB storage
|
|
175
|
+
data["voxel_spacing"] = str(data["voxel_spacing"])
|
|
176
|
+
data["dimensions"] = str(data["dimensions"])
|
|
177
|
+
data.update(obs_id=obs_id, obs_subject_id=obs_subject_id, series_type=series_type)
|
|
178
|
+
return data
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class DicomMetadata(BaseModel):
|
|
182
|
+
"""DICOM header metadata for obs DataFrame."""
|
|
183
|
+
|
|
184
|
+
# Voxel spacing as (x, y, z) tuple - z is slice_thickness
|
|
185
|
+
voxel_spacing: tuple[float, float, float]
|
|
186
|
+
|
|
187
|
+
# Dimensions as (rows, columns, n_slices) tuple
|
|
188
|
+
dimensions: tuple[int, int, int]
|
|
189
|
+
|
|
190
|
+
# Patient/Study info (anonymized identifiers only)
|
|
191
|
+
modality: str # CT, MR, PT, etc.
|
|
192
|
+
series_description: str
|
|
193
|
+
|
|
194
|
+
# Acquisition parameters (None if not applicable)
|
|
195
|
+
kvp: float | None = Field(default=None) # CT tube voltage
|
|
196
|
+
exposure: float | None = Field(default=None) # CT exposure (mAs)
|
|
197
|
+
repetition_time: float | None = Field(default=None) # MRI TR
|
|
198
|
+
echo_time: float | None = Field(default=None) # MRI TE
|
|
199
|
+
magnetic_field_strength: float | None = Field(default=None) # MRI field strength
|
|
200
|
+
|
|
201
|
+
# Orientation
|
|
202
|
+
axcodes: str
|
|
203
|
+
affine_json: str
|
|
204
|
+
orientation_source: Literal["dicom_iop", "identity"]
|
|
205
|
+
|
|
206
|
+
# Provenance
|
|
207
|
+
source_path: str
|
|
208
|
+
|
|
209
|
+
model_config = {"frozen": True}
|
|
210
|
+
|
|
211
|
+
def to_obs_dict(self, obs_id: str, obs_subject_id: str) -> dict:
|
|
212
|
+
"""Convert to dictionary for obs DataFrame row."""
|
|
213
|
+
data = self.model_dump()
|
|
214
|
+
# Serialize tuples as strings for TileDB storage
|
|
215
|
+
data["voxel_spacing"] = str(data["voxel_spacing"])
|
|
216
|
+
data["dimensions"] = str(data["dimensions"])
|
|
217
|
+
data.update(obs_id=obs_id, obs_subject_id=obs_subject_id)
|
|
218
|
+
return data
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def extract_nifti_metadata(nifti_path: str | Path) -> NiftiMetadata:
|
|
222
|
+
"""Extract comprehensive metadata from NIfTI header."""
|
|
223
|
+
path = Path(nifti_path)
|
|
224
|
+
if not path.exists():
|
|
225
|
+
raise FileNotFoundError(f"NIfTI file not found: {path}")
|
|
226
|
+
|
|
227
|
+
img = nib.load(path)
|
|
228
|
+
header = img.header
|
|
229
|
+
|
|
230
|
+
# Extract dimensions
|
|
231
|
+
dim = header.get("dim")
|
|
232
|
+
dim_x = int(dim[1]) if len(dim) > 1 else 0
|
|
233
|
+
dim_y = int(dim[2]) if len(dim) > 2 else 0
|
|
234
|
+
dim_z = int(dim[3]) if len(dim) > 3 else 0
|
|
235
|
+
|
|
236
|
+
# Extract voxel spacing
|
|
237
|
+
pixdim = header.get("pixdim")
|
|
238
|
+
voxel_spacing_x = float(pixdim[1]) if len(pixdim) > 1 else 1.0
|
|
239
|
+
voxel_spacing_y = float(pixdim[2]) if len(pixdim) > 2 else 1.0
|
|
240
|
+
voxel_spacing_z = float(pixdim[3]) if len(pixdim) > 3 else 1.0
|
|
241
|
+
|
|
242
|
+
# Data type info
|
|
243
|
+
datatype = int(header.get("datatype", 0))
|
|
244
|
+
bitpix = int(header.get("bitpix", 0))
|
|
245
|
+
|
|
246
|
+
# Scaling
|
|
247
|
+
scl_slope = float(header.get("scl_slope", 1.0))
|
|
248
|
+
scl_inter = float(header.get("scl_inter", 0.0))
|
|
249
|
+
# Handle NaN slope (nibabel returns nan for 0)
|
|
250
|
+
if np.isnan(scl_slope):
|
|
251
|
+
scl_slope = 1.0
|
|
252
|
+
if np.isnan(scl_inter):
|
|
253
|
+
scl_inter = 0.0
|
|
254
|
+
|
|
255
|
+
# Units
|
|
256
|
+
xyzt_units = int(header.get("xyzt_units", 0))
|
|
257
|
+
spatial_units = _get_spatial_units(xyzt_units)
|
|
258
|
+
|
|
259
|
+
# Coordinate system codes
|
|
260
|
+
sform_code = int(header.get("sform_code", 0))
|
|
261
|
+
qform_code = int(header.get("qform_code", 0))
|
|
262
|
+
|
|
263
|
+
# Determine orientation source and get affine
|
|
264
|
+
if sform_code > 0:
|
|
265
|
+
affine = img.get_sform()
|
|
266
|
+
orientation_source: Literal["nifti_sform", "nifti_qform", "identity"] = "nifti_sform"
|
|
267
|
+
elif qform_code > 0:
|
|
268
|
+
affine = img.get_qform()
|
|
269
|
+
orientation_source = "nifti_qform"
|
|
270
|
+
else:
|
|
271
|
+
affine = img.affine
|
|
272
|
+
orientation_source = "identity"
|
|
273
|
+
|
|
274
|
+
# Get axis codes
|
|
275
|
+
ornt = nib.orientations.io_orientation(affine)
|
|
276
|
+
axcodes = "".join(nib.orientations.ornt2axcodes(ornt))
|
|
277
|
+
|
|
278
|
+
return NiftiMetadata(
|
|
279
|
+
voxel_spacing=(voxel_spacing_x, voxel_spacing_y, voxel_spacing_z),
|
|
280
|
+
dimensions=(dim_x, dim_y, dim_z),
|
|
281
|
+
datatype=datatype,
|
|
282
|
+
bitpix=bitpix,
|
|
283
|
+
scl_slope=scl_slope,
|
|
284
|
+
scl_inter=scl_inter,
|
|
285
|
+
xyzt_units=xyzt_units,
|
|
286
|
+
spatial_units=spatial_units,
|
|
287
|
+
qform_code=qform_code,
|
|
288
|
+
sform_code=sform_code,
|
|
289
|
+
axcodes=axcodes,
|
|
290
|
+
affine_json=affine_to_json(affine),
|
|
291
|
+
orientation_source=orientation_source,
|
|
292
|
+
source_path=str(path.absolute()),
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def extract_dicom_metadata(dicom_dir: str | Path) -> DicomMetadata:
|
|
297
|
+
"""Extract comprehensive metadata from DICOM series."""
|
|
298
|
+
import pydicom
|
|
299
|
+
|
|
300
|
+
path = Path(dicom_dir)
|
|
301
|
+
if not path.exists():
|
|
302
|
+
raise FileNotFoundError(f"DICOM directory not found: {path}")
|
|
303
|
+
|
|
304
|
+
# Find DICOM files
|
|
305
|
+
dicom_files = sorted(path.glob("*.dcm"))
|
|
306
|
+
if not dicom_files:
|
|
307
|
+
dicom_files = sorted(
|
|
308
|
+
f for f in path.iterdir() if f.is_file() and not f.name.startswith(".")
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not dicom_files:
|
|
312
|
+
raise ValueError(f"No DICOM files found in {path}")
|
|
313
|
+
|
|
314
|
+
# Read first DICOM for most metadata
|
|
315
|
+
ds = pydicom.dcmread(dicom_files[0])
|
|
316
|
+
|
|
317
|
+
# Pixel spacing
|
|
318
|
+
pixel_spacing = getattr(ds, "PixelSpacing", [1.0, 1.0])
|
|
319
|
+
pixel_spacing_x = float(pixel_spacing[0])
|
|
320
|
+
pixel_spacing_y = float(pixel_spacing[1])
|
|
321
|
+
slice_thickness = float(getattr(ds, "SliceThickness", 1.0))
|
|
322
|
+
|
|
323
|
+
# Dimensions
|
|
324
|
+
rows = int(getattr(ds, "Rows", 0))
|
|
325
|
+
columns = int(getattr(ds, "Columns", 0))
|
|
326
|
+
n_slices = len(dicom_files)
|
|
327
|
+
|
|
328
|
+
# Modality and description
|
|
329
|
+
modality = str(getattr(ds, "Modality", "unknown"))
|
|
330
|
+
series_description = str(getattr(ds, "SeriesDescription", ""))
|
|
331
|
+
|
|
332
|
+
# Acquisition parameters (modality-specific)
|
|
333
|
+
kvp = float(ds.KVP) if hasattr(ds, "KVP") else None
|
|
334
|
+
exposure = float(ds.Exposure) if hasattr(ds, "Exposure") else None
|
|
335
|
+
repetition_time = float(ds.RepetitionTime) if hasattr(ds, "RepetitionTime") else None
|
|
336
|
+
echo_time = float(ds.EchoTime) if hasattr(ds, "EchoTime") else None
|
|
337
|
+
magnetic_field_strength = (
|
|
338
|
+
float(ds.MagneticFieldStrength) if hasattr(ds, "MagneticFieldStrength") else None
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Orientation
|
|
342
|
+
iop = getattr(ds, "ImageOrientationPatient", None)
|
|
343
|
+
ipp = getattr(ds, "ImagePositionPatient", None)
|
|
344
|
+
|
|
345
|
+
if iop is not None:
|
|
346
|
+
# Build affine from DICOM tags
|
|
347
|
+
row_cosines = np.array([float(iop[0]), float(iop[1]), float(iop[2])])
|
|
348
|
+
col_cosines = np.array([float(iop[3]), float(iop[4]), float(iop[5])])
|
|
349
|
+
slice_cosines = np.cross(row_cosines, col_cosines)
|
|
350
|
+
|
|
351
|
+
voxel_spacing = [pixel_spacing_y, pixel_spacing_x, slice_thickness]
|
|
352
|
+
|
|
353
|
+
affine = np.eye(4)
|
|
354
|
+
affine[:3, 0] = row_cosines * voxel_spacing[0]
|
|
355
|
+
affine[:3, 1] = col_cosines * voxel_spacing[1]
|
|
356
|
+
affine[:3, 2] = slice_cosines * voxel_spacing[2]
|
|
357
|
+
|
|
358
|
+
if ipp is not None:
|
|
359
|
+
affine[:3, 3] = [float(ipp[0]), float(ipp[1]), float(ipp[2])]
|
|
360
|
+
|
|
361
|
+
# Convert DICOM LPS to RAS
|
|
362
|
+
lps_to_ras = np.diag([-1, -1, 1, 1])
|
|
363
|
+
affine_ras = lps_to_ras @ affine
|
|
364
|
+
|
|
365
|
+
ornt = nib.orientations.io_orientation(affine_ras)
|
|
366
|
+
axcodes = "".join(nib.orientations.ornt2axcodes(ornt))
|
|
367
|
+
orientation_source: Literal["dicom_iop", "identity"] = "dicom_iop"
|
|
368
|
+
else:
|
|
369
|
+
affine_ras = np.eye(4)
|
|
370
|
+
axcodes = "RAS"
|
|
371
|
+
orientation_source = "identity"
|
|
372
|
+
|
|
373
|
+
return DicomMetadata(
|
|
374
|
+
voxel_spacing=(pixel_spacing_x, pixel_spacing_y, slice_thickness),
|
|
375
|
+
dimensions=(rows, columns, n_slices),
|
|
376
|
+
modality=modality,
|
|
377
|
+
series_description=series_description,
|
|
378
|
+
kvp=kvp,
|
|
379
|
+
exposure=exposure,
|
|
380
|
+
repetition_time=repetition_time,
|
|
381
|
+
echo_time=echo_time,
|
|
382
|
+
magnetic_field_strength=magnetic_field_strength,
|
|
383
|
+
axcodes=axcodes,
|
|
384
|
+
affine_json=affine_to_json(affine_ras),
|
|
385
|
+
orientation_source=orientation_source,
|
|
386
|
+
source_path=str(path.absolute()),
|
|
387
|
+
)
|
radiobject/indexing.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Shared indexing utilities for RadiObject entities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class Index:
|
|
11
|
+
"""Bidirectional index mapping string keys to integer positions."""
|
|
12
|
+
|
|
13
|
+
keys: tuple[str, ...]
|
|
14
|
+
key_to_idx: dict[str, int]
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def build(cls, keys: list[str]) -> Index:
|
|
18
|
+
"""Build an index from a list of string keys."""
|
|
19
|
+
if len(keys) != len(set(keys)):
|
|
20
|
+
duplicates = [k for k, count in Counter(keys).items() if count > 1]
|
|
21
|
+
raise ValueError(f"Duplicate keys detected: {duplicates[:5]}")
|
|
22
|
+
return cls(
|
|
23
|
+
keys=tuple(keys),
|
|
24
|
+
key_to_idx={key: idx for idx, key in enumerate(keys)},
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def __len__(self) -> int:
|
|
28
|
+
return len(self.keys)
|
|
29
|
+
|
|
30
|
+
def __contains__(self, key: str) -> bool:
|
|
31
|
+
return key in self.key_to_idx
|
|
32
|
+
|
|
33
|
+
def get_index(self, key: str) -> int:
|
|
34
|
+
"""Get integer index for a key. Raises KeyError if not found."""
|
|
35
|
+
idx = self.key_to_idx.get(key)
|
|
36
|
+
if idx is None:
|
|
37
|
+
raise KeyError(f"Key '{key}' not found in index")
|
|
38
|
+
return idx
|
|
39
|
+
|
|
40
|
+
def get_key(self, idx: int) -> str:
|
|
41
|
+
"""Get key at integer index. Raises IndexError if out of bounds."""
|
|
42
|
+
n = len(self.keys)
|
|
43
|
+
if idx < 0 or idx >= n:
|
|
44
|
+
raise IndexError(f"Index {idx} out of range [0, {n})")
|
|
45
|
+
return self.keys[idx]
|
radiobject/ingest.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""NIfTI discovery utilities for bulk ingestion."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from glob import glob
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class NiftiSource:
|
|
13
|
+
"""Source NIfTI file with optional paired label."""
|
|
14
|
+
|
|
15
|
+
image_path: Path
|
|
16
|
+
subject_id: str
|
|
17
|
+
label_path: Path | None = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def has_label(self) -> bool:
|
|
21
|
+
return self.label_path is not None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def discover_nifti_pairs(
|
|
25
|
+
image_dir: str | Path,
|
|
26
|
+
label_dir: str | Path | None = None,
|
|
27
|
+
pattern: str = "*.nii.gz",
|
|
28
|
+
subject_id_fn: Callable[[Path], str] | None = None,
|
|
29
|
+
) -> list[NiftiSource]:
|
|
30
|
+
"""Discover NIfTI files and optionally pair with labels.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
image_dir: Directory containing image NIfTIs
|
|
34
|
+
label_dir: Optional directory containing label NIfTIs (matched by filename)
|
|
35
|
+
pattern: Glob pattern for finding NIfTI files
|
|
36
|
+
subject_id_fn: Function to extract subject ID from path.
|
|
37
|
+
Default: stem without .nii extension
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
List of NiftiSource objects
|
|
41
|
+
"""
|
|
42
|
+
image_dir = Path(image_dir)
|
|
43
|
+
if not image_dir.exists():
|
|
44
|
+
raise FileNotFoundError(f"Image directory not found: {image_dir}")
|
|
45
|
+
|
|
46
|
+
if subject_id_fn is None:
|
|
47
|
+
|
|
48
|
+
def subject_id_fn(p: Path) -> str:
|
|
49
|
+
name = p.stem
|
|
50
|
+
if name.endswith(".nii"):
|
|
51
|
+
name = name[:-4]
|
|
52
|
+
return name
|
|
53
|
+
|
|
54
|
+
# Find all image files
|
|
55
|
+
image_files = sorted(image_dir.glob(pattern))
|
|
56
|
+
if not image_files:
|
|
57
|
+
# Try non-gzipped pattern
|
|
58
|
+
alt_pattern = pattern.replace(".gz", "")
|
|
59
|
+
image_files = sorted(image_dir.glob(alt_pattern))
|
|
60
|
+
|
|
61
|
+
if not image_files:
|
|
62
|
+
raise ValueError(f"No NIfTI files found in {image_dir} with pattern {pattern}")
|
|
63
|
+
|
|
64
|
+
# Build label lookup if label_dir provided
|
|
65
|
+
label_lookup: dict[str, Path] = {}
|
|
66
|
+
if label_dir is not None:
|
|
67
|
+
label_dir = Path(label_dir)
|
|
68
|
+
if not label_dir.exists():
|
|
69
|
+
raise FileNotFoundError(f"Label directory not found: {label_dir}")
|
|
70
|
+
|
|
71
|
+
for label_file in label_dir.glob(pattern):
|
|
72
|
+
label_id = subject_id_fn(label_file)
|
|
73
|
+
label_lookup[label_id] = label_file
|
|
74
|
+
|
|
75
|
+
# Also try non-gzipped
|
|
76
|
+
alt_pattern = pattern.replace(".gz", "")
|
|
77
|
+
for label_file in label_dir.glob(alt_pattern):
|
|
78
|
+
label_id = subject_id_fn(label_file)
|
|
79
|
+
if label_id not in label_lookup:
|
|
80
|
+
label_lookup[label_id] = label_file
|
|
81
|
+
|
|
82
|
+
# Create NiftiSource objects
|
|
83
|
+
sources = []
|
|
84
|
+
for image_path in image_files:
|
|
85
|
+
subject_id = subject_id_fn(image_path)
|
|
86
|
+
label_path = label_lookup.get(subject_id)
|
|
87
|
+
sources.append(
|
|
88
|
+
NiftiSource(
|
|
89
|
+
image_path=image_path,
|
|
90
|
+
subject_id=subject_id,
|
|
91
|
+
label_path=label_path,
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return sources
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def resolve_nifti_source(
|
|
99
|
+
source: str | Path | Sequence[tuple[str | Path, str]],
|
|
100
|
+
subject_id_fn: Callable[[Path], str] | None = None,
|
|
101
|
+
) -> list[tuple[Path, str]]:
|
|
102
|
+
"""Resolve various NIfTI source formats to (path, subject_id) tuples.
|
|
103
|
+
|
|
104
|
+
Supports:
|
|
105
|
+
- Glob pattern: "./imagesTr/*.nii.gz"
|
|
106
|
+
- Directory path: "./imagesTr"
|
|
107
|
+
- Pre-resolved list: [(path, subject_id), ...]
|
|
108
|
+
"""
|
|
109
|
+
# Already resolved - return as-is
|
|
110
|
+
if isinstance(source, (list, tuple)) and source and isinstance(source[0], tuple):
|
|
111
|
+
return [(Path(p), sid) for p, sid in source]
|
|
112
|
+
|
|
113
|
+
source_str = str(source)
|
|
114
|
+
|
|
115
|
+
if subject_id_fn is None:
|
|
116
|
+
|
|
117
|
+
def subject_id_fn(p: Path) -> str:
|
|
118
|
+
name = p.stem
|
|
119
|
+
if name.endswith(".nii"):
|
|
120
|
+
name = name[:-4]
|
|
121
|
+
return name
|
|
122
|
+
|
|
123
|
+
# Glob pattern
|
|
124
|
+
if any(c in source_str for c in "*?["):
|
|
125
|
+
matched = sorted(glob(source_str, recursive=True))
|
|
126
|
+
if not matched:
|
|
127
|
+
raise ValueError(f"No files matched pattern: {source}")
|
|
128
|
+
return [(Path(f), subject_id_fn(Path(f))) for f in matched]
|
|
129
|
+
|
|
130
|
+
# Directory path
|
|
131
|
+
sources = discover_nifti_pairs(source)
|
|
132
|
+
return [(s.image_path, s.subject_id) for s in sources]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""PyTorch training system for RadiObject."""
|
|
2
|
+
|
|
3
|
+
from radiobject.ml.compat import Compose, VolumeCollectionSubjectsDataset
|
|
4
|
+
from radiobject.ml.config import DatasetConfig, LoadingMode
|
|
5
|
+
from radiobject.ml.datasets import SegmentationDataset, VolumeCollectionDataset
|
|
6
|
+
from radiobject.ml.factory import (
|
|
7
|
+
create_inference_dataloader,
|
|
8
|
+
create_segmentation_dataloader,
|
|
9
|
+
create_training_dataloader,
|
|
10
|
+
create_validation_dataloader,
|
|
11
|
+
)
|
|
12
|
+
from radiobject.ml.utils import LabelSource
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"Compose",
|
|
16
|
+
"DatasetConfig",
|
|
17
|
+
"LabelSource",
|
|
18
|
+
"LoadingMode",
|
|
19
|
+
"SegmentationDataset",
|
|
20
|
+
"VolumeCollectionDataset",
|
|
21
|
+
"VolumeCollectionSubjectsDataset",
|
|
22
|
+
"create_inference_dataloader",
|
|
23
|
+
"create_segmentation_dataloader",
|
|
24
|
+
"create_training_dataloader",
|
|
25
|
+
"create_validation_dataloader",
|
|
26
|
+
]
|
radiobject/ml/cache.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Caching utilities for ML datasets."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseCache(ABC):
|
|
10
|
+
"""Abstract base class for sample caching."""
|
|
11
|
+
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
self._hits = 0
|
|
14
|
+
self._misses = 0
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def hits(self) -> int:
|
|
18
|
+
"""Number of cache hits."""
|
|
19
|
+
return self._hits
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def misses(self) -> int:
|
|
23
|
+
"""Number of cache misses."""
|
|
24
|
+
return self._misses
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def get(self, key: int) -> Any | None:
|
|
28
|
+
"""Get cached sample by key, or None if not cached."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def set(self, key: int, value: Any) -> None:
|
|
33
|
+
"""Cache a sample."""
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def clear(self) -> None:
|
|
38
|
+
"""Clear all cached samples."""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class NoOpCache(BaseCache):
|
|
43
|
+
"""Cache that doesn't cache (passthrough)."""
|
|
44
|
+
|
|
45
|
+
def get(self, key: int) -> Any | None:
|
|
46
|
+
self._misses += 1
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
def set(self, key: int, value: Any) -> None:
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
def clear(self) -> None:
|
|
53
|
+
pass
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""MONAI/TorchIO compatibility module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from radiobject.ml.compat.torchio import VolumeCollectionSubjectsDataset
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from monai.transforms import Compose
|
|
12
|
+
except ImportError:
|
|
13
|
+
try:
|
|
14
|
+
from torchio import Compose
|
|
15
|
+
except ImportError:
|
|
16
|
+
|
|
17
|
+
class Compose:
|
|
18
|
+
"""Minimal Compose fallback when MONAI/TorchIO unavailable."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, transforms: Sequence[Callable[[Any], Any]]):
|
|
21
|
+
self.transforms = list(transforms)
|
|
22
|
+
|
|
23
|
+
def __call__(self, data: Any) -> Any:
|
|
24
|
+
for t in self.transforms:
|
|
25
|
+
data = t(data)
|
|
26
|
+
return data
|
|
27
|
+
|
|
28
|
+
def __repr__(self) -> str:
|
|
29
|
+
names = [t.__class__.__name__ for t in self.transforms]
|
|
30
|
+
return f"Compose({names})"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
__all__ = ["VolumeCollectionSubjectsDataset", "Compose"]
|