cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/contrast.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""This module contains transforms that operate on image contrasts."""
|
|
2
|
+
__all__ = [
|
|
3
|
+
'ContrastMixtureTransform',
|
|
4
|
+
'ContrastMixtureFinalTransform',
|
|
5
|
+
'ContrastLookupTransform',
|
|
6
|
+
'ContrastLookupFinalTransform'
|
|
7
|
+
]
|
|
8
|
+
# stdlib
|
|
9
|
+
from math import inf
|
|
10
|
+
|
|
11
|
+
# dependencies
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
|
|
15
|
+
# internals
|
|
16
|
+
from .base import NonFinalTransform, FinalTransform, Transform
|
|
17
|
+
from .special import PerChannelTransform
|
|
18
|
+
from .utils.gmm import fit_gmm
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ContrastMixtureFinalTransform(FinalTransform):
|
|
22
|
+
"""Classwise shift and rescaling."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self, z: Tensor,
|
|
26
|
+
mu0: Tensor, sigma0: Tensor, mu: Tensor, sigma: Tensor,
|
|
27
|
+
**kwargs
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
z : (K, *spatial) tensor
|
|
33
|
+
Probability that each voxel belongs to a given class.
|
|
34
|
+
mu0 : (K, C) tensor
|
|
35
|
+
Original means for each class.
|
|
36
|
+
sigma0 : (K, C, C) tensor
|
|
37
|
+
Original covariances for each class.
|
|
38
|
+
mu : (K, C) tensor
|
|
39
|
+
New means for each class.
|
|
40
|
+
sigma : (K, C, C) tensor
|
|
41
|
+
New covariances for each class.
|
|
42
|
+
|
|
43
|
+
Other Parameters
|
|
44
|
+
----------------
|
|
45
|
+
returns, append, prefix, include, exclude, consume
|
|
46
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
self.z = z
|
|
50
|
+
self.mu0 = mu0
|
|
51
|
+
self.sigma0 = sigma0
|
|
52
|
+
self.mu = mu
|
|
53
|
+
self.sigma = sigma
|
|
54
|
+
|
|
55
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
56
|
+
z = self.z.to(x)
|
|
57
|
+
mu0 = self.mu0.to(x)
|
|
58
|
+
sigma0 = self.sigma0.to(x)
|
|
59
|
+
mu = self.mu.to(x)
|
|
60
|
+
sigma = self.sigma.to(x)
|
|
61
|
+
|
|
62
|
+
# Whiten using fitted parameters
|
|
63
|
+
chol = torch.linalg.cholesky(sigma0)
|
|
64
|
+
chol = chol.inverse()
|
|
65
|
+
x = x.movedim(0, -1)
|
|
66
|
+
x = x[..., None, :] - mu0 # [..., nk, nc]
|
|
67
|
+
x = torch.matmul(chol, x[..., :, None]) # [..., nk, nc, 1]
|
|
68
|
+
|
|
69
|
+
# Color using new parameters
|
|
70
|
+
chol = torch.linalg.cholesky(sigma)
|
|
71
|
+
x = torch.matmul(chol, x) # [..., nk, nc, 1]
|
|
72
|
+
x = x[..., 0] + mu # [..., nk, nc]
|
|
73
|
+
x = x.movedim(-1, 0) # [..., nc, nk]
|
|
74
|
+
|
|
75
|
+
# Weight using posterior
|
|
76
|
+
z = z.movedim(0, -1)
|
|
77
|
+
x = torch.matmul(x[..., None, :], z[..., None]) # [..., nc, 1, 1]
|
|
78
|
+
x = x[..., 0, 0] # [..., nc]
|
|
79
|
+
|
|
80
|
+
return x
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ContrastMixtureTransform(NonFinalTransform):
|
|
84
|
+
"""
|
|
85
|
+
Find intensity modes using a GMM and change their means and covariances.
|
|
86
|
+
|
|
87
|
+
??? reference
|
|
88
|
+
Meyer, M.I., de la Rosa, E., Pedrosa de Barros, N., Paolella, R.,
|
|
89
|
+
Van Leemput, K. and Sima, D.M., 2021.
|
|
90
|
+
[**A contrast augmentation approach to improve multi-scanner
|
|
91
|
+
generalization in MRI.**](https://www.frontiersin.org/articles/10.3389/fnins.2021.708196)
|
|
92
|
+
Frontiers in Neuroscience, 15, p.708196.
|
|
93
|
+
|
|
94
|
+
@article{meyer2021,
|
|
95
|
+
title = {A contrast augmentation approach to improve multi-scanner generalization in MRI},
|
|
96
|
+
author = {Meyer, Maria Ines and de la Rosa, Ezequiel and Pedrosa de Barros, Nuno and Paolella, Roberto and Van Leemput, Koen and Sima, Diana M},
|
|
97
|
+
journal = {Frontiers in Neuroscience},
|
|
98
|
+
volume = {15},
|
|
99
|
+
pages = {708196},
|
|
100
|
+
year = {2021},
|
|
101
|
+
publisher = {Frontiers Media SA},
|
|
102
|
+
url = {https://www.frontiersin.org/articles/10.3389/fnins.2021.708196}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
Final = Next = ContrastMixtureFinalTransform
|
|
108
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
nk: int = 16,
|
|
113
|
+
keep_background: bool = True,
|
|
114
|
+
*,
|
|
115
|
+
shared: bool = False,
|
|
116
|
+
**kwargs
|
|
117
|
+
) -> None:
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
nk : int
|
|
123
|
+
Number of classes
|
|
124
|
+
keep_background : bool
|
|
125
|
+
Do not change background mean/cov.
|
|
126
|
+
The background class is the class with minimum mean value.
|
|
127
|
+
|
|
128
|
+
Other Parameters
|
|
129
|
+
----------------
|
|
130
|
+
shared
|
|
131
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
132
|
+
for details.
|
|
133
|
+
returns, append, prefix, include, exclude, consume
|
|
134
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
135
|
+
"""
|
|
136
|
+
super().__init__(shared=shared, **kwargs)
|
|
137
|
+
self.keep_background = keep_background
|
|
138
|
+
self.nk = nk
|
|
139
|
+
|
|
140
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
141
|
+
if max_depth == 0:
|
|
142
|
+
return self
|
|
143
|
+
z, mu0, sigma0, _ = fit_gmm(x, self.nk)
|
|
144
|
+
mu, sigma = self._make_parameters(mu0, sigma0)
|
|
145
|
+
return self.Next(
|
|
146
|
+
z, mu0, sigma0, mu, sigma, **self.get_prm()
|
|
147
|
+
).unroll(x, max_depth-1)
|
|
148
|
+
|
|
149
|
+
def _make_parameters(self, old_mu, old_sigma):
|
|
150
|
+
backend = dict(dtype=old_mu.dtype, device=old_mu.device)
|
|
151
|
+
nk, nc = old_mu.shape
|
|
152
|
+
old_mu_min = old_mu.min(0).values
|
|
153
|
+
old_mu_max = old_mu.max(0).values
|
|
154
|
+
old_sigma_diag = old_sigma.diagonal(0, -1, -2)
|
|
155
|
+
old_sigma_min = old_sigma_diag.min(0).values.sqrt()
|
|
156
|
+
old_sigma_max = old_sigma_diag.max(0).values.sqrt()
|
|
157
|
+
|
|
158
|
+
mu = torch.rand_like(
|
|
159
|
+
old_mu).mul_(old_mu_max - old_mu_min).add_(old_mu_min)
|
|
160
|
+
sigma = torch.rand_like(
|
|
161
|
+
old_sigma_diag
|
|
162
|
+
).mul_(old_sigma_max - old_sigma_min).add_(old_sigma_min)
|
|
163
|
+
corr = torch.rand([len(old_mu), nc*(nc-1)//2], **backend).mul_(0.5)
|
|
164
|
+
|
|
165
|
+
fullsigma = torch.eye(nc, **backend).expand([nk, nc, nc]).clone()
|
|
166
|
+
cnt = 0
|
|
167
|
+
for i in range(nc):
|
|
168
|
+
for j in range(i+1, nc):
|
|
169
|
+
fullsigma[:, i, j] = fullsigma[:, j, i] = corr[:, cnt]
|
|
170
|
+
cnt += 1
|
|
171
|
+
fullsigma = fullsigma * sigma[:, :, None] * sigma[:, None, :]
|
|
172
|
+
|
|
173
|
+
if self.keep_background:
|
|
174
|
+
idx = old_mu.square().sum(-1).sqrt().min(0).indices
|
|
175
|
+
mu[idx] = old_mu[idx]
|
|
176
|
+
fullsigma[idx] = old_sigma[idx]
|
|
177
|
+
|
|
178
|
+
return mu, fullsigma
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class ContrastLookupFinalTransform(FinalTransform):
|
|
182
|
+
"""Binwise intensity shift."""
|
|
183
|
+
|
|
184
|
+
def __init__(self, edges: Tensor, mu: Tensor, **kwargs) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
edges : (K+1,) tensor
|
|
189
|
+
The limits of each input bin.
|
|
190
|
+
mu : (K,) tensor
|
|
191
|
+
The new mean value for each bin.
|
|
192
|
+
|
|
193
|
+
Other Parameters
|
|
194
|
+
----------------
|
|
195
|
+
returns, append, prefix, include, exclude, consume
|
|
196
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
197
|
+
"""
|
|
198
|
+
super().__init__(**kwargs)
|
|
199
|
+
self.edges = edges
|
|
200
|
+
self.mu = mu
|
|
201
|
+
|
|
202
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
203
|
+
edges, mu = self.edges.to(x), self.mu.to(x)
|
|
204
|
+
mu0 = (edges[:-1] + edges[1:]) / 2
|
|
205
|
+
nk = len(mu)
|
|
206
|
+
|
|
207
|
+
new_x = x.clone()
|
|
208
|
+
for k in range(nk):
|
|
209
|
+
mask = (edges[k] <= x) & (x < edges[k+1])
|
|
210
|
+
new_x[mask] += mu[k] - mu0[k]
|
|
211
|
+
return new_x
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class ContrastLookupTransform(NonFinalTransform):
|
|
215
|
+
"""
|
|
216
|
+
Segment intensities into equidistant bins and change their mean value.
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
Final = Next = ContrastLookupFinalTransform
|
|
220
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
221
|
+
|
|
222
|
+
def __init__(self, nk=16, keep_background=True,
|
|
223
|
+
*, shared=False, **kwargs):
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
nk : int
|
|
229
|
+
Number of classes
|
|
230
|
+
keep_background : bool
|
|
231
|
+
Do not change background mean/cov.
|
|
232
|
+
The background class is the class with minimum mean value.
|
|
233
|
+
|
|
234
|
+
Other Parameters
|
|
235
|
+
----------------
|
|
236
|
+
shared
|
|
237
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
238
|
+
for details.
|
|
239
|
+
returns, append, prefix, include, exclude, consume
|
|
240
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
241
|
+
"""
|
|
242
|
+
super().__init__(shared=shared, **kwargs)
|
|
243
|
+
self.keep_background = keep_background
|
|
244
|
+
self.nk = nk
|
|
245
|
+
|
|
246
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
247
|
+
if max_depth == 0:
|
|
248
|
+
return self
|
|
249
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
250
|
+
return PerChannelTransform(
|
|
251
|
+
[self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
|
|
252
|
+
**self.get_prm()
|
|
253
|
+
).unroll(x, max_depth-1)
|
|
254
|
+
|
|
255
|
+
vmin, vmax = x.min(), x.max()
|
|
256
|
+
edges = torch.linspace(vmin, vmax, self.nk+1)
|
|
257
|
+
new_mu = torch.rand(self.nk).to(x) * (vmax - vmin) + vmin
|
|
258
|
+
return self.Next(
|
|
259
|
+
edges, new_mu, **self.get_prm()
|
|
260
|
+
).unroll(x, max_depth-1)
|
cornucopia/ctx.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""This module contains context managers that modify the behavior of a transform."""
|
|
2
|
+
__all__ = [
|
|
3
|
+
'include',
|
|
4
|
+
'exclude',
|
|
5
|
+
'consume',
|
|
6
|
+
'batch',
|
|
7
|
+
'shared',
|
|
8
|
+
'returns',
|
|
9
|
+
'maybe',
|
|
10
|
+
'switch',
|
|
11
|
+
'map',
|
|
12
|
+
'randomize',
|
|
13
|
+
]
|
|
14
|
+
from .special import (
|
|
15
|
+
IncludeKeysTransform as include,
|
|
16
|
+
ExcludeKeysTransform as exclude,
|
|
17
|
+
ConsumeKeysTransform as consume,
|
|
18
|
+
SharedTransform as shared,
|
|
19
|
+
ReturningTransform as returns,
|
|
20
|
+
MaybeTransform as maybe,
|
|
21
|
+
SwitchTransform as switch,
|
|
22
|
+
MappedTransform as map,
|
|
23
|
+
RandomizedTransform as randomize,
|
|
24
|
+
BatchedTransform as batch,
|
|
25
|
+
)
|