pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tests/test_backends.py
ADDED
@@ -0,0 +1,446 @@
|
|
1
|
+
import warnings
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from multiprocessing.managers import SharedMemoryManager
|
7
|
+
|
8
|
+
from tme.backends import MatchingBackend, NumpyFFTWBackend, BackendManager, backend
|
9
|
+
|
10
|
+
BACKENDS_TO_TEST = []
|
11
|
+
for backend_class in backend._BACKEND_REGISTRY.values():
|
12
|
+
try:
|
13
|
+
BACKENDS_TO_TEST.append(backend_class(device="cpu"))
|
14
|
+
except ImportError:
|
15
|
+
print(f"Couldn't import {backend_class}. Skipping...")
|
16
|
+
|
17
|
+
METHODS_TO_TEST = MatchingBackend.__abstractmethods__
|
18
|
+
|
19
|
+
|
20
|
+
class TestBackendManager:
|
21
|
+
def setup_method(self):
|
22
|
+
self.manager = BackendManager()
|
23
|
+
|
24
|
+
def test_initialization(self):
|
25
|
+
manager = BackendManager()
|
26
|
+
backend_name = manager._backend_name
|
27
|
+
assert f"<BackendManager: using {backend_name}>" == str(manager)
|
28
|
+
|
29
|
+
def test_dir(self):
|
30
|
+
_ = dir(self.manager)
|
31
|
+
for method in METHODS_TO_TEST:
|
32
|
+
assert hasattr(self.manager, method)
|
33
|
+
|
34
|
+
def test_add_backend(self):
|
35
|
+
self.manager.add_backend(backend_name="test", backend_class=NumpyFFTWBackend)
|
36
|
+
|
37
|
+
def test_add_backend_error(self):
|
38
|
+
class _Bar:
|
39
|
+
def __init__(self):
|
40
|
+
pass
|
41
|
+
|
42
|
+
with pytest.raises(ValueError):
|
43
|
+
self.manager.add_backend(backend_name="test", backend_class=_Bar)
|
44
|
+
|
45
|
+
def test_change_backend_error(self):
|
46
|
+
with pytest.raises(NotImplementedError):
|
47
|
+
self.manager.change_backend(backend_name=None)
|
48
|
+
|
49
|
+
def test_available_backends(self):
|
50
|
+
available = self.manager.available_backends()
|
51
|
+
assert isinstance(available, list)
|
52
|
+
for be in available:
|
53
|
+
assert be in self.manager._BACKEND_REGISTRY
|
54
|
+
|
55
|
+
|
56
|
+
class TestBackends:
|
57
|
+
def setup_method(self):
|
58
|
+
self.backend = NumpyFFTWBackend()
|
59
|
+
self.x1 = np.random.rand(30, 30).astype(np.float32)
|
60
|
+
self.x2 = np.random.rand(30, 30).astype(np.float32)
|
61
|
+
|
62
|
+
def teardown_method(self):
|
63
|
+
self.backend = None
|
64
|
+
|
65
|
+
def test_initialization_errors(self):
|
66
|
+
with pytest.raises(TypeError):
|
67
|
+
_ = MatchingBackend()
|
68
|
+
|
69
|
+
@pytest.mark.parametrize("backend", [type(x) for x in BACKENDS_TO_TEST])
|
70
|
+
def test_initialization(self, backend):
|
71
|
+
_ = backend()
|
72
|
+
|
73
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
74
|
+
@pytest.mark.parametrize(
|
75
|
+
"method_name",
|
76
|
+
("add", "subtract", "multiply", "divide", "minimum", "maximum", "mod"),
|
77
|
+
)
|
78
|
+
def test_arithmetic_operations(self, method_name, backend):
|
79
|
+
base = getattr(self.backend, method_name)(self.x1, self.x2)
|
80
|
+
x1 = backend.to_backend_array(self.x1)
|
81
|
+
x2 = backend.to_backend_array(self.x2)
|
82
|
+
other = getattr(backend, method_name)(x1, x2)
|
83
|
+
|
84
|
+
assert np.allclose(base, backend.to_numpy_array(other))
|
85
|
+
|
86
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
87
|
+
@pytest.mark.parametrize(
|
88
|
+
"method_name", ("sum", "mean", "std", "max", "min", "unique")
|
89
|
+
)
|
90
|
+
@pytest.mark.parametrize("axis", ((0), (1)))
|
91
|
+
def test_reduction_operations(self, method_name, backend, axis):
|
92
|
+
base = getattr(self.backend, method_name)(self.x1, axis=axis)
|
93
|
+
other = getattr(backend, method_name)(
|
94
|
+
backend.to_backend_array(self.x1), axis=axis
|
95
|
+
)
|
96
|
+
# Account for bessel function correction in pytorch
|
97
|
+
rtol = 0.01 if method_name != "std" else 0.5
|
98
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=rtol)
|
99
|
+
|
100
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
101
|
+
@pytest.mark.parametrize(
|
102
|
+
"method_name",
|
103
|
+
("sqrt", "square", "abs", "transpose", "tobytes", "size"),
|
104
|
+
)
|
105
|
+
def test_array_manipulation(self, method_name, backend):
|
106
|
+
base = getattr(self.backend, method_name)(self.x1)
|
107
|
+
other = getattr(backend, method_name)(backend.to_backend_array(self.x1))
|
108
|
+
|
109
|
+
if isinstance(base, np.ndarray):
|
110
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
111
|
+
else:
|
112
|
+
assert base == other
|
113
|
+
|
114
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
115
|
+
@pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
|
116
|
+
@pytest.mark.parametrize(
|
117
|
+
"dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
|
118
|
+
)
|
119
|
+
def test_zeros(self, shape, backend, dtype):
|
120
|
+
dtype_base = getattr(self.backend, dtype)
|
121
|
+
dtype_backend = getattr(backend, dtype)
|
122
|
+
base = self.backend.zeros(shape, dtype=dtype_base)
|
123
|
+
other = backend.zeros(shape, dtype=dtype_backend)
|
124
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
125
|
+
|
126
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
127
|
+
@pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
|
128
|
+
@pytest.mark.parametrize("fill_value", (-1, 0, 1))
|
129
|
+
def test_full(self, shape, backend, fill_value):
|
130
|
+
base = self.backend.full(shape, fill_value=fill_value)
|
131
|
+
other = backend.full(shape, fill_value=fill_value)
|
132
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
133
|
+
|
134
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
135
|
+
@pytest.mark.parametrize("power", (0.5, 1, 2))
|
136
|
+
def test_power(self, backend, power):
|
137
|
+
base = self.backend.power(self.x1, power)
|
138
|
+
other = backend.power(backend.to_backend_array(self.x1), power)
|
139
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
140
|
+
|
141
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
142
|
+
@pytest.mark.parametrize("shift", (-5, 0, 10))
|
143
|
+
@pytest.mark.parametrize("axis", (0, 1))
|
144
|
+
def test_roll(self, backend, shift, axis):
|
145
|
+
base = self.backend.roll(self.x1, (shift,), (axis,))
|
146
|
+
other = backend.roll(backend.to_backend_array(self.x1), (shift,), (axis,))
|
147
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
148
|
+
|
149
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
150
|
+
@pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
|
151
|
+
@pytest.mark.parametrize("fill_value", (-1, 0, 1))
|
152
|
+
def test_fill(self, shape, backend, fill_value):
|
153
|
+
base = self.backend.full(shape, fill_value=fill_value)
|
154
|
+
other = backend.zeros(shape)
|
155
|
+
other = backend.fill(other, fill_value)
|
156
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
157
|
+
|
158
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
159
|
+
@pytest.mark.parametrize("min_distance", (1, 5, 10))
|
160
|
+
def test_max_filter_coordinates(self, backend, min_distance):
|
161
|
+
coordinates = backend.max_filter_coordinates(
|
162
|
+
backend.to_backend_array(self.x1), min_distance=min_distance
|
163
|
+
)
|
164
|
+
if len(coordinates):
|
165
|
+
assert coordinates.shape[1] == self.x1.ndim
|
166
|
+
assert True
|
167
|
+
|
168
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
169
|
+
@pytest.mark.parametrize(
|
170
|
+
"dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
|
171
|
+
)
|
172
|
+
@pytest.mark.parametrize(
|
173
|
+
"dtype_target", (("_int_dtype", "_complex_dtype", "_float_dtype"))
|
174
|
+
)
|
175
|
+
def test_astype(self, dtype, backend, dtype_target):
|
176
|
+
with warnings.catch_warnings():
|
177
|
+
warnings.simplefilter("ignore")
|
178
|
+
dtype_base = getattr(backend, dtype)
|
179
|
+
dtype_target = getattr(backend, dtype_target)
|
180
|
+
|
181
|
+
base = backend.zeros((20, 20, 20), dtype=dtype_base)
|
182
|
+
arr = backend.astype(base, dtype_target)
|
183
|
+
|
184
|
+
assert arr.dtype == dtype_target
|
185
|
+
|
186
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
187
|
+
@pytest.mark.parametrize("N", (0, 15, 30))
|
188
|
+
def test_arange(self, backend, N):
|
189
|
+
base = self.backend.arange(N)
|
190
|
+
other = getattr(backend, "arange")(
|
191
|
+
N,
|
192
|
+
)
|
193
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
|
194
|
+
|
195
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
196
|
+
@pytest.mark.parametrize("return_inverse", (False, True))
|
197
|
+
@pytest.mark.parametrize("return_counts", (False, True))
|
198
|
+
@pytest.mark.parametrize("return_index", (False, True))
|
199
|
+
def test_unique(self, backend, return_inverse, return_counts, return_index):
|
200
|
+
base = self.backend.unique(
|
201
|
+
self.x1,
|
202
|
+
return_inverse=return_inverse,
|
203
|
+
return_counts=return_counts,
|
204
|
+
return_index=return_index,
|
205
|
+
)
|
206
|
+
other = backend.unique(
|
207
|
+
backend.to_backend_array(self.x1),
|
208
|
+
return_inverse=return_inverse,
|
209
|
+
return_counts=return_counts,
|
210
|
+
return_index=return_index,
|
211
|
+
)
|
212
|
+
if isinstance(base, tuple):
|
213
|
+
base, other = tuple(base), tuple(other)
|
214
|
+
for k in range(len(base)):
|
215
|
+
print(
|
216
|
+
k,
|
217
|
+
base[k].shape,
|
218
|
+
other[k].shape,
|
219
|
+
return_inverse,
|
220
|
+
return_counts,
|
221
|
+
return_index,
|
222
|
+
)
|
223
|
+
assert np.allclose(
|
224
|
+
base[k].ravel(), backend.to_numpy_array(other[k]).ravel(), rtol=0.1
|
225
|
+
)
|
226
|
+
|
227
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
228
|
+
@pytest.mark.parametrize("k", (0, 15, 30))
|
229
|
+
def test_repeat(self, backend, k):
|
230
|
+
base = self.backend.repeat(self.x1, k)
|
231
|
+
other = backend.repeat(backend.to_backend_array(self.x1), k)
|
232
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
|
233
|
+
|
234
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
235
|
+
@pytest.mark.parametrize("dim", (1, 3))
|
236
|
+
@pytest.mark.parametrize("k", (0, 15, 30))
|
237
|
+
def test_topk_indices(self, backend, k: int, dim: int):
|
238
|
+
data = np.random.rand(*(50 for _ in range(dim)))
|
239
|
+
base = self.backend.topk_indices(data, k)
|
240
|
+
other = backend.topk_indices(backend.to_backend_array(data), k)
|
241
|
+
|
242
|
+
for i in range(len(base)):
|
243
|
+
np.allclose(
|
244
|
+
base[i],
|
245
|
+
backend.to_numpy_array(backend.to_backend_array(other[i])),
|
246
|
+
rtol=0.1,
|
247
|
+
)
|
248
|
+
|
249
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
250
|
+
def test_indices(self, backend):
|
251
|
+
base = self.backend.indices(self.x1.shape)
|
252
|
+
other = backend.indices(backend.to_backend_array(self.x1).shape)
|
253
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
|
254
|
+
|
255
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
256
|
+
def test_get_available_memory(self, backend):
|
257
|
+
mem = backend.get_available_memory()
|
258
|
+
assert isinstance(mem, int)
|
259
|
+
|
260
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
261
|
+
def test_shared_memory(self, backend):
|
262
|
+
shared_memory_handler = None
|
263
|
+
base = backend.to_backend_array(self.x1)
|
264
|
+
shared = backend.to_sharedarr(
|
265
|
+
arr=base, shared_memory_handler=shared_memory_handler
|
266
|
+
)
|
267
|
+
arr = backend.from_sharedarr(shared)
|
268
|
+
assert np.allclose(backend.to_numpy_array(arr), backend.to_numpy_array(base))
|
269
|
+
|
270
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
271
|
+
def test_shared_memory_managed(self, backend):
|
272
|
+
with SharedMemoryManager() as shared_memory_handler:
|
273
|
+
base = backend.to_backend_array(self.x1)
|
274
|
+
shared = backend.to_sharedarr(
|
275
|
+
arr=base, shared_memory_handler=shared_memory_handler
|
276
|
+
)
|
277
|
+
arr = backend.from_sharedarr(shared)
|
278
|
+
assert np.allclose(
|
279
|
+
backend.to_numpy_array(arr), backend.to_numpy_array(base)
|
280
|
+
)
|
281
|
+
|
282
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
283
|
+
@pytest.mark.parametrize("shape", ((10, 15, 100), (10, 15, 20)))
|
284
|
+
@pytest.mark.parametrize("padval", (-1, 0, 1))
|
285
|
+
def test_topleft_pad(self, backend, shape, padval):
|
286
|
+
arr = np.random.rand(30, 30, 30)
|
287
|
+
base = self.backend.topleft_pad(arr, shape=shape, padval=padval)
|
288
|
+
other = backend.topleft_pad(
|
289
|
+
backend.to_backend_array(arr), shape=shape, padval=padval
|
290
|
+
)
|
291
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
292
|
+
|
293
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
294
|
+
@pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
|
295
|
+
def test_build_fft(self, backend, fast_shape):
|
296
|
+
_, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
297
|
+
fast_shape, (1 for _ in range(len(fast_shape)))
|
298
|
+
)
|
299
|
+
rfftn, irfftn = backend.build_fft(
|
300
|
+
fwd_shape=fast_shape,
|
301
|
+
inv_shape=fast_ft_shape,
|
302
|
+
real_dtype=backend._float_dtype,
|
303
|
+
cmpl_dtype=backend._complex_dtype,
|
304
|
+
)
|
305
|
+
arr = np.random.rand(*fast_shape)
|
306
|
+
out = np.zeros(fast_ft_shape)
|
307
|
+
|
308
|
+
real_arr = backend.astype(backend.to_backend_array(arr), backend._float_dtype)
|
309
|
+
complex_arr = backend.astype(
|
310
|
+
backend.to_backend_array(out), backend._complex_dtype
|
311
|
+
)
|
312
|
+
|
313
|
+
rfftn(
|
314
|
+
backend.astype(backend.to_backend_array(arr), backend._float_dtype),
|
315
|
+
complex_arr,
|
316
|
+
)
|
317
|
+
irfftn(complex_arr, real_arr)
|
318
|
+
assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
|
319
|
+
|
320
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
321
|
+
@pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
|
322
|
+
def test_fftn(self, backend, fast_shape):
|
323
|
+
_, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
324
|
+
fast_shape, (1 for _ in range(len(fast_shape)))
|
325
|
+
)
|
326
|
+
arr = np.random.rand(*fast_shape)
|
327
|
+
out = np.zeros(fast_ft_shape)
|
328
|
+
|
329
|
+
real_arr = backend.astype(backend.to_backend_array(arr), backend._float_dtype)
|
330
|
+
complex_arr = backend.astype(
|
331
|
+
backend.to_backend_array(out), backend._complex_dtype
|
332
|
+
)
|
333
|
+
|
334
|
+
complex_arr = backend.rfftn(
|
335
|
+
backend.astype(backend.to_backend_array(arr), backend._float_dtype),
|
336
|
+
)
|
337
|
+
real_arr = backend.irfftn(complex_arr)
|
338
|
+
assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
|
339
|
+
|
340
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
341
|
+
def test_extract_center(self, backend):
|
342
|
+
new_shape = np.divide(self.x1.shape, 2).astype(int)
|
343
|
+
base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
|
344
|
+
other = backend.extract_center(
|
345
|
+
arr=backend.to_backend_array(self.x1), newshape=new_shape
|
346
|
+
)
|
347
|
+
|
348
|
+
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
349
|
+
|
350
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
351
|
+
def test_compute_convolution_shapes(self, backend):
|
352
|
+
base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
|
353
|
+
other = backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
|
354
|
+
|
355
|
+
assert base == other
|
356
|
+
|
357
|
+
@pytest.mark.parametrize("dim", (2, 3))
|
358
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
359
|
+
@pytest.mark.parametrize("create_mask", (False, True))
|
360
|
+
def test_rigid_transform(self, backend, dim, create_mask):
|
361
|
+
shape = tuple(50 for _ in range(dim))
|
362
|
+
arr = np.zeros(shape)
|
363
|
+
if dim == 2:
|
364
|
+
arr[20:25, 21:26] = 1
|
365
|
+
elif dim == 3:
|
366
|
+
arr[20:25, 21:26, 26:31] = 1
|
367
|
+
|
368
|
+
rotation_matrix = np.eye(dim)
|
369
|
+
rotation_matrix[0, 0] = -1
|
370
|
+
|
371
|
+
out = np.zeros_like(arr)
|
372
|
+
|
373
|
+
arr_mask, out_mask = None, None
|
374
|
+
if create_mask:
|
375
|
+
arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
|
376
|
+
out_mask = np.zeros_like(arr_mask)
|
377
|
+
arr_mask = backend.to_backend_array(arr_mask)
|
378
|
+
out_mask = backend.to_backend_array(out_mask)
|
379
|
+
|
380
|
+
arr = backend.to_backend_array(arr)
|
381
|
+
out = backend.to_backend_array(arr)
|
382
|
+
|
383
|
+
rotation_matrix = backend.to_backend_array(rotation_matrix)
|
384
|
+
|
385
|
+
backend.rigid_transform(
|
386
|
+
arr=arr,
|
387
|
+
arr_mask=arr_mask,
|
388
|
+
rotation_matrix=rotation_matrix,
|
389
|
+
out=out,
|
390
|
+
out_mask=out_mask,
|
391
|
+
)
|
392
|
+
|
393
|
+
assert np.round(arr.sum(), 3) == np.round(out.sum(), 3)
|
394
|
+
|
395
|
+
@pytest.mark.parametrize("dim", (2, 3))
|
396
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
397
|
+
@pytest.mark.parametrize("create_mask", (False, True))
|
398
|
+
def test_rigid_transform_batch(self, backend, dim, create_mask):
|
399
|
+
shape = tuple(30 for _ in range(dim))
|
400
|
+
arr = np.random.rand(5, *shape)
|
401
|
+
out = np.zeros_like(arr)
|
402
|
+
|
403
|
+
arr_b = arr.copy()
|
404
|
+
out_b = np.zeros_like(arr)
|
405
|
+
|
406
|
+
rotation_matrix = np.eye(dim)
|
407
|
+
rotation_matrix[0, 0] = -1
|
408
|
+
|
409
|
+
arr_mask, out_mask = None, None
|
410
|
+
if create_mask:
|
411
|
+
arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
|
412
|
+
out_mask = np.zeros_like(arr_mask)
|
413
|
+
arr_mask = backend.to_backend_array(arr_mask)
|
414
|
+
out_mask = backend.to_backend_array(out_mask)
|
415
|
+
|
416
|
+
arr = backend.to_backend_array(arr)
|
417
|
+
out = backend.to_backend_array(out)
|
418
|
+
|
419
|
+
rotation_matrix = backend.to_backend_array(rotation_matrix)
|
420
|
+
backend.rigid_transform(
|
421
|
+
arr=arr,
|
422
|
+
arr_mask=arr_mask,
|
423
|
+
rotation_matrix=rotation_matrix,
|
424
|
+
out=out,
|
425
|
+
out_mask=out_mask,
|
426
|
+
)
|
427
|
+
|
428
|
+
arr_b = backend.to_backend_array(arr_b)
|
429
|
+
out_b = backend.to_backend_array(out_b)
|
430
|
+
|
431
|
+
for i in range(arr_b.shape[0]):
|
432
|
+
backend.rigid_transform(
|
433
|
+
arr=arr[i],
|
434
|
+
arr_mask=arr_mask if arr_mask is None else arr_mask[i],
|
435
|
+
rotation_matrix=rotation_matrix,
|
436
|
+
out=out_b[i],
|
437
|
+
out_mask=out_mask if out_mask is None else out_mask[i],
|
438
|
+
)
|
439
|
+
|
440
|
+
assert np.allclose(arr, arr_b)
|
441
|
+
|
442
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
443
|
+
def test_datatype_bytes(self, backend):
|
444
|
+
assert isinstance(backend.datatype_bytes(backend._float_dtype), int)
|
445
|
+
assert isinstance(backend.datatype_bytes(backend._complex_dtype), int)
|
446
|
+
assert isinstance(backend.datatype_bytes(backend._int_dtype), int)
|