torchio 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. torchio/__init__.py +48 -0
  2. torchio/cli/__init__.py +0 -0
  3. torchio/cli/apply_transform.py +116 -0
  4. torchio/cli/print_info.py +56 -0
  5. torchio/constants.py +36 -0
  6. torchio/data/__init__.py +29 -0
  7. torchio/data/dataset.py +158 -0
  8. torchio/data/image.py +922 -0
  9. torchio/data/inference/__init__.py +7 -0
  10. torchio/data/inference/aggregator.py +251 -0
  11. torchio/data/io.py +465 -0
  12. torchio/data/loader.py +63 -0
  13. torchio/data/queue.py +393 -0
  14. torchio/data/sampler/__init__.py +15 -0
  15. torchio/data/sampler/grid.py +174 -0
  16. torchio/data/sampler/label.py +151 -0
  17. torchio/data/sampler/sampler.py +115 -0
  18. torchio/data/sampler/uniform.py +32 -0
  19. torchio/data/sampler/weighted.py +249 -0
  20. torchio/data/subject.py +453 -0
  21. torchio/datasets/__init__.py +45 -0
  22. torchio/datasets/bite.py +93 -0
  23. torchio/datasets/episurg.py +147 -0
  24. torchio/datasets/fpg.py +239 -0
  25. torchio/datasets/itk_snap/__init__.py +9 -0
  26. torchio/datasets/itk_snap/itk_snap.py +83 -0
  27. torchio/datasets/ixi.py +242 -0
  28. torchio/datasets/medmnist.py +77 -0
  29. torchio/datasets/mni/__init__.py +11 -0
  30. torchio/datasets/mni/colin.py +136 -0
  31. torchio/datasets/mni/icbm.py +87 -0
  32. torchio/datasets/mni/mni.py +16 -0
  33. torchio/datasets/mni/pediatric.py +81 -0
  34. torchio/datasets/mni/sheep.py +33 -0
  35. torchio/datasets/rsna_miccai.py +120 -0
  36. torchio/datasets/rsna_spine_fracture.py +150 -0
  37. torchio/datasets/slicer.py +68 -0
  38. torchio/download.py +173 -0
  39. torchio/external/__init__.py +0 -0
  40. torchio/external/due.py +70 -0
  41. torchio/py.typed +0 -0
  42. torchio/reference.py +40 -0
  43. torchio/transforms/__init__.py +117 -0
  44. torchio/transforms/augmentation/__init__.py +5 -0
  45. torchio/transforms/augmentation/composition.py +158 -0
  46. torchio/transforms/augmentation/intensity/__init__.py +39 -0
  47. torchio/transforms/augmentation/intensity/random_bias_field.py +175 -0
  48. torchio/transforms/augmentation/intensity/random_blur.py +113 -0
  49. torchio/transforms/augmentation/intensity/random_gamma.py +155 -0
  50. torchio/transforms/augmentation/intensity/random_ghosting.py +253 -0
  51. torchio/transforms/augmentation/intensity/random_labels_to_image.py +484 -0
  52. torchio/transforms/augmentation/intensity/random_motion.py +300 -0
  53. torchio/transforms/augmentation/intensity/random_noise.py +119 -0
  54. torchio/transforms/augmentation/intensity/random_spike.py +172 -0
  55. torchio/transforms/augmentation/intensity/random_swap.py +218 -0
  56. torchio/transforms/augmentation/random_transform.py +59 -0
  57. torchio/transforms/augmentation/spatial/__init__.py +17 -0
  58. torchio/transforms/augmentation/spatial/random_affine.py +470 -0
  59. torchio/transforms/augmentation/spatial/random_anisotropy.py +122 -0
  60. torchio/transforms/augmentation/spatial/random_elastic_deformation.py +342 -0
  61. torchio/transforms/augmentation/spatial/random_flip.py +135 -0
  62. torchio/transforms/data_parser.py +149 -0
  63. torchio/transforms/fourier.py +34 -0
  64. torchio/transforms/intensity_transform.py +42 -0
  65. torchio/transforms/interpolation.py +59 -0
  66. torchio/transforms/lambda_transform.py +71 -0
  67. torchio/transforms/preprocessing/__init__.py +41 -0
  68. torchio/transforms/preprocessing/intensity/__init__.py +5 -0
  69. torchio/transforms/preprocessing/intensity/clamp.py +60 -0
  70. torchio/transforms/preprocessing/intensity/histogram_standardization.py +309 -0
  71. torchio/transforms/preprocessing/intensity/mask.py +101 -0
  72. torchio/transforms/preprocessing/intensity/normalization_transform.py +60 -0
  73. torchio/transforms/preprocessing/intensity/rescale.py +131 -0
  74. torchio/transforms/preprocessing/intensity/z_normalization.py +55 -0
  75. torchio/transforms/preprocessing/label/__init__.py +0 -0
  76. torchio/transforms/preprocessing/label/contour.py +26 -0
  77. torchio/transforms/preprocessing/label/keep_largest_component.py +35 -0
  78. torchio/transforms/preprocessing/label/label_transform.py +26 -0
  79. torchio/transforms/preprocessing/label/one_hot.py +45 -0
  80. torchio/transforms/preprocessing/label/remap_labels.py +187 -0
  81. torchio/transforms/preprocessing/label/remove_labels.py +74 -0
  82. torchio/transforms/preprocessing/label/sequential_labels.py +61 -0
  83. torchio/transforms/preprocessing/spatial/__init__.py +0 -0
  84. torchio/transforms/preprocessing/spatial/bounds_transform.py +20 -0
  85. torchio/transforms/preprocessing/spatial/copy_affine.py +86 -0
  86. torchio/transforms/preprocessing/spatial/crop.py +58 -0
  87. torchio/transforms/preprocessing/spatial/crop_or_pad.py +287 -0
  88. torchio/transforms/preprocessing/spatial/ensure_shape_multiple.py +138 -0
  89. torchio/transforms/preprocessing/spatial/pad.py +112 -0
  90. torchio/transforms/preprocessing/spatial/resample.py +323 -0
  91. torchio/transforms/preprocessing/spatial/resize.py +78 -0
  92. torchio/transforms/preprocessing/spatial/to_canonical.py +48 -0
  93. torchio/transforms/spatial_transform.py +17 -0
  94. torchio/transforms/transform.py +573 -0
  95. torchio/typing.py +45 -0
  96. torchio/utils.py +440 -0
  97. torchio/visualization.py +258 -0
  98. torchio-0.20.1.dist-info/METADATA +513 -0
  99. torchio-0.20.1.dist-info/RECORD +102 -0
  100. torchio-0.20.1.dist-info/WHEEL +4 -0
  101. torchio-0.20.1.dist-info/entry_points.txt +4 -0
  102. torchio-0.20.1.dist-info/licenses/LICENSE +201 -0
torchio/__init__.py ADDED
@@ -0,0 +1,48 @@
1
+ """Top-level package for torchio."""
2
+
3
+ __author__ = """Fernando Perez-Garcia"""
4
+ __email__ = 'fepegar@gmail.com'
5
+ __version__ = '0.20.1'
6
+
7
+
8
+ from . import datasets
9
+ from . import reference
10
+ from . import utils
11
+ from .constants import * # noqa: F401, F403
12
+ from .data import GridAggregator
13
+ from .data import GridSampler
14
+ from .data import Image
15
+ from .data import LabelMap
16
+ from .data import LabelSampler
17
+ from .data import Queue
18
+ from .data import ScalarImage
19
+ from .data import Subject
20
+ from .data import SubjectsDataset
21
+ from .data import SubjectsLoader
22
+ from .data import UniformSampler
23
+ from .data import WeightedSampler
24
+ from .data import inference
25
+ from .data import io
26
+ from .data import sampler
27
+ from .transforms import * # noqa: F401, F403
28
+
29
+ __all__ = [
30
+ 'utils',
31
+ 'io',
32
+ 'sampler',
33
+ 'inference',
34
+ 'SubjectsDataset',
35
+ 'SubjectsLoader',
36
+ 'Image',
37
+ 'ScalarImage',
38
+ 'LabelMap',
39
+ 'Queue',
40
+ 'Subject',
41
+ 'datasets',
42
+ 'reference',
43
+ 'WeightedSampler',
44
+ 'UniformSampler',
45
+ 'LabelSampler',
46
+ 'GridSampler',
47
+ 'GridAggregator',
48
+ ]
File without changes
@@ -0,0 +1,116 @@
1
+ # pylint: disable=import-outside-toplevel
2
+
3
+ from pathlib import Path
4
+
5
+ import typer
6
+ from rich.progress import Progress
7
+ from rich.progress import SpinnerColumn
8
+ from rich.progress import TextColumn
9
+
10
+ app = typer.Typer()
11
+
12
+
13
+ @app.command()
14
+ def main(
15
+ input_path: Path = typer.Argument( # noqa: B008
16
+ ...,
17
+ exists=True,
18
+ file_okay=True,
19
+ dir_okay=True,
20
+ readable=True,
21
+ ),
22
+ transform_name: str = typer.Argument(...), # noqa: B008
23
+ output_path: Path = typer.Argument( # noqa: B008
24
+ ...,
25
+ file_okay=True,
26
+ dir_okay=False,
27
+ writable=True,
28
+ ),
29
+ kwargs: str = typer.Option( # noqa: B008
30
+ None,
31
+ '--kwargs',
32
+ '-k',
33
+ help='String of kwargs, e.g. "degrees=(-5,15) num_transforms=3".',
34
+ ),
35
+ imclass: str = typer.Option( # noqa: B008
36
+ 'ScalarImage',
37
+ '--imclass',
38
+ '-c',
39
+ help=(
40
+ 'Name of the subclass of torchio.Image'
41
+ ' that will be used to instantiate the image.'
42
+ ),
43
+ ),
44
+ seed: int = typer.Option( # noqa: B008
45
+ None,
46
+ '--seed',
47
+ '-s',
48
+ help='Seed for PyTorch random number generator.',
49
+ ),
50
+ verbose: bool = typer.Option( # noqa: B008
51
+ False,
52
+ help='Print random transform parameters.',
53
+ ),
54
+ show_progress: bool = typer.Option( # noqa: B008
55
+ True,
56
+ '--show-progress/--hide-progress',
57
+ '-p/-P',
58
+ help='Show animations indicating progress.',
59
+ ),
60
+ ):
61
+ """Apply transform to an image.
62
+
63
+ Example:
64
+ $ tiotr input.nrrd RandomMotion output.nii "degrees=(-5,15) num_transforms=3" -v
65
+ """
66
+ # Imports are placed here so that the tool loads faster if not being run
67
+ import torch
68
+
69
+ import torchio.transforms as transforms
70
+ from torchio.utils import apply_transform_to_file
71
+
72
+ try:
73
+ transform_class = getattr(transforms, transform_name)
74
+ except AttributeError as error:
75
+ message = f'Transform "{transform_name}" not found in torchio'
76
+ raise ValueError(message) from error
77
+
78
+ params_dict = get_params_dict_from_kwargs(kwargs)
79
+ transform = transform_class(**params_dict)
80
+ if seed is not None:
81
+ torch.manual_seed(seed)
82
+ with Progress(
83
+ SpinnerColumn(),
84
+ TextColumn('[progress.description]{task.description}'),
85
+ transient=True,
86
+ disable=not show_progress,
87
+ ) as progress:
88
+ progress.add_task('Applying transform', total=1)
89
+ apply_transform_to_file(
90
+ input_path,
91
+ transform,
92
+ output_path,
93
+ verbose=verbose,
94
+ class_=imclass,
95
+ )
96
+
97
+
98
+ def get_params_dict_from_kwargs(kwargs):
99
+ from torchio.utils import guess_type
100
+
101
+ params_dict = {}
102
+ if kwargs is not None:
103
+ for substring in kwargs.split():
104
+ try:
105
+ key, value_string = substring.split('=')
106
+ except ValueError as error:
107
+ message = f'Arguments string "{kwargs}" not valid'
108
+ raise ValueError(message) from error
109
+
110
+ value = guess_type(value_string)
111
+ params_dict[key] = value
112
+ return params_dict
113
+
114
+
115
+ if __name__ == '__main__':
116
+ app()
@@ -0,0 +1,56 @@
1
+ # pylint: disable=import-outside-toplevel
2
+ from pathlib import Path
3
+
4
+ import typer
5
+
6
+ app = typer.Typer()
7
+
8
+
9
+ @app.command()
10
+ def main(
11
+ input_path: Path = typer.Argument( # noqa: B008
12
+ ...,
13
+ exists=True,
14
+ file_okay=True,
15
+ dir_okay=True,
16
+ readable=True,
17
+ ),
18
+ plot: bool = typer.Option( # noqa: B008
19
+ False,
20
+ '--plot/--no-plot',
21
+ '-p/-P',
22
+ help='Plot the image using Matplotlib or Pillow.',
23
+ ),
24
+ show: bool = typer.Option( # noqa: B008
25
+ False,
26
+ '--show/--no-show',
27
+ '-s/-S',
28
+ help='Show the image using specialized visualisation software.',
29
+ ),
30
+ label: bool = typer.Option( # noqa: B008
31
+ False,
32
+ '--label/--scalar',
33
+ '-l/-s',
34
+ help='Use torchio.LabelMap to instantiate the image.',
35
+ ),
36
+ ):
37
+ """Print information about an image and, optionally, show it.
38
+
39
+ Example:
40
+ $ tiohd input.nii.gz
41
+ """
42
+ # Imports are placed here so that the tool loads faster if not being run
43
+ import torchio as tio
44
+
45
+ class_ = tio.LabelMap if label else tio.ScalarImage
46
+ image = class_(input_path)
47
+ image.load()
48
+ print(image) # noqa: T201
49
+ if plot:
50
+ image.plot()
51
+ if show:
52
+ image.show()
53
+
54
+
55
+ if __name__ == '__main__':
56
+ app()
torchio/constants.py ADDED
@@ -0,0 +1,36 @@
1
+ import torch
2
+
3
+ # Image types
4
+ INTENSITY = 'intensity'
5
+ LABEL = 'label'
6
+ SAMPLING_MAP = 'sampling_map'
7
+
8
+ # Keys for dataset samples
9
+ PATH = 'path'
10
+ TYPE = 'type'
11
+ STEM = 'stem'
12
+ DATA = 'data'
13
+ AFFINE = 'affine'
14
+ TENSOR = 'tensor'
15
+
16
+ # For aggregator
17
+ IMAGE = 'image'
18
+ LOCATION = 'location'
19
+
20
+ # For special collate function
21
+ HISTORY = 'history'
22
+
23
+ # In PyTorch convention
24
+ CHANNELS_DIMENSION = 1
25
+
26
+ # Code repository
27
+ REPO_URL = 'https://github.com/fepegar/torchio/'
28
+
29
+ # Data repository
30
+ DATA_REPO = 'https://github.com/fepegar/torchio-data/raw/main/data/'
31
+
32
+ # Floating point error
33
+ MIN_FLOAT_32 = torch.finfo(torch.float32).eps
34
+
35
+ # For the queue
36
+ NUM_SAMPLES = 'num_samples'
@@ -0,0 +1,29 @@
1
+ from .dataset import SubjectsDataset
2
+ from .image import Image
3
+ from .image import LabelMap
4
+ from .image import ScalarImage
5
+ from .inference import GridAggregator
6
+ from .loader import SubjectsLoader
7
+ from .queue import Queue
8
+ from .sampler import GridSampler
9
+ from .sampler import LabelSampler
10
+ from .sampler import PatchSampler
11
+ from .sampler import UniformSampler
12
+ from .sampler import WeightedSampler
13
+ from .subject import Subject
14
+
15
+ __all__ = [
16
+ 'Queue',
17
+ 'Subject',
18
+ 'SubjectsDataset',
19
+ 'SubjectsLoader',
20
+ 'Image',
21
+ 'ScalarImage',
22
+ 'LabelMap',
23
+ 'GridSampler',
24
+ 'GridAggregator',
25
+ 'PatchSampler',
26
+ 'LabelSampler',
27
+ 'WeightedSampler',
28
+ 'UniformSampler',
29
+ ]
@@ -0,0 +1,158 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import Callable
5
+ from typing import Iterable
6
+ from typing import List
7
+ from typing import Optional
8
+ from typing import Sequence
9
+
10
+ from torch.utils.data import Dataset
11
+
12
+ from ..utils import get_subjects_from_batch
13
+ from .subject import Subject
14
+
15
+
16
+ class SubjectsDataset(Dataset):
17
+ """Base TorchIO dataset.
18
+
19
+ Reader of 3D medical images that directly inherits from the PyTorch
20
+ :class:`~torch.utils.data.Dataset`. It can be used with a
21
+ :class:`~tio.SubjectsLoader` for efficient loading and
22
+ augmentation. It receives a list of instances of :class:`~torchio.Subject`
23
+ and an optional transform applied to the volumes after loading.
24
+
25
+ Args:
26
+ subjects: List of instances of :class:`~torchio.Subject`.
27
+ transform: An instance of :class:`~torchio.transforms.Transform`
28
+ that will be applied to each subject.
29
+ load_getitem: Load all subject images before returning it in
30
+ :meth:`__getitem__`. Set it to ``False`` if some of the images will
31
+ not be needed during training.
32
+
33
+ Example:
34
+ >>> import torchio as tio
35
+ >>> subject_a = tio.Subject(
36
+ ... t1=tio.ScalarImage('t1.nrrd',),
37
+ ... t2=tio.ScalarImage('t2.mha',),
38
+ ... label=tio.LabelMap('t1_seg.nii.gz'),
39
+ ... age=31,
40
+ ... name='Fernando Perez',
41
+ ... )
42
+ >>> subject_b = tio.Subject(
43
+ ... t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
44
+ ... t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
45
+ ... label=tio.LabelMap('colin27_seg1.nii.gz'),
46
+ ... age=56,
47
+ ... name='Colin Holmes',
48
+ ... )
49
+ >>> subjects_list = [subject_a, subject_b]
50
+ >>> transforms = [
51
+ ... tio.RescaleIntensity(out_min_max=(0, 1)),
52
+ ... tio.RandomAffine(),
53
+ ... ]
54
+ >>> transform = tio.Compose(transforms)
55
+ >>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
56
+ >>> subject = subjects_dataset[0]
57
+
58
+ .. _NiBabel: https://nipy.org/nibabel/#nibabel
59
+ .. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
60
+ .. _DICOM: https://www.dicomstandard.org/
61
+ .. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
62
+
63
+ .. tip:: To quickly iterate over the subjects without loading the images,
64
+ use :meth:`dry_iter()`.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ subjects: Sequence[Subject],
70
+ transform: Optional[Callable] = None,
71
+ load_getitem: bool = True,
72
+ ):
73
+ self._parse_subjects_list(subjects)
74
+ self._subjects = subjects
75
+ self._transform: Optional[Callable]
76
+ self.set_transform(transform)
77
+ self.load_getitem = load_getitem
78
+
79
+ def __len__(self):
80
+ return len(self._subjects)
81
+
82
+ def __getitem__(self, index: int) -> Subject:
83
+ try:
84
+ index = int(index)
85
+ except (RuntimeError, TypeError) as err:
86
+ message = (
87
+ f'Index "{index}" must be int or compatible dtype,'
88
+ f' but an object of type "{type(index)}" was passed'
89
+ )
90
+ raise ValueError(message) from err
91
+
92
+ subject = self._subjects[index]
93
+ subject = copy.deepcopy(subject) # cheap since images not loaded yet
94
+ if self.load_getitem:
95
+ subject.load()
96
+
97
+ # Apply transform (this is usually the bottleneck)
98
+ if self._transform is not None:
99
+ subject = self._transform(subject)
100
+ return subject
101
+
102
+ @classmethod
103
+ def from_batch(cls, batch: dict) -> SubjectsDataset:
104
+ """Instantiate a dataset from a batch generated by a data loader.
105
+
106
+ Args:
107
+ batch: Dictionary generated by a data loader, containing data that
108
+ can be converted to instances of :class:`~.torchio.Subject`.
109
+ """
110
+ subjects: List[Subject] = get_subjects_from_batch(batch)
111
+ return cls(subjects)
112
+
113
+ def dry_iter(self):
114
+ """Return the internal list of subjects.
115
+
116
+ This can be used to iterate over the subjects without loading the data
117
+ and applying any transforms::
118
+
119
+ >>> names = [subject.name for subject in dataset.dry_iter()]
120
+ """
121
+ return self._subjects
122
+
123
+ def set_transform(self, transform: Optional[Callable]) -> None:
124
+ """Set the :attr:`transform` attribute.
125
+
126
+ Args:
127
+ transform: Callable object, typically an subclass of
128
+ :class:`torchio.transforms.Transform`.
129
+ """
130
+ if transform is not None and not callable(transform):
131
+ message = (
132
+ 'The transform must be a callable object,'
133
+ f' but it has type {type(transform)}'
134
+ )
135
+ raise ValueError(message)
136
+ self._transform = transform
137
+
138
+ @staticmethod
139
+ def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None:
140
+ # Check that it's an iterable
141
+ try:
142
+ iter(subjects_list)
143
+ except TypeError as e:
144
+ message = f'Subject list must be an iterable, not {type(subjects_list)}'
145
+ raise TypeError(message) from e
146
+
147
+ # Check that it's not empty
148
+ if not subjects_list:
149
+ raise ValueError('Subjects list is empty')
150
+
151
+ # Check each element
152
+ for subject in subjects_list:
153
+ if not isinstance(subject, Subject):
154
+ message = (
155
+ 'Subjects list must contain instances of torchio.Subject,'
156
+ f' not "{type(subject)}"'
157
+ )
158
+ raise TypeError(message)