cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,163 @@
1
+ """
2
+ These are "smart" inplace operators that only operate inplace if doing
3
+ so does not break the computational graph with respect to variables
4
+ that require grad.
5
+
6
+ Note that they should still be used carefully, as the overwritten tensors
7
+ may be needed when backpropagating through other operations. For example,
8
+ the following code would break the computational graph:
9
+
10
+ ```python
11
+ x = torch.randn([])
12
+
13
+ y = torch.randn([])
14
+ y.requires_grad = True
15
+
16
+ a = x.mul(y) # mul requires x to be saved to bakpropagate through y
17
+ b = x.add_(1) # but we overwrite x here
18
+ c = a + b
19
+
20
+ c.backward()
21
+ ```
22
+
23
+ whereas this would work
24
+
25
+ ```python
26
+ x = torch.randn([])
27
+
28
+ y = torch.randn([])
29
+ y.requires_grad = True
30
+
31
+ a = x.add(y) # add does not need anythong to backpropagate
32
+ b = x.add_(1) # so we can overwrite x here
33
+ c = a + b
34
+
35
+ c.backward()
36
+ ```
37
+
38
+ """
39
+ import math
40
+ import torch
41
+
42
+
43
+ def add_(x, y, **kwargs):
44
+ # d(x+a*y)/dx = 1
45
+ # d(x+a*y)/dy = a
46
+ # d(x+a*y)/da = y
47
+ # -> we can overwrite x
48
+ if not torch.is_tensor(x):
49
+ return x + y * kwargs.get('alpha', 1)
50
+ return x.add_(y, **kwargs)
51
+
52
+
53
+ def sub_(x, y, **kwargs):
54
+ # d(x-a*y)/dx = 1
55
+ # d(x-a*y)/dy = -a
56
+ # d(x-a*y)/da = -y
57
+ # -> we can overwrite x
58
+ if not torch.is_tensor(x):
59
+ return x - y * kwargs.get('alpha', 1)
60
+ return x.sub_(y, **kwargs)
61
+
62
+
63
+ def mul_(x, y, **kwargs):
64
+ # d(x*y)/dx = y
65
+ # d(x*y)/dy = x
66
+ # -> we can overwrite x if we do not backprop through y
67
+ if not torch.is_tensor(x):
68
+ return x * y
69
+ return (
70
+ x.mul(y, **kwargs) if getattr(y, 'requires_grad', False) else
71
+ x.mul_(y, **kwargs)
72
+ )
73
+
74
+
75
+ def div_(x, y, **kwargs):
76
+ # d(x/y)/dx = 1/y
77
+ # d(x/y)/dy = -x/y**2
78
+ # -> we can overwrite x if we do not backprop through y
79
+ if not torch.is_tensor(x):
80
+ return x / y
81
+ return (
82
+ x.div(y, **kwargs) if getattr(y, 'requires_grad', False) else
83
+ x.div_(y, **kwargs)
84
+ )
85
+
86
+
87
+ def pow_(x, y, **kwargs):
88
+ # d(x**y)/dx = y * x**(y-1)
89
+ # d(x**y)/dy = (x**y) * log(|x|) * sign(x)**y
90
+ # -> we can overwrite x if we do not backprop through x or y
91
+ if not torch.is_tensor(x):
92
+ return x ** y
93
+ inplace = not (x.requires_grad or getattr(y, 'requires_grad', False))
94
+ return x.pow(y, **kwargs) if not inplace else x.pow_(y, **kwargs)
95
+
96
+
97
+ def square_(x, **kwargs):
98
+ # d(x**2)/dx = 2*x
99
+ # -> we can overwrite x if we do not backprop through x
100
+ if not torch.is_tensor(x):
101
+ return x * x
102
+ return x.square(**kwargs) if x.requires_grad else x.square_(**kwargs)
103
+
104
+
105
+ def sqrt_(x, **kwargs):
106
+ # d(x**0.5)/dx = 0.5*x
107
+ # -> we can overwrite x if we do not backprop through x
108
+ if not torch.is_tensor(x):
109
+ return x ** 0.5
110
+ return x.sqrt(**kwargs) if x.requires_grad else x.sqrt_(**kwargs)
111
+
112
+
113
+ def atan2_(x, y, **kwargs):
114
+ if not torch.is_tensor(x) and not torch.is_tensor(y):
115
+ return math.atan2(x, y)
116
+ if not torch.is_tensor(x):
117
+ x = torch.as_tensor(x, dtype=y.dtype, device=y.device)
118
+ if not torch.is_tensor(y):
119
+ y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
120
+ inplace = not (x.requires_grad or y.requires_grad)
121
+ return x.atan2(y, **kwargs) if not inplace else x.atan2_(y, **kwargs)
122
+
123
+
124
+ def neg_(x, **kwargs):
125
+ if not torch.is_tensor(x):
126
+ return -x
127
+ return x.neg_(**kwargs)
128
+
129
+
130
+ def reciprocal_(x, **kwargs):
131
+ if not torch.is_tensor(x):
132
+ return 1/x
133
+ return (
134
+ x.reciprocal(**kwargs) if x.requires_grad else
135
+ x.reciprocal_(**kwargs)
136
+ )
137
+
138
+
139
+ def abs_(x, **kwargs):
140
+ if not torch.is_tensor(x):
141
+ return abs(x)
142
+ if torch.is_complex(x):
143
+ # abs_ not supported for complex tensors
144
+ return x.abs(**kwargs)
145
+ return x.abs(**kwargs) if x.requires_grad else x.abs_(**kwargs)
146
+
147
+
148
+ def exp_(x, **kwargs):
149
+ if not torch.is_tensor(x):
150
+ return math.exp(x)
151
+ return x.exp(**kwargs) if x.requires_grad else x.exp_(**kwargs)
152
+
153
+
154
+ def log_(x, **kwargs):
155
+ if not torch.is_tensor(x):
156
+ return math.log(x)
157
+ return x.log(**kwargs) if x.requires_grad else x.log_(**kwargs)
158
+
159
+
160
+ def atan_(x, **kwargs):
161
+ if not torch.is_tensor(x):
162
+ return math.atan(x)
163
+ return x.atan(**kwargs) if x.requires_grad else x.atan_(**kwargs)
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import numbers
3
+
4
+
5
+ def _compare_versions(version1, mode, version2):
6
+ for v1, v2 in zip(version1, version2):
7
+ if mode in ('gt', '>'):
8
+ if v1 > v2:
9
+ return True
10
+ elif v1 < v2:
11
+ return False
12
+ elif mode in ('ge', '>='):
13
+ if v1 > v2:
14
+ return True
15
+ elif v1 < v2:
16
+ return False
17
+ elif mode in ('lt', '<'):
18
+ if v1 < v2:
19
+ return True
20
+ elif v1 > v2:
21
+ return False
22
+ elif mode in ('le', '<='):
23
+ if v1 < v2:
24
+ return True
25
+ elif v1 > v2:
26
+ return False
27
+ if mode in ('gt', 'lt', '>', '<'):
28
+ return False
29
+ else:
30
+ return True
31
+
32
+
33
+ def torch_version(mode, version):
34
+ """Check torch version
35
+
36
+ Parameters
37
+ ----------
38
+ mode : {'<', '<=', '>', '>='}
39
+ version : tuple[int]
40
+
41
+ Returns
42
+ -------
43
+ True if "torch.version <mode> version"
44
+
45
+ """
46
+ current_version, *cuda_variant = torch.__version__.split('+')
47
+ major, minor, patch, *_ = current_version.split('.')
48
+ # strip alpha tags
49
+ for x in 'abcdefghijklmnopqrstuvwxy':
50
+ if x in patch:
51
+ patch = patch[:patch.index(x)]
52
+ current_version = (int(major), int(minor), int(patch))
53
+
54
+ if isinstance(version, numbers.Number):
55
+ version = [version]
56
+ version = list(version)
57
+ return _compare_versions(current_version, mode, version)