pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/METADATA +1 -1
  5. pytme-0.2.5.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +226 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +3 -9
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +14 -10
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +14 -12
  61. tme/matching_exhaustive.py +17 -15
  62. tme/matching_optimization.py +215 -208
  63. tme/matching_utils.py +1 -0
  64. tme/preprocessing/_utils.py +14 -14
  65. tme/preprocessing/composable_filter.py +0 -2
  66. tme/preprocessing/compose.py +4 -4
  67. tme/preprocessing/frequency_filters.py +32 -35
  68. tme/preprocessing/tilt_series.py +198 -117
  69. tme/preprocessor.py +24 -246
  70. tme/structure.py +22 -22
  71. pytme-0.2.3.dist-info/RECORD +0 -75
  72. tme/matching_memory.py +0 -383
  73. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/estimate_ram_usage.py +0 -0
  74. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/postprocess.py +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/LICENSE +0 -0
  76. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/entry_points.txt +0 -0
tests/test_analyzer.py ADDED
@@ -0,0 +1,310 @@
1
+ from tempfile import mkstemp
2
+
3
+ import pytest
4
+ import numpy as np
5
+
6
+ from tme.backends import backend as be
7
+ from tme.analyzer import (
8
+ MaxScoreOverRotations,
9
+ PeakCaller,
10
+ PeakCallerSort,
11
+ PeakCallerMaximumFilter,
12
+ PeakCallerFast,
13
+ PeakCallerRecursiveMasking,
14
+ PeakCallerScipy,
15
+ PeakClustering,
16
+ MemmapHandler,
17
+ )
18
+
19
+
20
+ PEAK_CALLER_CHILDREN = [
21
+ PeakCallerSort,
22
+ PeakCallerMaximumFilter,
23
+ PeakCallerFast,
24
+ PeakCallerRecursiveMasking,
25
+ PeakCallerScipy,
26
+ PeakClustering,
27
+ ]
28
+
29
+
30
+ class TestPeakCallers:
31
+ def setup_method(self):
32
+ self.number_of_peaks = 100
33
+ self.min_distance = 5
34
+ self.data = np.random.rand(100, 100, 100)
35
+ self.rotation_matrix = np.eye(3)
36
+
37
+ @pytest.mark.parametrize("peak_caller", PEAK_CALLER_CHILDREN)
38
+ def test_initialization(self, peak_caller):
39
+ _ = peak_caller(number_of_peaks=100, min_distance=5)
40
+
41
+ def test_initialization_error(self):
42
+ with pytest.raises(TypeError):
43
+ _ = PeakCaller(number_of_peaks=100, min_distance=5)
44
+
45
+ @pytest.mark.parametrize("peak_caller", PEAK_CALLER_CHILDREN)
46
+ def test_initialization_error_parameter(self, peak_caller):
47
+ with pytest.raises(ValueError):
48
+ _ = peak_caller(number_of_peaks=0, min_distance=5)
49
+ with pytest.raises(ValueError):
50
+ _ = peak_caller(number_of_peaks=-1, min_distance=5)
51
+ with pytest.raises(ValueError):
52
+ _ = peak_caller(number_of_peaks=-1, min_distance=-1)
53
+
54
+ @pytest.mark.parametrize("peak_caller", PEAK_CALLER_CHILDREN)
55
+ @pytest.mark.parametrize("number_of_peaks", (1, 100))
56
+ @pytest.mark.parametrize("minimum_score", (None, 0.5))
57
+ def test__call__(self, peak_caller, number_of_peaks, minimum_score):
58
+ peak_caller = peak_caller(
59
+ number_of_peaks=number_of_peaks,
60
+ min_distance=self.min_distance,
61
+ minimum_score=minimum_score,
62
+ )
63
+ peak_caller(
64
+ self.data.copy(),
65
+ rotation_matrix=self.rotation_matrix,
66
+ )
67
+ candidates = tuple(peak_caller)
68
+ if minimum_score is None:
69
+ assert len(candidates[0] <= number_of_peaks)
70
+ else:
71
+ peaks = candidates[0].astype(int)
72
+ print(self.data[tuple(peaks.T)])
73
+ assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
74
+
75
+ @pytest.mark.parametrize("peak_caller", PEAK_CALLER_CHILDREN)
76
+ @pytest.mark.parametrize("number_of_peaks", (1, 100))
77
+ def test_merge(self, peak_caller, number_of_peaks):
78
+ peak_caller1 = peak_caller(
79
+ number_of_peaks=number_of_peaks, min_distance=self.min_distance
80
+ )
81
+ peak_caller1(self.data, rotation_matrix=self.rotation_matrix)
82
+
83
+ peak_caller2 = peak_caller(
84
+ number_of_peaks=number_of_peaks, min_distance=self.min_distance
85
+ )
86
+ peak_caller2(self.data, rotation_matrix=self.rotation_matrix)
87
+
88
+ parameters = [tuple(peak_caller1), tuple(peak_caller2)]
89
+
90
+ result = tuple(
91
+ peak_caller.merge(
92
+ candidates=parameters,
93
+ number_of_peaks=number_of_peaks,
94
+ min_distance=self.min_distance,
95
+ )
96
+ )
97
+ assert [len(res) == 2 for res in result]
98
+
99
+
100
+ class TestRecursiveMasking:
101
+ def setup_method(self):
102
+ self.number_of_peaks = 100
103
+ self.min_distance = 5
104
+ self.data = np.random.rand(100, 100, 100)
105
+ self.rotation_matrix = np.eye(3)
106
+ self.mask = np.random.rand(20, 20, 20)
107
+ self.rotation_space = np.zeros_like(self.data)
108
+ self.rotation_mapping = {0: (0, 0, 0)}
109
+
110
+ @pytest.mark.parametrize("number_of_peaks", (1, 100))
111
+ @pytest.mark.parametrize("compute_rotation", (True, False))
112
+ @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
+ def test__call__(self, number_of_peaks, compute_rotation, minimum_score):
114
+ peak_caller = PeakCallerRecursiveMasking(
115
+ number_of_peaks=number_of_peaks, min_distance=self.min_distance
116
+ )
117
+ rotation_space, rotation_mapping = None, None
118
+ if compute_rotation:
119
+ rotation_space = self.rotation_space
120
+ rotation_mapping = self.rotation_mapping
121
+
122
+ peak_caller(
123
+ self.data.copy(),
124
+ rotation_matrix=self.rotation_matrix,
125
+ mask=self.mask,
126
+ rotation_space=rotation_space,
127
+ rotation_mapping=rotation_mapping,
128
+ )
129
+
130
+ candidates = tuple(peak_caller)
131
+ if minimum_score is None:
132
+ assert len(candidates[0] <= number_of_peaks)
133
+ else:
134
+ peaks = candidates[0].astype(int)
135
+ assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
136
+
137
+
138
+ class TestMaxScoreOverRotations:
139
+ def setup_method(self):
140
+ self.number_of_peaks = 100
141
+ self.min_distance = 5
142
+ self.data = np.random.rand(100, 100, 100)
143
+ self.rotation_matrix = np.eye(3)
144
+
145
+ def test_initialization(self):
146
+ _ = MaxScoreOverRotations(
147
+ shape=self.data.shape,
148
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
149
+ )
150
+ _ = MaxScoreOverRotations(
151
+ scores=self.data,
152
+ rotations=self.data,
153
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
154
+ )
155
+
156
+ @pytest.mark.parametrize("use_memmap", [False, True])
157
+ def test__iter__(self, use_memmap: bool):
158
+ score_analyzer = MaxScoreOverRotations(
159
+ shape=self.data.shape,
160
+ use_memmap=use_memmap,
161
+ )
162
+ score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
163
+ res = tuple(score_analyzer)
164
+ assert np.allclose(res[0].shape, self.data.shape)
165
+ assert res[0].dtype == be._float_dtype
166
+ assert res[1].size == self.data.ndim
167
+ assert np.allclose(res[2].shape, self.data.shape)
168
+ assert len(res) == 4
169
+
170
+ @pytest.mark.parametrize("use_memmap", [False, True])
171
+ @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
172
+ def test__call__(self, use_memmap: bool, score_threshold: float):
173
+ score_analyzer = MaxScoreOverRotations(
174
+ shape=self.data.shape,
175
+ score_threshold=score_threshold,
176
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
177
+ use_memmap=use_memmap,
178
+ )
179
+ score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
180
+
181
+ data2 = self.data * 2
182
+ score_analyzer(data2, rotation_matrix=self.rotation_matrix)
183
+ scores, translation_offset, rotations, mapping = tuple(score_analyzer)
184
+ assert np.all(scores >= score_threshold)
185
+ max_scores = np.maximum(self.data, data2)
186
+ max_scores = np.maximum(max_scores, score_threshold)
187
+ assert np.allclose(scores, max_scores)
188
+
189
+ @pytest.mark.parametrize("use_memmap", [False, True])
190
+ @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
191
+ def test_merge(self, use_memmap: bool, score_threshold: float):
192
+ score_analyzer = MaxScoreOverRotations(
193
+ shape=self.data.shape,
194
+ score_threshold=score_threshold,
195
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
196
+ use_memmap=use_memmap,
197
+ )
198
+ score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
199
+
200
+ data2 = self.data * 2
201
+ score_analyzer2 = MaxScoreOverRotations(
202
+ shape=self.data.shape,
203
+ score_threshold=score_threshold,
204
+ translation_offset=np.zeros(self.data.ndim, dtype=int),
205
+ use_memmap=use_memmap,
206
+ )
207
+ score_analyzer2(data2, rotation_matrix=self.rotation_matrix)
208
+
209
+ parameters = [tuple(score_analyzer), tuple(score_analyzer2)]
210
+
211
+ ret = MaxScoreOverRotations.merge(
212
+ parameters, use_memmap=use_memmap, score_threshold=score_threshold
213
+ )
214
+ scores, translation, rotations, mapping = ret
215
+ assert np.all(scores >= score_threshold)
216
+ max_scores = np.maximum(self.data, data2)
217
+ max_scores = np.maximum(max_scores, score_threshold)
218
+ assert np.allclose(scores, max_scores)
219
+
220
+
221
+ class TestMemmapHandler:
222
+ def setup_method(self):
223
+ self.number_of_peaks = 100
224
+ self.min_distance = 5
225
+ self.data = np.random.rand(100, 100, 100)
226
+ self.indices = tuple(np.indices(self.data.shape))
227
+
228
+ self.rotation_matrix = np.eye(3)
229
+ rotation_matrix2 = np.eye(3)
230
+ rotation_matrix2[0, 0] = -1
231
+
232
+ rotation_matrix = "_".join(self.rotation_matrix.ravel().astype(str))
233
+ rotation_matrix2 = "_".join(rotation_matrix2.ravel().astype(str))
234
+
235
+ self.path_translation = {
236
+ rotation_matrix: mkstemp()[1],
237
+ rotation_matrix2: mkstemp()[1],
238
+ }
239
+
240
+ def test_initialization(self):
241
+ _ = MemmapHandler(
242
+ path_translation=self.path_translation,
243
+ shape=self.data.shape,
244
+ dtype=self.data.dtype,
245
+ indices=self.indices,
246
+ )
247
+
248
+ def test__call__(self):
249
+ score_analyzer = MemmapHandler(
250
+ path_translation=self.path_translation,
251
+ shape=self.data.shape,
252
+ dtype=self.data.dtype,
253
+ indices=self.indices,
254
+ )
255
+ score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
256
+ rotation_filepath = score_analyzer._rotation_matrix_to_filepath(
257
+ rotation_matrix=self.rotation_matrix
258
+ )
259
+ array = np.memmap(
260
+ rotation_filepath,
261
+ mode="r+",
262
+ shape=score_analyzer.shape,
263
+ dtype=score_analyzer.dtype,
264
+ )
265
+ assert np.allclose(array, self.data)
266
+
267
+ def test__iter__(self):
268
+ score_analyzer = MemmapHandler(
269
+ path_translation=self.path_translation,
270
+ shape=self.data.shape,
271
+ dtype=self.data.dtype,
272
+ indices=self.indices,
273
+ )
274
+ res = tuple(score_analyzer)
275
+ assert res == (None,)
276
+
277
+ def test_merge(self):
278
+ score_analyzer = MemmapHandler(
279
+ path_translation=self.path_translation,
280
+ shape=self.data.shape,
281
+ dtype=self.data.dtype,
282
+ indices=self.indices,
283
+ )
284
+ res = MemmapHandler.merge(score_analyzer)
285
+ assert res is None
286
+
287
+ def test_update_indices(self):
288
+ score_analyzer = MemmapHandler(
289
+ path_translation=self.path_translation,
290
+ shape=self.data.shape,
291
+ dtype=self.data.dtype,
292
+ indices=self.indices,
293
+ )
294
+ new_indices = np.random.rand(3)
295
+ score_analyzer.update_indices(new_indices)
296
+ assert np.allclose(score_analyzer._indices, new_indices)
297
+
298
+ def test__rotation_matrix_to_filepath(self):
299
+ score_analyzer = MemmapHandler(
300
+ path_translation=self.path_translation,
301
+ shape=self.data.shape,
302
+ dtype=self.data.dtype,
303
+ indices=self.indices,
304
+ )
305
+
306
+ rotation_matrix = list(self.path_translation.keys())[0]
307
+ rotation_filepath = score_analyzer._rotation_matrix_to_filepath(
308
+ rotation_matrix=self.rotation_matrix
309
+ )
310
+ assert rotation_filepath == self.path_translation.get(rotation_matrix)
tests/test_backends.py ADDED
@@ -0,0 +1,375 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from multiprocessing.managers import SharedMemoryManager
5
+
6
+ from tme.backends import MatchingBackend, NumpyFFTWBackend, BackendManager, backend
7
+
8
+ BACKENDS_TO_TEST = []
9
+ for backend_class in backend._BACKEND_REGISTRY.values():
10
+ try:
11
+ BACKENDS_TO_TEST.append(backend_class(device="cpu"))
12
+ except ImportError:
13
+ print(f"Couldn't import {backend_class}. Skipping...")
14
+
15
+ METHODS_TO_TEST = MatchingBackend.__abstractmethods__
16
+
17
+
18
+ class TestBackendManager:
19
+ def setup_method(self):
20
+ self.manager = BackendManager()
21
+
22
+ def test_initialization(self):
23
+ manager = BackendManager()
24
+ backend_name = manager._backend_name
25
+ assert f"<BackendManager: using {backend_name}>" == str(manager)
26
+
27
+ def test_dir(self):
28
+ _ = dir(self.manager)
29
+ for method in METHODS_TO_TEST:
30
+ assert hasattr(self.manager, method)
31
+
32
+ def test_add_backend(self):
33
+ self.manager.add_backend(backend_name="test", backend_class=NumpyFFTWBackend)
34
+
35
+ def test_add_backend_error(self):
36
+ class _Bar:
37
+ def __init__(self):
38
+ pass
39
+
40
+ with pytest.raises(ValueError):
41
+ self.manager.add_backend(backend_name="test", backend_class=_Bar)
42
+
43
+ def test_change_backend_error(self):
44
+ with pytest.raises(NotImplementedError):
45
+ self.manager.change_backend(backend_name=None)
46
+
47
+ def test_available_backends(self):
48
+ available = self.manager.available_backends()
49
+ assert isinstance(available, list)
50
+ for be in available:
51
+ assert be in self.manager._BACKEND_REGISTRY
52
+
53
+
54
+ class TestBackends:
55
+ def setup_method(self):
56
+ self.backend = NumpyFFTWBackend()
57
+ self.x1 = np.random.rand(30, 30).astype(np.float32)
58
+ self.x2 = np.random.rand(30, 30).astype(np.float32)
59
+
60
+ def teardown_method(self):
61
+ self.backend = None
62
+
63
+ def test_initialization_errors(self):
64
+ with pytest.raises(TypeError):
65
+ _ = MatchingBackend()
66
+
67
+ @pytest.mark.parametrize("backend", [type(x) for x in BACKENDS_TO_TEST])
68
+ def test_initialization(self, backend):
69
+ _ = backend()
70
+
71
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
72
+ @pytest.mark.parametrize(
73
+ "method_name",
74
+ ("add", "subtract", "multiply", "divide", "minimum", "maximum", "mod"),
75
+ )
76
+ def test_arithmetic_operations(self, method_name, backend):
77
+ base = getattr(self.backend, method_name)(self.x1, self.x2)
78
+ x1 = backend.to_backend_array(self.x1)
79
+ x2 = backend.to_backend_array(self.x2)
80
+ other = getattr(backend, method_name)(x1, x2)
81
+
82
+ assert np.allclose(base, backend.to_numpy_array(other))
83
+
84
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
85
+ @pytest.mark.parametrize(
86
+ "method_name", ("sum", "mean", "std", "max", "min", "unique")
87
+ )
88
+ @pytest.mark.parametrize("axis", ((0), (1)))
89
+ def test_reduction_operations(self, method_name, backend, axis):
90
+ base = getattr(self.backend, method_name)(self.x1, axis=axis)
91
+ other = getattr(backend, method_name)(
92
+ backend.to_backend_array(self.x1), axis=axis
93
+ )
94
+ # Account for bessel function correction in pytorch
95
+ rtol = 0.01 if method_name != "std" else 0.5
96
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=rtol)
97
+
98
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
99
+ @pytest.mark.parametrize(
100
+ "method_name",
101
+ ("sqrt", "square", "abs", "transpose", "tobytes", "size"),
102
+ )
103
+ def test_array_manipulation(self, method_name, backend):
104
+ base = getattr(self.backend, method_name)(self.x1)
105
+ other = getattr(backend, method_name)(backend.to_backend_array(self.x1))
106
+
107
+ if type(base) == np.ndarray:
108
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
109
+ else:
110
+ assert base == other
111
+
112
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
113
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
114
+ @pytest.mark.parametrize(
115
+ "dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
116
+ )
117
+ def test_zeros(self, shape, backend, dtype):
118
+ dtype_base = getattr(self.backend, dtype)
119
+ dtype_backend = getattr(backend, dtype)
120
+ base = self.backend.zeros(shape, dtype=dtype_base)
121
+ other = backend.zeros(shape, dtype=dtype_backend)
122
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
123
+
124
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
125
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
126
+ @pytest.mark.parametrize("fill_value", (-1, 0, 1))
127
+ def test_full(self, shape, backend, fill_value):
128
+ base = self.backend.full(shape, fill_value=fill_value)
129
+ other = backend.full(shape, fill_value=fill_value)
130
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
131
+
132
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
133
+ @pytest.mark.parametrize("power", (0.5, 1, 2))
134
+ def test_power(self, backend, power):
135
+ base = self.backend.power(self.x1, power)
136
+ other = backend.power(backend.to_backend_array(self.x1), power)
137
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
138
+
139
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
140
+ @pytest.mark.parametrize("shift", (-5, 0, 10))
141
+ @pytest.mark.parametrize("axis", (0, 1))
142
+ def test_roll(self, backend, shift, axis):
143
+ base = self.backend.roll(self.x1, (shift,), (axis,))
144
+ other = backend.roll(backend.to_backend_array(self.x1), (shift,), (axis,))
145
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
146
+
147
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
148
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
149
+ @pytest.mark.parametrize("fill_value", (-1, 0, 1))
150
+ def test_fill(self, shape, backend, fill_value):
151
+ base = self.backend.full(shape, fill_value=fill_value)
152
+ other = backend.zeros(shape)
153
+ other = backend.fill(other, fill_value)
154
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
155
+
156
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
157
+ @pytest.mark.parametrize("min_distance", (1, 5, 10))
158
+ def test_max_filter_coordinates(self, backend, min_distance):
159
+ coordinates = backend.max_filter_coordinates(
160
+ backend.to_backend_array(self.x1), min_distance=min_distance
161
+ )
162
+ if len(coordinates):
163
+ assert coordinates.shape[1] == self.x1.ndim
164
+ assert True
165
+
166
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
167
+ @pytest.mark.parametrize(
168
+ "dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
169
+ )
170
+ @pytest.mark.parametrize(
171
+ "dtype_target", (("_int_dtype", "_complex_dtype", "_float_dtype"))
172
+ )
173
+ def test_astype(self, dtype, backend, dtype_target):
174
+ dtype_base = getattr(backend, dtype)
175
+ dtype_target = getattr(backend, dtype_target)
176
+
177
+ base = backend.zeros((20, 20, 20), dtype=dtype_base)
178
+ arr = backend.astype(base, dtype_target)
179
+
180
+ assert arr.dtype == dtype_target
181
+
182
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
183
+ @pytest.mark.parametrize("N", (0, 15, 30))
184
+ def test_arange(self, backend, N):
185
+ base = self.backend.arange(N)
186
+ other = getattr(backend, "arange")(
187
+ N,
188
+ )
189
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
190
+
191
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
192
+ @pytest.mark.parametrize("return_inverse", (False, True))
193
+ @pytest.mark.parametrize("return_counts", (False, True))
194
+ @pytest.mark.parametrize("return_index", (False, True))
195
+ def test_unique(self, backend, return_inverse, return_counts, return_index):
196
+ base = self.backend.unique(
197
+ self.x1,
198
+ return_inverse=return_inverse,
199
+ return_counts=return_counts,
200
+ return_index=return_index,
201
+ )
202
+ other = backend.unique(
203
+ backend.to_backend_array(self.x1),
204
+ return_inverse=return_inverse,
205
+ return_counts=return_counts,
206
+ return_index=return_index,
207
+ )
208
+ if isinstance(base, tuple):
209
+ base, other = tuple(base), tuple(other)
210
+ for k in range(len(base)):
211
+ print(
212
+ k,
213
+ base[k].shape,
214
+ other[k].shape,
215
+ return_inverse,
216
+ return_counts,
217
+ return_index,
218
+ )
219
+ assert np.allclose(
220
+ base[k].ravel(), backend.to_numpy_array(other[k]).ravel(), rtol=0.1
221
+ )
222
+
223
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
224
+ @pytest.mark.parametrize("k", (0, 15, 30))
225
+ def test_repeat(self, backend, k):
226
+ base = self.backend.repeat(self.x1, k)
227
+ other = backend.repeat(backend.to_backend_array(self.x1), k)
228
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
229
+
230
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
231
+ @pytest.mark.parametrize("dim", (1, 3))
232
+ @pytest.mark.parametrize("k", (0, 15, 30))
233
+ def test_topk_indices(self, backend, k: int, dim: int):
234
+ data = np.random.rand(*(50 for _ in range(dim)))
235
+ base = self.backend.topk_indices(data, k)
236
+ other = backend.topk_indices(backend.to_backend_array(data), k)
237
+
238
+ for i in range(len(base)):
239
+ np.allclose(
240
+ base[i],
241
+ backend.to_numpy_array(backend.to_backend_array(other[i])),
242
+ rtol=0.1,
243
+ )
244
+
245
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
246
+ def test_indices(self, backend):
247
+ base = self.backend.indices(self.x1.shape)
248
+ other = backend.indices(backend.to_backend_array(self.x1).shape)
249
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
250
+
251
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
252
+ def test_get_available_memory(self, backend):
253
+ mem = backend.get_available_memory()
254
+ assert isinstance(mem, int)
255
+
256
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
257
+ def test_shared_memory(self, backend):
258
+ shared_memory_handler = None
259
+ base = backend.to_backend_array(self.x1)
260
+ shared = backend.to_sharedarr(
261
+ arr=base, shared_memory_handler=shared_memory_handler
262
+ )
263
+ arr = backend.from_sharedarr(shared)
264
+ assert np.allclose(backend.to_numpy_array(arr), backend.to_numpy_array(base))
265
+
266
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
267
+ def test_shared_memory_managed(self, backend):
268
+ with SharedMemoryManager() as shared_memory_handler:
269
+ base = backend.to_backend_array(self.x1)
270
+ shared = backend.to_sharedarr(
271
+ arr=base, shared_memory_handler=shared_memory_handler
272
+ )
273
+ arr = backend.from_sharedarr(shared)
274
+ assert np.allclose(
275
+ backend.to_numpy_array(arr), backend.to_numpy_array(base)
276
+ )
277
+
278
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
279
+ @pytest.mark.parametrize("shape", ((10, 15, 100), (10, 15, 20)))
280
+ @pytest.mark.parametrize("padval", (-1, 0, 1))
281
+ def test_topleft_pad(self, backend, shape, padval):
282
+ arr = np.random.rand(30, 30, 30)
283
+ base = self.backend.topleft_pad(arr, shape=shape, padval=padval)
284
+ other = backend.topleft_pad(
285
+ backend.to_backend_array(arr), shape=shape, padval=padval
286
+ )
287
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
288
+
289
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
290
+ @pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
291
+ def test_fft(self, backend, fast_shape):
292
+ _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
293
+ fast_shape, (1 for _ in range(len(fast_shape)))
294
+ )
295
+ rfftn, irfftn = backend.build_fft(
296
+ fast_shape=fast_shape,
297
+ fast_ft_shape=fast_ft_shape,
298
+ real_dtype=backend._float_dtype,
299
+ complex_dtype=backend._complex_dtype,
300
+ )
301
+ arr = np.random.rand(*fast_shape)
302
+ out = np.zeros(fast_ft_shape)
303
+
304
+ real_arr = backend.astype(backend.to_backend_array(arr), backend._float_dtype)
305
+ complex_arr = backend.astype(
306
+ backend.to_backend_array(out), backend._complex_dtype
307
+ )
308
+
309
+ rfftn(
310
+ backend.astype(backend.to_backend_array(arr), backend._float_dtype),
311
+ complex_arr,
312
+ )
313
+ irfftn(complex_arr, real_arr)
314
+ assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
315
+
316
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
317
+ def test_extract_center(self, backend):
318
+ new_shape = np.divide(self.x1.shape, 2).astype(int)
319
+ base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
320
+ other = backend.extract_center(
321
+ arr=backend.to_backend_array(self.x1), newshape=new_shape
322
+ )
323
+
324
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
325
+
326
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
327
+ def test_compute_convolution_shapes(self, backend):
328
+ base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
329
+ other = backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
330
+
331
+ assert base == other
332
+
333
+ @pytest.mark.parametrize("dim", (2, 3))
334
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
335
+ @pytest.mark.parametrize("create_mask", (False, True))
336
+ def test_rigid_transform(self, backend, dim, create_mask):
337
+ shape = tuple(50 for _ in range(dim))
338
+ arr = np.zeros(shape)
339
+ if dim == 2:
340
+ arr[20:25, 21:26] = 1
341
+ elif dim == 3:
342
+ arr[20:25, 21:26, 26:31] = 1
343
+
344
+ rotation_matrix = np.eye(dim)
345
+ rotation_matrix[0, 0] = -1
346
+
347
+ out = np.zeros_like(arr)
348
+
349
+ arr_mask, out_mask = None, None
350
+ if create_mask:
351
+ arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
352
+ out_mask = np.zeros_like(arr_mask)
353
+ arr_mask = backend.to_backend_array(arr_mask)
354
+ out_mask = backend.to_backend_array(out_mask)
355
+
356
+ arr = backend.to_backend_array(arr)
357
+ out = backend.to_backend_array(arr)
358
+
359
+ rotation_matrix = backend.to_backend_array(rotation_matrix)
360
+
361
+ backend.rigid_transform(
362
+ arr=arr,
363
+ arr_mask=arr_mask,
364
+ rotation_matrix=rotation_matrix,
365
+ out=out,
366
+ out_mask=out_mask,
367
+ )
368
+
369
+ assert np.round(arr.sum(), 3) == np.round(out.sum(), 3)
370
+
371
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
372
+ def test_datatype_bytes(self, backend):
373
+ assert isinstance(backend.datatype_bytes(backend._float_dtype), int)
374
+ assert isinstance(backend.datatype_bytes(backend._complex_dtype), int)
375
+ assert isinstance(backend.datatype_bytes(backend._int_dtype), int)