pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
- pytme-0.2.4.data/scripts/preprocess.py +148 -0
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +97 -148
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +116 -61
- scripts/preprocessor_gui.py +15 -23
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +35 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +26 -12
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +33 -24
- tme/matching_exhaustive.py +39 -20
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +8 -2
- tme/orientations.py +26 -9
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +5 -4
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +210 -148
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.2.dist-info/RECORD +0 -74
- tme/matching_memory.py +0 -383
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import tempfile
|
2
|
+
import subprocess
|
3
|
+
from shutil import rmtree
|
4
|
+
from os.path import exists
|
5
|
+
from os import remove, makedirs
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
import numpy as np
|
9
|
+
from tme import Density
|
10
|
+
from tme.backends import backend as be
|
11
|
+
|
12
|
+
BACKEND_CLASSES = ["NumpyFFTWBackend", "PytorchBackend", "CupyBackend", "MLXBackend"]
|
13
|
+
BACKENDS_TO_TEST = []
|
14
|
+
|
15
|
+
test_gpu = (False,)
|
16
|
+
for backend_class in BACKEND_CLASSES:
|
17
|
+
try:
|
18
|
+
BackendClass = getattr(
|
19
|
+
__import__("tme.backends", fromlist=[backend_class]), backend_class
|
20
|
+
)
|
21
|
+
BACKENDS_TO_TEST.append(BackendClass(device="cpu"))
|
22
|
+
if backend_class == "CupyBackend":
|
23
|
+
if BACKENDS_TO_TEST[-1].device_count() >= 1:
|
24
|
+
test_gpu = (False, True)
|
25
|
+
except ImportError:
|
26
|
+
print(f"Couldn't import {backend_class}. Skipping...")
|
27
|
+
|
28
|
+
|
29
|
+
available_backends = (x for x in be.available_backends() if x != "mlx")
|
30
|
+
|
31
|
+
|
32
|
+
def argdict_to_command(input_args, executable: str):
|
33
|
+
ret = []
|
34
|
+
for key, value in input_args.items():
|
35
|
+
if value is None:
|
36
|
+
continue
|
37
|
+
elif isinstance(value, bool):
|
38
|
+
if value:
|
39
|
+
ret.append(key)
|
40
|
+
else:
|
41
|
+
ret.extend([key, value])
|
42
|
+
|
43
|
+
ret = [str(x) for x in ret]
|
44
|
+
ret.insert(0, executable)
|
45
|
+
return " ".join(ret)
|
46
|
+
|
47
|
+
|
48
|
+
class TestMatchTemplate:
|
49
|
+
@classmethod
|
50
|
+
def setup_class(cls):
|
51
|
+
np.random.seed(42)
|
52
|
+
target = np.random.rand(20, 20, 20)
|
53
|
+
template = np.random.rand(5, 5, 5)
|
54
|
+
|
55
|
+
target_mask = 1.0 * (target > 0.5)
|
56
|
+
template_mask = 1.0 * (template > 0.5)
|
57
|
+
|
58
|
+
cls.target_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mrc").name
|
59
|
+
cls.template_path = tempfile.NamedTemporaryFile(
|
60
|
+
delete=False, suffix=".mrc"
|
61
|
+
).name
|
62
|
+
cls.target_mask_path = tempfile.NamedTemporaryFile(
|
63
|
+
delete=False, suffix=".mrc"
|
64
|
+
).name
|
65
|
+
cls.template_mask_path = tempfile.NamedTemporaryFile(
|
66
|
+
delete=False, suffix=".mrc"
|
67
|
+
).name
|
68
|
+
|
69
|
+
Density(target, sampling_rate=5).to_file(cls.target_path)
|
70
|
+
Density(template, sampling_rate=5).to_file(cls.template_path)
|
71
|
+
Density(target_mask, sampling_rate=5).to_file(cls.target_mask_path)
|
72
|
+
Density(template_mask, sampling_rate=5).to_file(cls.template_mask_path)
|
73
|
+
|
74
|
+
def teardown_class(cls):
|
75
|
+
cls.try_delete(cls.target_path)
|
76
|
+
cls.try_delete(cls.template_path)
|
77
|
+
cls.try_delete(cls.target_mask_path)
|
78
|
+
cls.try_delete(cls.template_mask_path)
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def try_delete(file_path: str):
|
82
|
+
try:
|
83
|
+
remove(file_path)
|
84
|
+
except Exception:
|
85
|
+
pass
|
86
|
+
try:
|
87
|
+
rmtree(file_path, ignore_errors=True)
|
88
|
+
except Exception:
|
89
|
+
pass
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def run_matching(
|
93
|
+
use_template_mask: bool,
|
94
|
+
test_filter: bool,
|
95
|
+
call_peaks: bool,
|
96
|
+
target_path: str,
|
97
|
+
template_path: str,
|
98
|
+
template_mask_path: str,
|
99
|
+
target_mask_path: str,
|
100
|
+
use_target_mask: bool = False,
|
101
|
+
backend: str = "numpyfftw",
|
102
|
+
):
|
103
|
+
output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
|
104
|
+
|
105
|
+
argdict = {
|
106
|
+
"-m": target_path,
|
107
|
+
"-i": template_path,
|
108
|
+
"-n": 1,
|
109
|
+
"-a": 60,
|
110
|
+
"-o": output_path,
|
111
|
+
"--pad_edges": False,
|
112
|
+
"--pad_fourier": False,
|
113
|
+
"--backend": backend,
|
114
|
+
}
|
115
|
+
|
116
|
+
if use_template_mask:
|
117
|
+
argdict["--template_mask"] = template_mask_path
|
118
|
+
|
119
|
+
if use_target_mask:
|
120
|
+
argdict["--target_mask"] = target_mask_path
|
121
|
+
|
122
|
+
if backend in ("cupy", "pytorch") and True in test_gpu:
|
123
|
+
argdict["--use_gpu"] = True
|
124
|
+
|
125
|
+
if test_filter:
|
126
|
+
argdict["--lowpass"] = 30
|
127
|
+
argdict["--defocus"] = 3000
|
128
|
+
argdict["--tilt_angles"] = "40,40:10"
|
129
|
+
argdict["--wedge_axes"] = "0,2"
|
130
|
+
argdict["--whiten"] = True
|
131
|
+
|
132
|
+
if call_peaks:
|
133
|
+
argdict["-p"] = True
|
134
|
+
|
135
|
+
cmd = argdict_to_command(argdict, executable="match_template.py")
|
136
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
137
|
+
print(ret)
|
138
|
+
assert ret.returncode == 0
|
139
|
+
return output_path
|
140
|
+
|
141
|
+
@pytest.mark.parametrize("backend", available_backends)
|
142
|
+
@pytest.mark.parametrize("call_peaks", (False, True))
|
143
|
+
@pytest.mark.parametrize("use_template_mask", (False, True))
|
144
|
+
@pytest.mark.parametrize("test_filter", (False, True))
|
145
|
+
def test_match_template(
|
146
|
+
self,
|
147
|
+
backend: bool,
|
148
|
+
call_peaks: bool,
|
149
|
+
use_template_mask: bool,
|
150
|
+
test_filter: bool,
|
151
|
+
):
|
152
|
+
self.run_matching(
|
153
|
+
use_template_mask=use_template_mask,
|
154
|
+
use_target_mask=True,
|
155
|
+
backend=backend,
|
156
|
+
test_filter=test_filter,
|
157
|
+
call_peaks=call_peaks,
|
158
|
+
template_path=self.template_path,
|
159
|
+
target_path=self.target_path,
|
160
|
+
template_mask_path=self.template_mask_path,
|
161
|
+
target_mask_path=self.target_mask_path,
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
class TestPostprocessing(TestMatchTemplate):
|
166
|
+
@classmethod
|
167
|
+
def setup_class(cls):
|
168
|
+
super().setup_class()
|
169
|
+
|
170
|
+
matching_kwargs = {
|
171
|
+
"use_template_mask": False,
|
172
|
+
"use_target_mask": False,
|
173
|
+
"test_filter": False,
|
174
|
+
"template_path": cls.template_path,
|
175
|
+
"target_path": cls.target_path,
|
176
|
+
"template_mask_path": cls.template_mask_path,
|
177
|
+
"target_mask_path": cls.target_mask_path,
|
178
|
+
}
|
179
|
+
|
180
|
+
cls.score_pickle = cls.run_matching(
|
181
|
+
call_peaks=False,
|
182
|
+
**matching_kwargs,
|
183
|
+
)
|
184
|
+
cls.peak_pickle = cls.run_matching(call_peaks=True, **matching_kwargs)
|
185
|
+
cls.tempdir = tempfile.TemporaryDirectory().name
|
186
|
+
|
187
|
+
@classmethod
|
188
|
+
def teardown_class(cls):
|
189
|
+
cls.try_delete(cls.score_pickle)
|
190
|
+
cls.try_delete(cls.peak_pickle)
|
191
|
+
cls.try_delete(cls.tempdir)
|
192
|
+
|
193
|
+
@pytest.mark.parametrize("distance_cutoff_strategy", (0, 1, 2, 3))
|
194
|
+
@pytest.mark.parametrize("score_cutoff", (None, (1,), (0, 1), (None, 1), (0, None)))
|
195
|
+
@pytest.mark.parametrize("peak_oversampling", (False, 4))
|
196
|
+
def test_postprocess_score_orientations(
|
197
|
+
self,
|
198
|
+
peak_oversampling,
|
199
|
+
score_cutoff,
|
200
|
+
distance_cutoff_strategy,
|
201
|
+
):
|
202
|
+
self.try_delete(self.tempdir)
|
203
|
+
makedirs(self.tempdir, exist_ok=True)
|
204
|
+
|
205
|
+
argdict = {
|
206
|
+
"--input_file": self.score_pickle,
|
207
|
+
"--output_format": "orientations",
|
208
|
+
"--output_prefix": f"{self.tempdir}/temp",
|
209
|
+
"--peak_oversampling": peak_oversampling,
|
210
|
+
"--number_of_peaks": 3,
|
211
|
+
}
|
212
|
+
|
213
|
+
if score_cutoff is not None:
|
214
|
+
if len(score_cutoff) == 1:
|
215
|
+
argdict["--n_false_positives"] = 1
|
216
|
+
else:
|
217
|
+
min_score, max_score = score_cutoff
|
218
|
+
argdict["--minimum_score"] = min_score
|
219
|
+
argdict["--maximum_score"] = max_score
|
220
|
+
|
221
|
+
match distance_cutoff_strategy:
|
222
|
+
case 1:
|
223
|
+
argdict["--mask_edges"] = True
|
224
|
+
case 2:
|
225
|
+
argdict["--min_distance"] = 5
|
226
|
+
case 3:
|
227
|
+
argdict["--min_boundary_distance"] = 5
|
228
|
+
|
229
|
+
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
230
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
231
|
+
assert ret.returncode == 0
|
232
|
+
|
233
|
+
@pytest.mark.parametrize("input_format", ("score", "peaks"))
|
234
|
+
@pytest.mark.parametrize(
|
235
|
+
"output_format",
|
236
|
+
("orientations", "alignment", "backmapping", "average", "relion"),
|
237
|
+
)
|
238
|
+
def test_postproces_score_formats(self, input_format, output_format):
|
239
|
+
self.try_delete(self.tempdir)
|
240
|
+
makedirs(self.tempdir, exist_ok=True)
|
241
|
+
|
242
|
+
input_file = self.score_pickle
|
243
|
+
if input_format == "peaks":
|
244
|
+
input_file = self.peak_pickle
|
245
|
+
|
246
|
+
argdict = {
|
247
|
+
"--input_file": input_file,
|
248
|
+
"--output_format": output_format,
|
249
|
+
"--output_prefix": f"{self.tempdir}/temp",
|
250
|
+
"--number_of_peaks": 3,
|
251
|
+
"--peak_caller": "PeakCallerMaximumFilter",
|
252
|
+
}
|
253
|
+
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
254
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
255
|
+
|
256
|
+
match output_format:
|
257
|
+
case "orientations":
|
258
|
+
assert exists(f"{self.tempdir}/temp.tsv")
|
259
|
+
case "alignment":
|
260
|
+
assert exists(f"{self.tempdir}/temp_0.mrc")
|
261
|
+
case "backmapping":
|
262
|
+
assert exists(f"{self.tempdir}/temp_backmapped.mrc")
|
263
|
+
case "average":
|
264
|
+
assert exists(f"{self.tempdir}/temp_average.mrc")
|
265
|
+
case "relion":
|
266
|
+
assert exists(f"{self.tempdir}/temp.star")
|
267
|
+
|
268
|
+
assert ret.returncode == 0
|
269
|
+
|
270
|
+
def test_postprocess_score_local_optimization(self):
|
271
|
+
self.try_delete(self.tempdir)
|
272
|
+
makedirs(self.tempdir, exist_ok=True)
|
273
|
+
|
274
|
+
argdict = {
|
275
|
+
"--input_file": self.score_pickle,
|
276
|
+
"--output_format": "orientations",
|
277
|
+
"--output_prefix": f"{self.tempdir}/temp",
|
278
|
+
"--number_of_peaks": 1,
|
279
|
+
"--local_optimization": True,
|
280
|
+
}
|
281
|
+
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
282
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
283
|
+
assert ret.returncode == 0
|
@@ -0,0 +1,162 @@
|
|
1
|
+
from tempfile import mkstemp
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from tme.backends import backend as be
|
7
|
+
from tme.matching_data import MatchingData
|
8
|
+
|
9
|
+
|
10
|
+
class TestDensity:
|
11
|
+
def setup_method(self):
|
12
|
+
target = np.zeros((50, 50, 50))
|
13
|
+
target[20:30, 30:40, 12:17] = 1
|
14
|
+
|
15
|
+
self.target = target
|
16
|
+
template = np.zeros((50, 50, 50))
|
17
|
+
template[15:25, 20:30, 2:7] = 1
|
18
|
+
self.template = template
|
19
|
+
self.rotations = np.random.rand(100, target.ndim, target.ndim).astype(
|
20
|
+
np.float32
|
21
|
+
)
|
22
|
+
|
23
|
+
def teardown_method(self):
|
24
|
+
self.target = None
|
25
|
+
self.template = None
|
26
|
+
self.coordinates = None
|
27
|
+
self.coordinates_weights = None
|
28
|
+
self.rotations = None
|
29
|
+
|
30
|
+
def test_initialization(self):
|
31
|
+
_ = MatchingData(target=self.target, template=self.template)
|
32
|
+
|
33
|
+
@pytest.mark.parametrize("shape", [(10,), (10, 15), (10, 20, 30)])
|
34
|
+
def test__shape_to_slice(self, shape):
|
35
|
+
slices = MatchingData._shape_to_slice(shape=shape)
|
36
|
+
assert len(slices) == len(shape)
|
37
|
+
for i, k in enumerate(shape):
|
38
|
+
assert slices[i].start == 0
|
39
|
+
assert slices[i].stop == k
|
40
|
+
|
41
|
+
@pytest.mark.parametrize("shape", [(10,), (10, 15), (10, 20, 30)])
|
42
|
+
def test_slice_to_mesh(self, shape):
|
43
|
+
if shape is not None:
|
44
|
+
slices = MatchingData._shape_to_slice(shape=shape)
|
45
|
+
|
46
|
+
indices = MatchingData._slice_to_mesh(slice_variable=slices, shape=shape)
|
47
|
+
assert len(indices) == len(shape)
|
48
|
+
for i, k in enumerate(shape):
|
49
|
+
assert indices[i].min() == 0
|
50
|
+
assert indices[i].max() == k - 1
|
51
|
+
|
52
|
+
indices = MatchingData._slice_to_mesh(slice_variable=None, shape=shape)
|
53
|
+
assert len(indices) == len(shape)
|
54
|
+
for i, k in enumerate(shape):
|
55
|
+
assert indices[i].min() == 0
|
56
|
+
assert indices[i].max() == k - 1
|
57
|
+
|
58
|
+
def test__load_array(self):
|
59
|
+
arr = MatchingData._load_array(self.target)
|
60
|
+
assert np.allclose(arr, self.target)
|
61
|
+
|
62
|
+
def test__load_array_memmap(self):
|
63
|
+
_, filename = mkstemp()
|
64
|
+
shape, dtype = self.target.shape, self.target.dtype
|
65
|
+
arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
|
66
|
+
arr_memmap[:] = self.target[:]
|
67
|
+
arr_memmap.flush()
|
68
|
+
|
69
|
+
arr = MatchingData._load_array(arr_memmap)
|
70
|
+
assert np.allclose(arr, self.target)
|
71
|
+
|
72
|
+
def test_subset_array(self):
|
73
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
74
|
+
slices = MatchingData._shape_to_slice(
|
75
|
+
shape=np.divide(self.target.shape, 2).astype(int)
|
76
|
+
)
|
77
|
+
ret = matching_data.subset_array(
|
78
|
+
arr=self.target, arr_slice=slices, padding=(2, 2, 2)
|
79
|
+
)
|
80
|
+
assert np.allclose(
|
81
|
+
ret.shape, np.add(np.divide(self.target.shape, 2).astype(int), 2)
|
82
|
+
)
|
83
|
+
|
84
|
+
def test_subset_by_slice_none(self):
|
85
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
86
|
+
matching_data.rotations = self.rotations
|
87
|
+
matching_data.target_mask = self.target
|
88
|
+
matching_data.template_mask = self.template
|
89
|
+
|
90
|
+
ret = matching_data.subset_by_slice()
|
91
|
+
|
92
|
+
assert type(ret) == type(matching_data)
|
93
|
+
assert np.allclose(ret.target, matching_data.target)
|
94
|
+
assert np.allclose(ret.template, matching_data.template)
|
95
|
+
assert np.allclose(ret.target_mask, matching_data.target_mask)
|
96
|
+
assert np.allclose(ret.template_mask, matching_data.template_mask)
|
97
|
+
|
98
|
+
def test_subset_by_slice(self):
|
99
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
100
|
+
matching_data.rotations = self.rotations
|
101
|
+
matching_data.target_mask = self.target
|
102
|
+
matching_data.template_mask = self.template
|
103
|
+
|
104
|
+
target_slice = MatchingData._shape_to_slice(
|
105
|
+
shape=np.divide(self.target.shape, 2).astype(int)
|
106
|
+
)
|
107
|
+
template_slice = MatchingData._shape_to_slice(
|
108
|
+
shape=np.divide(self.template.shape, 2).astype(int)
|
109
|
+
)
|
110
|
+
ret = matching_data.subset_by_slice(
|
111
|
+
target_slice=target_slice, template_slice=template_slice
|
112
|
+
)
|
113
|
+
assert type(ret) == type(matching_data)
|
114
|
+
|
115
|
+
assert np.allclose(
|
116
|
+
ret.target.shape, np.divide(self.target.shape, 2).astype(int)
|
117
|
+
)
|
118
|
+
assert np.allclose(
|
119
|
+
ret.template.shape, np.divide(self.target.shape, 2).astype(int)[::-1]
|
120
|
+
)
|
121
|
+
|
122
|
+
def test_rotations(self):
|
123
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
124
|
+
matching_data.rotations = self.rotations
|
125
|
+
matching_data.target_mask = self.target
|
126
|
+
|
127
|
+
assert np.allclose(matching_data.rotations, self.rotations)
|
128
|
+
|
129
|
+
matching_data.rotations = np.random.rand(self.target.ndim, self.target.ndim)
|
130
|
+
assert np.allclose(
|
131
|
+
matching_data.rotations.shape, (1, self.target.ndim, self.target.ndim)
|
132
|
+
)
|
133
|
+
|
134
|
+
def test_target(self):
|
135
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
136
|
+
|
137
|
+
assert np.allclose(matching_data.target, self.target)
|
138
|
+
|
139
|
+
def test_template(self):
|
140
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
141
|
+
|
142
|
+
assert np.allclose(matching_data.template, be.reverse(self.template))
|
143
|
+
|
144
|
+
def test_target_mask(self):
|
145
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
146
|
+
matching_data.target_mask = self.target
|
147
|
+
|
148
|
+
assert np.allclose(matching_data.target_mask, self.target)
|
149
|
+
|
150
|
+
def test_template_mask(self):
|
151
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
152
|
+
matching_data.template_mask = self.template
|
153
|
+
|
154
|
+
assert np.allclose(matching_data.template_mask, be.reverse(self.template))
|
155
|
+
|
156
|
+
@pytest.mark.parametrize("jobs", range(1, 50, 5))
|
157
|
+
def test__split_rotations_on_jobs(self, jobs):
|
158
|
+
matching_data = MatchingData(target=self.target, template=self.template)
|
159
|
+
matching_data.rotations = self.rotations
|
160
|
+
|
161
|
+
ret = matching_data._split_rotations_on_jobs(n_jobs=jobs)
|
162
|
+
assert len(ret) == jobs
|
@@ -0,0 +1,162 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pytest
|
3
|
+
|
4
|
+
from scipy.ndimage import laplace
|
5
|
+
|
6
|
+
from tme.matching_data import MatchingData
|
7
|
+
from tme.memory import MATCHING_MEMORY_REGISTRY
|
8
|
+
from tme.matching_utils import get_rotation_matrices
|
9
|
+
from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
|
10
|
+
from tme.matching_exhaustive import (
|
11
|
+
scan,
|
12
|
+
scan_subsets,
|
13
|
+
MATCHING_EXHAUSTIVE_REGISTER,
|
14
|
+
register_matching_exhaustive,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class TestMatchExhaustive:
|
19
|
+
def setup_method(self):
|
20
|
+
# To be valid for splitting, the template needs to be fully inside the object
|
21
|
+
target = np.zeros((80, 80, 80))
|
22
|
+
target[25:31, 22:28, 12:16] = 1
|
23
|
+
|
24
|
+
self.target = target
|
25
|
+
self.template = np.zeros((41, 41, 35))
|
26
|
+
self.template[20:26, 25:31, 17:21] = 1
|
27
|
+
self.template_mask = np.ones_like(self.template)
|
28
|
+
self.target_mask = np.ones_like(target)
|
29
|
+
|
30
|
+
self.rotations = get_rotation_matrices(60)[0,]
|
31
|
+
self.peak_position = np.array([25, 17, 12])
|
32
|
+
|
33
|
+
def teardown_method(self):
|
34
|
+
self.target = None
|
35
|
+
self.template = None
|
36
|
+
self.coordinates = None
|
37
|
+
self.coordinates_weights = None
|
38
|
+
self.rotations = None
|
39
|
+
|
40
|
+
@pytest.mark.parametrize("score", list(MATCHING_EXHAUSTIVE_REGISTER.keys()))
|
41
|
+
@pytest.mark.parametrize("n_jobs", (1, 2))
|
42
|
+
def test_scan(self, score: str, n_jobs: int):
|
43
|
+
matching_data = MatchingData(
|
44
|
+
target=self.target,
|
45
|
+
template=self.template,
|
46
|
+
target_mask=self.target_mask,
|
47
|
+
template_mask=self.template_mask,
|
48
|
+
rotations=self.rotations,
|
49
|
+
)
|
50
|
+
setup, process = MATCHING_EXHAUSTIVE_REGISTER[score]
|
51
|
+
ret = scan(
|
52
|
+
matching_data=matching_data,
|
53
|
+
matching_setup=setup,
|
54
|
+
matching_score=process,
|
55
|
+
n_jobs=n_jobs,
|
56
|
+
callback_class=MaxScoreOverRotations,
|
57
|
+
)
|
58
|
+
scores = ret[0]
|
59
|
+
peak = np.unravel_index(np.argmax(scores), scores.shape)
|
60
|
+
|
61
|
+
theoretical_score = 1
|
62
|
+
if score == "CC":
|
63
|
+
theoretical_score = self.template.sum()
|
64
|
+
elif score == "LCC":
|
65
|
+
theoretical_score = (laplace(self.template) * laplace(self.template)).sum()
|
66
|
+
|
67
|
+
assert np.allclose(peak, self.peak_position)
|
68
|
+
assert np.allclose(scores[peak], theoretical_score, rtol=0.05)
|
69
|
+
|
70
|
+
@pytest.mark.parametrize("evaluate_peak", (False, True))
|
71
|
+
@pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
|
72
|
+
@pytest.mark.parametrize("job_schedule", ((2, 1), (1, 1)))
|
73
|
+
@pytest.mark.parametrize("pad_fourier", (True, False))
|
74
|
+
@pytest.mark.parametrize("pad_edge", (True, False))
|
75
|
+
def test_scan_subset(
|
76
|
+
self,
|
77
|
+
score: str,
|
78
|
+
job_schedule: int,
|
79
|
+
evaluate_peak: bool,
|
80
|
+
pad_fourier: bool,
|
81
|
+
pad_edge: bool,
|
82
|
+
):
|
83
|
+
matching_data = MatchingData(
|
84
|
+
target=self.target,
|
85
|
+
template=self.template,
|
86
|
+
target_mask=self.target_mask,
|
87
|
+
template_mask=self.template_mask,
|
88
|
+
rotations=self.rotations,
|
89
|
+
)
|
90
|
+
|
91
|
+
setup, process = MATCHING_EXHAUSTIVE_REGISTER[score]
|
92
|
+
|
93
|
+
target_splits = {}
|
94
|
+
if job_schedule[0] == 2:
|
95
|
+
target_splits = {0: 2 if i != 0 else 2 for i in range(self.target.ndim)}
|
96
|
+
|
97
|
+
callback_class = PeakCallerSort
|
98
|
+
if evaluate_peak:
|
99
|
+
callback_class = MaxScoreOverRotations
|
100
|
+
|
101
|
+
ret = scan_subsets(
|
102
|
+
matching_data=matching_data,
|
103
|
+
matching_setup=setup,
|
104
|
+
matching_score=process,
|
105
|
+
target_splits=target_splits,
|
106
|
+
job_schedule=job_schedule,
|
107
|
+
callback_class=callback_class,
|
108
|
+
pad_target_edges=pad_edge,
|
109
|
+
pad_fourier=pad_fourier,
|
110
|
+
)
|
111
|
+
|
112
|
+
if evaluate_peak:
|
113
|
+
scores = ret[0]
|
114
|
+
peak = np.unravel_index(np.argmax(scores), scores.shape)
|
115
|
+
achieved_score = scores[tuple(peak)]
|
116
|
+
else:
|
117
|
+
try:
|
118
|
+
peak, achieved_score = ret[0][0], ret[2][0]
|
119
|
+
except Exception as e:
|
120
|
+
return None
|
121
|
+
|
122
|
+
if not pad_edge:
|
123
|
+
# To be valid, the match needs to be fully within the target subset
|
124
|
+
return None
|
125
|
+
|
126
|
+
theoretical_score = 1
|
127
|
+
if score == "CC":
|
128
|
+
theoretical_score = self.template.sum()
|
129
|
+
elif score == "LCC":
|
130
|
+
theoretical_score = (laplace(self.template) * laplace(self.template)).sum()
|
131
|
+
|
132
|
+
if not np.allclose(peak, self.peak_position):
|
133
|
+
print(peak)
|
134
|
+
assert False
|
135
|
+
|
136
|
+
assert np.allclose(achieved_score, theoretical_score, rtol=0.3)
|
137
|
+
|
138
|
+
def test_register_matching_exhaustive(self):
|
139
|
+
setup, matching = MATCHING_EXHAUSTIVE_REGISTER[
|
140
|
+
list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
|
141
|
+
]
|
142
|
+
memory_class = MATCHING_MEMORY_REGISTRY[
|
143
|
+
list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
|
144
|
+
]
|
145
|
+
register_matching_exhaustive(
|
146
|
+
matching="TEST",
|
147
|
+
matching_setup=setup,
|
148
|
+
matching_scoring=matching,
|
149
|
+
memory_class=memory_class,
|
150
|
+
)
|
151
|
+
|
152
|
+
def test_register_matching_exhaustive_error(self):
|
153
|
+
key = list(MATCHING_EXHAUSTIVE_REGISTER.keys())[0]
|
154
|
+
setup, matching = MATCHING_EXHAUSTIVE_REGISTER[key]
|
155
|
+
memory_class = MATCHING_MEMORY_REGISTRY[key]
|
156
|
+
with pytest.raises(ValueError):
|
157
|
+
register_matching_exhaustive(
|
158
|
+
matching=key,
|
159
|
+
matching_setup=setup,
|
160
|
+
matching_scoring=matching,
|
161
|
+
memory_class=memory_class,
|
162
|
+
)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import pytest
|
2
|
+
from importlib_resources import files
|
3
|
+
|
4
|
+
from tme import Density
|
5
|
+
from tme.memory import MATCHING_MEMORY_REGISTRY, estimate_ram_usage
|
6
|
+
|
7
|
+
BASEPATH = files("tests.data")
|
8
|
+
|
9
|
+
|
10
|
+
class TestMatchingMemory:
|
11
|
+
def setup_method(self):
|
12
|
+
self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
|
13
|
+
self.structure_density = Density.from_structure(
|
14
|
+
filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
|
15
|
+
origin=self.density.origin,
|
16
|
+
shape=self.density.shape,
|
17
|
+
sampling_rate=self.density.sampling_rate,
|
18
|
+
)
|
19
|
+
|
20
|
+
@pytest.mark.parametrize("analyzer_method", ["MaxScoreOverRotations", None])
|
21
|
+
@pytest.mark.parametrize("matching_method", list(MATCHING_MEMORY_REGISTRY.keys()))
|
22
|
+
@pytest.mark.parametrize("ncores", range(1, 10, 3))
|
23
|
+
def test_estimate_ram_usage(self, matching_method, ncores, analyzer_method):
|
24
|
+
estimate_ram_usage(
|
25
|
+
shape1=self.density.shape,
|
26
|
+
shape2=self.structure_density.shape,
|
27
|
+
matching_method=matching_method,
|
28
|
+
ncores=ncores,
|
29
|
+
analyzer_method=analyzer_method,
|
30
|
+
)
|