radiobject 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,387 @@
1
+ """Imaging metadata extraction for NIfTI and DICOM files."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Literal
7
+
8
+ import nibabel as nib
9
+ import numpy as np
10
+ from pydantic import BaseModel, Field
11
+
12
+ from radiobject.utils import affine_to_json
13
+
14
+ # BIDS-aligned series type identifiers
15
+ KNOWN_SERIES_TYPES: frozenset[str] = frozenset(
16
+ {
17
+ # Anatomical MRI
18
+ "T1w",
19
+ "T2w",
20
+ "T1rho",
21
+ "T1map",
22
+ "T2map",
23
+ "T2star",
24
+ "FLAIR",
25
+ "FLASH",
26
+ "PD",
27
+ "PDmap",
28
+ "PDT2",
29
+ "inplaneT1",
30
+ "inplaneT2",
31
+ "angio",
32
+ "T1gd",
33
+ # Functional MRI
34
+ "bold",
35
+ "cbv",
36
+ "phase",
37
+ # Diffusion MRI
38
+ "dwi",
39
+ # CT variants
40
+ "CT",
41
+ "HRCT",
42
+ "CTA",
43
+ "CTPA",
44
+ # Field maps
45
+ "phasediff",
46
+ "magnitude",
47
+ "fieldmap",
48
+ "epi",
49
+ }
50
+ )
51
+
52
+ # Filename pattern to series type mapping (ordered from most specific to least specific)
53
+ _FILENAME_PATTERNS: tuple[tuple[str, str], ...] = (
54
+ # Contrast-enhanced T1 (check before T1)
55
+ ("T1GD", "T1gd"),
56
+ ("T1CE", "T1gd"),
57
+ ("T1C", "T1gd"),
58
+ # Standard patterns
59
+ ("T1W", "T1w"),
60
+ ("T1", "T1w"),
61
+ ("T2W", "T2w"),
62
+ ("T2", "T2w"),
63
+ ("FLAIR", "FLAIR"),
64
+ ("DWI", "dwi"),
65
+ ("DTI", "dwi"),
66
+ ("BOLD", "bold"),
67
+ ("FUNC", "bold"),
68
+ # CT patterns (MSD datasets, common naming)
69
+ ("LUNG_", "CT"),
70
+ ("LIVER_", "CT"),
71
+ ("COLON_", "CT"),
72
+ ("PANCREAS_", "CT"),
73
+ ("SPLEEN_", "CT"),
74
+ ("HEPATIC", "CT"),
75
+ ("_CT_", "CT"),
76
+ ("_CT.", "CT"),
77
+ )
78
+
79
+ # Spatial unit mapping from NIfTI xyzt_units
80
+ _SPATIAL_UNIT_MAP: dict[int, str] = {
81
+ 0: "unknown",
82
+ 1: "m",
83
+ 2: "mm",
84
+ 3: "um",
85
+ }
86
+
87
+
88
+ def infer_series_type(path: Path, header: nib.Nifti1Header | None = None) -> str:
89
+ """Infer series type from filename patterns and header.
90
+
91
+ Priority:
92
+ 1. BIDS-style suffix: sub-01_ses-01_T1w.nii.gz -> "T1w"
93
+ 2. Common patterns: T1_MPRAGE.nii.gz -> "T1w"
94
+ 3. Header description field
95
+ 4. Fallback: "unknown"
96
+ """
97
+ filename = path.stem
98
+ if filename.endswith(".nii"):
99
+ filename = filename[:-4]
100
+
101
+ # Check for BIDS suffix (last underscore-separated part)
102
+ parts = filename.split("_")
103
+ if parts:
104
+ suffix = parts[-1]
105
+ for known in KNOWN_SERIES_TYPES:
106
+ if known.lower() == suffix.lower():
107
+ return known
108
+
109
+ # Check common patterns in full filename (ordered from most specific to least)
110
+ filename_upper = filename.upper()
111
+ for pattern, series_type in _FILENAME_PATTERNS:
112
+ if pattern in filename_upper:
113
+ return series_type
114
+
115
+ # Check header description if available
116
+ if header is not None:
117
+ descrip_raw = header.get("descrip", b"")
118
+ if isinstance(descrip_raw, bytes):
119
+ descrip = descrip_raw.decode("utf-8", errors="ignore")
120
+ else:
121
+ descrip = str(descrip_raw)
122
+ descrip_lower = descrip.lower()
123
+ for known in KNOWN_SERIES_TYPES:
124
+ if known.lower() in descrip_lower:
125
+ return known
126
+
127
+ return "unknown"
128
+
129
+
130
+ def _get_spatial_units(xyzt_units: int) -> str:
131
+ """Extract spatial units from NIfTI xyzt_units field."""
132
+ spatial_code = xyzt_units & 0x07 # Lower 3 bits
133
+ return _SPATIAL_UNIT_MAP.get(spatial_code, "unknown")
134
+
135
+
136
+ class NiftiMetadata(BaseModel):
137
+ """NIfTI header metadata for obs DataFrame."""
138
+
139
+ # Voxel spacing (from pixdim[1:4]) as (x, y, z) tuple
140
+ voxel_spacing: tuple[float, float, float]
141
+
142
+ # Original dimensions as (x, y, z) tuple
143
+ dimensions: tuple[int, int, int]
144
+
145
+ # Data type
146
+ datatype: int
147
+ bitpix: int
148
+
149
+ # Scaling
150
+ scl_slope: float
151
+ scl_inter: float
152
+
153
+ # Units
154
+ xyzt_units: int
155
+ spatial_units: str # "mm", "um", "m", or "unknown"
156
+
157
+ # Coordinate system codes
158
+ qform_code: int
159
+ sform_code: int
160
+
161
+ # Orientation
162
+ axcodes: str # e.g., "RAS"
163
+ affine_json: str # 4x4 matrix as JSON
164
+ orientation_source: Literal["nifti_sform", "nifti_qform", "identity"]
165
+
166
+ # Provenance
167
+ source_path: str
168
+
169
+ model_config = {"frozen": True}
170
+
171
+ def to_obs_dict(self, obs_id: str, obs_subject_id: str, series_type: str) -> dict:
172
+ """Convert to dictionary for obs DataFrame row."""
173
+ data = self.model_dump()
174
+ # Serialize tuples as strings for TileDB storage
175
+ data["voxel_spacing"] = str(data["voxel_spacing"])
176
+ data["dimensions"] = str(data["dimensions"])
177
+ data.update(obs_id=obs_id, obs_subject_id=obs_subject_id, series_type=series_type)
178
+ return data
179
+
180
+
181
+ class DicomMetadata(BaseModel):
182
+ """DICOM header metadata for obs DataFrame."""
183
+
184
+ # Voxel spacing as (x, y, z) tuple - z is slice_thickness
185
+ voxel_spacing: tuple[float, float, float]
186
+
187
+ # Dimensions as (rows, columns, n_slices) tuple
188
+ dimensions: tuple[int, int, int]
189
+
190
+ # Patient/Study info (anonymized identifiers only)
191
+ modality: str # CT, MR, PT, etc.
192
+ series_description: str
193
+
194
+ # Acquisition parameters (None if not applicable)
195
+ kvp: float | None = Field(default=None) # CT tube voltage
196
+ exposure: float | None = Field(default=None) # CT exposure (mAs)
197
+ repetition_time: float | None = Field(default=None) # MRI TR
198
+ echo_time: float | None = Field(default=None) # MRI TE
199
+ magnetic_field_strength: float | None = Field(default=None) # MRI field strength
200
+
201
+ # Orientation
202
+ axcodes: str
203
+ affine_json: str
204
+ orientation_source: Literal["dicom_iop", "identity"]
205
+
206
+ # Provenance
207
+ source_path: str
208
+
209
+ model_config = {"frozen": True}
210
+
211
+ def to_obs_dict(self, obs_id: str, obs_subject_id: str) -> dict:
212
+ """Convert to dictionary for obs DataFrame row."""
213
+ data = self.model_dump()
214
+ # Serialize tuples as strings for TileDB storage
215
+ data["voxel_spacing"] = str(data["voxel_spacing"])
216
+ data["dimensions"] = str(data["dimensions"])
217
+ data.update(obs_id=obs_id, obs_subject_id=obs_subject_id)
218
+ return data
219
+
220
+
221
+ def extract_nifti_metadata(nifti_path: str | Path) -> NiftiMetadata:
222
+ """Extract comprehensive metadata from NIfTI header."""
223
+ path = Path(nifti_path)
224
+ if not path.exists():
225
+ raise FileNotFoundError(f"NIfTI file not found: {path}")
226
+
227
+ img = nib.load(path)
228
+ header = img.header
229
+
230
+ # Extract dimensions
231
+ dim = header.get("dim")
232
+ dim_x = int(dim[1]) if len(dim) > 1 else 0
233
+ dim_y = int(dim[2]) if len(dim) > 2 else 0
234
+ dim_z = int(dim[3]) if len(dim) > 3 else 0
235
+
236
+ # Extract voxel spacing
237
+ pixdim = header.get("pixdim")
238
+ voxel_spacing_x = float(pixdim[1]) if len(pixdim) > 1 else 1.0
239
+ voxel_spacing_y = float(pixdim[2]) if len(pixdim) > 2 else 1.0
240
+ voxel_spacing_z = float(pixdim[3]) if len(pixdim) > 3 else 1.0
241
+
242
+ # Data type info
243
+ datatype = int(header.get("datatype", 0))
244
+ bitpix = int(header.get("bitpix", 0))
245
+
246
+ # Scaling
247
+ scl_slope = float(header.get("scl_slope", 1.0))
248
+ scl_inter = float(header.get("scl_inter", 0.0))
249
+ # Handle NaN slope (nibabel returns nan for 0)
250
+ if np.isnan(scl_slope):
251
+ scl_slope = 1.0
252
+ if np.isnan(scl_inter):
253
+ scl_inter = 0.0
254
+
255
+ # Units
256
+ xyzt_units = int(header.get("xyzt_units", 0))
257
+ spatial_units = _get_spatial_units(xyzt_units)
258
+
259
+ # Coordinate system codes
260
+ sform_code = int(header.get("sform_code", 0))
261
+ qform_code = int(header.get("qform_code", 0))
262
+
263
+ # Determine orientation source and get affine
264
+ if sform_code > 0:
265
+ affine = img.get_sform()
266
+ orientation_source: Literal["nifti_sform", "nifti_qform", "identity"] = "nifti_sform"
267
+ elif qform_code > 0:
268
+ affine = img.get_qform()
269
+ orientation_source = "nifti_qform"
270
+ else:
271
+ affine = img.affine
272
+ orientation_source = "identity"
273
+
274
+ # Get axis codes
275
+ ornt = nib.orientations.io_orientation(affine)
276
+ axcodes = "".join(nib.orientations.ornt2axcodes(ornt))
277
+
278
+ return NiftiMetadata(
279
+ voxel_spacing=(voxel_spacing_x, voxel_spacing_y, voxel_spacing_z),
280
+ dimensions=(dim_x, dim_y, dim_z),
281
+ datatype=datatype,
282
+ bitpix=bitpix,
283
+ scl_slope=scl_slope,
284
+ scl_inter=scl_inter,
285
+ xyzt_units=xyzt_units,
286
+ spatial_units=spatial_units,
287
+ qform_code=qform_code,
288
+ sform_code=sform_code,
289
+ axcodes=axcodes,
290
+ affine_json=affine_to_json(affine),
291
+ orientation_source=orientation_source,
292
+ source_path=str(path.absolute()),
293
+ )
294
+
295
+
296
+ def extract_dicom_metadata(dicom_dir: str | Path) -> DicomMetadata:
297
+ """Extract comprehensive metadata from DICOM series."""
298
+ import pydicom
299
+
300
+ path = Path(dicom_dir)
301
+ if not path.exists():
302
+ raise FileNotFoundError(f"DICOM directory not found: {path}")
303
+
304
+ # Find DICOM files
305
+ dicom_files = sorted(path.glob("*.dcm"))
306
+ if not dicom_files:
307
+ dicom_files = sorted(
308
+ f for f in path.iterdir() if f.is_file() and not f.name.startswith(".")
309
+ )
310
+
311
+ if not dicom_files:
312
+ raise ValueError(f"No DICOM files found in {path}")
313
+
314
+ # Read first DICOM for most metadata
315
+ ds = pydicom.dcmread(dicom_files[0])
316
+
317
+ # Pixel spacing
318
+ pixel_spacing = getattr(ds, "PixelSpacing", [1.0, 1.0])
319
+ pixel_spacing_x = float(pixel_spacing[0])
320
+ pixel_spacing_y = float(pixel_spacing[1])
321
+ slice_thickness = float(getattr(ds, "SliceThickness", 1.0))
322
+
323
+ # Dimensions
324
+ rows = int(getattr(ds, "Rows", 0))
325
+ columns = int(getattr(ds, "Columns", 0))
326
+ n_slices = len(dicom_files)
327
+
328
+ # Modality and description
329
+ modality = str(getattr(ds, "Modality", "unknown"))
330
+ series_description = str(getattr(ds, "SeriesDescription", ""))
331
+
332
+ # Acquisition parameters (modality-specific)
333
+ kvp = float(ds.KVP) if hasattr(ds, "KVP") else None
334
+ exposure = float(ds.Exposure) if hasattr(ds, "Exposure") else None
335
+ repetition_time = float(ds.RepetitionTime) if hasattr(ds, "RepetitionTime") else None
336
+ echo_time = float(ds.EchoTime) if hasattr(ds, "EchoTime") else None
337
+ magnetic_field_strength = (
338
+ float(ds.MagneticFieldStrength) if hasattr(ds, "MagneticFieldStrength") else None
339
+ )
340
+
341
+ # Orientation
342
+ iop = getattr(ds, "ImageOrientationPatient", None)
343
+ ipp = getattr(ds, "ImagePositionPatient", None)
344
+
345
+ if iop is not None:
346
+ # Build affine from DICOM tags
347
+ row_cosines = np.array([float(iop[0]), float(iop[1]), float(iop[2])])
348
+ col_cosines = np.array([float(iop[3]), float(iop[4]), float(iop[5])])
349
+ slice_cosines = np.cross(row_cosines, col_cosines)
350
+
351
+ voxel_spacing = [pixel_spacing_y, pixel_spacing_x, slice_thickness]
352
+
353
+ affine = np.eye(4)
354
+ affine[:3, 0] = row_cosines * voxel_spacing[0]
355
+ affine[:3, 1] = col_cosines * voxel_spacing[1]
356
+ affine[:3, 2] = slice_cosines * voxel_spacing[2]
357
+
358
+ if ipp is not None:
359
+ affine[:3, 3] = [float(ipp[0]), float(ipp[1]), float(ipp[2])]
360
+
361
+ # Convert DICOM LPS to RAS
362
+ lps_to_ras = np.diag([-1, -1, 1, 1])
363
+ affine_ras = lps_to_ras @ affine
364
+
365
+ ornt = nib.orientations.io_orientation(affine_ras)
366
+ axcodes = "".join(nib.orientations.ornt2axcodes(ornt))
367
+ orientation_source: Literal["dicom_iop", "identity"] = "dicom_iop"
368
+ else:
369
+ affine_ras = np.eye(4)
370
+ axcodes = "RAS"
371
+ orientation_source = "identity"
372
+
373
+ return DicomMetadata(
374
+ voxel_spacing=(pixel_spacing_x, pixel_spacing_y, slice_thickness),
375
+ dimensions=(rows, columns, n_slices),
376
+ modality=modality,
377
+ series_description=series_description,
378
+ kvp=kvp,
379
+ exposure=exposure,
380
+ repetition_time=repetition_time,
381
+ echo_time=echo_time,
382
+ magnetic_field_strength=magnetic_field_strength,
383
+ axcodes=axcodes,
384
+ affine_json=affine_to_json(affine_ras),
385
+ orientation_source=orientation_source,
386
+ source_path=str(path.absolute()),
387
+ )
radiobject/indexing.py ADDED
@@ -0,0 +1,45 @@
1
+ """Shared indexing utilities for RadiObject entities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import Counter
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Index:
11
+ """Bidirectional index mapping string keys to integer positions."""
12
+
13
+ keys: tuple[str, ...]
14
+ key_to_idx: dict[str, int]
15
+
16
+ @classmethod
17
+ def build(cls, keys: list[str]) -> Index:
18
+ """Build an index from a list of string keys."""
19
+ if len(keys) != len(set(keys)):
20
+ duplicates = [k for k, count in Counter(keys).items() if count > 1]
21
+ raise ValueError(f"Duplicate keys detected: {duplicates[:5]}")
22
+ return cls(
23
+ keys=tuple(keys),
24
+ key_to_idx={key: idx for idx, key in enumerate(keys)},
25
+ )
26
+
27
+ def __len__(self) -> int:
28
+ return len(self.keys)
29
+
30
+ def __contains__(self, key: str) -> bool:
31
+ return key in self.key_to_idx
32
+
33
+ def get_index(self, key: str) -> int:
34
+ """Get integer index for a key. Raises KeyError if not found."""
35
+ idx = self.key_to_idx.get(key)
36
+ if idx is None:
37
+ raise KeyError(f"Key '{key}' not found in index")
38
+ return idx
39
+
40
+ def get_key(self, idx: int) -> str:
41
+ """Get key at integer index. Raises IndexError if out of bounds."""
42
+ n = len(self.keys)
43
+ if idx < 0 or idx >= n:
44
+ raise IndexError(f"Index {idx} out of range [0, {n})")
45
+ return self.keys[idx]
radiobject/ingest.py ADDED
@@ -0,0 +1,132 @@
1
+ """NIfTI discovery utilities for bulk ingestion."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Sequence
6
+ from dataclasses import dataclass
7
+ from glob import glob
8
+ from pathlib import Path
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class NiftiSource:
13
+ """Source NIfTI file with optional paired label."""
14
+
15
+ image_path: Path
16
+ subject_id: str
17
+ label_path: Path | None = None
18
+
19
+ @property
20
+ def has_label(self) -> bool:
21
+ return self.label_path is not None
22
+
23
+
24
+ def discover_nifti_pairs(
25
+ image_dir: str | Path,
26
+ label_dir: str | Path | None = None,
27
+ pattern: str = "*.nii.gz",
28
+ subject_id_fn: Callable[[Path], str] | None = None,
29
+ ) -> list[NiftiSource]:
30
+ """Discover NIfTI files and optionally pair with labels.
31
+
32
+ Args:
33
+ image_dir: Directory containing image NIfTIs
34
+ label_dir: Optional directory containing label NIfTIs (matched by filename)
35
+ pattern: Glob pattern for finding NIfTI files
36
+ subject_id_fn: Function to extract subject ID from path.
37
+ Default: stem without .nii extension
38
+
39
+ Returns:
40
+ List of NiftiSource objects
41
+ """
42
+ image_dir = Path(image_dir)
43
+ if not image_dir.exists():
44
+ raise FileNotFoundError(f"Image directory not found: {image_dir}")
45
+
46
+ if subject_id_fn is None:
47
+
48
+ def subject_id_fn(p: Path) -> str:
49
+ name = p.stem
50
+ if name.endswith(".nii"):
51
+ name = name[:-4]
52
+ return name
53
+
54
+ # Find all image files
55
+ image_files = sorted(image_dir.glob(pattern))
56
+ if not image_files:
57
+ # Try non-gzipped pattern
58
+ alt_pattern = pattern.replace(".gz", "")
59
+ image_files = sorted(image_dir.glob(alt_pattern))
60
+
61
+ if not image_files:
62
+ raise ValueError(f"No NIfTI files found in {image_dir} with pattern {pattern}")
63
+
64
+ # Build label lookup if label_dir provided
65
+ label_lookup: dict[str, Path] = {}
66
+ if label_dir is not None:
67
+ label_dir = Path(label_dir)
68
+ if not label_dir.exists():
69
+ raise FileNotFoundError(f"Label directory not found: {label_dir}")
70
+
71
+ for label_file in label_dir.glob(pattern):
72
+ label_id = subject_id_fn(label_file)
73
+ label_lookup[label_id] = label_file
74
+
75
+ # Also try non-gzipped
76
+ alt_pattern = pattern.replace(".gz", "")
77
+ for label_file in label_dir.glob(alt_pattern):
78
+ label_id = subject_id_fn(label_file)
79
+ if label_id not in label_lookup:
80
+ label_lookup[label_id] = label_file
81
+
82
+ # Create NiftiSource objects
83
+ sources = []
84
+ for image_path in image_files:
85
+ subject_id = subject_id_fn(image_path)
86
+ label_path = label_lookup.get(subject_id)
87
+ sources.append(
88
+ NiftiSource(
89
+ image_path=image_path,
90
+ subject_id=subject_id,
91
+ label_path=label_path,
92
+ )
93
+ )
94
+
95
+ return sources
96
+
97
+
98
+ def resolve_nifti_source(
99
+ source: str | Path | Sequence[tuple[str | Path, str]],
100
+ subject_id_fn: Callable[[Path], str] | None = None,
101
+ ) -> list[tuple[Path, str]]:
102
+ """Resolve various NIfTI source formats to (path, subject_id) tuples.
103
+
104
+ Supports:
105
+ - Glob pattern: "./imagesTr/*.nii.gz"
106
+ - Directory path: "./imagesTr"
107
+ - Pre-resolved list: [(path, subject_id), ...]
108
+ """
109
+ # Already resolved - return as-is
110
+ if isinstance(source, (list, tuple)) and source and isinstance(source[0], tuple):
111
+ return [(Path(p), sid) for p, sid in source]
112
+
113
+ source_str = str(source)
114
+
115
+ if subject_id_fn is None:
116
+
117
+ def subject_id_fn(p: Path) -> str:
118
+ name = p.stem
119
+ if name.endswith(".nii"):
120
+ name = name[:-4]
121
+ return name
122
+
123
+ # Glob pattern
124
+ if any(c in source_str for c in "*?["):
125
+ matched = sorted(glob(source_str, recursive=True))
126
+ if not matched:
127
+ raise ValueError(f"No files matched pattern: {source}")
128
+ return [(Path(f), subject_id_fn(Path(f))) for f in matched]
129
+
130
+ # Directory path
131
+ sources = discover_nifti_pairs(source)
132
+ return [(s.image_path, s.subject_id) for s in sources]
@@ -0,0 +1,26 @@
1
+ """PyTorch training system for RadiObject."""
2
+
3
+ from radiobject.ml.compat import Compose, VolumeCollectionSubjectsDataset
4
+ from radiobject.ml.config import DatasetConfig, LoadingMode
5
+ from radiobject.ml.datasets import SegmentationDataset, VolumeCollectionDataset
6
+ from radiobject.ml.factory import (
7
+ create_inference_dataloader,
8
+ create_segmentation_dataloader,
9
+ create_training_dataloader,
10
+ create_validation_dataloader,
11
+ )
12
+ from radiobject.ml.utils import LabelSource
13
+
14
+ __all__ = [
15
+ "Compose",
16
+ "DatasetConfig",
17
+ "LabelSource",
18
+ "LoadingMode",
19
+ "SegmentationDataset",
20
+ "VolumeCollectionDataset",
21
+ "VolumeCollectionSubjectsDataset",
22
+ "create_inference_dataloader",
23
+ "create_segmentation_dataloader",
24
+ "create_training_dataloader",
25
+ "create_validation_dataloader",
26
+ ]
radiobject/ml/cache.py ADDED
@@ -0,0 +1,53 @@
1
+ """Caching utilities for ML datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any
7
+
8
+
9
+ class BaseCache(ABC):
10
+ """Abstract base class for sample caching."""
11
+
12
+ def __init__(self) -> None:
13
+ self._hits = 0
14
+ self._misses = 0
15
+
16
+ @property
17
+ def hits(self) -> int:
18
+ """Number of cache hits."""
19
+ return self._hits
20
+
21
+ @property
22
+ def misses(self) -> int:
23
+ """Number of cache misses."""
24
+ return self._misses
25
+
26
+ @abstractmethod
27
+ def get(self, key: int) -> Any | None:
28
+ """Get cached sample by key, or None if not cached."""
29
+ ...
30
+
31
+ @abstractmethod
32
+ def set(self, key: int, value: Any) -> None:
33
+ """Cache a sample."""
34
+ ...
35
+
36
+ @abstractmethod
37
+ def clear(self) -> None:
38
+ """Clear all cached samples."""
39
+ ...
40
+
41
+
42
+ class NoOpCache(BaseCache):
43
+ """Cache that doesn't cache (passthrough)."""
44
+
45
+ def get(self, key: int) -> Any | None:
46
+ self._misses += 1
47
+ return None
48
+
49
+ def set(self, key: int, value: Any) -> None:
50
+ pass
51
+
52
+ def clear(self) -> None:
53
+ pass
@@ -0,0 +1,33 @@
1
+ """MONAI/TorchIO compatibility module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Sequence
6
+ from typing import Any
7
+
8
+ from radiobject.ml.compat.torchio import VolumeCollectionSubjectsDataset
9
+
10
+ try:
11
+ from monai.transforms import Compose
12
+ except ImportError:
13
+ try:
14
+ from torchio import Compose
15
+ except ImportError:
16
+
17
+ class Compose:
18
+ """Minimal Compose fallback when MONAI/TorchIO unavailable."""
19
+
20
+ def __init__(self, transforms: Sequence[Callable[[Any], Any]]):
21
+ self.transforms = list(transforms)
22
+
23
+ def __call__(self, data: Any) -> Any:
24
+ for t in self.transforms:
25
+ data = t(data)
26
+ return data
27
+
28
+ def __repr__(self) -> str:
29
+ names = [t.__class__.__name__ for t in self.transforms]
30
+ return f"Compose({names})"
31
+
32
+
33
+ __all__ = ["VolumeCollectionSubjectsDataset", "Compose"]