pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tests/test_backends.py ADDED
@@ -0,0 +1,446 @@
1
+ import warnings
2
+
3
+ import pytest
4
+ import numpy as np
5
+
6
+ from multiprocessing.managers import SharedMemoryManager
7
+
8
+ from tme.backends import MatchingBackend, NumpyFFTWBackend, BackendManager, backend
9
+
10
+ BACKENDS_TO_TEST = []
11
+ for backend_class in backend._BACKEND_REGISTRY.values():
12
+ try:
13
+ BACKENDS_TO_TEST.append(backend_class(device="cpu"))
14
+ except ImportError:
15
+ print(f"Couldn't import {backend_class}. Skipping...")
16
+
17
+ METHODS_TO_TEST = MatchingBackend.__abstractmethods__
18
+
19
+
20
+ class TestBackendManager:
21
+ def setup_method(self):
22
+ self.manager = BackendManager()
23
+
24
+ def test_initialization(self):
25
+ manager = BackendManager()
26
+ backend_name = manager._backend_name
27
+ assert f"<BackendManager: using {backend_name}>" == str(manager)
28
+
29
+ def test_dir(self):
30
+ _ = dir(self.manager)
31
+ for method in METHODS_TO_TEST:
32
+ assert hasattr(self.manager, method)
33
+
34
+ def test_add_backend(self):
35
+ self.manager.add_backend(backend_name="test", backend_class=NumpyFFTWBackend)
36
+
37
+ def test_add_backend_error(self):
38
+ class _Bar:
39
+ def __init__(self):
40
+ pass
41
+
42
+ with pytest.raises(ValueError):
43
+ self.manager.add_backend(backend_name="test", backend_class=_Bar)
44
+
45
+ def test_change_backend_error(self):
46
+ with pytest.raises(NotImplementedError):
47
+ self.manager.change_backend(backend_name=None)
48
+
49
+ def test_available_backends(self):
50
+ available = self.manager.available_backends()
51
+ assert isinstance(available, list)
52
+ for be in available:
53
+ assert be in self.manager._BACKEND_REGISTRY
54
+
55
+
56
+ class TestBackends:
57
+ def setup_method(self):
58
+ self.backend = NumpyFFTWBackend()
59
+ self.x1 = np.random.rand(30, 30).astype(np.float32)
60
+ self.x2 = np.random.rand(30, 30).astype(np.float32)
61
+
62
+ def teardown_method(self):
63
+ self.backend = None
64
+
65
+ def test_initialization_errors(self):
66
+ with pytest.raises(TypeError):
67
+ _ = MatchingBackend()
68
+
69
+ @pytest.mark.parametrize("backend", [type(x) for x in BACKENDS_TO_TEST])
70
+ def test_initialization(self, backend):
71
+ _ = backend()
72
+
73
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
74
+ @pytest.mark.parametrize(
75
+ "method_name",
76
+ ("add", "subtract", "multiply", "divide", "minimum", "maximum", "mod"),
77
+ )
78
+ def test_arithmetic_operations(self, method_name, backend):
79
+ base = getattr(self.backend, method_name)(self.x1, self.x2)
80
+ x1 = backend.to_backend_array(self.x1)
81
+ x2 = backend.to_backend_array(self.x2)
82
+ other = getattr(backend, method_name)(x1, x2)
83
+
84
+ assert np.allclose(base, backend.to_numpy_array(other))
85
+
86
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
87
+ @pytest.mark.parametrize(
88
+ "method_name", ("sum", "mean", "std", "max", "min", "unique")
89
+ )
90
+ @pytest.mark.parametrize("axis", ((0), (1)))
91
+ def test_reduction_operations(self, method_name, backend, axis):
92
+ base = getattr(self.backend, method_name)(self.x1, axis=axis)
93
+ other = getattr(backend, method_name)(
94
+ backend.to_backend_array(self.x1), axis=axis
95
+ )
96
+ # Account for bessel function correction in pytorch
97
+ rtol = 0.01 if method_name != "std" else 0.5
98
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=rtol)
99
+
100
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
101
+ @pytest.mark.parametrize(
102
+ "method_name",
103
+ ("sqrt", "square", "abs", "transpose", "tobytes", "size"),
104
+ )
105
+ def test_array_manipulation(self, method_name, backend):
106
+ base = getattr(self.backend, method_name)(self.x1)
107
+ other = getattr(backend, method_name)(backend.to_backend_array(self.x1))
108
+
109
+ if isinstance(base, np.ndarray):
110
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
111
+ else:
112
+ assert base == other
113
+
114
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
115
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
116
+ @pytest.mark.parametrize(
117
+ "dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
118
+ )
119
+ def test_zeros(self, shape, backend, dtype):
120
+ dtype_base = getattr(self.backend, dtype)
121
+ dtype_backend = getattr(backend, dtype)
122
+ base = self.backend.zeros(shape, dtype=dtype_base)
123
+ other = backend.zeros(shape, dtype=dtype_backend)
124
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
125
+
126
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
127
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
128
+ @pytest.mark.parametrize("fill_value", (-1, 0, 1))
129
+ def test_full(self, shape, backend, fill_value):
130
+ base = self.backend.full(shape, fill_value=fill_value)
131
+ other = backend.full(shape, fill_value=fill_value)
132
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
133
+
134
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
135
+ @pytest.mark.parametrize("power", (0.5, 1, 2))
136
+ def test_power(self, backend, power):
137
+ base = self.backend.power(self.x1, power)
138
+ other = backend.power(backend.to_backend_array(self.x1), power)
139
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
140
+
141
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
142
+ @pytest.mark.parametrize("shift", (-5, 0, 10))
143
+ @pytest.mark.parametrize("axis", (0, 1))
144
+ def test_roll(self, backend, shift, axis):
145
+ base = self.backend.roll(self.x1, (shift,), (axis,))
146
+ other = backend.roll(backend.to_backend_array(self.x1), (shift,), (axis,))
147
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
148
+
149
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
150
+ @pytest.mark.parametrize("shape", ((10, 15), (10, 15, 20)))
151
+ @pytest.mark.parametrize("fill_value", (-1, 0, 1))
152
+ def test_fill(self, shape, backend, fill_value):
153
+ base = self.backend.full(shape, fill_value=fill_value)
154
+ other = backend.zeros(shape)
155
+ other = backend.fill(other, fill_value)
156
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
157
+
158
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
159
+ @pytest.mark.parametrize("min_distance", (1, 5, 10))
160
+ def test_max_filter_coordinates(self, backend, min_distance):
161
+ coordinates = backend.max_filter_coordinates(
162
+ backend.to_backend_array(self.x1), min_distance=min_distance
163
+ )
164
+ if len(coordinates):
165
+ assert coordinates.shape[1] == self.x1.ndim
166
+ assert True
167
+
168
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
169
+ @pytest.mark.parametrize(
170
+ "dtype", (("_float_dtype", "_complex_dtype", "_int_dtype"))
171
+ )
172
+ @pytest.mark.parametrize(
173
+ "dtype_target", (("_int_dtype", "_complex_dtype", "_float_dtype"))
174
+ )
175
+ def test_astype(self, dtype, backend, dtype_target):
176
+ with warnings.catch_warnings():
177
+ warnings.simplefilter("ignore")
178
+ dtype_base = getattr(backend, dtype)
179
+ dtype_target = getattr(backend, dtype_target)
180
+
181
+ base = backend.zeros((20, 20, 20), dtype=dtype_base)
182
+ arr = backend.astype(base, dtype_target)
183
+
184
+ assert arr.dtype == dtype_target
185
+
186
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
187
+ @pytest.mark.parametrize("N", (0, 15, 30))
188
+ def test_arange(self, backend, N):
189
+ base = self.backend.arange(N)
190
+ other = getattr(backend, "arange")(
191
+ N,
192
+ )
193
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
194
+
195
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
196
+ @pytest.mark.parametrize("return_inverse", (False, True))
197
+ @pytest.mark.parametrize("return_counts", (False, True))
198
+ @pytest.mark.parametrize("return_index", (False, True))
199
+ def test_unique(self, backend, return_inverse, return_counts, return_index):
200
+ base = self.backend.unique(
201
+ self.x1,
202
+ return_inverse=return_inverse,
203
+ return_counts=return_counts,
204
+ return_index=return_index,
205
+ )
206
+ other = backend.unique(
207
+ backend.to_backend_array(self.x1),
208
+ return_inverse=return_inverse,
209
+ return_counts=return_counts,
210
+ return_index=return_index,
211
+ )
212
+ if isinstance(base, tuple):
213
+ base, other = tuple(base), tuple(other)
214
+ for k in range(len(base)):
215
+ print(
216
+ k,
217
+ base[k].shape,
218
+ other[k].shape,
219
+ return_inverse,
220
+ return_counts,
221
+ return_index,
222
+ )
223
+ assert np.allclose(
224
+ base[k].ravel(), backend.to_numpy_array(other[k]).ravel(), rtol=0.1
225
+ )
226
+
227
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
228
+ @pytest.mark.parametrize("k", (0, 15, 30))
229
+ def test_repeat(self, backend, k):
230
+ base = self.backend.repeat(self.x1, k)
231
+ other = backend.repeat(backend.to_backend_array(self.x1), k)
232
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
233
+
234
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
235
+ @pytest.mark.parametrize("dim", (1, 3))
236
+ @pytest.mark.parametrize("k", (0, 15, 30))
237
+ def test_topk_indices(self, backend, k: int, dim: int):
238
+ data = np.random.rand(*(50 for _ in range(dim)))
239
+ base = self.backend.topk_indices(data, k)
240
+ other = backend.topk_indices(backend.to_backend_array(data), k)
241
+
242
+ for i in range(len(base)):
243
+ np.allclose(
244
+ base[i],
245
+ backend.to_numpy_array(backend.to_backend_array(other[i])),
246
+ rtol=0.1,
247
+ )
248
+
249
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
250
+ def test_indices(self, backend):
251
+ base = self.backend.indices(self.x1.shape)
252
+ other = backend.indices(backend.to_backend_array(self.x1).shape)
253
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.1)
254
+
255
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
256
+ def test_get_available_memory(self, backend):
257
+ mem = backend.get_available_memory()
258
+ assert isinstance(mem, int)
259
+
260
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
261
+ def test_shared_memory(self, backend):
262
+ shared_memory_handler = None
263
+ base = backend.to_backend_array(self.x1)
264
+ shared = backend.to_sharedarr(
265
+ arr=base, shared_memory_handler=shared_memory_handler
266
+ )
267
+ arr = backend.from_sharedarr(shared)
268
+ assert np.allclose(backend.to_numpy_array(arr), backend.to_numpy_array(base))
269
+
270
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
271
+ def test_shared_memory_managed(self, backend):
272
+ with SharedMemoryManager() as shared_memory_handler:
273
+ base = backend.to_backend_array(self.x1)
274
+ shared = backend.to_sharedarr(
275
+ arr=base, shared_memory_handler=shared_memory_handler
276
+ )
277
+ arr = backend.from_sharedarr(shared)
278
+ assert np.allclose(
279
+ backend.to_numpy_array(arr), backend.to_numpy_array(base)
280
+ )
281
+
282
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
283
+ @pytest.mark.parametrize("shape", ((10, 15, 100), (10, 15, 20)))
284
+ @pytest.mark.parametrize("padval", (-1, 0, 1))
285
+ def test_topleft_pad(self, backend, shape, padval):
286
+ arr = np.random.rand(30, 30, 30)
287
+ base = self.backend.topleft_pad(arr, shape=shape, padval=padval)
288
+ other = backend.topleft_pad(
289
+ backend.to_backend_array(arr), shape=shape, padval=padval
290
+ )
291
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
292
+
293
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
294
+ @pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
295
+ def test_build_fft(self, backend, fast_shape):
296
+ _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
297
+ fast_shape, (1 for _ in range(len(fast_shape)))
298
+ )
299
+ rfftn, irfftn = backend.build_fft(
300
+ fwd_shape=fast_shape,
301
+ inv_shape=fast_ft_shape,
302
+ real_dtype=backend._float_dtype,
303
+ cmpl_dtype=backend._complex_dtype,
304
+ )
305
+ arr = np.random.rand(*fast_shape)
306
+ out = np.zeros(fast_ft_shape)
307
+
308
+ real_arr = backend.astype(backend.to_backend_array(arr), backend._float_dtype)
309
+ complex_arr = backend.astype(
310
+ backend.to_backend_array(out), backend._complex_dtype
311
+ )
312
+
313
+ rfftn(
314
+ backend.astype(backend.to_backend_array(arr), backend._float_dtype),
315
+ complex_arr,
316
+ )
317
+ irfftn(complex_arr, real_arr)
318
+ assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
319
+
320
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
321
+ @pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
322
+ def test_fftn(self, backend, fast_shape):
323
+ _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
324
+ fast_shape, (1 for _ in range(len(fast_shape)))
325
+ )
326
+ arr = np.random.rand(*fast_shape)
327
+ out = np.zeros(fast_ft_shape)
328
+
329
+ real_arr = backend.astype(backend.to_backend_array(arr), backend._float_dtype)
330
+ complex_arr = backend.astype(
331
+ backend.to_backend_array(out), backend._complex_dtype
332
+ )
333
+
334
+ complex_arr = backend.rfftn(
335
+ backend.astype(backend.to_backend_array(arr), backend._float_dtype),
336
+ )
337
+ real_arr = backend.irfftn(complex_arr)
338
+ assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
339
+
340
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
341
+ def test_extract_center(self, backend):
342
+ new_shape = np.divide(self.x1.shape, 2).astype(int)
343
+ base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
344
+ other = backend.extract_center(
345
+ arr=backend.to_backend_array(self.x1), newshape=new_shape
346
+ )
347
+
348
+ assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
349
+
350
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
351
+ def test_compute_convolution_shapes(self, backend):
352
+ base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
353
+ other = backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
354
+
355
+ assert base == other
356
+
357
+ @pytest.mark.parametrize("dim", (2, 3))
358
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
359
+ @pytest.mark.parametrize("create_mask", (False, True))
360
+ def test_rigid_transform(self, backend, dim, create_mask):
361
+ shape = tuple(50 for _ in range(dim))
362
+ arr = np.zeros(shape)
363
+ if dim == 2:
364
+ arr[20:25, 21:26] = 1
365
+ elif dim == 3:
366
+ arr[20:25, 21:26, 26:31] = 1
367
+
368
+ rotation_matrix = np.eye(dim)
369
+ rotation_matrix[0, 0] = -1
370
+
371
+ out = np.zeros_like(arr)
372
+
373
+ arr_mask, out_mask = None, None
374
+ if create_mask:
375
+ arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
376
+ out_mask = np.zeros_like(arr_mask)
377
+ arr_mask = backend.to_backend_array(arr_mask)
378
+ out_mask = backend.to_backend_array(out_mask)
379
+
380
+ arr = backend.to_backend_array(arr)
381
+ out = backend.to_backend_array(arr)
382
+
383
+ rotation_matrix = backend.to_backend_array(rotation_matrix)
384
+
385
+ backend.rigid_transform(
386
+ arr=arr,
387
+ arr_mask=arr_mask,
388
+ rotation_matrix=rotation_matrix,
389
+ out=out,
390
+ out_mask=out_mask,
391
+ )
392
+
393
+ assert np.round(arr.sum(), 3) == np.round(out.sum(), 3)
394
+
395
+ @pytest.mark.parametrize("dim", (2, 3))
396
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
397
+ @pytest.mark.parametrize("create_mask", (False, True))
398
+ def test_rigid_transform_batch(self, backend, dim, create_mask):
399
+ shape = tuple(30 for _ in range(dim))
400
+ arr = np.random.rand(5, *shape)
401
+ out = np.zeros_like(arr)
402
+
403
+ arr_b = arr.copy()
404
+ out_b = np.zeros_like(arr)
405
+
406
+ rotation_matrix = np.eye(dim)
407
+ rotation_matrix[0, 0] = -1
408
+
409
+ arr_mask, out_mask = None, None
410
+ if create_mask:
411
+ arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
412
+ out_mask = np.zeros_like(arr_mask)
413
+ arr_mask = backend.to_backend_array(arr_mask)
414
+ out_mask = backend.to_backend_array(out_mask)
415
+
416
+ arr = backend.to_backend_array(arr)
417
+ out = backend.to_backend_array(out)
418
+
419
+ rotation_matrix = backend.to_backend_array(rotation_matrix)
420
+ backend.rigid_transform(
421
+ arr=arr,
422
+ arr_mask=arr_mask,
423
+ rotation_matrix=rotation_matrix,
424
+ out=out,
425
+ out_mask=out_mask,
426
+ )
427
+
428
+ arr_b = backend.to_backend_array(arr_b)
429
+ out_b = backend.to_backend_array(out_b)
430
+
431
+ for i in range(arr_b.shape[0]):
432
+ backend.rigid_transform(
433
+ arr=arr[i],
434
+ arr_mask=arr_mask if arr_mask is None else arr_mask[i],
435
+ rotation_matrix=rotation_matrix,
436
+ out=out_b[i],
437
+ out_mask=out_mask if out_mask is None else out_mask[i],
438
+ )
439
+
440
+ assert np.allclose(arr, arr_b)
441
+
442
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
443
+ def test_datatype_bytes(self, backend):
444
+ assert isinstance(backend.datatype_bytes(backend._float_dtype), int)
445
+ assert isinstance(backend.datatype_bytes(backend._complex_dtype), int)
446
+ assert isinstance(backend.datatype_bytes(backend._int_dtype), int)