pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
  2. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.4.data/scripts/preprocess.py +148 -0
  4. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
  5. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
  6. pytme-0.2.4.dist-info/RECORD +119 -0
  7. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  8. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  9. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  10. scripts/match_template.py +97 -148
  11. scripts/postprocess.py +20 -29
  12. scripts/preprocess.py +116 -61
  13. scripts/preprocessor_gui.py +15 -23
  14. tests/__init__.py +0 -0
  15. tests/data/.DS_Store +0 -0
  16. tests/data/Blurring/.DS_Store +0 -0
  17. tests/data/Blurring/blob_width18.npy +0 -0
  18. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  19. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  20. tests/data/Blurring/hamming_width6.npy +0 -0
  21. tests/data/Blurring/kaiserb_width18.npy +0 -0
  22. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  23. tests/data/Blurring/mean_size5.npy +0 -0
  24. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  25. tests/data/Blurring/rank_rank3.npy +0 -0
  26. tests/data/Maps/.DS_Store +0 -0
  27. tests/data/Maps/emd_8621.mrc.gz +0 -0
  28. tests/data/README.md +2 -0
  29. tests/data/Raw/.DS_Store +0 -0
  30. tests/data/Raw/em_map.map +0 -0
  31. tests/data/Structures/.DS_Store +0 -0
  32. tests/data/Structures/1pdj.cif +3339 -0
  33. tests/data/Structures/1pdj.pdb +1429 -0
  34. tests/data/Structures/5khe.cif +3685 -0
  35. tests/data/Structures/5khe.ent +2210 -0
  36. tests/data/Structures/5khe.pdb +2210 -0
  37. tests/data/Structures/5uz4.cif +70548 -0
  38. tests/preprocessing/__init__.py +0 -0
  39. tests/preprocessing/test_compose.py +76 -0
  40. tests/preprocessing/test_frequency_filters.py +178 -0
  41. tests/preprocessing/test_preprocessor.py +136 -0
  42. tests/preprocessing/test_utils.py +79 -0
  43. tests/test_analyzer.py +310 -0
  44. tests/test_backends.py +375 -0
  45. tests/test_density.py +508 -0
  46. tests/test_extensions.py +130 -0
  47. tests/test_matching_cli.py +283 -0
  48. tests/test_matching_data.py +162 -0
  49. tests/test_matching_exhaustive.py +162 -0
  50. tests/test_matching_memory.py +30 -0
  51. tests/test_matching_optimization.py +276 -0
  52. tests/test_matching_utils.py +326 -0
  53. tests/test_orientations.py +173 -0
  54. tests/test_packaging.py +95 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_structure.py +243 -0
  57. tme/__init__.py +0 -1
  58. tme/__version__.py +1 -1
  59. tme/analyzer.py +9 -6
  60. tme/backends/__init__.py +1 -1
  61. tme/backends/_jax_utils.py +10 -8
  62. tme/backends/cupy_backend.py +2 -7
  63. tme/backends/jax_backend.py +35 -20
  64. tme/backends/npfftw_backend.py +3 -2
  65. tme/backends/pytorch_backend.py +10 -7
  66. tme/data/scattering_factors.pickle +0 -0
  67. tme/density.py +26 -12
  68. tme/extensions.cpython-311-darwin.so +0 -0
  69. tme/external/bindings.cpp +332 -0
  70. tme/matching_data.py +33 -24
  71. tme/matching_exhaustive.py +39 -20
  72. tme/matching_scores.py +5 -2
  73. tme/matching_utils.py +8 -2
  74. tme/orientations.py +26 -9
  75. tme/preprocessing/_utils.py +14 -14
  76. tme/preprocessing/composable_filter.py +5 -4
  77. tme/preprocessing/compose.py +4 -4
  78. tme/preprocessing/frequency_filters.py +32 -35
  79. tme/preprocessing/tilt_series.py +210 -148
  80. tme/preprocessor.py +24 -246
  81. tme/structure.py +14 -14
  82. pytme-0.2.2.dist-info/RECORD +0 -74
  83. tme/matching_memory.py +0 -383
  84. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  85. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  86. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,276 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from tme.matching_utils import euler_from_rotationmatrix
5
+ from tme.matching_optimization import (
6
+ MATCHING_OPTIMIZATION_REGISTER,
7
+ register_matching_optimization,
8
+ _MatchCoordinatesToDensity,
9
+ _MatchCoordinatesToCoordinates,
10
+ optimize_match,
11
+ create_score_object,
12
+ )
13
+
14
+ density_to_density = ["FLC"]
15
+
16
+ coordinate_to_density = [
17
+ k
18
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
19
+ if issubclass(v, _MatchCoordinatesToDensity)
20
+ ]
21
+
22
+ coordinate_to_coordinate = [
23
+ k
24
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
25
+ if issubclass(v, _MatchCoordinatesToCoordinates)
26
+ ]
27
+
28
+
29
+ class TestMatchDensityToDensity:
30
+ def setup_method(self):
31
+ target = np.zeros((50, 50, 50))
32
+ target[20:30, 30:40, 12:17] = 1
33
+ self.target = target
34
+ self.template = target.copy()
35
+ self.template_mask = np.ones_like(target)
36
+
37
+ def teardown_method(self):
38
+ self.target = None
39
+ self.template = None
40
+ self.template_mask = None
41
+
42
+ @pytest.mark.parametrize("method", density_to_density)
43
+ def test_initialization(self, method: str, notest: bool = False):
44
+ instance = create_score_object(
45
+ score=method,
46
+ target=self.target,
47
+ template=self.template,
48
+ template_mask=self.template_mask,
49
+ )
50
+ if notest:
51
+ return instance
52
+
53
+ @pytest.mark.parametrize("method", density_to_density)
54
+ def test_call(self, method):
55
+ instance = self.test_initialization(method=method, notest=True)
56
+ score = instance()
57
+ assert isinstance(score, float)
58
+
59
+
60
+ class TestMatchDensityToCoordinates:
61
+ def setup_method(self):
62
+ data = np.zeros((50, 50, 50))
63
+ data[20:30, 30:40, 12:17] = 1
64
+ self.target = data
65
+ self.target_mask_density = data > 0
66
+ self.coordinates = np.array(np.where(self.target > 0))
67
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
68
+
69
+ np.random.seed(42)
70
+ random_pixels = np.random.choice(
71
+ range(self.coordinates.shape[1]), self.coordinates.shape[1] // 2
72
+ )
73
+ self.coordinates_mask = self.coordinates[:, random_pixels]
74
+
75
+ self.origin = np.zeros(self.coordinates.shape[0])
76
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
77
+
78
+ def teardown_method(self):
79
+ self.target = None
80
+ self.target_mask_density = None
81
+ self.coordinates = None
82
+ self.coordinates_weights = None
83
+ self.coordinates_mask = None
84
+
85
+ @pytest.mark.parametrize("method", coordinate_to_density)
86
+ def test_initialization(self, method: str, notest: bool = False):
87
+ instance = create_score_object(
88
+ score=method,
89
+ target=self.target,
90
+ target_mask=self.target_mask_density,
91
+ template_coordinates=self.coordinates,
92
+ template_weights=self.coordinates_weights,
93
+ template_mask_coordinates=self.coordinates_mask,
94
+ )
95
+ if notest:
96
+ return instance
97
+
98
+ @pytest.mark.parametrize("method", coordinate_to_density)
99
+ def test_call(self, method):
100
+ instance = self.test_initialization(method=method, notest=True)
101
+ score = instance()
102
+ assert isinstance(score, float)
103
+
104
+ def test_map_coordinates_to_array(self):
105
+ ret = _MatchCoordinatesToDensity.map_coordinates_to_array(
106
+ coordinates=self.coordinates.astype(np.float32),
107
+ array_shape=self.target.shape,
108
+ array_origin=np.zeros(self.target.ndim),
109
+ sampling_rate=np.ones(self.target.ndim),
110
+ )
111
+ assert len(ret) == 2
112
+
113
+ in_vol, in_vol_mask = ret
114
+
115
+ assert in_vol_mask is None
116
+ assert np.allclose(in_vol.shape, self.coordinates.shape[1])
117
+
118
+ def test_map_coordinates_to_array_mask(self):
119
+ ret = _MatchCoordinatesToDensity.map_coordinates_to_array(
120
+ coordinates=self.coordinates.astype(np.float32),
121
+ array_shape=self.target.shape,
122
+ array_origin=self.origin,
123
+ sampling_rate=self.sampling_rate,
124
+ coordinates_mask=self.coordinates.astype(np.float32),
125
+ )
126
+ assert len(ret) == 2
127
+
128
+ in_vol, in_vol_mask = ret
129
+ assert np.allclose(in_vol, in_vol_mask)
130
+
131
+ def test_array_from_coordinates(self):
132
+ ret = _MatchCoordinatesToDensity.array_from_coordinates(
133
+ coordinates=self.coordinates,
134
+ weights=self.coordinates_weights,
135
+ sampling_rate=self.sampling_rate,
136
+ )
137
+ assert len(ret) == 3
138
+ arr, positions, origin = ret
139
+ assert arr.ndim == self.coordinates.shape[0]
140
+ assert positions.shape == self.coordinates.shape
141
+ assert origin.shape == (self.coordinates.shape[0],)
142
+
143
+ assert np.allclose(origin, self.coordinates.min(axis=1))
144
+
145
+ ret = _MatchCoordinatesToDensity.array_from_coordinates(
146
+ coordinates=self.coordinates,
147
+ weights=self.coordinates_weights,
148
+ sampling_rate=self.sampling_rate,
149
+ origin=self.origin,
150
+ )
151
+ arr, positions, origin = ret
152
+ assert np.allclose(origin, self.origin)
153
+
154
+
155
+ class TestMatchCoordinateToCoordinates:
156
+ def setup_method(self):
157
+ data = np.zeros((50, 50, 50))
158
+ data[20:30, 30:40, 12:17] = 1
159
+ self.target_coordinates = np.array(np.where(data > 0))
160
+ self.target_weights = data[tuple(self.target_coordinates)]
161
+
162
+ self.coordinates = np.array(np.where(data > 0))
163
+ self.coordinates_weights = data[tuple(self.coordinates)]
164
+
165
+ self.origin = np.zeros(self.coordinates.shape[0])
166
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
167
+
168
+ def teardown_method(self):
169
+ self.target_coordinates = None
170
+ self.target_weights = None
171
+ self.coordinates = None
172
+ self.coordinates_weights = None
173
+
174
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
175
+ def test_initialization(self, method: str, notest: bool = False):
176
+ instance = create_score_object(
177
+ score=method,
178
+ target_coordinates=self.target_coordinates,
179
+ target_weights=self.target_weights,
180
+ template_coordinates=self.coordinates,
181
+ template_weights=self.coordinates_weights,
182
+ )
183
+ if notest:
184
+ return instance
185
+
186
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
187
+ def test_call(self, method):
188
+ instance = self.test_initialization(method=method, notest=True)
189
+ score = instance()
190
+ assert isinstance(score, float)
191
+
192
+
193
+ class TestOptimizeMatch:
194
+ def setup_method(self):
195
+ data = np.zeros((50, 50, 50))
196
+ data[20:30, 30:40, 12:17] = 1
197
+ self.target = data
198
+ self.coordinates = np.array(np.where(self.target > 0))
199
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
200
+
201
+ self.origin = np.zeros(self.coordinates.shape[0])
202
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
203
+
204
+ self.score_object = MATCHING_OPTIMIZATION_REGISTER["CrossCorrelation"]
205
+ self.score_object = self.score_object(
206
+ target=self.target,
207
+ template_coordinates=self.coordinates,
208
+ template_weights=self.coordinates_weights,
209
+ )
210
+
211
+ def teardown_method(self):
212
+ self.target = None
213
+ self.coordinates = None
214
+ self.coordinates_weights = None
215
+
216
+ @pytest.mark.parametrize(
217
+ "method", ("differential_evolution", "basinhopping", "minimize")
218
+ )
219
+ @pytest.mark.parametrize("bound_translation", (True, False))
220
+ @pytest.mark.parametrize("bound_rotation", (True, False))
221
+ def test_call(self, method, bound_translation, bound_rotation):
222
+ if bound_rotation:
223
+ bound_rotation = tuple((-90, 90) for _ in range(self.target.ndim))
224
+ else:
225
+ bound_rotation = None
226
+
227
+ if bound_translation:
228
+ bound_translation = tuple((-5, 5) for _ in range(self.target.ndim))
229
+ else:
230
+ bound_translation = None
231
+
232
+ translation, rotation, score = optimize_match(
233
+ score_object=self.score_object,
234
+ optimization_method=method,
235
+ bounds_rotation=bound_rotation,
236
+ bounds_translation=bound_translation,
237
+ maxiter=10,
238
+ )
239
+ assert translation.size == self.target.ndim
240
+ assert rotation.shape[0] == self.target.ndim
241
+ assert rotation.shape[1] == self.target.ndim
242
+ assert isinstance(score, float)
243
+
244
+ if bound_translation is not None:
245
+ lower_bound = np.array([x[0] for x in bound_translation])
246
+ upper_bound = np.array([x[1] for x in bound_translation])
247
+ assert np.all(
248
+ np.logical_and(translation >= lower_bound, translation <= upper_bound)
249
+ )
250
+
251
+ if bound_rotation is not None:
252
+ angles = euler_from_rotationmatrix(rotation)
253
+ lower_bound = np.array([x[0] for x in bound_rotation])
254
+ upper_bound = np.array([x[1] for x in bound_rotation])
255
+ assert np.all(np.logical_and(angles >= lower_bound, angles <= upper_bound))
256
+
257
+ def test_call_error(self):
258
+ with pytest.raises(ValueError):
259
+ translation, rotation, score = optimize_match(
260
+ score_object=self.score_object,
261
+ optimization_method="RAISERROR",
262
+ maxiter=10,
263
+ )
264
+
265
+
266
+ class TestUtils:
267
+ def test_register_matching_optimization(self):
268
+ new_class = list(MATCHING_OPTIMIZATION_REGISTER.keys())[0]
269
+ register_matching_optimization(
270
+ match_name="new_score",
271
+ match_class=MATCHING_OPTIMIZATION_REGISTER[new_class],
272
+ )
273
+
274
+ def test_register_matching_optimization_error(self):
275
+ with pytest.raises(ValueError):
276
+ register_matching_optimization(match_name="new_score", match_class=None)
@@ -0,0 +1,326 @@
1
+ import sys
2
+ from tempfile import mkstemp
3
+ from importlib_resources import files
4
+ from itertools import combinations, chain, product
5
+
6
+ import pytest
7
+ import numpy as np
8
+ from scipy.signal import correlate
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from tme import Density
12
+ from tme.backends import backend as be
13
+ from tme.memory import MATCHING_MEMORY_REGISTRY
14
+ from tme.matching_exhaustive import _handle_traceback
15
+ from tme.matching_utils import (
16
+ compute_parallelization_schedule,
17
+ elliptical_mask,
18
+ box_mask,
19
+ tube_mask,
20
+ create_mask,
21
+ scramble_phases,
22
+ split_shape,
23
+ compute_full_convolution_index,
24
+ apply_convolution_mode,
25
+ get_rotation_matrices,
26
+ write_pickle,
27
+ load_pickle,
28
+ euler_from_rotationmatrix,
29
+ euler_to_rotationmatrix,
30
+ get_rotations_around_vector,
31
+ rotation_aligning_vectors,
32
+ _normalize_template_overflow_safe,
33
+ )
34
+
35
+ BASEPATH = files("tests.data")
36
+
37
+
38
+ class TestMatchingUtils:
39
+ def setup_method(self):
40
+ self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
41
+ self.structure_density = Density.from_structure(
42
+ filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
43
+ origin=self.density.origin,
44
+ shape=self.density.shape,
45
+ sampling_rate=self.density.sampling_rate,
46
+ )
47
+
48
+ @pytest.mark.parametrize("matching_method", list(MATCHING_MEMORY_REGISTRY.keys()))
49
+ @pytest.mark.parametrize("max_cores", range(1, 10, 3))
50
+ @pytest.mark.parametrize("max_ram", [1e5, 1e7, 1e9])
51
+ def test_compute_parallelization_schedule(
52
+ self, matching_method, max_cores, max_ram
53
+ ):
54
+ max_cores, max_ram = int(max_cores), int(max_ram)
55
+ compute_parallelization_schedule(
56
+ self.density.shape,
57
+ self.structure_density.shape,
58
+ matching_method=matching_method,
59
+ max_cores=max_cores,
60
+ max_ram=max_ram,
61
+ max_splits=256,
62
+ )
63
+
64
+ def test_create_mask(self):
65
+ create_mask(
66
+ mask_type="ellipse",
67
+ shape=self.density.shape,
68
+ radius=5,
69
+ center=np.divide(self.density.shape, 2),
70
+ )
71
+
72
+ def test_create_mask_error(self):
73
+ with pytest.raises(ValueError):
74
+ create_mask(mask_type=None)
75
+
76
+ def test_elliptical_mask(self):
77
+ elliptical_mask(
78
+ shape=self.density.shape,
79
+ radius=5,
80
+ center=np.divide(self.density.shape, 2),
81
+ )
82
+
83
+ def test_box_mask(self):
84
+ box_mask(
85
+ shape=self.density.shape,
86
+ height=[5, 10, 20],
87
+ center=np.divide(self.density.shape, 2),
88
+ )
89
+
90
+ def test_tube_mask(self):
91
+ tube_mask(
92
+ shape=self.density.shape,
93
+ outer_radius=10,
94
+ inner_radius=5,
95
+ height=5,
96
+ base_center=np.divide(self.density.shape, 2),
97
+ symmetry_axis=1,
98
+ )
99
+
100
+ def test_tube_mask_error(self):
101
+ with pytest.raises(ValueError):
102
+ tube_mask(
103
+ shape=self.density.shape,
104
+ outer_radius=5,
105
+ inner_radius=10,
106
+ height=5,
107
+ base_center=np.divide(self.density.shape, 2),
108
+ symmetry_axis=1,
109
+ )
110
+
111
+ with pytest.raises(ValueError):
112
+ tube_mask(
113
+ shape=self.density.shape,
114
+ outer_radius=5,
115
+ inner_radius=10,
116
+ height=10 * np.max(self.density.shape),
117
+ base_center=np.divide(self.density.shape, 2),
118
+ symmetry_axis=1,
119
+ )
120
+
121
+ with pytest.raises(ValueError):
122
+ tube_mask(
123
+ shape=self.density.shape,
124
+ outer_radius=5,
125
+ inner_radius=10,
126
+ height=10 * np.max(self.density.shape),
127
+ base_center=np.divide(self.density.shape, 2),
128
+ symmetry_axis=len(self.density.shape) + 1,
129
+ )
130
+
131
+ def test_scramble_phases(self):
132
+ scramble_phases(arr=self.density.data, noise_proportion=0.5)
133
+
134
+ @pytest.mark.parametrize("dim", range(1, 3, 5))
135
+ @pytest.mark.parametrize("angular_sampling", [10, 15, 20])
136
+ def test_get_rotation_matrices(self, dim, angular_sampling):
137
+ rotation_matrices = get_rotation_matrices(
138
+ angular_sampling=angular_sampling, dim=dim
139
+ )
140
+ assert np.allclose(rotation_matrices[0] @ rotation_matrices[0].T, np.eye(dim))
141
+
142
+ def test_split_correlation(self):
143
+ arr1 = elliptical_mask(shape=(50, 51), center=(20, 30), radius=5)
144
+
145
+ arr2 = elliptical_mask(shape=(41, 36), center=(25, 20), radius=5)
146
+ s = range(arr1.ndim)
147
+ outer_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
148
+
149
+ s = range(arr2.ndim)
150
+ inner_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
151
+
152
+ outer_splits = [dict(zip(i, [2] * len(i))) for i in list(outer_split)]
153
+ inner_splits = [dict(zip(i, [2] * len(i))) for i in list(inner_split)]
154
+
155
+ for outer_split, inner_split in product(outer_splits, inner_splits):
156
+ splits1 = split_shape(
157
+ shape=arr1.shape, splits=outer_split, equal_shape=False
158
+ )
159
+ splits2 = split_shape(
160
+ shape=arr2.shape, splits=inner_split, equal_shape=False
161
+ )
162
+
163
+ full = correlate(arr1, arr2, method="direct", mode="full")
164
+ temp = np.zeros_like(full)
165
+
166
+ for arr1_split, arr2_split in product(splits1, splits2):
167
+ correlation = correlate(
168
+ arr1[arr1_split], arr2[arr2_split], method="direct", mode="full"
169
+ )
170
+ score_slice = compute_full_convolution_index(
171
+ arr1.shape, arr2.shape, arr1_split, arr2_split
172
+ )
173
+ temp[score_slice] += correlation
174
+
175
+ assert np.allclose(temp, full)
176
+
177
+ @pytest.mark.parametrize("convolution_mode", ["full", "valid", "same"])
178
+ def test_apply_convolution_mode(self, convolution_mode):
179
+ correlation = correlate(
180
+ self.density.data, self.structure_density.data, method="direct", mode="full"
181
+ )
182
+ ret = apply_convolution_mode(
183
+ arr=correlation,
184
+ convolution_mode=convolution_mode,
185
+ s1=self.density.shape,
186
+ s2=self.structure_density.shape,
187
+ )
188
+ if convolution_mode == "full":
189
+ expected_size = correlation.shape
190
+ elif convolution_mode == "same":
191
+ expected_size = self.density.shape
192
+ else:
193
+ expected_size = np.subtract(
194
+ self.density.shape, self.structure_density.shape
195
+ )
196
+ expected_size += np.mod(self.structure_density.shape, 2)
197
+ assert np.allclose(ret.shape, expected_size)
198
+
199
+ def test_apply_convolution_mode_error(self):
200
+ correlation = correlate(
201
+ self.density.data, self.structure_density.data, method="direct", mode="full"
202
+ )
203
+ with pytest.raises(ValueError):
204
+ _ = apply_convolution_mode(
205
+ arr=correlation,
206
+ convolution_mode=None,
207
+ s1=self.density.shape,
208
+ s2=self.structure_density.shape,
209
+ )
210
+
211
+ def test_handle_traceback(self):
212
+ try:
213
+ raise ValueError("Test error")
214
+ except Exception:
215
+ type_, value_, traceback_ = sys.exc_info()
216
+ with pytest.raises(Exception, match="Test error"):
217
+ _handle_traceback(type_, value_, traceback_)
218
+
219
+ def test_pickle_io(self):
220
+ _, filename = mkstemp()
221
+
222
+ data = ["Hello", 123, np.array([1, 2, 3])]
223
+ write_pickle(data=data, filename=filename)
224
+ loaded_data = load_pickle(filename)
225
+ assert all([np.array_equal(a, b) for a, b in zip(data, loaded_data)])
226
+
227
+ data = 42
228
+ write_pickle(data=data, filename=filename)
229
+ loaded_data = load_pickle(filename)
230
+ assert loaded_data == data
231
+
232
+ _, filename = mkstemp()
233
+ data = np.memmap(filename, dtype="float32", mode="w+", shape=(3,))
234
+ data[:] = [1.1, 2.2, 3.3]
235
+ data.flush()
236
+ data = np.memmap(filename, dtype="float32", mode="r+", shape=(3,))
237
+ _, filename = mkstemp()
238
+ write_pickle(data=data, filename=filename)
239
+ loaded_data = load_pickle(filename)
240
+ assert np.array_equal(loaded_data, data)
241
+
242
+ @pytest.mark.parametrize(
243
+ "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
244
+ [
245
+ (30, 5, 360, None, (1, 0, 0), 1, None),
246
+ (45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
247
+ (60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
248
+ ],
249
+ )
250
+ def test_get_rotations_around_vector(
251
+ self,
252
+ cone_angle,
253
+ cone_sampling,
254
+ axis_angle,
255
+ axis_sampling,
256
+ vector,
257
+ n_symmetry,
258
+ convention,
259
+ ):
260
+ result = get_rotations_around_vector(
261
+ cone_angle,
262
+ cone_sampling,
263
+ axis_angle,
264
+ axis_sampling,
265
+ vector,
266
+ n_symmetry,
267
+ convention,
268
+ )
269
+
270
+ assert isinstance(result, np.ndarray)
271
+ if convention is None:
272
+ assert result.shape[1:] == (3, 3)
273
+ else:
274
+ assert result.shape[1] == 3
275
+
276
+ @pytest.mark.parametrize(
277
+ "initial_vector, target_vector, convention",
278
+ [
279
+ ([1, 0, 0], [0, 1, 0], None),
280
+ ([0, 1, 0], [0, 0, 1], "zyx"),
281
+ ([1, 1, 1], [1, 0, 0], "xyz"),
282
+ ],
283
+ )
284
+ def test_rotation_aligning_vectors(self, initial_vector, target_vector, convention):
285
+ result = rotation_aligning_vectors(initial_vector, target_vector, convention)
286
+
287
+ assert isinstance(result, np.ndarray)
288
+ if convention is None:
289
+ assert result.shape == (3, 3)
290
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
291
+ else:
292
+ assert len(result) == 3
293
+ result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
294
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
295
+
296
+ rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
297
+ assert np.allclose(
298
+ rotated / np.linalg.norm(rotated),
299
+ target_vector / np.linalg.norm(target_vector),
300
+ atol=1e-6,
301
+ )
302
+
303
+ def test_normalize_template_overflow_safe(self):
304
+ template = be.random.random((10, 10)).astype(be.float32)
305
+ mask = be.ones_like(template)
306
+ n_observations = 100.0
307
+
308
+ result = _normalize_template_overflow_safe(template, mask, n_observations)
309
+ assert result.shape == template.shape
310
+ assert result.dtype == template.dtype
311
+ assert np.allclose(result.mean(), 0, atol=0.1)
312
+ assert np.allclose(result.std(), 1, atol=0.1)
313
+
314
+ def test_euler_conversion(self):
315
+ rotation_matrix_initial = np.array(
316
+ [
317
+ [0.35355339, 0.61237244, -0.70710678],
318
+ [-0.8660254, 0.5, -0.0],
319
+ [0.35355339, 0.61237244, 0.70710678],
320
+ ]
321
+ )
322
+ euler_angles = euler_from_rotationmatrix(rotation_matrix_initial)
323
+ rotation_matrix_converted = euler_to_rotationmatrix(euler_angles)
324
+ assert np.allclose(
325
+ rotation_matrix_initial, rotation_matrix_converted, atol=1e-6
326
+ )