cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
These are "smart" inplace operators that only operate inplace if doing
|
|
3
|
+
so does not break the computational graph with respect to variables
|
|
4
|
+
that require grad.
|
|
5
|
+
|
|
6
|
+
Note that they should still be used carefully, as the overwritten tensors
|
|
7
|
+
may be needed when backpropagating through other operations. For example,
|
|
8
|
+
the following code would break the computational graph:
|
|
9
|
+
|
|
10
|
+
```python
|
|
11
|
+
x = torch.randn([])
|
|
12
|
+
|
|
13
|
+
y = torch.randn([])
|
|
14
|
+
y.requires_grad = True
|
|
15
|
+
|
|
16
|
+
a = x.mul(y) # mul requires x to be saved to bakpropagate through y
|
|
17
|
+
b = x.add_(1) # but we overwrite x here
|
|
18
|
+
c = a + b
|
|
19
|
+
|
|
20
|
+
c.backward()
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
whereas this would work
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
x = torch.randn([])
|
|
27
|
+
|
|
28
|
+
y = torch.randn([])
|
|
29
|
+
y.requires_grad = True
|
|
30
|
+
|
|
31
|
+
a = x.add(y) # add does not need anythong to backpropagate
|
|
32
|
+
b = x.add_(1) # so we can overwrite x here
|
|
33
|
+
c = a + b
|
|
34
|
+
|
|
35
|
+
c.backward()
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
import math
|
|
40
|
+
import torch
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def add_(x, y, **kwargs):
|
|
44
|
+
# d(x+a*y)/dx = 1
|
|
45
|
+
# d(x+a*y)/dy = a
|
|
46
|
+
# d(x+a*y)/da = y
|
|
47
|
+
# -> we can overwrite x
|
|
48
|
+
if not torch.is_tensor(x):
|
|
49
|
+
return x + y * kwargs.get('alpha', 1)
|
|
50
|
+
return x.add_(y, **kwargs)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def sub_(x, y, **kwargs):
|
|
54
|
+
# d(x-a*y)/dx = 1
|
|
55
|
+
# d(x-a*y)/dy = -a
|
|
56
|
+
# d(x-a*y)/da = -y
|
|
57
|
+
# -> we can overwrite x
|
|
58
|
+
if not torch.is_tensor(x):
|
|
59
|
+
return x - y * kwargs.get('alpha', 1)
|
|
60
|
+
return x.sub_(y, **kwargs)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def mul_(x, y, **kwargs):
|
|
64
|
+
# d(x*y)/dx = y
|
|
65
|
+
# d(x*y)/dy = x
|
|
66
|
+
# -> we can overwrite x if we do not backprop through y
|
|
67
|
+
if not torch.is_tensor(x):
|
|
68
|
+
return x * y
|
|
69
|
+
return (
|
|
70
|
+
x.mul(y, **kwargs) if getattr(y, 'requires_grad', False) else
|
|
71
|
+
x.mul_(y, **kwargs)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def div_(x, y, **kwargs):
|
|
76
|
+
# d(x/y)/dx = 1/y
|
|
77
|
+
# d(x/y)/dy = -x/y**2
|
|
78
|
+
# -> we can overwrite x if we do not backprop through y
|
|
79
|
+
if not torch.is_tensor(x):
|
|
80
|
+
return x / y
|
|
81
|
+
return (
|
|
82
|
+
x.div(y, **kwargs) if getattr(y, 'requires_grad', False) else
|
|
83
|
+
x.div_(y, **kwargs)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def pow_(x, y, **kwargs):
|
|
88
|
+
# d(x**y)/dx = y * x**(y-1)
|
|
89
|
+
# d(x**y)/dy = (x**y) * log(|x|) * sign(x)**y
|
|
90
|
+
# -> we can overwrite x if we do not backprop through x or y
|
|
91
|
+
if not torch.is_tensor(x):
|
|
92
|
+
return x ** y
|
|
93
|
+
inplace = not (x.requires_grad or getattr(y, 'requires_grad', False))
|
|
94
|
+
return x.pow(y, **kwargs) if not inplace else x.pow_(y, **kwargs)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def square_(x, **kwargs):
|
|
98
|
+
# d(x**2)/dx = 2*x
|
|
99
|
+
# -> we can overwrite x if we do not backprop through x
|
|
100
|
+
if not torch.is_tensor(x):
|
|
101
|
+
return x * x
|
|
102
|
+
return x.square(**kwargs) if x.requires_grad else x.square_(**kwargs)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def sqrt_(x, **kwargs):
|
|
106
|
+
# d(x**0.5)/dx = 0.5*x
|
|
107
|
+
# -> we can overwrite x if we do not backprop through x
|
|
108
|
+
if not torch.is_tensor(x):
|
|
109
|
+
return x ** 0.5
|
|
110
|
+
return x.sqrt(**kwargs) if x.requires_grad else x.sqrt_(**kwargs)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def atan2_(x, y, **kwargs):
|
|
114
|
+
if not torch.is_tensor(x) and not torch.is_tensor(y):
|
|
115
|
+
return math.atan2(x, y)
|
|
116
|
+
if not torch.is_tensor(x):
|
|
117
|
+
x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
|
|
118
|
+
if not torch.is_tensor(y):
|
|
119
|
+
y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
|
|
120
|
+
inplace = not (x.requires_grad or y.requires_grad)
|
|
121
|
+
return x.atan2(y, **kwargs) if not inplace else x.atan2_(y, **kwargs)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def neg_(x, **kwargs):
|
|
125
|
+
if not torch.is_tensor(x):
|
|
126
|
+
return -x
|
|
127
|
+
return x.neg_(**kwargs)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def reciprocal_(x, **kwargs):
|
|
131
|
+
if not torch.is_tensor(x):
|
|
132
|
+
return 1/x
|
|
133
|
+
return (
|
|
134
|
+
x.reciprocal(**kwargs) if x.requires_grad else
|
|
135
|
+
x.reciprocal_(**kwargs)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def abs_(x, **kwargs):
|
|
140
|
+
if not torch.is_tensor(x):
|
|
141
|
+
return abs(x)
|
|
142
|
+
if torch.is_complex(x):
|
|
143
|
+
# abs_ not supported for complex tensors
|
|
144
|
+
return x.abs(**kwargs)
|
|
145
|
+
return x.abs(**kwargs) if x.requires_grad else x.abs_(**kwargs)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def exp_(x, **kwargs):
|
|
149
|
+
if not torch.is_tensor(x):
|
|
150
|
+
return math.exp(x)
|
|
151
|
+
return x.exp(**kwargs) if x.requires_grad else x.exp_(**kwargs)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def log_(x, **kwargs):
|
|
155
|
+
if not torch.is_tensor(x):
|
|
156
|
+
return math.log(x)
|
|
157
|
+
return x.log(**kwargs) if x.requires_grad else x.log_(**kwargs)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def atan_(x, **kwargs):
|
|
161
|
+
if not torch.is_tensor(x):
|
|
162
|
+
return math.atan(x)
|
|
163
|
+
return x.atan(**kwargs) if x.requires_grad else x.atan_(**kwargs)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numbers
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _compare_versions(version1, mode, version2):
|
|
6
|
+
for v1, v2 in zip(version1, version2):
|
|
7
|
+
if mode in ('gt', '>'):
|
|
8
|
+
if v1 > v2:
|
|
9
|
+
return True
|
|
10
|
+
elif v1 < v2:
|
|
11
|
+
return False
|
|
12
|
+
elif mode in ('ge', '>='):
|
|
13
|
+
if v1 > v2:
|
|
14
|
+
return True
|
|
15
|
+
elif v1 < v2:
|
|
16
|
+
return False
|
|
17
|
+
elif mode in ('lt', '<'):
|
|
18
|
+
if v1 < v2:
|
|
19
|
+
return True
|
|
20
|
+
elif v1 > v2:
|
|
21
|
+
return False
|
|
22
|
+
elif mode in ('le', '<='):
|
|
23
|
+
if v1 < v2:
|
|
24
|
+
return True
|
|
25
|
+
elif v1 > v2:
|
|
26
|
+
return False
|
|
27
|
+
if mode in ('gt', 'lt', '>', '<'):
|
|
28
|
+
return False
|
|
29
|
+
else:
|
|
30
|
+
return True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def torch_version(mode, version):
|
|
34
|
+
"""Check torch version
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
mode : {'<', '<=', '>', '>='}
|
|
39
|
+
version : tuple[int]
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
True if "torch.version <mode> version"
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
current_version, *cuda_variant = torch.__version__.split('+')
|
|
47
|
+
major, minor, patch, *_ = current_version.split('.')
|
|
48
|
+
# strip alpha tags
|
|
49
|
+
for x in 'abcdefghijklmnopqrstuvwxy':
|
|
50
|
+
if x in patch:
|
|
51
|
+
patch = patch[:patch.index(x)]
|
|
52
|
+
current_version = (int(major), int(minor), int(patch))
|
|
53
|
+
|
|
54
|
+
if isinstance(version, numbers.Number):
|
|
55
|
+
version = [version]
|
|
56
|
+
version = list(version)
|
|
57
|
+
return _compare_versions(current_version, mode, version)
|