cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/contrast.py ADDED
@@ -0,0 +1,260 @@
1
+ """This module contains transforms that operate on image contrasts."""
2
+ __all__ = [
3
+ 'ContrastMixtureTransform',
4
+ 'ContrastMixtureFinalTransform',
5
+ 'ContrastLookupTransform',
6
+ 'ContrastLookupFinalTransform'
7
+ ]
8
+ # stdlib
9
+ from math import inf
10
+
11
+ # dependencies
12
+ import torch
13
+ from torch import Tensor
14
+
15
+ # internals
16
+ from .base import NonFinalTransform, FinalTransform, Transform
17
+ from .special import PerChannelTransform
18
+ from .utils.gmm import fit_gmm
19
+
20
+
21
+ class ContrastMixtureFinalTransform(FinalTransform):
22
+ """Classwise shift and rescaling."""
23
+
24
+ def __init__(
25
+ self, z: Tensor,
26
+ mu0: Tensor, sigma0: Tensor, mu: Tensor, sigma: Tensor,
27
+ **kwargs
28
+ ) -> None:
29
+ """
30
+ Parameters
31
+ ----------
32
+ z : (K, *spatial) tensor
33
+ Probability that each voxel belongs to a given class.
34
+ mu0 : (K, C) tensor
35
+ Original means for each class.
36
+ sigma0 : (K, C, C) tensor
37
+ Original covariances for each class.
38
+ mu : (K, C) tensor
39
+ New means for each class.
40
+ sigma : (K, C, C) tensor
41
+ New covariances for each class.
42
+
43
+ Other Parameters
44
+ ----------------
45
+ returns, append, prefix, include, exclude, consume
46
+ See [`Transform`][cornucopia.base.Transform] for details.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.z = z
50
+ self.mu0 = mu0
51
+ self.sigma0 = sigma0
52
+ self.mu = mu
53
+ self.sigma = sigma
54
+
55
+ def _xform(self, x: Tensor) -> Tensor:
56
+ z = self.z.to(x)
57
+ mu0 = self.mu0.to(x)
58
+ sigma0 = self.sigma0.to(x)
59
+ mu = self.mu.to(x)
60
+ sigma = self.sigma.to(x)
61
+
62
+ # Whiten using fitted parameters
63
+ chol = torch.linalg.cholesky(sigma0)
64
+ chol = chol.inverse()
65
+ x = x.movedim(0, -1)
66
+ x = x[..., None, :] - mu0 # [..., nk, nc]
67
+ x = torch.matmul(chol, x[..., :, None]) # [..., nk, nc, 1]
68
+
69
+ # Color using new parameters
70
+ chol = torch.linalg.cholesky(sigma)
71
+ x = torch.matmul(chol, x) # [..., nk, nc, 1]
72
+ x = x[..., 0] + mu # [..., nk, nc]
73
+ x = x.movedim(-1, 0) # [..., nc, nk]
74
+
75
+ # Weight using posterior
76
+ z = z.movedim(0, -1)
77
+ x = torch.matmul(x[..., None, :], z[..., None]) # [..., nc, 1, 1]
78
+ x = x[..., 0, 0] # [..., nc]
79
+
80
+ return x
81
+
82
+
83
+ class ContrastMixtureTransform(NonFinalTransform):
84
+ """
85
+ Find intensity modes using a GMM and change their means and covariances.
86
+
87
+ ??? reference
88
+ Meyer, M.I., de la Rosa, E., Pedrosa de Barros, N., Paolella, R.,
89
+ Van Leemput, K. and Sima, D.M., 2021.
90
+ [**A contrast augmentation approach to improve multi-scanner
91
+ generalization in MRI.**](https://www.frontiersin.org/articles/10.3389/fnins.2021.708196)
92
+ Frontiers in Neuroscience, 15, p.708196.
93
+
94
+ @article{meyer2021,
95
+ title = {A contrast augmentation approach to improve multi-scanner generalization in MRI},
96
+ author = {Meyer, Maria Ines and de la Rosa, Ezequiel and Pedrosa de Barros, Nuno and Paolella, Roberto and Van Leemput, Koen and Sima, Diana M},
97
+ journal = {Frontiers in Neuroscience},
98
+ volume = {15},
99
+ pages = {708196},
100
+ year = {2021},
101
+ publisher = {Frontiers Media SA},
102
+ url = {https://www.frontiersin.org/articles/10.3389/fnins.2021.708196}
103
+ }
104
+
105
+ """
106
+
107
+ Final = Next = ContrastMixtureFinalTransform
108
+ """The transform type returned by `unroll`, `next` and `final`."""
109
+
110
+ def __init__(
111
+ self,
112
+ nk: int = 16,
113
+ keep_background: bool = True,
114
+ *,
115
+ shared: bool = False,
116
+ **kwargs
117
+ ) -> None:
118
+ """
119
+
120
+ Parameters
121
+ ----------
122
+ nk : int
123
+ Number of classes
124
+ keep_background : bool
125
+ Do not change background mean/cov.
126
+ The background class is the class with minimum mean value.
127
+
128
+ Other Parameters
129
+ ----------------
130
+ shared
131
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
132
+ for details.
133
+ returns, append, prefix, include, exclude, consume
134
+ See [`Transform`][cornucopia.base.Transform] for details.
135
+ """
136
+ super().__init__(shared=shared, **kwargs)
137
+ self.keep_background = keep_background
138
+ self.nk = nk
139
+
140
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
141
+ if max_depth == 0:
142
+ return self
143
+ z, mu0, sigma0, _ = fit_gmm(x, self.nk)
144
+ mu, sigma = self._make_parameters(mu0, sigma0)
145
+ return self.Next(
146
+ z, mu0, sigma0, mu, sigma, **self.get_prm()
147
+ ).unroll(x, max_depth-1)
148
+
149
+ def _make_parameters(self, old_mu, old_sigma):
150
+ backend = dict(dtype=old_mu.dtype, device=old_mu.device)
151
+ nk, nc = old_mu.shape
152
+ old_mu_min = old_mu.min(0).values
153
+ old_mu_max = old_mu.max(0).values
154
+ old_sigma_diag = old_sigma.diagonal(0, -1, -2)
155
+ old_sigma_min = old_sigma_diag.min(0).values.sqrt()
156
+ old_sigma_max = old_sigma_diag.max(0).values.sqrt()
157
+
158
+ mu = torch.rand_like(
159
+ old_mu).mul_(old_mu_max - old_mu_min).add_(old_mu_min)
160
+ sigma = torch.rand_like(
161
+ old_sigma_diag
162
+ ).mul_(old_sigma_max - old_sigma_min).add_(old_sigma_min)
163
+ corr = torch.rand([len(old_mu), nc*(nc-1)//2], **backend).mul_(0.5)
164
+
165
+ fullsigma = torch.eye(nc, **backend).expand([nk, nc, nc]).clone()
166
+ cnt = 0
167
+ for i in range(nc):
168
+ for j in range(i+1, nc):
169
+ fullsigma[:, i, j] = fullsigma[:, j, i] = corr[:, cnt]
170
+ cnt += 1
171
+ fullsigma = fullsigma * sigma[:, :, None] * sigma[:, None, :]
172
+
173
+ if self.keep_background:
174
+ idx = old_mu.square().sum(-1).sqrt().min(0).indices
175
+ mu[idx] = old_mu[idx]
176
+ fullsigma[idx] = old_sigma[idx]
177
+
178
+ return mu, fullsigma
179
+
180
+
181
+ class ContrastLookupFinalTransform(FinalTransform):
182
+ """Binwise intensity shift."""
183
+
184
+ def __init__(self, edges: Tensor, mu: Tensor, **kwargs) -> None:
185
+ """
186
+ Parameters
187
+ ----------
188
+ edges : (K+1,) tensor
189
+ The limits of each input bin.
190
+ mu : (K,) tensor
191
+ The new mean value for each bin.
192
+
193
+ Other Parameters
194
+ ----------------
195
+ returns, append, prefix, include, exclude, consume
196
+ See [`Transform`][cornucopia.base.Transform] for details.
197
+ """
198
+ super().__init__(**kwargs)
199
+ self.edges = edges
200
+ self.mu = mu
201
+
202
+ def _xform(self, x: Tensor) -> Tensor:
203
+ edges, mu = self.edges.to(x), self.mu.to(x)
204
+ mu0 = (edges[:-1] + edges[1:]) / 2
205
+ nk = len(mu)
206
+
207
+ new_x = x.clone()
208
+ for k in range(nk):
209
+ mask = (edges[k] <= x) & (x < edges[k+1])
210
+ new_x[mask] += mu[k] - mu0[k]
211
+ return new_x
212
+
213
+
214
+ class ContrastLookupTransform(NonFinalTransform):
215
+ """
216
+ Segment intensities into equidistant bins and change their mean value.
217
+ """
218
+
219
+ Final = Next = ContrastLookupFinalTransform
220
+ """The transform type returned by `unroll`, `next` and `final`."""
221
+
222
+ def __init__(self, nk=16, keep_background=True,
223
+ *, shared=False, **kwargs):
224
+ """
225
+
226
+ Parameters
227
+ ----------
228
+ nk : int
229
+ Number of classes
230
+ keep_background : bool
231
+ Do not change background mean/cov.
232
+ The background class is the class with minimum mean value.
233
+
234
+ Other Parameters
235
+ ----------------
236
+ shared
237
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
238
+ for details.
239
+ returns, append, prefix, include, exclude, consume
240
+ See [`Transform`][cornucopia.base.Transform] for details.
241
+ """
242
+ super().__init__(shared=shared, **kwargs)
243
+ self.keep_background = keep_background
244
+ self.nk = nk
245
+
246
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
247
+ if max_depth == 0:
248
+ return self
249
+ if 'channels' not in self.shared and len(x) > 1:
250
+ return PerChannelTransform(
251
+ [self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
252
+ **self.get_prm()
253
+ ).unroll(x, max_depth-1)
254
+
255
+ vmin, vmax = x.min(), x.max()
256
+ edges = torch.linspace(vmin, vmax, self.nk+1)
257
+ new_mu = torch.rand(self.nk).to(x) * (vmax - vmin) + vmin
258
+ return self.Next(
259
+ edges, new_mu, **self.get_prm()
260
+ ).unroll(x, max_depth-1)
cornucopia/ctx.py ADDED
@@ -0,0 +1,25 @@
1
+ """This module contains context managers that modify the behavior of a transform."""
2
+ __all__ = [
3
+ 'include',
4
+ 'exclude',
5
+ 'consume',
6
+ 'batch',
7
+ 'shared',
8
+ 'returns',
9
+ 'maybe',
10
+ 'switch',
11
+ 'map',
12
+ 'randomize',
13
+ ]
14
+ from .special import (
15
+ IncludeKeysTransform as include,
16
+ ExcludeKeysTransform as exclude,
17
+ ConsumeKeysTransform as consume,
18
+ SharedTransform as shared,
19
+ ReturningTransform as returns,
20
+ MaybeTransform as maybe,
21
+ SwitchTransform as switch,
22
+ MappedTransform as map,
23
+ RandomizedTransform as randomize,
24
+ BatchedTransform as batch,
25
+ )