pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,226 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from tme.rotations import euler_from_rotationmatrix
5
+ from tme.matching_optimization import (
6
+ MATCHING_OPTIMIZATION_REGISTER,
7
+ register_matching_optimization,
8
+ _MatchCoordinatesToDensity,
9
+ _MatchCoordinatesToCoordinates,
10
+ optimize_match,
11
+ create_score_object,
12
+ )
13
+
14
+ density_to_density = ["FLC"]
15
+
16
+ coordinate_to_density = [
17
+ k
18
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
19
+ if issubclass(v, _MatchCoordinatesToDensity)
20
+ ]
21
+
22
+ coordinate_to_coordinate = [
23
+ k
24
+ for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
25
+ if issubclass(v, _MatchCoordinatesToCoordinates)
26
+ ]
27
+
28
+
29
+ class TestMatchDensityToDensity:
30
+ def setup_method(self):
31
+ target = np.zeros((50, 50, 50))
32
+ target[20:30, 30:40, 12:17] = 1
33
+ self.target = target
34
+ self.template = target.copy()
35
+ self.template_mask = np.ones_like(target)
36
+
37
+ def teardown_method(self):
38
+ self.target = None
39
+ self.template = None
40
+ self.template_mask = None
41
+
42
+ @pytest.mark.parametrize("method", density_to_density)
43
+ def test_initialization(self, method: str, notest: bool = False):
44
+ instance = create_score_object(
45
+ score=method,
46
+ target=self.target,
47
+ template=self.template,
48
+ template_mask=self.template_mask,
49
+ )
50
+ if notest:
51
+ return instance
52
+
53
+ @pytest.mark.parametrize("method", density_to_density)
54
+ def test_call(self, method):
55
+ instance = self.test_initialization(method=method, notest=True)
56
+ score = instance()
57
+ assert isinstance(score, float)
58
+
59
+
60
+ class TestMatchDensityToCoordinates:
61
+ def setup_method(self):
62
+ data = np.zeros((50, 50, 50))
63
+ data[20:30, 30:40, 12:17] = 1
64
+ self.target = data
65
+ self.target_mask_density = data > 0
66
+ self.coordinates = np.array(np.where(self.target > 0))
67
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
68
+
69
+ np.random.seed(42)
70
+ random_pixels = np.random.choice(
71
+ range(self.coordinates.shape[1]), self.coordinates.shape[1] // 2
72
+ )
73
+ self.coordinates_mask = self.coordinates[:, random_pixels]
74
+
75
+ self.origin = np.zeros(self.coordinates.shape[0])
76
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
77
+
78
+ def teardown_method(self):
79
+ self.target = None
80
+ self.target_mask_density = None
81
+ self.coordinates = None
82
+ self.coordinates_weights = None
83
+ self.coordinates_mask = None
84
+
85
+ @pytest.mark.parametrize("method", coordinate_to_density)
86
+ def test_initialization(self, method: str, notest: bool = False):
87
+ instance = create_score_object(
88
+ score=method,
89
+ target=self.target,
90
+ target_mask=self.target_mask_density,
91
+ template_coordinates=self.coordinates,
92
+ template_weights=self.coordinates_weights,
93
+ template_mask_coordinates=self.coordinates_mask,
94
+ )
95
+ if notest:
96
+ return instance
97
+
98
+ @pytest.mark.parametrize("method", coordinate_to_density)
99
+ def test_call(self, method):
100
+ instance = self.test_initialization(method=method, notest=True)
101
+ score = instance()
102
+ assert isinstance(score, float)
103
+
104
+
105
+ class TestMatchCoordinateToCoordinates:
106
+ def setup_method(self):
107
+ data = np.zeros((50, 50, 50))
108
+ data[20:30, 30:40, 12:17] = 1
109
+ self.target_coordinates = np.array(np.where(data > 0))
110
+ self.target_weights = data[tuple(self.target_coordinates)]
111
+
112
+ self.coordinates = np.array(np.where(data > 0))
113
+ self.coordinates_weights = data[tuple(self.coordinates)]
114
+
115
+ self.origin = np.zeros(self.coordinates.shape[0])
116
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
117
+
118
+ def teardown_method(self):
119
+ self.target_coordinates = None
120
+ self.target_weights = None
121
+ self.coordinates = None
122
+ self.coordinates_weights = None
123
+
124
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
125
+ def test_initialization(self, method: str, notest: bool = False):
126
+ instance = create_score_object(
127
+ score=method,
128
+ target_coordinates=self.target_coordinates,
129
+ target_weights=self.target_weights,
130
+ template_coordinates=self.coordinates,
131
+ template_weights=self.coordinates_weights,
132
+ )
133
+ if notest:
134
+ return instance
135
+
136
+ @pytest.mark.parametrize("method", coordinate_to_coordinate)
137
+ def test_call(self, method):
138
+ instance = self.test_initialization(method=method, notest=True)
139
+ score = instance()
140
+ assert isinstance(score, float)
141
+
142
+
143
+ class TestOptimizeMatch:
144
+ def setup_method(self):
145
+ data = np.zeros((50, 50, 50))
146
+ data[20:30, 30:40, 12:17] = 1
147
+ self.target = data
148
+ self.coordinates = np.array(np.where(self.target > 0))
149
+ self.coordinates_weights = self.target[tuple(self.coordinates)]
150
+
151
+ self.origin = np.zeros(self.coordinates.shape[0])
152
+ self.sampling_rate = np.ones(self.coordinates.shape[0])
153
+
154
+ self.score_object = MATCHING_OPTIMIZATION_REGISTER["CrossCorrelation"]
155
+ self.score_object = self.score_object(
156
+ target=self.target,
157
+ template_coordinates=self.coordinates,
158
+ template_weights=self.coordinates_weights,
159
+ )
160
+
161
+ def teardown_method(self):
162
+ self.target = None
163
+ self.coordinates = None
164
+ self.coordinates_weights = None
165
+
166
+ @pytest.mark.parametrize(
167
+ "method", ("differential_evolution", "basinhopping", "minimize")
168
+ )
169
+ @pytest.mark.parametrize("bound_translation", (True, False))
170
+ @pytest.mark.parametrize("bound_rotation", (True, False))
171
+ def test_call(self, method, bound_translation, bound_rotation):
172
+ if bound_rotation:
173
+ bound_rotation = tuple((-90, 90) for _ in range(self.target.ndim))
174
+ else:
175
+ bound_rotation = None
176
+
177
+ if bound_translation:
178
+ bound_translation = tuple((-5, 5) for _ in range(self.target.ndim))
179
+ else:
180
+ bound_translation = None
181
+
182
+ translation, rotation, score = optimize_match(
183
+ score_object=self.score_object,
184
+ optimization_method=method,
185
+ bounds_rotation=bound_rotation,
186
+ bounds_translation=bound_translation,
187
+ maxiter=10,
188
+ )
189
+ assert translation.size == self.target.ndim
190
+ assert rotation.shape[0] == self.target.ndim
191
+ assert rotation.shape[1] == self.target.ndim
192
+ assert isinstance(score, float)
193
+
194
+ if bound_translation is not None:
195
+ lower_bound = np.array([x[0] for x in bound_translation])
196
+ upper_bound = np.array([x[1] for x in bound_translation])
197
+ assert np.all(
198
+ np.logical_and(translation >= lower_bound, translation <= upper_bound)
199
+ )
200
+
201
+ if bound_rotation is not None:
202
+ angles = euler_from_rotationmatrix(rotation)
203
+ lower_bound = np.array([x[0] for x in bound_rotation])
204
+ upper_bound = np.array([x[1] for x in bound_rotation])
205
+ assert np.all(np.logical_and(angles >= lower_bound, angles <= upper_bound))
206
+
207
+ def test_call_error(self):
208
+ with pytest.raises(ValueError):
209
+ translation, rotation, score = optimize_match(
210
+ score_object=self.score_object,
211
+ optimization_method="RAISERROR",
212
+ maxiter=10,
213
+ )
214
+
215
+
216
+ class TestUtils:
217
+ def test_register_matching_optimization(self):
218
+ new_class = list(MATCHING_OPTIMIZATION_REGISTER.keys())[0]
219
+ register_matching_optimization(
220
+ match_name="new_score",
221
+ match_class=MATCHING_OPTIMIZATION_REGISTER[new_class],
222
+ )
223
+
224
+ def test_register_matching_optimization_error(self):
225
+ with pytest.raises(ValueError):
226
+ register_matching_optimization(match_name="new_score", match_class=None)
@@ -0,0 +1,189 @@
1
+ from tempfile import mkstemp
2
+ from importlib_resources import files
3
+
4
+ import pytest
5
+ import numpy as np
6
+ from scipy.signal import correlate
7
+
8
+ from tme import Density
9
+ from tme.backends import backend as be
10
+ from tme.memory import MATCHING_MEMORY_REGISTRY
11
+ from tme.matching_utils import (
12
+ compute_parallelization_schedule,
13
+ elliptical_mask,
14
+ box_mask,
15
+ tube_mask,
16
+ create_mask,
17
+ scramble_phases,
18
+ apply_convolution_mode,
19
+ write_pickle,
20
+ load_pickle,
21
+ _normalize_template_overflow_safe,
22
+ )
23
+
24
+ BASEPATH = files("tests.data")
25
+
26
+
27
+ class TestMatchingUtils:
28
+ def setup_method(self):
29
+ self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
30
+ self.structure_density = Density.from_structure(
31
+ filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
32
+ origin=self.density.origin,
33
+ shape=self.density.shape,
34
+ sampling_rate=self.density.sampling_rate,
35
+ )
36
+
37
+ @pytest.mark.parametrize("matching_method", list(MATCHING_MEMORY_REGISTRY.keys()))
38
+ @pytest.mark.parametrize("max_cores", range(1, 10, 3))
39
+ @pytest.mark.parametrize("max_ram", [1e5, 1e7, 1e9])
40
+ def test_compute_parallelization_schedule(
41
+ self, matching_method, max_cores, max_ram
42
+ ):
43
+ max_cores, max_ram = int(max_cores), int(max_ram)
44
+ compute_parallelization_schedule(
45
+ self.density.shape,
46
+ self.structure_density.shape,
47
+ matching_method=matching_method,
48
+ max_cores=max_cores,
49
+ max_ram=max_ram,
50
+ max_splits=256,
51
+ )
52
+
53
+ def test_create_mask(self):
54
+ create_mask(
55
+ mask_type="ellipse",
56
+ shape=self.density.shape,
57
+ radius=5,
58
+ center=np.divide(self.density.shape, 2),
59
+ )
60
+
61
+ def test_create_mask_error(self):
62
+ with pytest.raises(ValueError):
63
+ create_mask(mask_type=None)
64
+
65
+ def test_elliptical_mask(self):
66
+ elliptical_mask(
67
+ shape=self.density.shape,
68
+ radius=5,
69
+ center=np.divide(self.density.shape, 2),
70
+ )
71
+
72
+ def test_box_mask(self):
73
+ box_mask(
74
+ shape=self.density.shape,
75
+ height=[5, 10, 20],
76
+ center=np.divide(self.density.shape, 2),
77
+ )
78
+
79
+ def test_tube_mask(self):
80
+ tube_mask(
81
+ shape=self.density.shape,
82
+ outer_radius=10,
83
+ inner_radius=5,
84
+ height=5,
85
+ base_center=np.divide(self.density.shape, 2),
86
+ symmetry_axis=1,
87
+ )
88
+
89
+ def test_tube_mask_error(self):
90
+ with pytest.raises(ValueError):
91
+ tube_mask(
92
+ shape=self.density.shape,
93
+ outer_radius=5,
94
+ inner_radius=10,
95
+ height=5,
96
+ base_center=np.divide(self.density.shape, 2),
97
+ symmetry_axis=1,
98
+ )
99
+
100
+ with pytest.raises(ValueError):
101
+ tube_mask(
102
+ shape=self.density.shape,
103
+ outer_radius=5,
104
+ inner_radius=10,
105
+ height=10 * np.max(self.density.shape),
106
+ base_center=np.divide(self.density.shape, 2),
107
+ symmetry_axis=1,
108
+ )
109
+
110
+ with pytest.raises(ValueError):
111
+ tube_mask(
112
+ shape=self.density.shape,
113
+ outer_radius=5,
114
+ inner_radius=10,
115
+ height=10 * np.max(self.density.shape),
116
+ base_center=np.divide(self.density.shape, 2),
117
+ symmetry_axis=len(self.density.shape) + 1,
118
+ )
119
+
120
+ def test_scramble_phases(self):
121
+ scramble_phases(arr=self.density.data, noise_proportion=0.5)
122
+
123
+ @pytest.mark.parametrize("convolution_mode", ["full", "valid", "same"])
124
+ def test_apply_convolution_mode(self, convolution_mode):
125
+ correlation = correlate(
126
+ self.density.data, self.structure_density.data, method="direct", mode="full"
127
+ )
128
+ ret = apply_convolution_mode(
129
+ arr=correlation,
130
+ convolution_mode=convolution_mode,
131
+ s1=self.density.shape,
132
+ s2=self.structure_density.shape,
133
+ )
134
+ if convolution_mode == "full":
135
+ expected_size = correlation.shape
136
+ elif convolution_mode == "same":
137
+ expected_size = self.density.shape
138
+ else:
139
+ expected_size = np.subtract(
140
+ self.density.shape, self.structure_density.shape
141
+ )
142
+ expected_size += np.mod(self.structure_density.shape, 2)
143
+ assert np.allclose(ret.shape, expected_size)
144
+
145
+ def test_apply_convolution_mode_error(self):
146
+ correlation = correlate(
147
+ self.density.data, self.structure_density.data, method="direct", mode="full"
148
+ )
149
+ with pytest.raises(ValueError):
150
+ _ = apply_convolution_mode(
151
+ arr=correlation,
152
+ convolution_mode=None,
153
+ s1=self.density.shape,
154
+ s2=self.structure_density.shape,
155
+ )
156
+
157
+ def test_pickle_io(self):
158
+ _, filename = mkstemp()
159
+
160
+ data = ["Hello", 123, np.array([1, 2, 3])]
161
+ write_pickle(data=data, filename=filename)
162
+ loaded_data = load_pickle(filename)
163
+ assert all([np.array_equal(a, b) for a, b in zip(data, loaded_data)])
164
+
165
+ data = 42
166
+ write_pickle(data=data, filename=filename)
167
+ loaded_data = load_pickle(filename)
168
+ assert loaded_data == data
169
+
170
+ _, filename = mkstemp()
171
+ data = np.memmap(filename, dtype="float32", mode="w+", shape=(3,))
172
+ data[:] = [1.1, 2.2, 3.3]
173
+ data.flush()
174
+ data = np.memmap(filename, dtype="float32", mode="r+", shape=(3,))
175
+ _, filename = mkstemp()
176
+ write_pickle(data=data, filename=filename)
177
+ loaded_data = load_pickle(filename)
178
+ assert np.array_equal(loaded_data, data)
179
+
180
+ def test_normalize_template_overflow_safe(self):
181
+ template = be.random.random((10, 10)).astype(be.float32)
182
+ mask = be.ones_like(template)
183
+ n_observations = 100.0
184
+
185
+ result = _normalize_template_overflow_safe(template, mask, n_observations)
186
+ assert result.shape == template.shape
187
+ assert result.dtype == template.dtype
188
+ assert np.allclose(result.mean(), 0, atol=0.1)
189
+ assert np.allclose(result.std(), 1, atol=0.1)
@@ -0,0 +1,175 @@
1
+ from tempfile import mkstemp
2
+
3
+ import pytest
4
+ import numpy as np
5
+
6
+ from tme import Orientations
7
+
8
+
9
+ class TestDensity:
10
+ def setup_method(self):
11
+ self.translations = np.random.rand(100, 3).astype(np.float32)
12
+ self.rotations = np.random.rand(100, 3).astype(np.float32)
13
+ self.scores = np.random.rand(100).astype(np.float32)
14
+ self.details = np.random.rand(100).astype(np.float32)
15
+
16
+ self.orientations = Orientations(
17
+ translations=self.translations,
18
+ rotations=self.rotations,
19
+ scores=self.scores,
20
+ details=self.details,
21
+ )
22
+
23
+ def teardown_method(self):
24
+ self.translations = None
25
+ self.rotations = None
26
+ self.scores = None
27
+ self.details = None
28
+ self.orientations = None
29
+
30
+ def test_initialization(self):
31
+ orientations = Orientations(
32
+ translations=self.translations,
33
+ rotations=self.rotations,
34
+ scores=self.scores,
35
+ details=self.details,
36
+ )
37
+
38
+ assert np.array_equal(self.translations, orientations.translations)
39
+ assert np.array_equal(self.rotations, orientations.rotations)
40
+ assert np.array_equal(self.scores, orientations.scores)
41
+ assert np.array_equal(self.details, orientations.details)
42
+
43
+ def test_initialization_type(self):
44
+ orientations = Orientations(
45
+ translations=self.translations.astype(int),
46
+ rotations=self.rotations.astype(int),
47
+ scores=self.scores.astype(int),
48
+ details=self.details.astype(int),
49
+ )
50
+ assert np.issubdtype(orientations.translations.dtype, np.floating)
51
+ assert np.issubdtype(orientations.rotations.dtype, np.floating)
52
+ assert np.issubdtype(orientations.scores.dtype, np.floating)
53
+ assert np.issubdtype(orientations.details.dtype, np.floating)
54
+
55
+ def test_initialization_error(self):
56
+ with pytest.raises(ValueError):
57
+ _ = Orientations(
58
+ translations=self.translations,
59
+ rotations=np.random.rand(self.translations.shape[0] + 1),
60
+ scores=self.scores,
61
+ details=self.details,
62
+ )
63
+
64
+ with pytest.raises(ValueError):
65
+ _ = Orientations(
66
+ translations=np.random.rand(self.translations.shape[0]),
67
+ rotations=np.random.rand(self.translations.shape[0] + 1),
68
+ scores=self.scores,
69
+ details=self.details,
70
+ )
71
+ _ = Orientations(
72
+ translations=self.translations,
73
+ rotations=np.random.rand(self.translations.shape[0]),
74
+ scores=self.scores,
75
+ details=self.details,
76
+ )
77
+
78
+ assert True
79
+
80
+ @pytest.mark.parametrize("file_format", ("text", "relion", "tbl"))
81
+ def test_to_file(self, file_format: str):
82
+ _, output_file = mkstemp(suffix=f".{file_format}")
83
+ self.orientations.to_file(output_file)
84
+ assert True
85
+
86
+ @pytest.mark.parametrize("file_format", ("text", "star", "tbl"))
87
+ def test_from_file(self, file_format: str):
88
+ _, output_file = mkstemp(suffix=f".{file_format}")
89
+ self.orientations.to_file(output_file)
90
+ orientations_new = Orientations.from_file(output_file)
91
+
92
+ assert np.allclose(
93
+ self.orientations.translations, orientations_new.translations
94
+ )
95
+ assert np.allclose(
96
+ self.orientations.rotations, orientations_new.rotations, atol=1e-3
97
+ )
98
+
99
+ @pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
100
+ @pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
101
+ def test_file_format_io(self, input_format: str, output_format: str):
102
+ _, output_file = mkstemp(suffix=f".{input_format}")
103
+ _, output_file2 = mkstemp(suffix=f".{output_format}")
104
+
105
+ self.orientations.to_file(output_file)
106
+ orientations_new = Orientations.from_file(output_file)
107
+ orientations_new.to_file(output_file2)
108
+
109
+ assert True
110
+
111
+ @pytest.mark.parametrize("drop_oob", (True, False))
112
+ @pytest.mark.parametrize("shape", (10, 40, 80))
113
+ @pytest.mark.parametrize("odd", (True, False))
114
+ def test_extraction(self, shape: int, drop_oob: bool, odd: bool):
115
+ if odd:
116
+ shape = shape + (1 - shape % 2)
117
+
118
+ data = np.random.rand(50, 50, 50)
119
+ translations = np.array([[25, 25, 25], [15, 25, 35], [35, 25, 15], [0, 15, 49]])
120
+ orientations = Orientations(
121
+ translations=translations,
122
+ rotations=np.random.rand(*translations.shape),
123
+ scores=np.random.rand(translations.shape[0]),
124
+ details=np.random.rand(translations.shape[0]),
125
+ )
126
+ extraction_shape = np.repeat(np.array(shape), data.ndim)
127
+ orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
128
+ target_shape=data.shape,
129
+ extraction_shape=extraction_shape,
130
+ drop_out_of_box=drop_oob,
131
+ return_orientations=True,
132
+ )
133
+ assert orientations.translations.shape[0] == len(cand_slices)
134
+ assert len(cand_slices) == len(obs_slices)
135
+
136
+ cand_slices2, obs_slices2 = orientations.get_extraction_slices(
137
+ target_shape=data.shape,
138
+ extraction_shape=extraction_shape,
139
+ drop_out_of_box=drop_oob,
140
+ return_orientations=False,
141
+ )
142
+ assert cand_slices == cand_slices2
143
+ assert obs_slices == obs_slices2
144
+
145
+ # Check whether extraction slices are pasted in center
146
+ out = np.zeros(extraction_shape, dtype=data.dtype)
147
+ center = np.divide(extraction_shape, 2).astype(int)
148
+ for index, (cand_slice, obs_slice) in enumerate(zip(cand_slices, obs_slices)):
149
+ out[cand_slice] = data[obs_slice]
150
+ assert np.allclose(
151
+ out[tuple(center)],
152
+ data[tuple(orientations.translations[index].astype(int))],
153
+ )
154
+
155
+ @pytest.mark.parametrize(
156
+ "order", (("x", "y", "z"), ("z", "y", "x"), ("y", "x", "z"))
157
+ )
158
+ def test_txt_sort(self, order: str):
159
+ _, output_file = mkstemp(suffix=".tsv")
160
+ translations = ((50, 30, 20), (10, 5, 30))
161
+
162
+ with open(output_file, mode="w", encoding="utf-8") as ofile:
163
+ ofile.write("\t".join([str(x) for x in order]) + "\n")
164
+ for translation in translations:
165
+ ofile.write("\t".join([str(x) for x in translation]) + "\n")
166
+
167
+ translations = np.array(translations).astype(np.float32)
168
+ orientations = Orientations.from_file(output_file)
169
+
170
+ out_order = zip(order, range(len(order)))
171
+ out_order = tuple(
172
+ x[1] for x in sorted(out_order, key=lambda x: x[0], reverse=False)
173
+ )
174
+
175
+ assert np.array_equal(translations[..., out_order], orientations.translations)
tests/test_parser.py ADDED
@@ -0,0 +1,33 @@
1
+ import pytest
2
+ from importlib_resources import files
3
+
4
+ from tme.parser import Parser, PDBParser
5
+
6
+
7
+ class TestParser:
8
+ def setup_method(self):
9
+ self.pdb_file = str(files("tests.data").joinpath("Structures/5khe.pdb"))
10
+
11
+ def teardown_method(self):
12
+ self.pdb_file = None
13
+
14
+ def test_initialize_parser_error(self):
15
+ with pytest.raises(TypeError):
16
+ _ = Parser(self.pdb_file)
17
+
18
+ def test_parser_get(self):
19
+ parser = PDBParser(self.pdb_file)
20
+ assert parser.get("NOTPRESENT", None) is None
21
+ assert parser.get("record_type", None) is not None
22
+
23
+ def test_parser_keys(self):
24
+ parser = PDBParser(self.pdb_file)
25
+ assert parser.keys() == parser._data.keys()
26
+
27
+ def test_parser_values(self):
28
+ parser = PDBParser(self.pdb_file)
29
+ assert str(parser.values()) == str(parser._data.values())
30
+
31
+ def test_parser_items(self):
32
+ parser = PDBParser(self.pdb_file)
33
+ assert parser.items() == parser._data.items()