pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/METADATA +1 -1
  5. pytme-0.2.4.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +276 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +8 -7
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +11 -4
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +11 -9
  61. tme/matching_exhaustive.py +10 -8
  62. tme/matching_utils.py +1 -0
  63. tme/preprocessing/_utils.py +14 -14
  64. tme/preprocessing/composable_filter.py +0 -2
  65. tme/preprocessing/compose.py +4 -4
  66. tme/preprocessing/frequency_filters.py +32 -35
  67. tme/preprocessing/tilt_series.py +202 -118
  68. tme/preprocessor.py +24 -246
  69. tme/structure.py +14 -14
  70. pytme-0.2.3.dist-info/RECORD +0 -75
  71. tme/matching_memory.py +0 -383
  72. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  73. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/postprocess.py +0 -0
  74. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,283 @@
1
+ import tempfile
2
+ import subprocess
3
+ from shutil import rmtree
4
+ from os.path import exists
5
+ from os import remove, makedirs
6
+
7
+ import pytest
8
+ import numpy as np
9
+ from tme import Density
10
+ from tme.backends import backend as be
11
+
12
+ BACKEND_CLASSES = ["NumpyFFTWBackend", "PytorchBackend", "CupyBackend", "MLXBackend"]
13
+ BACKENDS_TO_TEST = []
14
+
15
+ test_gpu = (False,)
16
+ for backend_class in BACKEND_CLASSES:
17
+ try:
18
+ BackendClass = getattr(
19
+ __import__("tme.backends", fromlist=[backend_class]), backend_class
20
+ )
21
+ BACKENDS_TO_TEST.append(BackendClass(device="cpu"))
22
+ if backend_class == "CupyBackend":
23
+ if BACKENDS_TO_TEST[-1].device_count() >= 1:
24
+ test_gpu = (False, True)
25
+ except ImportError:
26
+ print(f"Couldn't import {backend_class}. Skipping...")
27
+
28
+
29
+ available_backends = (x for x in be.available_backends() if x != "mlx")
30
+
31
+
32
+ def argdict_to_command(input_args, executable: str):
33
+ ret = []
34
+ for key, value in input_args.items():
35
+ if value is None:
36
+ continue
37
+ elif isinstance(value, bool):
38
+ if value:
39
+ ret.append(key)
40
+ else:
41
+ ret.extend([key, value])
42
+
43
+ ret = [str(x) for x in ret]
44
+ ret.insert(0, executable)
45
+ return " ".join(ret)
46
+
47
+
48
+ class TestMatchTemplate:
49
+ @classmethod
50
+ def setup_class(cls):
51
+ np.random.seed(42)
52
+ target = np.random.rand(20, 20, 20)
53
+ template = np.random.rand(5, 5, 5)
54
+
55
+ target_mask = 1.0 * (target > 0.5)
56
+ template_mask = 1.0 * (template > 0.5)
57
+
58
+ cls.target_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mrc").name
59
+ cls.template_path = tempfile.NamedTemporaryFile(
60
+ delete=False, suffix=".mrc"
61
+ ).name
62
+ cls.target_mask_path = tempfile.NamedTemporaryFile(
63
+ delete=False, suffix=".mrc"
64
+ ).name
65
+ cls.template_mask_path = tempfile.NamedTemporaryFile(
66
+ delete=False, suffix=".mrc"
67
+ ).name
68
+
69
+ Density(target, sampling_rate=5).to_file(cls.target_path)
70
+ Density(template, sampling_rate=5).to_file(cls.template_path)
71
+ Density(target_mask, sampling_rate=5).to_file(cls.target_mask_path)
72
+ Density(template_mask, sampling_rate=5).to_file(cls.template_mask_path)
73
+
74
+ def teardown_class(cls):
75
+ cls.try_delete(cls.target_path)
76
+ cls.try_delete(cls.template_path)
77
+ cls.try_delete(cls.target_mask_path)
78
+ cls.try_delete(cls.template_mask_path)
79
+
80
+ @staticmethod
81
+ def try_delete(file_path: str):
82
+ try:
83
+ remove(file_path)
84
+ except Exception:
85
+ pass
86
+ try:
87
+ rmtree(file_path, ignore_errors=True)
88
+ except Exception:
89
+ pass
90
+
91
+ @staticmethod
92
+ def run_matching(
93
+ use_template_mask: bool,
94
+ test_filter: bool,
95
+ call_peaks: bool,
96
+ target_path: str,
97
+ template_path: str,
98
+ template_mask_path: str,
99
+ target_mask_path: str,
100
+ use_target_mask: bool = False,
101
+ backend: str = "numpyfftw",
102
+ ):
103
+ output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
104
+
105
+ argdict = {
106
+ "-m": target_path,
107
+ "-i": template_path,
108
+ "-n": 1,
109
+ "-a": 60,
110
+ "-o": output_path,
111
+ "--pad_edges": False,
112
+ "--pad_fourier": False,
113
+ "--backend": backend,
114
+ }
115
+
116
+ if use_template_mask:
117
+ argdict["--template_mask"] = template_mask_path
118
+
119
+ if use_target_mask:
120
+ argdict["--target_mask"] = target_mask_path
121
+
122
+ if backend in ("cupy", "pytorch") and True in test_gpu:
123
+ argdict["--use_gpu"] = True
124
+
125
+ if test_filter:
126
+ argdict["--lowpass"] = 30
127
+ argdict["--defocus"] = 3000
128
+ argdict["--tilt_angles"] = "40,40:10"
129
+ argdict["--wedge_axes"] = "0,2"
130
+ argdict["--whiten"] = True
131
+
132
+ if call_peaks:
133
+ argdict["-p"] = True
134
+
135
+ cmd = argdict_to_command(argdict, executable="match_template.py")
136
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
137
+ print(ret)
138
+ assert ret.returncode == 0
139
+ return output_path
140
+
141
+ @pytest.mark.parametrize("backend", available_backends)
142
+ @pytest.mark.parametrize("call_peaks", (False, True))
143
+ @pytest.mark.parametrize("use_template_mask", (False, True))
144
+ @pytest.mark.parametrize("test_filter", (False, True))
145
+ def test_match_template(
146
+ self,
147
+ backend: bool,
148
+ call_peaks: bool,
149
+ use_template_mask: bool,
150
+ test_filter: bool,
151
+ ):
152
+ self.run_matching(
153
+ use_template_mask=use_template_mask,
154
+ use_target_mask=True,
155
+ backend=backend,
156
+ test_filter=test_filter,
157
+ call_peaks=call_peaks,
158
+ template_path=self.template_path,
159
+ target_path=self.target_path,
160
+ template_mask_path=self.template_mask_path,
161
+ target_mask_path=self.target_mask_path,
162
+ )
163
+
164
+
165
+ class TestPostprocessing(TestMatchTemplate):
166
+ @classmethod
167
+ def setup_class(cls):
168
+ super().setup_class()
169
+
170
+ matching_kwargs = {
171
+ "use_template_mask": False,
172
+ "use_target_mask": False,
173
+ "test_filter": False,
174
+ "template_path": cls.template_path,
175
+ "target_path": cls.target_path,
176
+ "template_mask_path": cls.template_mask_path,
177
+ "target_mask_path": cls.target_mask_path,
178
+ }
179
+
180
+ cls.score_pickle = cls.run_matching(
181
+ call_peaks=False,
182
+ **matching_kwargs,
183
+ )
184
+ cls.peak_pickle = cls.run_matching(call_peaks=True, **matching_kwargs)
185
+ cls.tempdir = tempfile.TemporaryDirectory().name
186
+
187
+ @classmethod
188
+ def teardown_class(cls):
189
+ cls.try_delete(cls.score_pickle)
190
+ cls.try_delete(cls.peak_pickle)
191
+ cls.try_delete(cls.tempdir)
192
+
193
+ @pytest.mark.parametrize("distance_cutoff_strategy", (0, 1, 2, 3))
194
+ @pytest.mark.parametrize("score_cutoff", (None, (1,), (0, 1), (None, 1), (0, None)))
195
+ @pytest.mark.parametrize("peak_oversampling", (False, 4))
196
+ def test_postprocess_score_orientations(
197
+ self,
198
+ peak_oversampling,
199
+ score_cutoff,
200
+ distance_cutoff_strategy,
201
+ ):
202
+ self.try_delete(self.tempdir)
203
+ makedirs(self.tempdir, exist_ok=True)
204
+
205
+ argdict = {
206
+ "--input_file": self.score_pickle,
207
+ "--output_format": "orientations",
208
+ "--output_prefix": f"{self.tempdir}/temp",
209
+ "--peak_oversampling": peak_oversampling,
210
+ "--number_of_peaks": 3,
211
+ }
212
+
213
+ if score_cutoff is not None:
214
+ if len(score_cutoff) == 1:
215
+ argdict["--n_false_positives"] = 1
216
+ else:
217
+ min_score, max_score = score_cutoff
218
+ argdict["--minimum_score"] = min_score
219
+ argdict["--maximum_score"] = max_score
220
+
221
+ match distance_cutoff_strategy:
222
+ case 1:
223
+ argdict["--mask_edges"] = True
224
+ case 2:
225
+ argdict["--min_distance"] = 5
226
+ case 3:
227
+ argdict["--min_boundary_distance"] = 5
228
+
229
+ cmd = argdict_to_command(argdict, executable="postprocess.py")
230
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
231
+ assert ret.returncode == 0
232
+
233
+ @pytest.mark.parametrize("input_format", ("score", "peaks"))
234
+ @pytest.mark.parametrize(
235
+ "output_format",
236
+ ("orientations", "alignment", "backmapping", "average", "relion"),
237
+ )
238
+ def test_postproces_score_formats(self, input_format, output_format):
239
+ self.try_delete(self.tempdir)
240
+ makedirs(self.tempdir, exist_ok=True)
241
+
242
+ input_file = self.score_pickle
243
+ if input_format == "peaks":
244
+ input_file = self.peak_pickle
245
+
246
+ argdict = {
247
+ "--input_file": input_file,
248
+ "--output_format": output_format,
249
+ "--output_prefix": f"{self.tempdir}/temp",
250
+ "--number_of_peaks": 3,
251
+ "--peak_caller": "PeakCallerMaximumFilter",
252
+ }
253
+ cmd = argdict_to_command(argdict, executable="postprocess.py")
254
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
255
+
256
+ match output_format:
257
+ case "orientations":
258
+ assert exists(f"{self.tempdir}/temp.tsv")
259
+ case "alignment":
260
+ assert exists(f"{self.tempdir}/temp_0.mrc")
261
+ case "backmapping":
262
+ assert exists(f"{self.tempdir}/temp_backmapped.mrc")
263
+ case "average":
264
+ assert exists(f"{self.tempdir}/temp_average.mrc")
265
+ case "relion":
266
+ assert exists(f"{self.tempdir}/temp.star")
267
+
268
+ assert ret.returncode == 0
269
+
270
+ def test_postprocess_score_local_optimization(self):
271
+ self.try_delete(self.tempdir)
272
+ makedirs(self.tempdir, exist_ok=True)
273
+
274
+ argdict = {
275
+ "--input_file": self.score_pickle,
276
+ "--output_format": "orientations",
277
+ "--output_prefix": f"{self.tempdir}/temp",
278
+ "--number_of_peaks": 1,
279
+ "--local_optimization": True,
280
+ }
281
+ cmd = argdict_to_command(argdict, executable="postprocess.py")
282
+ ret = subprocess.run(cmd, capture_output=True, shell=True)
283
+ assert ret.returncode == 0
@@ -0,0 +1,162 @@
1
+ from tempfile import mkstemp
2
+
3
+ import pytest
4
+ import numpy as np
5
+
6
+ from tme.backends import backend as be
7
+ from tme.matching_data import MatchingData
8
+
9
+
10
+ class TestDensity:
11
+ def setup_method(self):
12
+ target = np.zeros((50, 50, 50))
13
+ target[20:30, 30:40, 12:17] = 1
14
+
15
+ self.target = target
16
+ template = np.zeros((50, 50, 50))
17
+ template[15:25, 20:30, 2:7] = 1
18
+ self.template = template
19
+ self.rotations = np.random.rand(100, target.ndim, target.ndim).astype(
20
+ np.float32
21
+ )
22
+
23
+ def teardown_method(self):
24
+ self.target = None
25
+ self.template = None
26
+ self.coordinates = None
27
+ self.coordinates_weights = None
28
+ self.rotations = None
29
+
30
+ def test_initialization(self):
31
+ _ = MatchingData(target=self.target, template=self.template)
32
+
33
+ @pytest.mark.parametrize("shape", [(10,), (10, 15), (10, 20, 30)])
34
+ def test__shape_to_slice(self, shape):
35
+ slices = MatchingData._shape_to_slice(shape=shape)
36
+ assert len(slices) == len(shape)
37
+ for i, k in enumerate(shape):
38
+ assert slices[i].start == 0
39
+ assert slices[i].stop == k
40
+
41
+ @pytest.mark.parametrize("shape", [(10,), (10, 15), (10, 20, 30)])
42
+ def test_slice_to_mesh(self, shape):
43
+ if shape is not None:
44
+ slices = MatchingData._shape_to_slice(shape=shape)
45
+
46
+ indices = MatchingData._slice_to_mesh(slice_variable=slices, shape=shape)
47
+ assert len(indices) == len(shape)
48
+ for i, k in enumerate(shape):
49
+ assert indices[i].min() == 0
50
+ assert indices[i].max() == k - 1
51
+
52
+ indices = MatchingData._slice_to_mesh(slice_variable=None, shape=shape)
53
+ assert len(indices) == len(shape)
54
+ for i, k in enumerate(shape):
55
+ assert indices[i].min() == 0
56
+ assert indices[i].max() == k - 1
57
+
58
+ def test__load_array(self):
59
+ arr = MatchingData._load_array(self.target)
60
+ assert np.allclose(arr, self.target)
61
+
62
+ def test__load_array_memmap(self):
63
+ _, filename = mkstemp()
64
+ shape, dtype = self.target.shape, self.target.dtype
65
+ arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
66
+ arr_memmap[:] = self.target[:]
67
+ arr_memmap.flush()
68
+
69
+ arr = MatchingData._load_array(arr_memmap)
70
+ assert np.allclose(arr, self.target)
71
+
72
+ def test_subset_array(self):
73
+ matching_data = MatchingData(target=self.target, template=self.template)
74
+ slices = MatchingData._shape_to_slice(
75
+ shape=np.divide(self.target.shape, 2).astype(int)
76
+ )
77
+ ret = matching_data.subset_array(
78
+ arr=self.target, arr_slice=slices, padding=(2, 2, 2)
79
+ )
80
+ assert np.allclose(
81
+ ret.shape, np.add(np.divide(self.target.shape, 2).astype(int), 2)
82
+ )
83
+
84
+ def test_subset_by_slice_none(self):
85
+ matching_data = MatchingData(target=self.target, template=self.template)
86
+ matching_data.rotations = self.rotations
87
+ matching_data.target_mask = self.target
88
+ matching_data.template_mask = self.template
89
+
90
+ ret = matching_data.subset_by_slice()
91
+
92
+ assert type(ret) == type(matching_data)
93
+ assert np.allclose(ret.target, matching_data.target)
94
+ assert np.allclose(ret.template, matching_data.template)
95
+ assert np.allclose(ret.target_mask, matching_data.target_mask)
96
+ assert np.allclose(ret.template_mask, matching_data.template_mask)
97
+
98
+ def test_subset_by_slice(self):
99
+ matching_data = MatchingData(target=self.target, template=self.template)
100
+ matching_data.rotations = self.rotations
101
+ matching_data.target_mask = self.target
102
+ matching_data.template_mask = self.template
103
+
104
+ target_slice = MatchingData._shape_to_slice(
105
+ shape=np.divide(self.target.shape, 2).astype(int)
106
+ )
107
+ template_slice = MatchingData._shape_to_slice(
108
+ shape=np.divide(self.template.shape, 2).astype(int)
109
+ )
110
+ ret = matching_data.subset_by_slice(
111
+ target_slice=target_slice, template_slice=template_slice
112
+ )
113
+ assert type(ret) == type(matching_data)
114
+
115
+ assert np.allclose(
116
+ ret.target.shape, np.divide(self.target.shape, 2).astype(int)
117
+ )
118
+ assert np.allclose(
119
+ ret.template.shape, np.divide(self.target.shape, 2).astype(int)[::-1]
120
+ )
121
+
122
+ def test_rotations(self):
123
+ matching_data = MatchingData(target=self.target, template=self.template)
124
+ matching_data.rotations = self.rotations
125
+ matching_data.target_mask = self.target
126
+
127
+ assert np.allclose(matching_data.rotations, self.rotations)
128
+
129
+ matching_data.rotations = np.random.rand(self.target.ndim, self.target.ndim)
130
+ assert np.allclose(
131
+ matching_data.rotations.shape, (1, self.target.ndim, self.target.ndim)
132
+ )
133
+
134
+ def test_target(self):
135
+ matching_data = MatchingData(target=self.target, template=self.template)
136
+
137
+ assert np.allclose(matching_data.target, self.target)
138
+
139
+ def test_template(self):
140
+ matching_data = MatchingData(target=self.target, template=self.template)
141
+
142
+ assert np.allclose(matching_data.template, be.reverse(self.template))
143
+
144
+ def test_target_mask(self):
145
+ matching_data = MatchingData(target=self.target, template=self.template)
146
+ matching_data.target_mask = self.target
147
+
148
+ assert np.allclose(matching_data.target_mask, self.target)
149
+
150
+ def test_template_mask(self):
151
+ matching_data = MatchingData(target=self.target, template=self.template)
152
+ matching_data.template_mask = self.template
153
+
154
+ assert np.allclose(matching_data.template_mask, be.reverse(self.template))
155
+
156
+ @pytest.mark.parametrize("jobs", range(1, 50, 5))
157
+ def test__split_rotations_on_jobs(self, jobs):
158
+ matching_data = MatchingData(target=self.target, template=self.template)
159
+ matching_data.rotations = self.rotations
160
+
161
+ ret = matching_data._split_rotations_on_jobs(n_jobs=jobs)
162
+ assert len(ret) == jobs
@@ -0,0 +1,162 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from scipy.ndimage import laplace
5
+
6
+ from tme.matching_data import MatchingData
7
+ from tme.memory import MATCHING_MEMORY_REGISTRY
8
+ from tme.matching_utils import get_rotation_matrices
9
+ from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
10
+ from tme.matching_exhaustive import (
11
+ scan,
12
+ scan_subsets,
13
+ MATCHING_EXHAUSTIVE_REGISTER,
14
+ register_matching_exhaustive,
15
+ )
16
+
17
+
18
+ class TestMatchExhaustive:
19
+ def setup_method(self):
20
+ # To be valid for splitting, the template needs to be fully inside the object
21
+ target = np.zeros((80, 80, 80))
22
+ target[25:31, 22:28, 12:16] = 1
23
+
24
+ self.target = target
25
+ self.template = np.zeros((41, 41, 35))
26
+ self.template[20:26, 25:31, 17:21] = 1
27
+ self.template_mask = np.ones_like(self.template)
28
+ self.target_mask = np.ones_like(target)
29
+
30
+ self.rotations = get_rotation_matrices(60)[0,]
31
+ self.peak_position = np.array([25, 17, 12])
32
+
33
+ def teardown_method(self):
34
+ self.target = None
35
+ self.template = None
36
+ self.coordinates = None
37
+ self.coordinates_weights = None
38
+ self.rotations = None
39
+
40
+ @pytest.mark.parametrize("score", list(MATCHING_EXHAUSTIVE_REGISTER.keys()))
41
+ @pytest.mark.parametrize("n_jobs", (1, 2))
42
+ def test_scan(self, score: str, n_jobs: int):
43
+ matching_data = MatchingData(
44
+ target=self.target,
45
+ template=self.template,
46
+ target_mask=self.target_mask,
47
+ template_mask=self.template_mask,
48
+ rotations=self.rotations,
49
+ )
50
+ setup, process = MATCHING_EXHAUSTIVE_REGISTER[score]
51
+ ret = scan(
52
+ matching_data=matching_data,
53
+ matching_setup=setup,
54
+ matching_score=process,
55
+ n_jobs=n_jobs,
56
+ callback_class=MaxScoreOverRotations,
57
+ )
58
+ scores = ret[0]
59
+ peak = np.unravel_index(np.argmax(scores), scores.shape)
60
+
61
+ theoretical_score = 1
62
+ if score == "CC":
63
+ theoretical_score = self.template.sum()
64
+ elif score == "LCC":
65
+ theoretical_score = (laplace(self.template) * laplace(self.template)).sum()
66
+
67
+ assert np.allclose(peak, self.peak_position)
68
+ assert np.allclose(scores[peak], theoretical_score, rtol=0.05)
69
+
70
+ @pytest.mark.parametrize("evaluate_peak", (False, True))
71
+ @pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
72
+ @pytest.mark.parametrize("job_schedule", ((2, 1), (1, 1)))
73
+ @pytest.mark.parametrize("pad_fourier", (True, False))
74
+ @pytest.mark.parametrize("pad_edge", (True, False))
75
+ def test_scan_subset(
76
+ self,
77
+ score: str,
78
+ job_schedule: int,
79
+ evaluate_peak: bool,
80
+ pad_fourier: bool,
81
+ pad_edge: bool,
82
+ ):
83
+ matching_data = MatchingData(
84
+ target=self.target,
85
+ template=self.template,
86
+ target_mask=self.target_mask,
87
+ template_mask=self.template_mask,
88
+ rotations=self.rotations,
89
+ )
90
+
91
+ setup, process = MATCHING_EXHAUSTIVE_REGISTER[score]
92
+
93
+ target_splits = {}
94
+ if job_schedule[0] == 2:
95
+ target_splits = {0: 2 if i != 0 else 2 for i in range(self.target.ndim)}
96
+
97
+ callback_class = PeakCallerSort
98
+ if evaluate_peak:
99
+ callback_class = MaxScoreOverRotations
100
+
101
+ ret = scan_subsets(
102
+ matching_data=matching_data,
103
+ matching_setup=setup,
104
+ matching_score=process,
105
+ target_splits=target_splits,
106
+ job_schedule=job_schedule,
107
+ callback_class=callback_class,
108
+ pad_target_edges=pad_edge,
109
+ pad_fourier=pad_fourier,
110
+ )
111
+
112
+ if evaluate_peak:
113
+ scores = ret[0]
114
+ peak = np.unravel_index(np.argmax(scores), scores.shape)
115
+ achieved_score = scores[tuple(peak)]
116
+ else:
117
+ try:
118
+ peak, achieved_score = ret[0][0], ret[2][0]
119
+ except Exception as e:
120
+ return None
121
+
122
+ if not pad_edge:
123
+ # To be valid, the match needs to be fully within the target subset
124
+ return None
125
+
126
+ theoretical_score = 1
127
+ if score == "CC":
128
+ theoretical_score = self.template.sum()
129
+ elif score == "LCC":
130
+ theoretical_score = (laplace(self.template) * laplace(self.template)).sum()
131
+
132
+ if not np.allclose(peak, self.peak_position):
133
+ print(peak)
134
+ assert False
135
+
136
+ assert np.allclose(achieved_score, theoretical_score, rtol=0.3)
137
+
138
+ def test_register_matching_exhaustive(self):
139
+ setup, matching = MATCHING_EXHAUSTIVE_REGISTER[
140
+ list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
141
+ ]
142
+ memory_class = MATCHING_MEMORY_REGISTRY[
143
+ list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
144
+ ]
145
+ register_matching_exhaustive(
146
+ matching="TEST",
147
+ matching_setup=setup,
148
+ matching_scoring=matching,
149
+ memory_class=memory_class,
150
+ )
151
+
152
+ def test_register_matching_exhaustive_error(self):
153
+ key = list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
154
+ setup, matching = MATCHING_EXHAUSTIVE_REGISTER[key]
155
+ memory_class = MATCHING_MEMORY_REGISTRY[key]
156
+ with pytest.raises(ValueError):
157
+ register_matching_exhaustive(
158
+ matching=key,
159
+ matching_setup=setup,
160
+ matching_scoring=matching,
161
+ memory_class=memory_class,
162
+ )
@@ -0,0 +1,30 @@
1
+ import pytest
2
+ from importlib_resources import files
3
+
4
+ from tme import Density
5
+ from tme.memory import MATCHING_MEMORY_REGISTRY, estimate_ram_usage
6
+
7
+ BASEPATH = files("tests.data")
8
+
9
+
10
+ class TestMatchingMemory:
11
+ def setup_method(self):
12
+ self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
13
+ self.structure_density = Density.from_structure(
14
+ filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
15
+ origin=self.density.origin,
16
+ shape=self.density.shape,
17
+ sampling_rate=self.density.sampling_rate,
18
+ )
19
+
20
+ @pytest.mark.parametrize("analyzer_method", ["MaxScoreOverRotations", None])
21
+ @pytest.mark.parametrize("matching_method", list(MATCHING_MEMORY_REGISTRY.keys()))
22
+ @pytest.mark.parametrize("ncores", range(1, 10, 3))
23
+ def test_estimate_ram_usage(self, matching_method, ncores, analyzer_method):
24
+ estimate_ram_usage(
25
+ shape1=self.density.shape,
26
+ shape2=self.structure_density.shape,
27
+ matching_method=matching_method,
28
+ ncores=ncores,
29
+ analyzer_method=analyzer_method,
30
+ )