torchio 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchio/__init__.py +48 -0
- torchio/cli/__init__.py +0 -0
- torchio/cli/apply_transform.py +116 -0
- torchio/cli/print_info.py +56 -0
- torchio/constants.py +36 -0
- torchio/data/__init__.py +29 -0
- torchio/data/dataset.py +158 -0
- torchio/data/image.py +922 -0
- torchio/data/inference/__init__.py +7 -0
- torchio/data/inference/aggregator.py +251 -0
- torchio/data/io.py +465 -0
- torchio/data/loader.py +63 -0
- torchio/data/queue.py +393 -0
- torchio/data/sampler/__init__.py +15 -0
- torchio/data/sampler/grid.py +174 -0
- torchio/data/sampler/label.py +151 -0
- torchio/data/sampler/sampler.py +115 -0
- torchio/data/sampler/uniform.py +32 -0
- torchio/data/sampler/weighted.py +249 -0
- torchio/data/subject.py +453 -0
- torchio/datasets/__init__.py +45 -0
- torchio/datasets/bite.py +93 -0
- torchio/datasets/episurg.py +147 -0
- torchio/datasets/fpg.py +239 -0
- torchio/datasets/itk_snap/__init__.py +9 -0
- torchio/datasets/itk_snap/itk_snap.py +83 -0
- torchio/datasets/ixi.py +242 -0
- torchio/datasets/medmnist.py +77 -0
- torchio/datasets/mni/__init__.py +11 -0
- torchio/datasets/mni/colin.py +136 -0
- torchio/datasets/mni/icbm.py +87 -0
- torchio/datasets/mni/mni.py +16 -0
- torchio/datasets/mni/pediatric.py +81 -0
- torchio/datasets/mni/sheep.py +33 -0
- torchio/datasets/rsna_miccai.py +120 -0
- torchio/datasets/rsna_spine_fracture.py +150 -0
- torchio/datasets/slicer.py +68 -0
- torchio/download.py +173 -0
- torchio/external/__init__.py +0 -0
- torchio/external/due.py +70 -0
- torchio/py.typed +0 -0
- torchio/reference.py +40 -0
- torchio/transforms/__init__.py +117 -0
- torchio/transforms/augmentation/__init__.py +5 -0
- torchio/transforms/augmentation/composition.py +158 -0
- torchio/transforms/augmentation/intensity/__init__.py +39 -0
- torchio/transforms/augmentation/intensity/random_bias_field.py +175 -0
- torchio/transforms/augmentation/intensity/random_blur.py +113 -0
- torchio/transforms/augmentation/intensity/random_gamma.py +155 -0
- torchio/transforms/augmentation/intensity/random_ghosting.py +253 -0
- torchio/transforms/augmentation/intensity/random_labels_to_image.py +484 -0
- torchio/transforms/augmentation/intensity/random_motion.py +300 -0
- torchio/transforms/augmentation/intensity/random_noise.py +119 -0
- torchio/transforms/augmentation/intensity/random_spike.py +172 -0
- torchio/transforms/augmentation/intensity/random_swap.py +218 -0
- torchio/transforms/augmentation/random_transform.py +59 -0
- torchio/transforms/augmentation/spatial/__init__.py +17 -0
- torchio/transforms/augmentation/spatial/random_affine.py +470 -0
- torchio/transforms/augmentation/spatial/random_anisotropy.py +122 -0
- torchio/transforms/augmentation/spatial/random_elastic_deformation.py +342 -0
- torchio/transforms/augmentation/spatial/random_flip.py +135 -0
- torchio/transforms/data_parser.py +149 -0
- torchio/transforms/fourier.py +34 -0
- torchio/transforms/intensity_transform.py +42 -0
- torchio/transforms/interpolation.py +59 -0
- torchio/transforms/lambda_transform.py +71 -0
- torchio/transforms/preprocessing/__init__.py +41 -0
- torchio/transforms/preprocessing/intensity/__init__.py +5 -0
- torchio/transforms/preprocessing/intensity/clamp.py +60 -0
- torchio/transforms/preprocessing/intensity/histogram_standardization.py +309 -0
- torchio/transforms/preprocessing/intensity/mask.py +101 -0
- torchio/transforms/preprocessing/intensity/normalization_transform.py +60 -0
- torchio/transforms/preprocessing/intensity/rescale.py +131 -0
- torchio/transforms/preprocessing/intensity/z_normalization.py +55 -0
- torchio/transforms/preprocessing/label/__init__.py +0 -0
- torchio/transforms/preprocessing/label/contour.py +26 -0
- torchio/transforms/preprocessing/label/keep_largest_component.py +35 -0
- torchio/transforms/preprocessing/label/label_transform.py +26 -0
- torchio/transforms/preprocessing/label/one_hot.py +45 -0
- torchio/transforms/preprocessing/label/remap_labels.py +187 -0
- torchio/transforms/preprocessing/label/remove_labels.py +74 -0
- torchio/transforms/preprocessing/label/sequential_labels.py +61 -0
- torchio/transforms/preprocessing/spatial/__init__.py +0 -0
- torchio/transforms/preprocessing/spatial/bounds_transform.py +20 -0
- torchio/transforms/preprocessing/spatial/copy_affine.py +86 -0
- torchio/transforms/preprocessing/spatial/crop.py +58 -0
- torchio/transforms/preprocessing/spatial/crop_or_pad.py +287 -0
- torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py +138 -0
- torchio/transforms/preprocessing/spatial/pad.py +112 -0
- torchio/transforms/preprocessing/spatial/resample.py +323 -0
- torchio/transforms/preprocessing/spatial/resize.py +78 -0
- torchio/transforms/preprocessing/spatial/to_canonical.py +48 -0
- torchio/transforms/spatial_transform.py +17 -0
- torchio/transforms/transform.py +573 -0
- torchio/typing.py +45 -0
- torchio/utils.py +440 -0
- torchio/visualization.py +258 -0
- torchio-0.20.1.dist-info/METADATA +513 -0
- torchio-0.20.1.dist-info/RECORD +102 -0
- torchio-0.20.1.dist-info/WHEEL +4 -0
- torchio-0.20.1.dist-info/entry_points.txt +4 -0
- torchio-0.20.1.dist-info/licenses/LICENSE +201 -0
torchio/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Top-level package for torchio."""
|
|
2
|
+
|
|
3
|
+
__author__ = """Fernando Perez-Garcia"""
|
|
4
|
+
__email__ = 'fepegar@gmail.com'
|
|
5
|
+
__version__ = '0.20.1'
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from . import datasets
|
|
9
|
+
from . import reference
|
|
10
|
+
from . import utils
|
|
11
|
+
from .constants import * # noqa: F401, F403
|
|
12
|
+
from .data import GridAggregator
|
|
13
|
+
from .data import GridSampler
|
|
14
|
+
from .data import Image
|
|
15
|
+
from .data import LabelMap
|
|
16
|
+
from .data import LabelSampler
|
|
17
|
+
from .data import Queue
|
|
18
|
+
from .data import ScalarImage
|
|
19
|
+
from .data import Subject
|
|
20
|
+
from .data import SubjectsDataset
|
|
21
|
+
from .data import SubjectsLoader
|
|
22
|
+
from .data import UniformSampler
|
|
23
|
+
from .data import WeightedSampler
|
|
24
|
+
from .data import inference
|
|
25
|
+
from .data import io
|
|
26
|
+
from .data import sampler
|
|
27
|
+
from .transforms import * # noqa: F401, F403
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
'utils',
|
|
31
|
+
'io',
|
|
32
|
+
'sampler',
|
|
33
|
+
'inference',
|
|
34
|
+
'SubjectsDataset',
|
|
35
|
+
'SubjectsLoader',
|
|
36
|
+
'Image',
|
|
37
|
+
'ScalarImage',
|
|
38
|
+
'LabelMap',
|
|
39
|
+
'Queue',
|
|
40
|
+
'Subject',
|
|
41
|
+
'datasets',
|
|
42
|
+
'reference',
|
|
43
|
+
'WeightedSampler',
|
|
44
|
+
'UniformSampler',
|
|
45
|
+
'LabelSampler',
|
|
46
|
+
'GridSampler',
|
|
47
|
+
'GridAggregator',
|
|
48
|
+
]
|
torchio/cli/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# pylint: disable=import-outside-toplevel
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
from rich.progress import Progress
|
|
7
|
+
from rich.progress import SpinnerColumn
|
|
8
|
+
from rich.progress import TextColumn
|
|
9
|
+
|
|
10
|
+
app = typer.Typer()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@app.command()
|
|
14
|
+
def main(
|
|
15
|
+
input_path: Path = typer.Argument( # noqa: B008
|
|
16
|
+
...,
|
|
17
|
+
exists=True,
|
|
18
|
+
file_okay=True,
|
|
19
|
+
dir_okay=True,
|
|
20
|
+
readable=True,
|
|
21
|
+
),
|
|
22
|
+
transform_name: str = typer.Argument(...), # noqa: B008
|
|
23
|
+
output_path: Path = typer.Argument( # noqa: B008
|
|
24
|
+
...,
|
|
25
|
+
file_okay=True,
|
|
26
|
+
dir_okay=False,
|
|
27
|
+
writable=True,
|
|
28
|
+
),
|
|
29
|
+
kwargs: str = typer.Option( # noqa: B008
|
|
30
|
+
None,
|
|
31
|
+
'--kwargs',
|
|
32
|
+
'-k',
|
|
33
|
+
help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".',
|
|
34
|
+
),
|
|
35
|
+
imclass: str = typer.Option( # noqa: B008
|
|
36
|
+
'ScalarImage',
|
|
37
|
+
'--imclass',
|
|
38
|
+
'-c',
|
|
39
|
+
help=(
|
|
40
|
+
'Name of the subclass of torchio.Image'
|
|
41
|
+
' that will be used to instantiate the image.'
|
|
42
|
+
),
|
|
43
|
+
),
|
|
44
|
+
seed: int = typer.Option( # noqa: B008
|
|
45
|
+
None,
|
|
46
|
+
'--seed',
|
|
47
|
+
'-s',
|
|
48
|
+
help='Seed for PyTorch random number generator.',
|
|
49
|
+
),
|
|
50
|
+
verbose: bool = typer.Option( # noqa: B008
|
|
51
|
+
False,
|
|
52
|
+
help='Print random transform parameters.',
|
|
53
|
+
),
|
|
54
|
+
show_progress: bool = typer.Option( # noqa: B008
|
|
55
|
+
True,
|
|
56
|
+
'--show-progress/--hide-progress',
|
|
57
|
+
'-p/-P',
|
|
58
|
+
help='Show animations indicating progress.',
|
|
59
|
+
),
|
|
60
|
+
):
|
|
61
|
+
"""Apply transform to an image.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
$ tiotr input.nrrd RandomMotion output.nii "degrees=(-5,15) num_transforms=3" -v
|
|
65
|
+
"""
|
|
66
|
+
# Imports are placed here so that the tool loads faster if not being run
|
|
67
|
+
import torch
|
|
68
|
+
|
|
69
|
+
import torchio.transforms as transforms
|
|
70
|
+
from torchio.utils import apply_transform_to_file
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
transform_class = getattr(transforms, transform_name)
|
|
74
|
+
except AttributeError as error:
|
|
75
|
+
message = f'Transform "{transform_name}" not found in torchio'
|
|
76
|
+
raise ValueError(message) from error
|
|
77
|
+
|
|
78
|
+
params_dict = get_params_dict_from_kwargs(kwargs)
|
|
79
|
+
transform = transform_class(**params_dict)
|
|
80
|
+
if seed is not None:
|
|
81
|
+
torch.manual_seed(seed)
|
|
82
|
+
with Progress(
|
|
83
|
+
SpinnerColumn(),
|
|
84
|
+
TextColumn('[progress.description]{task.description}'),
|
|
85
|
+
transient=True,
|
|
86
|
+
disable=not show_progress,
|
|
87
|
+
) as progress:
|
|
88
|
+
progress.add_task('Applying transform', total=1)
|
|
89
|
+
apply_transform_to_file(
|
|
90
|
+
input_path,
|
|
91
|
+
transform,
|
|
92
|
+
output_path,
|
|
93
|
+
verbose=verbose,
|
|
94
|
+
class_=imclass,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_params_dict_from_kwargs(kwargs):
|
|
99
|
+
from torchio.utils import guess_type
|
|
100
|
+
|
|
101
|
+
params_dict = {}
|
|
102
|
+
if kwargs is not None:
|
|
103
|
+
for substring in kwargs.split():
|
|
104
|
+
try:
|
|
105
|
+
key, value_string = substring.split('=')
|
|
106
|
+
except ValueError as error:
|
|
107
|
+
message = f'Arguments string "{kwargs}" not valid'
|
|
108
|
+
raise ValueError(message) from error
|
|
109
|
+
|
|
110
|
+
value = guess_type(value_string)
|
|
111
|
+
params_dict[key] = value
|
|
112
|
+
return params_dict
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
if __name__ == '__main__':
|
|
116
|
+
app()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# pylint: disable=import-outside-toplevel
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
|
|
6
|
+
app = typer.Typer()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@app.command()
|
|
10
|
+
def main(
|
|
11
|
+
input_path: Path = typer.Argument( # noqa: B008
|
|
12
|
+
...,
|
|
13
|
+
exists=True,
|
|
14
|
+
file_okay=True,
|
|
15
|
+
dir_okay=True,
|
|
16
|
+
readable=True,
|
|
17
|
+
),
|
|
18
|
+
plot: bool = typer.Option( # noqa: B008
|
|
19
|
+
False,
|
|
20
|
+
'--plot/--no-plot',
|
|
21
|
+
'-p/-P',
|
|
22
|
+
help='Plot the image using Matplotlib or Pillow.',
|
|
23
|
+
),
|
|
24
|
+
show: bool = typer.Option( # noqa: B008
|
|
25
|
+
False,
|
|
26
|
+
'--show/--no-show',
|
|
27
|
+
'-s/-S',
|
|
28
|
+
help='Show the image using specialized visualisation software.',
|
|
29
|
+
),
|
|
30
|
+
label: bool = typer.Option( # noqa: B008
|
|
31
|
+
False,
|
|
32
|
+
'--label/--scalar',
|
|
33
|
+
'-l/-s',
|
|
34
|
+
help='Use torchio.LabelMap to instantiate the image.',
|
|
35
|
+
),
|
|
36
|
+
):
|
|
37
|
+
"""Print information about an image and, optionally, show it.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
$ tiohd input.nii.gz
|
|
41
|
+
"""
|
|
42
|
+
# Imports are placed here so that the tool loads faster if not being run
|
|
43
|
+
import torchio as tio
|
|
44
|
+
|
|
45
|
+
class_ = tio.LabelMap if label else tio.ScalarImage
|
|
46
|
+
image = class_(input_path)
|
|
47
|
+
image.load()
|
|
48
|
+
print(image) # noqa: T201
|
|
49
|
+
if plot:
|
|
50
|
+
image.plot()
|
|
51
|
+
if show:
|
|
52
|
+
image.show()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == '__main__':
|
|
56
|
+
app()
|
torchio/constants.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
# Image types
|
|
4
|
+
INTENSITY = 'intensity'
|
|
5
|
+
LABEL = 'label'
|
|
6
|
+
SAMPLING_MAP = 'sampling_map'
|
|
7
|
+
|
|
8
|
+
# Keys for dataset samples
|
|
9
|
+
PATH = 'path'
|
|
10
|
+
TYPE = 'type'
|
|
11
|
+
STEM = 'stem'
|
|
12
|
+
DATA = 'data'
|
|
13
|
+
AFFINE = 'affine'
|
|
14
|
+
TENSOR = 'tensor'
|
|
15
|
+
|
|
16
|
+
# For aggregator
|
|
17
|
+
IMAGE = 'image'
|
|
18
|
+
LOCATION = 'location'
|
|
19
|
+
|
|
20
|
+
# For special collate function
|
|
21
|
+
HISTORY = 'history'
|
|
22
|
+
|
|
23
|
+
# In PyTorch convention
|
|
24
|
+
CHANNELS_DIMENSION = 1
|
|
25
|
+
|
|
26
|
+
# Code repository
|
|
27
|
+
REPO_URL = 'https://github.com/fepegar/torchio/'
|
|
28
|
+
|
|
29
|
+
# Data repository
|
|
30
|
+
DATA_REPO = 'https://github.com/fepegar/torchio-data/raw/main/data/'
|
|
31
|
+
|
|
32
|
+
# Floating point error
|
|
33
|
+
MIN_FLOAT_32 = torch.finfo(torch.float32).eps
|
|
34
|
+
|
|
35
|
+
# For the queue
|
|
36
|
+
NUM_SAMPLES = 'num_samples'
|
torchio/data/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .dataset import SubjectsDataset
|
|
2
|
+
from .image import Image
|
|
3
|
+
from .image import LabelMap
|
|
4
|
+
from .image import ScalarImage
|
|
5
|
+
from .inference import GridAggregator
|
|
6
|
+
from .loader import SubjectsLoader
|
|
7
|
+
from .queue import Queue
|
|
8
|
+
from .sampler import GridSampler
|
|
9
|
+
from .sampler import LabelSampler
|
|
10
|
+
from .sampler import PatchSampler
|
|
11
|
+
from .sampler import UniformSampler
|
|
12
|
+
from .sampler import WeightedSampler
|
|
13
|
+
from .subject import Subject
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'Queue',
|
|
17
|
+
'Subject',
|
|
18
|
+
'SubjectsDataset',
|
|
19
|
+
'SubjectsLoader',
|
|
20
|
+
'Image',
|
|
21
|
+
'ScalarImage',
|
|
22
|
+
'LabelMap',
|
|
23
|
+
'GridSampler',
|
|
24
|
+
'GridAggregator',
|
|
25
|
+
'PatchSampler',
|
|
26
|
+
'LabelSampler',
|
|
27
|
+
'WeightedSampler',
|
|
28
|
+
'UniformSampler',
|
|
29
|
+
]
|
torchio/data/dataset.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Callable
|
|
5
|
+
from typing import Iterable
|
|
6
|
+
from typing import List
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from typing import Sequence
|
|
9
|
+
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from ..utils import get_subjects_from_batch
|
|
13
|
+
from .subject import Subject
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SubjectsDataset(Dataset):
|
|
17
|
+
"""Base TorchIO dataset.
|
|
18
|
+
|
|
19
|
+
Reader of 3D medical images that directly inherits from the PyTorch
|
|
20
|
+
:class:`~torch.utils.data.Dataset`. It can be used with a
|
|
21
|
+
:class:`~tio.SubjectsLoader` for efficient loading and
|
|
22
|
+
augmentation. It receives a list of instances of :class:`~torchio.Subject`
|
|
23
|
+
and an optional transform applied to the volumes after loading.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
subjects: List of instances of :class:`~torchio.Subject`.
|
|
27
|
+
transform: An instance of :class:`~torchio.transforms.Transform`
|
|
28
|
+
that will be applied to each subject.
|
|
29
|
+
load_getitem: Load all subject images before returning it in
|
|
30
|
+
:meth:`__getitem__`. Set it to ``False`` if some of the images will
|
|
31
|
+
not be needed during training.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> import torchio as tio
|
|
35
|
+
>>> subject_a = tio.Subject(
|
|
36
|
+
... t1=tio.ScalarImage('t1.nrrd',),
|
|
37
|
+
... t2=tio.ScalarImage('t2.mha',),
|
|
38
|
+
... label=tio.LabelMap('t1_seg.nii.gz'),
|
|
39
|
+
... age=31,
|
|
40
|
+
... name='Fernando Perez',
|
|
41
|
+
... )
|
|
42
|
+
>>> subject_b = tio.Subject(
|
|
43
|
+
... t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
|
|
44
|
+
... t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
|
|
45
|
+
... label=tio.LabelMap('colin27_seg1.nii.gz'),
|
|
46
|
+
... age=56,
|
|
47
|
+
... name='Colin Holmes',
|
|
48
|
+
... )
|
|
49
|
+
>>> subjects_list = [subject_a, subject_b]
|
|
50
|
+
>>> transforms = [
|
|
51
|
+
... tio.RescaleIntensity(out_min_max=(0, 1)),
|
|
52
|
+
... tio.RandomAffine(),
|
|
53
|
+
... ]
|
|
54
|
+
>>> transform = tio.Compose(transforms)
|
|
55
|
+
>>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
|
|
56
|
+
>>> subject = subjects_dataset[0]
|
|
57
|
+
|
|
58
|
+
.. _NiBabel: https://nipy.org/nibabel/#nibabel
|
|
59
|
+
.. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
|
|
60
|
+
.. _DICOM: https://www.dicomstandard.org/
|
|
61
|
+
.. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
|
|
62
|
+
|
|
63
|
+
.. tip:: To quickly iterate over the subjects without loading the images,
|
|
64
|
+
use :meth:`dry_iter()`.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
subjects: Sequence[Subject],
|
|
70
|
+
transform: Optional[Callable] = None,
|
|
71
|
+
load_getitem: bool = True,
|
|
72
|
+
):
|
|
73
|
+
self._parse_subjects_list(subjects)
|
|
74
|
+
self._subjects = subjects
|
|
75
|
+
self._transform: Optional[Callable]
|
|
76
|
+
self.set_transform(transform)
|
|
77
|
+
self.load_getitem = load_getitem
|
|
78
|
+
|
|
79
|
+
def __len__(self):
|
|
80
|
+
return len(self._subjects)
|
|
81
|
+
|
|
82
|
+
def __getitem__(self, index: int) -> Subject:
|
|
83
|
+
try:
|
|
84
|
+
index = int(index)
|
|
85
|
+
except (RuntimeError, TypeError) as err:
|
|
86
|
+
message = (
|
|
87
|
+
f'Index "{index}" must be int or compatible dtype,'
|
|
88
|
+
f' but an object of type "{type(index)}" was passed'
|
|
89
|
+
)
|
|
90
|
+
raise ValueError(message) from err
|
|
91
|
+
|
|
92
|
+
subject = self._subjects[index]
|
|
93
|
+
subject = copy.deepcopy(subject) # cheap since images not loaded yet
|
|
94
|
+
if self.load_getitem:
|
|
95
|
+
subject.load()
|
|
96
|
+
|
|
97
|
+
# Apply transform (this is usually the bottleneck)
|
|
98
|
+
if self._transform is not None:
|
|
99
|
+
subject = self._transform(subject)
|
|
100
|
+
return subject
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def from_batch(cls, batch: dict) -> SubjectsDataset:
|
|
104
|
+
"""Instantiate a dataset from a batch generated by a data loader.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
batch: Dictionary generated by a data loader, containing data that
|
|
108
|
+
can be converted to instances of :class:`~.torchio.Subject`.
|
|
109
|
+
"""
|
|
110
|
+
subjects: List[Subject] = get_subjects_from_batch(batch)
|
|
111
|
+
return cls(subjects)
|
|
112
|
+
|
|
113
|
+
def dry_iter(self):
|
|
114
|
+
"""Return the internal list of subjects.
|
|
115
|
+
|
|
116
|
+
This can be used to iterate over the subjects without loading the data
|
|
117
|
+
and applying any transforms::
|
|
118
|
+
|
|
119
|
+
>>> names = [subject.name for subject in dataset.dry_iter()]
|
|
120
|
+
"""
|
|
121
|
+
return self._subjects
|
|
122
|
+
|
|
123
|
+
def set_transform(self, transform: Optional[Callable]) -> None:
|
|
124
|
+
"""Set the :attr:`transform` attribute.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
transform: Callable object, typically an subclass of
|
|
128
|
+
:class:`torchio.transforms.Transform`.
|
|
129
|
+
"""
|
|
130
|
+
if transform is not None and not callable(transform):
|
|
131
|
+
message = (
|
|
132
|
+
'The transform must be a callable object,'
|
|
133
|
+
f' but it has type {type(transform)}'
|
|
134
|
+
)
|
|
135
|
+
raise ValueError(message)
|
|
136
|
+
self._transform = transform
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None:
|
|
140
|
+
# Check that it's an iterable
|
|
141
|
+
try:
|
|
142
|
+
iter(subjects_list)
|
|
143
|
+
except TypeError as e:
|
|
144
|
+
message = f'Subject list must be an iterable, not {type(subjects_list)}'
|
|
145
|
+
raise TypeError(message) from e
|
|
146
|
+
|
|
147
|
+
# Check that it's not empty
|
|
148
|
+
if not subjects_list:
|
|
149
|
+
raise ValueError('Subjects list is empty')
|
|
150
|
+
|
|
151
|
+
# Check each element
|
|
152
|
+
for subject in subjects_list:
|
|
153
|
+
if not isinstance(subject, Subject):
|
|
154
|
+
message = (
|
|
155
|
+
'Subjects list must contain instances of torchio.Subject,'
|
|
156
|
+
f' not "{type(subject)}"'
|
|
157
|
+
)
|
|
158
|
+
raise TypeError(message)
|