pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3.0.data/scripts/match_template.py +1106 -0
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
- pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
- pytme-0.3.0.dist-info/RECORD +126 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
- pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +349 -378
- pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +320 -190
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +85 -19
- scripts/pytme_runner.py +771 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -53
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +396 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -201
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +158 -28
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.dist-info/RECORD +0 -119
- pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,15 @@ from typing import Tuple
|
|
4
4
|
|
5
5
|
from tme.backends import backend as be
|
6
6
|
from tme.filters._utils import compute_fourier_shape
|
7
|
-
from tme.filters import
|
7
|
+
from tme.filters import BandPassReconstructed, LinearWhiteningFilter
|
8
|
+
from tme.filters.bandpass import gaussian_bandpass, discrete_bandpass
|
9
|
+
from tme.filters._utils import fftfreqn
|
8
10
|
|
9
11
|
|
10
12
|
class TestBandPassFilter:
|
11
13
|
@pytest.fixture
|
12
14
|
def band_pass_filter(self):
|
13
|
-
return
|
15
|
+
return BandPassReconstructed()
|
14
16
|
|
15
17
|
@pytest.mark.parametrize(
|
16
18
|
"shape, lowpass, highpass, sampling_rate",
|
@@ -24,9 +26,13 @@ class TestBandPassFilter:
|
|
24
26
|
def test_discrete_bandpass(
|
25
27
|
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
26
28
|
):
|
27
|
-
|
28
|
-
shape,
|
29
|
+
grid = fftfreqn(
|
30
|
+
shape=shape,
|
31
|
+
sampling_rate=0.5,
|
32
|
+
shape_is_real_fourier=False,
|
33
|
+
compute_euclidean_norm=True,
|
29
34
|
)
|
35
|
+
result = discrete_bandpass(grid, lowpass, highpass, sampling_rate)
|
30
36
|
assert isinstance(result, type(be.ones((1,))))
|
31
37
|
assert result.shape == shape
|
32
38
|
assert np.all((result >= 0) & (result <= 1))
|
@@ -43,9 +49,13 @@ class TestBandPassFilter:
|
|
43
49
|
def test_gaussian_bandpass(
|
44
50
|
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
45
51
|
):
|
46
|
-
|
47
|
-
shape,
|
52
|
+
grid = fftfreqn(
|
53
|
+
shape=shape,
|
54
|
+
sampling_rate=0.5,
|
55
|
+
shape_is_real_fourier=False,
|
56
|
+
compute_euclidean_norm=True,
|
48
57
|
)
|
58
|
+
result = gaussian_bandpass(grid, lowpass, highpass, sampling_rate)
|
49
59
|
assert isinstance(result, type(be.ones((1,))))
|
50
60
|
assert result.shape == shape
|
51
61
|
assert np.all((result >= 0) & (result <= 1))
|
@@ -55,7 +65,7 @@ class TestBandPassFilter:
|
|
55
65
|
@pytest.mark.parametrize("shape_is_real_fourier", [True, False])
|
56
66
|
def test_call_method(
|
57
67
|
self,
|
58
|
-
band_pass_filter:
|
68
|
+
band_pass_filter: BandPassReconstructed,
|
59
69
|
use_gaussian: bool,
|
60
70
|
return_real_fourier: bool,
|
61
71
|
shape_is_real_fourier: bool,
|
@@ -68,22 +78,20 @@ class TestBandPassFilter:
|
|
68
78
|
|
69
79
|
assert isinstance(result, dict)
|
70
80
|
assert "data" in result
|
71
|
-
assert "sampling_rate" in result
|
72
81
|
assert "is_multiplicative_filter" in result
|
73
82
|
assert isinstance(result["data"], type(be.ones((1,))))
|
74
83
|
assert result["is_multiplicative_filter"] is True
|
75
84
|
|
76
|
-
def test_default_values(self, band_pass_filter:
|
85
|
+
def test_default_values(self, band_pass_filter: BandPassReconstructed):
|
77
86
|
assert band_pass_filter.lowpass is None
|
78
87
|
assert band_pass_filter.highpass is None
|
79
88
|
assert band_pass_filter.sampling_rate == 1
|
80
89
|
assert band_pass_filter.use_gaussian is True
|
81
90
|
assert band_pass_filter.return_real_fourier is False
|
82
|
-
assert band_pass_filter.shape_is_real_fourier is False
|
83
91
|
|
84
92
|
@pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
|
85
93
|
def test_return_real_fourier(self, shape: Tuple[int]):
|
86
|
-
bpf =
|
94
|
+
bpf = BandPassReconstructed(return_real_fourier=True)
|
87
95
|
result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
|
88
96
|
expected_shape = tuple(compute_fourier_shape(shape, False))
|
89
97
|
assert result["data"].shape == expected_shape
|
@@ -146,7 +154,11 @@ class TestLinearWhiteningFilter:
|
|
146
154
|
):
|
147
155
|
data = be.random.random(shape)
|
148
156
|
result = LinearWhiteningFilter()(
|
149
|
-
|
157
|
+
shape=shape,
|
158
|
+
data=data,
|
159
|
+
n_bins=n_bins,
|
160
|
+
batch_dimension=batch_dimension,
|
161
|
+
order=order,
|
150
162
|
)
|
151
163
|
|
152
164
|
assert isinstance(result, dict)
|
@@ -161,7 +173,9 @@ class TestLinearWhiteningFilter:
|
|
161
173
|
def test_call_method_with_data_rfft(self):
|
162
174
|
shape = (30, 30, 30)
|
163
175
|
data_rfft = be.fft.rfftn(be.random.random(shape))
|
164
|
-
result = LinearWhiteningFilter()(
|
176
|
+
result = LinearWhiteningFilter()(
|
177
|
+
shape=shape, data_rfft=data_rfft, return_real_fourier=True
|
178
|
+
)
|
165
179
|
|
166
180
|
assert isinstance(result, dict)
|
167
181
|
assert result.get("data", False) is not False
|
@@ -172,7 +186,7 @@ class TestLinearWhiteningFilter:
|
|
172
186
|
@pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
|
173
187
|
def test_filter_mask_range(self, shape: Tuple[int]):
|
174
188
|
data = be.random.random(shape)
|
175
|
-
result = LinearWhiteningFilter()(data=data)
|
189
|
+
result = LinearWhiteningFilter()(shape=shape, data=data)
|
176
190
|
|
177
191
|
filter_mask = result["data"]
|
178
192
|
assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
|
tests/test_analyzer.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1
|
-
from tempfile import mkstemp
|
2
|
-
|
3
1
|
import pytest
|
4
2
|
import numpy as np
|
5
3
|
|
@@ -24,6 +22,7 @@ PEAK_CALLER_CHILDREN = [
|
|
24
22
|
PeakCallerScipy,
|
25
23
|
PeakClustering,
|
26
24
|
]
|
25
|
+
np.random.seed(123)
|
27
26
|
|
28
27
|
|
29
28
|
class TestPeakCallers:
|
@@ -54,21 +53,23 @@ class TestPeakCallers:
|
|
54
53
|
@pytest.mark.parametrize("num_peaks", (1, 100))
|
55
54
|
@pytest.mark.parametrize("minimum_score", (None, 0.5))
|
56
55
|
def test__call__(self, peak_caller, num_peaks, minimum_score):
|
57
|
-
|
58
|
-
shape
|
59
|
-
num_peaks
|
60
|
-
min_distance
|
61
|
-
min_score
|
62
|
-
|
63
|
-
peak_caller(
|
56
|
+
kwargs = {
|
57
|
+
"shape": self.data.shape,
|
58
|
+
"num_peaks": num_peaks,
|
59
|
+
"min_distance": self.min_distance,
|
60
|
+
"min_score": minimum_score,
|
61
|
+
}
|
62
|
+
peak_caller = peak_caller(**kwargs)
|
63
|
+
state = peak_caller(
|
64
|
+
peak_caller.init_state(),
|
64
65
|
self.data.copy(),
|
65
66
|
rotation_matrix=self.rotation_matrix,
|
66
67
|
)
|
67
|
-
|
68
|
+
state = peak_caller.result(state)
|
68
69
|
if minimum_score is None:
|
69
|
-
assert len(
|
70
|
+
assert len(state[0] <= num_peaks)
|
70
71
|
else:
|
71
|
-
peaks =
|
72
|
+
peaks = state[0].astype(int)
|
72
73
|
print(self.data[tuple(peaks.T)])
|
73
74
|
assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
|
74
75
|
|
@@ -78,21 +79,20 @@ class TestPeakCallers:
|
|
78
79
|
peak_caller1 = peak_caller(
|
79
80
|
shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
|
80
81
|
)
|
81
|
-
|
82
|
+
state1 = peak_caller1.init_state()
|
83
|
+
state1 = peak_caller1(state1, self.data, rotation_matrix=self.rotation_matrix)
|
82
84
|
|
83
85
|
peak_caller2 = peak_caller(
|
84
86
|
shape=self.data.shape, num_peaks=num_peaks, min_distance=self.min_distance
|
85
87
|
)
|
86
|
-
|
88
|
+
state2 = peak_caller2.init_state()
|
89
|
+
state2 = peak_caller2(state2, self.data, rotation_matrix=self.rotation_matrix)
|
87
90
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
num_peaks=num_peaks,
|
94
|
-
min_distance=self.min_distance,
|
95
|
-
)
|
91
|
+
states = [peak_caller1.result(state1), peak_caller2.result(state2)]
|
92
|
+
result = peak_caller.merge(
|
93
|
+
results=states,
|
94
|
+
num_peaks=num_peaks,
|
95
|
+
min_distance=self.min_distance,
|
96
96
|
)
|
97
97
|
assert [len(res) == 2 for res in result]
|
98
98
|
|
@@ -122,7 +122,9 @@ class TestRecursiveMasking:
|
|
122
122
|
rotation_space = self.rotation_space
|
123
123
|
rotation_mapping = self.rotation_mapping
|
124
124
|
|
125
|
-
peak_caller(
|
125
|
+
state = peak_caller.init_state()
|
126
|
+
state = peak_caller(
|
127
|
+
state,
|
126
128
|
self.data.copy(),
|
127
129
|
rotation_matrix=self.rotation_matrix,
|
128
130
|
mask=self.mask,
|
@@ -130,11 +132,10 @@ class TestRecursiveMasking:
|
|
130
132
|
rotation_mapping=rotation_mapping,
|
131
133
|
)
|
132
134
|
|
133
|
-
candidates = tuple(peak_caller)
|
134
135
|
if minimum_score is None:
|
135
|
-
assert len(
|
136
|
+
assert len(state[0] <= num_peaks)
|
136
137
|
else:
|
137
|
-
peaks =
|
138
|
+
peaks = state[0].astype(int)
|
138
139
|
assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
|
139
140
|
|
140
141
|
|
@@ -157,8 +158,9 @@ class TestMaxScoreOverRotations:
|
|
157
158
|
shape=self.data.shape,
|
158
159
|
use_memmap=use_memmap,
|
159
160
|
)
|
160
|
-
|
161
|
-
|
161
|
+
state = score_analyzer.init_state()
|
162
|
+
state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
|
163
|
+
res = score_analyzer.result(state)
|
162
164
|
assert np.allclose(res[0].shape, self.data.shape)
|
163
165
|
assert res[0].dtype == be._float_dtype
|
164
166
|
assert res[1].size == self.data.ndim
|
@@ -174,11 +176,13 @@ class TestMaxScoreOverRotations:
|
|
174
176
|
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
175
177
|
use_memmap=use_memmap,
|
176
178
|
)
|
177
|
-
|
179
|
+
state = score_analyzer.init_state()
|
180
|
+
state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
|
178
181
|
|
179
182
|
data2 = self.data * 2
|
180
|
-
score_analyzer(data2, rotation_matrix=self.rotation_matrix)
|
181
|
-
scores, translation_offset, rotations, mapping =
|
183
|
+
score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
|
184
|
+
scores, translation_offset, rotations, mapping = score_analyzer.result(state)
|
185
|
+
|
182
186
|
assert np.all(scores >= score_threshold)
|
183
187
|
max_scores = np.maximum(self.data, data2)
|
184
188
|
max_scores = np.maximum(max_scores, score_threshold)
|
@@ -193,7 +197,8 @@ class TestMaxScoreOverRotations:
|
|
193
197
|
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
194
198
|
use_memmap=use_memmap,
|
195
199
|
)
|
196
|
-
|
200
|
+
state1 = score_analyzer.init_state()
|
201
|
+
state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
|
197
202
|
|
198
203
|
data2 = self.data * 2
|
199
204
|
score_analyzer2 = MaxScoreOverRotations(
|
@@ -202,12 +207,12 @@ class TestMaxScoreOverRotations:
|
|
202
207
|
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
203
208
|
use_memmap=use_memmap,
|
204
209
|
)
|
205
|
-
|
206
|
-
|
207
|
-
|
210
|
+
state2 = score_analyzer2.init_state()
|
211
|
+
state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
|
212
|
+
states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
|
208
213
|
|
209
214
|
ret = MaxScoreOverRotations.merge(
|
210
|
-
|
215
|
+
results=states, use_memmap=use_memmap, score_threshold=score_threshold
|
211
216
|
)
|
212
217
|
scores, translation, rotations, mapping = ret
|
213
218
|
assert np.all(scores >= score_threshold)
|
tests/test_backends.py
CHANGED
tests/test_matching_cli.py
CHANGED
@@ -6,26 +6,10 @@ from os import remove, makedirs
|
|
6
6
|
|
7
7
|
import pytest
|
8
8
|
import numpy as np
|
9
|
-
from tme import Density
|
9
|
+
from tme import Density, Orientations
|
10
10
|
from tme.backends import backend as be
|
11
11
|
|
12
|
-
|
13
|
-
BACKENDS_TO_TEST = []
|
14
|
-
|
15
|
-
test_gpu = (False,)
|
16
|
-
for backend_class in BACKEND_CLASSES:
|
17
|
-
try:
|
18
|
-
BackendClass = getattr(
|
19
|
-
__import__("tme.backends", fromlist=[backend_class]), backend_class
|
20
|
-
)
|
21
|
-
BACKENDS_TO_TEST.append(BackendClass(device="cpu"))
|
22
|
-
if backend_class == "CupyBackend":
|
23
|
-
if BACKENDS_TO_TEST[-1].device_count() >= 1:
|
24
|
-
test_gpu = (False, True)
|
25
|
-
except ImportError:
|
26
|
-
print(f"Couldn't import {backend_class}. Skipping...")
|
27
|
-
|
28
|
-
|
12
|
+
np.random.seed(42)
|
29
13
|
available_backends = (x for x in be.available_backends() if x != "mlx")
|
30
14
|
|
31
15
|
|
@@ -48,7 +32,6 @@ def argdict_to_command(input_args, executable: str):
|
|
48
32
|
class TestMatchTemplate:
|
49
33
|
@classmethod
|
50
34
|
def setup_class(cls):
|
51
|
-
np.random.seed(42)
|
52
35
|
target = np.random.rand(20, 20, 20)
|
53
36
|
template = np.random.rand(5, 5, 5)
|
54
37
|
|
@@ -65,6 +48,19 @@ class TestMatchTemplate:
|
|
65
48
|
cls.template_mask_path = tempfile.NamedTemporaryFile(
|
66
49
|
delete=False, suffix=".mrc"
|
67
50
|
).name
|
51
|
+
cls.tempdir = tempfile.TemporaryDirectory().name
|
52
|
+
makedirs(cls.tempdir, exist_ok=True)
|
53
|
+
|
54
|
+
orientations = Orientations(
|
55
|
+
translations=((10, 10, 10), (12, 10, 15)),
|
56
|
+
rotations=((0, 0, 0), (45, 12, 90)),
|
57
|
+
scores=(0, 0),
|
58
|
+
details=(-1, -1),
|
59
|
+
)
|
60
|
+
cls.orientations_path = tempfile.NamedTemporaryFile(
|
61
|
+
delete=False, suffix=".star"
|
62
|
+
).name
|
63
|
+
orientations.to_file(cls.orientations_path)
|
68
64
|
|
69
65
|
Density(target, sampling_rate=5).to_file(cls.target_path)
|
70
66
|
Density(template, sampling_rate=5).to_file(cls.template_path)
|
@@ -76,6 +72,8 @@ class TestMatchTemplate:
|
|
76
72
|
cls.try_delete(cls.template_path)
|
77
73
|
cls.try_delete(cls.target_mask_path)
|
78
74
|
cls.try_delete(cls.template_mask_path)
|
75
|
+
cls.try_delete(cls.orientations_path)
|
76
|
+
cls.try_delete(cls.tempdir)
|
79
77
|
|
80
78
|
@staticmethod
|
81
79
|
def try_delete(file_path: str):
|
@@ -88,8 +86,8 @@ class TestMatchTemplate:
|
|
88
86
|
except Exception:
|
89
87
|
pass
|
90
88
|
|
91
|
-
@staticmethod
|
92
89
|
def run_matching(
|
90
|
+
self,
|
93
91
|
use_template_mask: bool,
|
94
92
|
test_filter: bool,
|
95
93
|
call_peaks: bool,
|
@@ -99,6 +97,7 @@ class TestMatchTemplate:
|
|
99
97
|
target_mask_path: str,
|
100
98
|
use_target_mask: bool = False,
|
101
99
|
backend: str = "numpyfftw",
|
100
|
+
test_rejection_sampling: bool = False,
|
102
101
|
):
|
103
102
|
output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
|
104
103
|
|
@@ -108,25 +107,24 @@ class TestMatchTemplate:
|
|
108
107
|
"-n": 1,
|
109
108
|
"-a": 60,
|
110
109
|
"-o": output_path,
|
111
|
-
"--
|
112
|
-
"--pad_fourier": False,
|
110
|
+
"--pad-edges": False,
|
113
111
|
"--backend": backend,
|
114
112
|
}
|
115
113
|
|
116
114
|
if use_template_mask:
|
117
|
-
argdict["--
|
115
|
+
argdict["--template-mask"] = template_mask_path
|
118
116
|
|
119
117
|
if use_target_mask:
|
120
|
-
argdict["--
|
118
|
+
argdict["--target-mask"] = target_mask_path
|
121
119
|
|
122
|
-
if
|
123
|
-
argdict["--
|
120
|
+
if test_rejection_sampling:
|
121
|
+
argdict["--orientations"] = self.orientations_path
|
124
122
|
|
125
123
|
if test_filter:
|
126
124
|
argdict["--lowpass"] = 30
|
127
125
|
argdict["--defocus"] = 3000
|
128
|
-
argdict["--
|
129
|
-
argdict["--
|
126
|
+
argdict["--tilt-angles"] = "40,40"
|
127
|
+
argdict["--wedge-axes"] = "2,0"
|
130
128
|
argdict["--whiten"] = True
|
131
129
|
|
132
130
|
if call_peaks:
|
@@ -142,13 +140,18 @@ class TestMatchTemplate:
|
|
142
140
|
@pytest.mark.parametrize("call_peaks", (False, True))
|
143
141
|
@pytest.mark.parametrize("use_template_mask", (False, True))
|
144
142
|
@pytest.mark.parametrize("test_filter", (False, True))
|
143
|
+
@pytest.mark.parametrize("test_rejection_sampling", (False, True))
|
145
144
|
def test_match_template(
|
146
145
|
self,
|
147
146
|
backend: bool,
|
148
147
|
call_peaks: bool,
|
149
148
|
use_template_mask: bool,
|
150
149
|
test_filter: bool,
|
150
|
+
test_rejection_sampling: bool,
|
151
151
|
):
|
152
|
+
if backend == "jax" and (call_peaks or test_rejection_sampling):
|
153
|
+
return None
|
154
|
+
|
152
155
|
self.run_matching(
|
153
156
|
use_template_mask=use_template_mask,
|
154
157
|
use_target_mask=True,
|
@@ -159,6 +162,7 @@ class TestMatchTemplate:
|
|
159
162
|
target_path=self.target_path,
|
160
163
|
template_mask_path=self.template_mask_path,
|
161
164
|
target_mask_path=self.target_mask_path,
|
165
|
+
test_rejection_sampling=test_rejection_sampling,
|
162
166
|
)
|
163
167
|
|
164
168
|
|
@@ -175,20 +179,20 @@ class TestPostprocessing(TestMatchTemplate):
|
|
175
179
|
"target_path": cls.target_path,
|
176
180
|
"template_mask_path": cls.template_mask_path,
|
177
181
|
"target_mask_path": cls.target_mask_path,
|
182
|
+
"test_rejection_sampling": False,
|
178
183
|
}
|
179
184
|
|
180
185
|
cls.score_pickle = cls.run_matching(
|
186
|
+
cls,
|
181
187
|
call_peaks=False,
|
182
188
|
**matching_kwargs,
|
183
189
|
)
|
184
|
-
cls.peak_pickle = cls.run_matching(call_peaks=True, **matching_kwargs)
|
185
|
-
cls.tempdir = tempfile.TemporaryDirectory().name
|
190
|
+
cls.peak_pickle = cls.run_matching(cls, call_peaks=True, **matching_kwargs)
|
186
191
|
|
187
192
|
@classmethod
|
188
193
|
def teardown_class(cls):
|
189
194
|
cls.try_delete(cls.score_pickle)
|
190
195
|
cls.try_delete(cls.peak_pickle)
|
191
|
-
cls.try_delete(cls.tempdir)
|
192
196
|
|
193
197
|
@pytest.mark.parametrize("distance_cutoff_strategy", (0, 1, 2, 3))
|
194
198
|
@pytest.mark.parametrize("score_cutoff", (None, (1,), (0, 1), (None, 1), (0, None)))
|
@@ -203,28 +207,28 @@ class TestPostprocessing(TestMatchTemplate):
|
|
203
207
|
makedirs(self.tempdir, exist_ok=True)
|
204
208
|
|
205
209
|
argdict = {
|
206
|
-
"--
|
207
|
-
"--
|
208
|
-
"--
|
209
|
-
"--
|
210
|
-
"--
|
210
|
+
"--input-file": self.score_pickle,
|
211
|
+
"--output-format": "orientations",
|
212
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
213
|
+
"--peak-oversampling": peak_oversampling,
|
214
|
+
"--num-peaks": 3,
|
211
215
|
}
|
212
216
|
|
213
217
|
if score_cutoff is not None:
|
214
218
|
if len(score_cutoff) == 1:
|
215
|
-
argdict["--
|
219
|
+
argdict["--n-false-positives"] = 1
|
216
220
|
else:
|
217
221
|
min_score, max_score = score_cutoff
|
218
|
-
argdict["--
|
219
|
-
argdict["--
|
222
|
+
argdict["--min-score"] = min_score
|
223
|
+
argdict["--max-score"] = max_score
|
220
224
|
|
221
225
|
match distance_cutoff_strategy:
|
222
226
|
case 1:
|
223
|
-
argdict["--
|
227
|
+
argdict["--mask-edges"] = True
|
224
228
|
case 2:
|
225
|
-
argdict["--
|
229
|
+
argdict["--min-distance"] = 5
|
226
230
|
case 3:
|
227
|
-
argdict["--
|
231
|
+
argdict["--min-boundary-distance"] = 5
|
228
232
|
|
229
233
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
230
234
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
@@ -244,14 +248,15 @@ class TestPostprocessing(TestMatchTemplate):
|
|
244
248
|
input_file = self.peak_pickle
|
245
249
|
|
246
250
|
argdict = {
|
247
|
-
"--
|
248
|
-
"--
|
249
|
-
"--
|
250
|
-
"--
|
251
|
-
"--
|
251
|
+
"--input-file": input_file,
|
252
|
+
"--output-format": output_format,
|
253
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
254
|
+
"--num-peaks": 3,
|
255
|
+
"--peak-caller": "PeakCallerMaximumFilter",
|
252
256
|
}
|
253
257
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
254
258
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
259
|
+
print(ret)
|
255
260
|
|
256
261
|
match output_format:
|
257
262
|
case "orientations":
|
@@ -264,7 +269,8 @@ class TestPostprocessing(TestMatchTemplate):
|
|
264
269
|
assert exists(f"{self.tempdir}/temp.star")
|
265
270
|
case "relion5":
|
266
271
|
assert exists(f"{self.tempdir}/temp.star")
|
267
|
-
|
272
|
+
case "pickle":
|
273
|
+
assert exists(f"{self.tempdir}/temp.pickle")
|
268
274
|
assert ret.returncode == 0
|
269
275
|
|
270
276
|
def test_postprocess_score_local_optimization(self):
|
@@ -272,12 +278,62 @@ class TestPostprocessing(TestMatchTemplate):
|
|
272
278
|
makedirs(self.tempdir, exist_ok=True)
|
273
279
|
|
274
280
|
argdict = {
|
275
|
-
"--
|
276
|
-
"--
|
277
|
-
"--
|
278
|
-
"--
|
279
|
-
"--
|
281
|
+
"--input-file": self.score_pickle,
|
282
|
+
"--output-format": "orientations",
|
283
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
284
|
+
"--num-peaks": 1,
|
285
|
+
"--local-optimization": True,
|
280
286
|
}
|
281
287
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
282
288
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
283
289
|
assert ret.returncode == 0
|
290
|
+
|
291
|
+
|
292
|
+
class TestEstimateMemoryUsage(TestMatchTemplate):
|
293
|
+
@classmethod
|
294
|
+
def setup_class(cls):
|
295
|
+
super().setup_class()
|
296
|
+
|
297
|
+
@pytest.mark.parametrize("ncores", (1, 4, 8))
|
298
|
+
@pytest.mark.parametrize("pad_edges", (False, True))
|
299
|
+
def test_estimation(self, ncores, pad_edges):
|
300
|
+
|
301
|
+
argdict = {
|
302
|
+
"-m": self.target_path,
|
303
|
+
"-i": self.template_path,
|
304
|
+
"--ncores": ncores,
|
305
|
+
"--pad-edges": pad_edges,
|
306
|
+
"--score": "FLCSphericalMask",
|
307
|
+
}
|
308
|
+
|
309
|
+
cmd = argdict_to_command(argdict, executable="estimate_memory_usage.py")
|
310
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
311
|
+
assert ret.returncode == 0
|
312
|
+
|
313
|
+
|
314
|
+
class TestPreprocess(TestMatchTemplate):
|
315
|
+
@classmethod
|
316
|
+
def setup_class(cls):
|
317
|
+
super().setup_class()
|
318
|
+
|
319
|
+
@pytest.mark.parametrize("backend", available_backends)
|
320
|
+
@pytest.mark.parametrize("align_axis", (False, True))
|
321
|
+
@pytest.mark.parametrize("invert_contrast", (False, True))
|
322
|
+
def test_estimation(self, backend, align_axis, invert_contrast):
|
323
|
+
|
324
|
+
argdict = {
|
325
|
+
"-m": self.target_path,
|
326
|
+
"--backend": backend,
|
327
|
+
"--lowpass": 40,
|
328
|
+
"--sampling-rate": 5,
|
329
|
+
"-o": f"{self.tempdir}/out.mrc",
|
330
|
+
}
|
331
|
+
if align_axis:
|
332
|
+
argdict["--align-axis"] = 2
|
333
|
+
|
334
|
+
if invert_contrast:
|
335
|
+
argdict["--invert-contrast"] = True
|
336
|
+
|
337
|
+
cmd = argdict_to_command(argdict, executable="preprocess.py")
|
338
|
+
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
339
|
+
assert ret.returncode == 0
|
tests/test_matching_data.py
CHANGED
@@ -7,7 +7,7 @@ from tme.backends import backend as be
|
|
7
7
|
from tme.matching_data import MatchingData
|
8
8
|
|
9
9
|
|
10
|
-
class
|
10
|
+
class TestMatchingData:
|
11
11
|
def setup_method(self):
|
12
12
|
target = np.zeros((50, 50, 50))
|
13
13
|
target[20:30, 30:40, 12:17] = 1
|
@@ -87,9 +87,9 @@ class TestDensity:
|
|
87
87
|
matching_data.target_mask = self.target
|
88
88
|
matching_data.template_mask = self.template
|
89
89
|
|
90
|
-
ret = matching_data.subset_by_slice()
|
90
|
+
ret, offset = matching_data.subset_by_slice()
|
91
91
|
|
92
|
-
assert
|
92
|
+
assert isinstance(ret, type(matching_data))
|
93
93
|
assert np.allclose(ret.target, matching_data.target)
|
94
94
|
assert np.allclose(ret.template, matching_data.template)
|
95
95
|
assert np.allclose(ret.target_mask, matching_data.target_mask)
|
@@ -107,10 +107,10 @@ class TestDensity:
|
|
107
107
|
template_slice = MatchingData._shape_to_slice(
|
108
108
|
shape=np.divide(self.template.shape, 2).astype(int)
|
109
109
|
)
|
110
|
-
ret = matching_data.subset_by_slice(
|
110
|
+
ret, offset = matching_data.subset_by_slice(
|
111
111
|
target_slice=target_slice, template_slice=template_slice
|
112
112
|
)
|
113
|
-
assert
|
113
|
+
assert isinstance(ret, type(matching_data))
|
114
114
|
|
115
115
|
assert np.allclose(
|
116
116
|
ret.target.shape, np.divide(self.target.shape, 2).astype(int)
|
@@ -7,7 +7,6 @@ from tme.matching_data import MatchingData
|
|
7
7
|
from tme.memory import MATCHING_MEMORY_REGISTRY
|
8
8
|
from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
|
9
9
|
from tme.matching_exhaustive import (
|
10
|
-
scan,
|
11
10
|
scan_subsets,
|
12
11
|
MATCHING_EXHAUSTIVE_REGISTER,
|
13
12
|
register_matching_exhaustive,
|
@@ -36,7 +35,7 @@ class TestMatchExhaustive:
|
|
36
35
|
self.coordinates_weights = None
|
37
36
|
self.rotations = None
|
38
37
|
|
39
|
-
@pytest.mark.parametrize("evaluate_peak", (
|
38
|
+
@pytest.mark.parametrize("evaluate_peak", (True,))
|
40
39
|
@pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
|
41
40
|
@pytest.mark.parametrize("job_schedule", ((2, 1),))
|
42
41
|
@pytest.mark.parametrize("pad_edge", (False, True))
|
@@ -24,6 +24,7 @@ coordinate_to_coordinate = [
|
|
24
24
|
for k, v in MATCHING_OPTIMIZATION_REGISTER.items()
|
25
25
|
if issubclass(v, _MatchCoordinatesToCoordinates)
|
26
26
|
]
|
27
|
+
np.random.seed(42)
|
27
28
|
|
28
29
|
|
29
30
|
class TestMatchDensityToDensity:
|
@@ -52,9 +53,7 @@ class TestMatchDensityToDensity:
|
|
52
53
|
|
53
54
|
@pytest.mark.parametrize("method", density_to_density)
|
54
55
|
def test_call(self, method):
|
55
|
-
|
56
|
-
score = instance()
|
57
|
-
assert isinstance(score, float)
|
56
|
+
self.test_initialization(method=method, notest=True)()
|
58
57
|
|
59
58
|
|
60
59
|
class TestMatchDensityToCoordinates:
|
@@ -97,9 +96,7 @@ class TestMatchDensityToCoordinates:
|
|
97
96
|
|
98
97
|
@pytest.mark.parametrize("method", coordinate_to_density)
|
99
98
|
def test_call(self, method):
|
100
|
-
|
101
|
-
score = instance()
|
102
|
-
assert isinstance(score, float)
|
99
|
+
self.test_initialization(method=method, notest=True)()
|
103
100
|
|
104
101
|
|
105
102
|
class TestMatchCoordinateToCoordinates:
|
@@ -135,9 +132,7 @@ class TestMatchCoordinateToCoordinates:
|
|
135
132
|
|
136
133
|
@pytest.mark.parametrize("method", coordinate_to_coordinate)
|
137
134
|
def test_call(self, method):
|
138
|
-
|
139
|
-
score = instance()
|
140
|
-
assert isinstance(score, float)
|
135
|
+
self.test_initialization(method=method, notest=True)()
|
141
136
|
|
142
137
|
|
143
138
|
class TestOptimizeMatch:
|