pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/METADATA +1 -1
  5. pytme-0.2.5.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +226 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +3 -9
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +14 -10
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +14 -12
  61. tme/matching_exhaustive.py +17 -15
  62. tme/matching_optimization.py +215 -208
  63. tme/matching_utils.py +1 -0
  64. tme/preprocessing/_utils.py +14 -14
  65. tme/preprocessing/composable_filter.py +0 -2
  66. tme/preprocessing/compose.py +4 -4
  67. tme/preprocessing/frequency_filters.py +32 -35
  68. tme/preprocessing/tilt_series.py +198 -117
  69. tme/preprocessor.py +24 -246
  70. tme/structure.py +22 -22
  71. pytme-0.2.3.dist-info/RECORD +0 -75
  72. tme/matching_memory.py +0 -383
  73. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/estimate_ram_usage.py +0 -0
  74. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/postprocess.py +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/LICENSE +0 -0
  76. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,226 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from tme.matching_utils import euler_from_rotationmatrix
5
+ from tme.matching_optimization import (
6
+ MATCHING_OPTIMIZATION_REGISTER,
7
+ register_matching_optimization,
8
+ _MatchCoordinatesToDensity,
9
+ _MatchCoordinatesToCoordinates,
10
+ optimize_match,
11
+ create_score_object,
12
+ )
13
+
14
+ density_to_density = ["FLC"]
15
+
16
+ coordinate_to_density = [
17
+ k
18
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
19
+ if issubclass(v, _MatchCoordinatesToDensity)
20
+ ]
21
+
22
+ coordinate_to_coordinate = [
23
+ k
24
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
25
+ if issubclass(v, _MatchCoordinatesToCoordinates)
26
+ ]
27
+
28
+
29
+ class TestMatchDensityToDensity:
30
+ def setup_method(self):
31
+ target = np.zeros((50, 50, 50))
32
+ target[20:30, 30:40, 12:17] = 1
33
+ self.target = target
34
+ self.template = target.copy()
35
+ self.template_mask = np.ones_like(target)
36
+
37
+ def teardown_method(self):
38
+ self.target = None
39
+ self.template = None
40
+ self.template_mask = None
41
+
42
+ @pytest.mark.parametrize("method", density_to_density)
43
+ def test_initialization(self, method: str, notest: bool = False):
44
+ instance = create_score_object(
45
+ score=method,
46
+ target=self.target,
47
+ template=self.template,
48
+ template_mask=self.template_mask,
49
+ )
50
+ if notest:
51
+ return instance
52
+
53
+ @pytest.mark.parametrize("method", density_to_density)
54
+ def test_call(self, method):
55
+ instance = self.test_initialization(method=method, notest=True)
56
+ score = instance()
57
+ assert isinstance(score, float)
58
+
59
+
60
+ class TestMatchDensityToCoordinates:
61
+ def setup_method(self):
62
+ data = np.zeros((50, 50, 50))
63
+ data[20:30, 30:40, 12:17] = 1
64
+ self.target = data
65
+ self.target_mask_density = data > 0
66
+ self.coordinates = np.array(np.where(self.target > 0))
67
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
68
+
69
+ np.random.seed(42)
70
+ random_pixels = np.random.choice(
71
+ range(self.coordinates.shape[1]), self.coordinates.shape[1] // 2
72
+ )
73
+ self.coordinates_mask = self.coordinates[:, random_pixels]
74
+
75
+ self.origin = np.zeros(self.coordinates.shape[0])
76
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
77
+
78
+ def teardown_method(self):
79
+ self.target = None
80
+ self.target_mask_density = None
81
+ self.coordinates = None
82
+ self.coordinates_weights = None
83
+ self.coordinates_mask = None
84
+
85
+ @pytest.mark.parametrize("method", coordinate_to_density)
86
+ def test_initialization(self, method: str, notest: bool = False):
87
+ instance = create_score_object(
88
+ score=method,
89
+ target=self.target,
90
+ target_mask=self.target_mask_density,
91
+ template_coordinates=self.coordinates,
92
+ template_weights=self.coordinates_weights,
93
+ template_mask_coordinates=self.coordinates_mask,
94
+ )
95
+ if notest:
96
+ return instance
97
+
98
+ @pytest.mark.parametrize("method", coordinate_to_density)
99
+ def test_call(self, method):
100
+ instance = self.test_initialization(method=method, notest=True)
101
+ score = instance()
102
+ assert isinstance(score, float)
103
+
104
+
105
+ class TestMatchCoordinateToCoordinates:
106
+ def setup_method(self):
107
+ data = np.zeros((50, 50, 50))
108
+ data[20:30, 30:40, 12:17] = 1
109
+ self.target_coordinates = np.array(np.where(data > 0))
110
+ self.target_weights = data[tuple(self.target_coordinates)]
111
+
112
+ self.coordinates = np.array(np.where(data > 0))
113
+ self.coordinates_weights = data[tuple(self.coordinates)]
114
+
115
+ self.origin = np.zeros(self.coordinates.shape[0])
116
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
117
+
118
+ def teardown_method(self):
119
+ self.target_coordinates = None
120
+ self.target_weights = None
121
+ self.coordinates = None
122
+ self.coordinates_weights = None
123
+
124
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
125
+ def test_initialization(self, method: str, notest: bool = False):
126
+ instance = create_score_object(
127
+ score=method,
128
+ target_coordinates=self.target_coordinates,
129
+ target_weights=self.target_weights,
130
+ template_coordinates=self.coordinates,
131
+ template_weights=self.coordinates_weights,
132
+ )
133
+ if notest:
134
+ return instance
135
+
136
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
137
+ def test_call(self, method):
138
+ instance = self.test_initialization(method=method, notest=True)
139
+ score = instance()
140
+ assert isinstance(score, float)
141
+
142
+
143
+ class TestOptimizeMatch:
144
+ def setup_method(self):
145
+ data = np.zeros((50, 50, 50))
146
+ data[20:30, 30:40, 12:17] = 1
147
+ self.target = data
148
+ self.coordinates = np.array(np.where(self.target > 0))
149
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
150
+
151
+ self.origin = np.zeros(self.coordinates.shape[0])
152
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
153
+
154
+ self.score_object = MATCHING_OPTIMIZATION_REGISTER["CrossCorrelation"]
155
+ self.score_object = self.score_object(
156
+ target=self.target,
157
+ template_coordinates=self.coordinates,
158
+ template_weights=self.coordinates_weights,
159
+ )
160
+
161
+ def teardown_method(self):
162
+ self.target = None
163
+ self.coordinates = None
164
+ self.coordinates_weights = None
165
+
166
+ @pytest.mark.parametrize(
167
+ "method", ("differential_evolution", "basinhopping", "minimize")
168
+ )
169
+ @pytest.mark.parametrize("bound_translation", (True, False))
170
+ @pytest.mark.parametrize("bound_rotation", (True, False))
171
+ def test_call(self, method, bound_translation, bound_rotation):
172
+ if bound_rotation:
173
+ bound_rotation = tuple((-90, 90) for _ in range(self.target.ndim))
174
+ else:
175
+ bound_rotation = None
176
+
177
+ if bound_translation:
178
+ bound_translation = tuple((-5, 5) for _ in range(self.target.ndim))
179
+ else:
180
+ bound_translation = None
181
+
182
+ translation, rotation, score = optimize_match(
183
+ score_object=self.score_object,
184
+ optimization_method=method,
185
+ bounds_rotation=bound_rotation,
186
+ bounds_translation=bound_translation,
187
+ maxiter=10,
188
+ )
189
+ assert translation.size == self.target.ndim
190
+ assert rotation.shape[0] == self.target.ndim
191
+ assert rotation.shape[1] == self.target.ndim
192
+ assert isinstance(score, float)
193
+
194
+ if bound_translation is not None:
195
+ lower_bound = np.array([x[0] for x in bound_translation])
196
+ upper_bound = np.array([x[1] for x in bound_translation])
197
+ assert np.all(
198
+ np.logical_and(translation >= lower_bound, translation <= upper_bound)
199
+ )
200
+
201
+ if bound_rotation is not None:
202
+ angles = euler_from_rotationmatrix(rotation)
203
+ lower_bound = np.array([x[0] for x in bound_rotation])
204
+ upper_bound = np.array([x[1] for x in bound_rotation])
205
+ assert np.all(np.logical_and(angles >= lower_bound, angles <= upper_bound))
206
+
207
+ def test_call_error(self):
208
+ with pytest.raises(ValueError):
209
+ translation, rotation, score = optimize_match(
210
+ score_object=self.score_object,
211
+ optimization_method="RAISERROR",
212
+ maxiter=10,
213
+ )
214
+
215
+
216
+ class TestUtils:
217
+ def test_register_matching_optimization(self):
218
+ new_class = list(MATCHING_OPTIMIZATION_REGISTER.keys())[0]
219
+ register_matching_optimization(
220
+ match_name="new_score",
221
+ match_class=MATCHING_OPTIMIZATION_REGISTER[new_class],
222
+ )
223
+
224
+ def test_register_matching_optimization_error(self):
225
+ with pytest.raises(ValueError):
226
+ register_matching_optimization(match_name="new_score", match_class=None)
@@ -0,0 +1,326 @@
1
+ import sys
2
+ from tempfile import mkstemp
3
+ from importlib_resources import files
4
+ from itertools import combinations, chain, product
5
+
6
+ import pytest
7
+ import numpy as np
8
+ from scipy.signal import correlate
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from tme import Density
12
+ from tme.backends import backend as be
13
+ from tme.memory import MATCHING_MEMORY_REGISTRY
14
+ from tme.matching_exhaustive import _handle_traceback
15
+ from tme.matching_utils import (
16
+ compute_parallelization_schedule,
17
+ elliptical_mask,
18
+ box_mask,
19
+ tube_mask,
20
+ create_mask,
21
+ scramble_phases,
22
+ split_shape,
23
+ compute_full_convolution_index,
24
+ apply_convolution_mode,
25
+ get_rotation_matrices,
26
+ write_pickle,
27
+ load_pickle,
28
+ euler_from_rotationmatrix,
29
+ euler_to_rotationmatrix,
30
+ get_rotations_around_vector,
31
+ rotation_aligning_vectors,
32
+ _normalize_template_overflow_safe,
33
+ )
34
+
35
+ BASEPATH = files("tests.data")
36
+
37
+
38
+ class TestMatchingUtils:
39
+ def setup_method(self):
40
+ self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
41
+ self.structure_density = Density.from_structure(
42
+ filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
43
+ origin=self.density.origin,
44
+ shape=self.density.shape,
45
+ sampling_rate=self.density.sampling_rate,
46
+ )
47
+
48
+ @pytest.mark.parametrize("matching_method", list(MATCHING_MEMORY_REGISTRY.keys()))
49
+ @pytest.mark.parametrize("max_cores", range(1, 10, 3))
50
+ @pytest.mark.parametrize("max_ram", [1e5, 1e7, 1e9])
51
+ def test_compute_parallelization_schedule(
52
+ self, matching_method, max_cores, max_ram
53
+ ):
54
+ max_cores, max_ram = int(max_cores), int(max_ram)
55
+ compute_parallelization_schedule(
56
+ self.density.shape,
57
+ self.structure_density.shape,
58
+ matching_method=matching_method,
59
+ max_cores=max_cores,
60
+ max_ram=max_ram,
61
+ max_splits=256,
62
+ )
63
+
64
+ def test_create_mask(self):
65
+ create_mask(
66
+ mask_type="ellipse",
67
+ shape=self.density.shape,
68
+ radius=5,
69
+ center=np.divide(self.density.shape, 2),
70
+ )
71
+
72
+ def test_create_mask_error(self):
73
+ with pytest.raises(ValueError):
74
+ create_mask(mask_type=None)
75
+
76
+ def test_elliptical_mask(self):
77
+ elliptical_mask(
78
+ shape=self.density.shape,
79
+ radius=5,
80
+ center=np.divide(self.density.shape, 2),
81
+ )
82
+
83
+ def test_box_mask(self):
84
+ box_mask(
85
+ shape=self.density.shape,
86
+ height=[5, 10, 20],
87
+ center=np.divide(self.density.shape, 2),
88
+ )
89
+
90
+ def test_tube_mask(self):
91
+ tube_mask(
92
+ shape=self.density.shape,
93
+ outer_radius=10,
94
+ inner_radius=5,
95
+ height=5,
96
+ base_center=np.divide(self.density.shape, 2),
97
+ symmetry_axis=1,
98
+ )
99
+
100
+ def test_tube_mask_error(self):
101
+ with pytest.raises(ValueError):
102
+ tube_mask(
103
+ shape=self.density.shape,
104
+ outer_radius=5,
105
+ inner_radius=10,
106
+ height=5,
107
+ base_center=np.divide(self.density.shape, 2),
108
+ symmetry_axis=1,
109
+ )
110
+
111
+ with pytest.raises(ValueError):
112
+ tube_mask(
113
+ shape=self.density.shape,
114
+ outer_radius=5,
115
+ inner_radius=10,
116
+ height=10 * np.max(self.density.shape),
117
+ base_center=np.divide(self.density.shape, 2),
118
+ symmetry_axis=1,
119
+ )
120
+
121
+ with pytest.raises(ValueError):
122
+ tube_mask(
123
+ shape=self.density.shape,
124
+ outer_radius=5,
125
+ inner_radius=10,
126
+ height=10 * np.max(self.density.shape),
127
+ base_center=np.divide(self.density.shape, 2),
128
+ symmetry_axis=len(self.density.shape) + 1,
129
+ )
130
+
131
+ def test_scramble_phases(self):
132
+ scramble_phases(arr=self.density.data, noise_proportion=0.5)
133
+
134
+ @pytest.mark.parametrize("dim", range(1, 3, 5))
135
+ @pytest.mark.parametrize("angular_sampling", [10, 15, 20])
136
+ def test_get_rotation_matrices(self, dim, angular_sampling):
137
+ rotation_matrices = get_rotation_matrices(
138
+ angular_sampling=angular_sampling, dim=dim
139
+ )
140
+ assert np.allclose(rotation_matrices[0] @ rotation_matrices[0].T, np.eye(dim))
141
+
142
+ def test_split_correlation(self):
143
+ arr1 = elliptical_mask(shape=(50, 51), center=(20, 30), radius=5)
144
+
145
+ arr2 = elliptical_mask(shape=(41, 36), center=(25, 20), radius=5)
146
+ s = range(arr1.ndim)
147
+ outer_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
148
+
149
+ s = range(arr2.ndim)
150
+ inner_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
151
+
152
+ outer_splits = [dict(zip(i, [2] * len(i))) for i in list(outer_split)]
153
+ inner_splits = [dict(zip(i, [2] * len(i))) for i in list(inner_split)]
154
+
155
+ for outer_split, inner_split in product(outer_splits, inner_splits):
156
+ splits1 = split_shape(
157
+ shape=arr1.shape, splits=outer_split, equal_shape=False
158
+ )
159
+ splits2 = split_shape(
160
+ shape=arr2.shape, splits=inner_split, equal_shape=False
161
+ )
162
+
163
+ full = correlate(arr1, arr2, method="direct", mode="full")
164
+ temp = np.zeros_like(full)
165
+
166
+ for arr1_split, arr2_split in product(splits1, splits2):
167
+ correlation = correlate(
168
+ arr1[arr1_split], arr2[arr2_split], method="direct", mode="full"
169
+ )
170
+ score_slice = compute_full_convolution_index(
171
+ arr1.shape, arr2.shape, arr1_split, arr2_split
172
+ )
173
+ temp[score_slice] += correlation
174
+
175
+ assert np.allclose(temp, full)
176
+
177
+ @pytest.mark.parametrize("convolution_mode", ["full", "valid", "same"])
178
+ def test_apply_convolution_mode(self, convolution_mode):
179
+ correlation = correlate(
180
+ self.density.data, self.structure_density.data, method="direct", mode="full"
181
+ )
182
+ ret = apply_convolution_mode(
183
+ arr=correlation,
184
+ convolution_mode=convolution_mode,
185
+ s1=self.density.shape,
186
+ s2=self.structure_density.shape,
187
+ )
188
+ if convolution_mode == "full":
189
+ expected_size = correlation.shape
190
+ elif convolution_mode == "same":
191
+ expected_size = self.density.shape
192
+ else:
193
+ expected_size = np.subtract(
194
+ self.density.shape, self.structure_density.shape
195
+ )
196
+ expected_size += np.mod(self.structure_density.shape, 2)
197
+ assert np.allclose(ret.shape, expected_size)
198
+
199
+ def test_apply_convolution_mode_error(self):
200
+ correlation = correlate(
201
+ self.density.data, self.structure_density.data, method="direct", mode="full"
202
+ )
203
+ with pytest.raises(ValueError):
204
+ _ = apply_convolution_mode(
205
+ arr=correlation,
206
+ convolution_mode=None,
207
+ s1=self.density.shape,
208
+ s2=self.structure_density.shape,
209
+ )
210
+
211
+ def test_handle_traceback(self):
212
+ try:
213
+ raise ValueError("Test error")
214
+ except Exception:
215
+ type_, value_, traceback_ = sys.exc_info()
216
+ with pytest.raises(Exception, match="Test error"):
217
+ _handle_traceback(type_, value_, traceback_)
218
+
219
+ def test_pickle_io(self):
220
+ _, filename = mkstemp()
221
+
222
+ data = ["Hello", 123, np.array([1, 2, 3])]
223
+ write_pickle(data=data, filename=filename)
224
+ loaded_data = load_pickle(filename)
225
+ assert all([np.array_equal(a, b) for a, b in zip(data, loaded_data)])
226
+
227
+ data = 42
228
+ write_pickle(data=data, filename=filename)
229
+ loaded_data = load_pickle(filename)
230
+ assert loaded_data == data
231
+
232
+ _, filename = mkstemp()
233
+ data = np.memmap(filename, dtype="float32", mode="w+", shape=(3,))
234
+ data[:] = [1.1, 2.2, 3.3]
235
+ data.flush()
236
+ data = np.memmap(filename, dtype="float32", mode="r+", shape=(3,))
237
+ _, filename = mkstemp()
238
+ write_pickle(data=data, filename=filename)
239
+ loaded_data = load_pickle(filename)
240
+ assert np.array_equal(loaded_data, data)
241
+
242
+ @pytest.mark.parametrize(
243
+ "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
244
+ [
245
+ (30, 5, 360, None, (1, 0, 0), 1, None),
246
+ (45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
247
+ (60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
248
+ ],
249
+ )
250
+ def test_get_rotations_around_vector(
251
+ self,
252
+ cone_angle,
253
+ cone_sampling,
254
+ axis_angle,
255
+ axis_sampling,
256
+ vector,
257
+ n_symmetry,
258
+ convention,
259
+ ):
260
+ result = get_rotations_around_vector(
261
+ cone_angle,
262
+ cone_sampling,
263
+ axis_angle,
264
+ axis_sampling,
265
+ vector,
266
+ n_symmetry,
267
+ convention,
268
+ )
269
+
270
+ assert isinstance(result, np.ndarray)
271
+ if convention is None:
272
+ assert result.shape[1:] == (3, 3)
273
+ else:
274
+ assert result.shape[1] == 3
275
+
276
+ @pytest.mark.parametrize(
277
+ "initial_vector, target_vector, convention",
278
+ [
279
+ ([1, 0, 0], [0, 1, 0], None),
280
+ ([0, 1, 0], [0, 0, 1], "zyx"),
281
+ ([1, 1, 1], [1, 0, 0], "xyz"),
282
+ ],
283
+ )
284
+ def test_rotation_aligning_vectors(self, initial_vector, target_vector, convention):
285
+ result = rotation_aligning_vectors(initial_vector, target_vector, convention)
286
+
287
+ assert isinstance(result, np.ndarray)
288
+ if convention is None:
289
+ assert result.shape == (3, 3)
290
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
291
+ else:
292
+ assert len(result) == 3
293
+ result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
294
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
295
+
296
+ rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
297
+ assert np.allclose(
298
+ rotated / np.linalg.norm(rotated),
299
+ target_vector / np.linalg.norm(target_vector),
300
+ atol=1e-6,
301
+ )
302
+
303
+ def test_normalize_template_overflow_safe(self):
304
+ template = be.random.random((10, 10)).astype(be.float32)
305
+ mask = be.ones_like(template)
306
+ n_observations = 100.0
307
+
308
+ result = _normalize_template_overflow_safe(template, mask, n_observations)
309
+ assert result.shape == template.shape
310
+ assert result.dtype == template.dtype
311
+ assert np.allclose(result.mean(), 0, atol=0.1)
312
+ assert np.allclose(result.std(), 1, atol=0.1)
313
+
314
+ def test_euler_conversion(self):
315
+ rotation_matrix_initial = np.array(
316
+ [
317
+ [0.35355339, 0.61237244, -0.70710678],
318
+ [-0.8660254, 0.5, -0.0],
319
+ [0.35355339, 0.61237244, 0.70710678],
320
+ ]
321
+ )
322
+ euler_angles = euler_from_rotationmatrix(rotation_matrix_initial)
323
+ rotation_matrix_converted = euler_to_rotationmatrix(euler_angles)
324
+ assert np.allclose(
325
+ rotation_matrix_initial, rotation_matrix_converted, atol=1e-6
326
+ )