pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
- pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/match_template.py +163 -201
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +48 -39
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +3 -4
- scripts/pytme_runner.py +769 -0
- scripts/refine_matches.py +0 -1
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +1 -4
- tme/analyzer/aggregation.py +15 -6
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +39 -113
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +16 -15
- tme/backends/cupy_backend.py +9 -13
- tme/backends/jax_backend.py +19 -16
- tme/backends/npfftw_backend.py +27 -25
- tme/backends/pytorch_backend.py +4 -0
- tme/density.py +5 -4
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +117 -67
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +88 -105
- tme/filters/whitening.py +1 -6
- tme/matching_data.py +24 -36
- tme/matching_exhaustive.py +14 -11
- tme/matching_scores.py +21 -12
- tme/matching_utils.py +13 -6
- tme/orientations.py +13 -3
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
scripts/refine_matches.py
CHANGED
@@ -4,13 +4,15 @@ from typing import Tuple
|
|
4
4
|
|
5
5
|
from tme.backends import backend as be
|
6
6
|
from tme.filters._utils import compute_fourier_shape
|
7
|
-
from tme.filters import
|
7
|
+
from tme.filters import BandPassReconstructed, LinearWhiteningFilter
|
8
|
+
from tme.filters.bandpass import gaussian_bandpass, discrete_bandpass
|
9
|
+
from tme.filters._utils import fftfreqn
|
8
10
|
|
9
11
|
|
10
12
|
class TestBandPassFilter:
|
11
13
|
@pytest.fixture
|
12
14
|
def band_pass_filter(self):
|
13
|
-
return
|
15
|
+
return BandPassReconstructed()
|
14
16
|
|
15
17
|
@pytest.mark.parametrize(
|
16
18
|
"shape, lowpass, highpass, sampling_rate",
|
@@ -24,9 +26,13 @@ class TestBandPassFilter:
|
|
24
26
|
def test_discrete_bandpass(
|
25
27
|
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
26
28
|
):
|
27
|
-
|
28
|
-
shape,
|
29
|
+
grid = fftfreqn(
|
30
|
+
shape=shape,
|
31
|
+
sampling_rate=0.5,
|
32
|
+
shape_is_real_fourier=False,
|
33
|
+
compute_euclidean_norm=True,
|
29
34
|
)
|
35
|
+
result = discrete_bandpass(grid, lowpass, highpass, sampling_rate)
|
30
36
|
assert isinstance(result, type(be.ones((1,))))
|
31
37
|
assert result.shape == shape
|
32
38
|
assert np.all((result >= 0) & (result <= 1))
|
@@ -43,9 +49,13 @@ class TestBandPassFilter:
|
|
43
49
|
def test_gaussian_bandpass(
|
44
50
|
self, shape: Tuple[int], lowpass: float, highpass: float, sampling_rate: float
|
45
51
|
):
|
46
|
-
|
47
|
-
shape,
|
52
|
+
grid = fftfreqn(
|
53
|
+
shape=shape,
|
54
|
+
sampling_rate=0.5,
|
55
|
+
shape_is_real_fourier=False,
|
56
|
+
compute_euclidean_norm=True,
|
48
57
|
)
|
58
|
+
result = gaussian_bandpass(grid, lowpass, highpass, sampling_rate)
|
49
59
|
assert isinstance(result, type(be.ones((1,))))
|
50
60
|
assert result.shape == shape
|
51
61
|
assert np.all((result >= 0) & (result <= 1))
|
@@ -55,7 +65,7 @@ class TestBandPassFilter:
|
|
55
65
|
@pytest.mark.parametrize("shape_is_real_fourier", [True, False])
|
56
66
|
def test_call_method(
|
57
67
|
self,
|
58
|
-
band_pass_filter:
|
68
|
+
band_pass_filter: BandPassReconstructed,
|
59
69
|
use_gaussian: bool,
|
60
70
|
return_real_fourier: bool,
|
61
71
|
shape_is_real_fourier: bool,
|
@@ -72,17 +82,16 @@ class TestBandPassFilter:
|
|
72
82
|
assert isinstance(result["data"], type(be.ones((1,))))
|
73
83
|
assert result["is_multiplicative_filter"] is True
|
74
84
|
|
75
|
-
def test_default_values(self, band_pass_filter:
|
85
|
+
def test_default_values(self, band_pass_filter: BandPassReconstructed):
|
76
86
|
assert band_pass_filter.lowpass is None
|
77
87
|
assert band_pass_filter.highpass is None
|
78
88
|
assert band_pass_filter.sampling_rate == 1
|
79
89
|
assert band_pass_filter.use_gaussian is True
|
80
90
|
assert band_pass_filter.return_real_fourier is False
|
81
|
-
assert band_pass_filter.shape_is_real_fourier is False
|
82
91
|
|
83
92
|
@pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
|
84
93
|
def test_return_real_fourier(self, shape: Tuple[int]):
|
85
|
-
bpf =
|
94
|
+
bpf = BandPassReconstructed(return_real_fourier=True)
|
86
95
|
result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
|
87
96
|
expected_shape = tuple(compute_fourier_shape(shape, False))
|
88
97
|
assert result["data"].shape == expected_shape
|
tests/test_analyzer.py
CHANGED
@@ -97,125 +97,125 @@ class TestPeakCallers:
|
|
97
97
|
assert [len(res) == 2 for res in result]
|
98
98
|
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
100
|
+
class TestRecursiveMasking:
|
101
|
+
def setup_method(self):
|
102
|
+
self.num_peaks = 100
|
103
|
+
self.min_distance = 5
|
104
|
+
self.data = np.random.rand(100, 100, 100)
|
105
|
+
self.rotation_matrix = np.eye(3)
|
106
|
+
self.mask = np.random.rand(20, 20, 20)
|
107
|
+
self.rotation_space = np.zeros_like(self.data)
|
108
|
+
self.rotation_mapping = {0: (0, 0, 0)}
|
109
|
+
|
110
|
+
@pytest.mark.parametrize("num_peaks", (1, 100))
|
111
|
+
@pytest.mark.parametrize("compute_rotation", (True, False))
|
112
|
+
@pytest.mark.parametrize("minimum_score", (None, 0.5))
|
113
|
+
def test__call__(self, num_peaks, compute_rotation, minimum_score):
|
114
|
+
peak_caller = PeakCallerRecursiveMasking(
|
115
|
+
shape=self.data.shape,
|
116
|
+
num_peaks=num_peaks,
|
117
|
+
min_distance=self.min_distance,
|
118
|
+
min_score=minimum_score,
|
119
|
+
)
|
120
|
+
rotation_space, rotation_mapping = None, None
|
121
|
+
if compute_rotation:
|
122
|
+
rotation_space = self.rotation_space
|
123
|
+
rotation_mapping = self.rotation_mapping
|
124
|
+
|
125
|
+
state = peak_caller.init_state()
|
126
|
+
state = peak_caller(
|
127
|
+
state,
|
128
|
+
self.data.copy(),
|
129
|
+
rotation_matrix=self.rotation_matrix,
|
130
|
+
mask=self.mask,
|
131
|
+
rotation_space=rotation_space,
|
132
|
+
rotation_mapping=rotation_mapping,
|
133
|
+
)
|
134
|
+
|
135
|
+
if minimum_score is None:
|
136
|
+
assert len(state[0] <= num_peaks)
|
137
|
+
else:
|
138
|
+
peaks = state[0].astype(int)
|
139
|
+
assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
|
140
|
+
|
141
|
+
|
142
|
+
class TestMaxScoreOverRotations:
|
143
|
+
def setup_method(self):
|
144
|
+
self.num_peaks = 100
|
145
|
+
self.min_distance = 5
|
146
|
+
self.data = np.random.rand(100, 100, 100)
|
147
|
+
self.rotation_matrix = np.eye(3)
|
148
|
+
|
149
|
+
def test_initialization(self):
|
150
|
+
_ = MaxScoreOverRotations(
|
151
|
+
shape=self.data.shape,
|
152
|
+
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
153
|
+
)
|
154
|
+
|
155
|
+
@pytest.mark.parametrize("use_memmap", [False, True])
|
156
|
+
def test__iter__(self, use_memmap: bool):
|
157
|
+
score_analyzer = MaxScoreOverRotations(
|
158
|
+
shape=self.data.shape,
|
159
|
+
use_memmap=use_memmap,
|
160
|
+
)
|
161
|
+
state = score_analyzer.init_state()
|
162
|
+
state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
|
163
|
+
res = score_analyzer.result(state)
|
164
|
+
assert np.allclose(res[0].shape, self.data.shape)
|
165
|
+
assert res[0].dtype == be._float_dtype
|
166
|
+
assert res[1].size == self.data.ndim
|
167
|
+
assert np.allclose(res[2].shape, self.data.shape)
|
168
|
+
assert len(res) == 4
|
169
|
+
|
170
|
+
@pytest.mark.parametrize("use_memmap", [False, True])
|
171
|
+
@pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
|
172
|
+
def test__call__(self, use_memmap: bool, score_threshold: float):
|
173
|
+
score_analyzer = MaxScoreOverRotations(
|
174
|
+
shape=self.data.shape,
|
175
|
+
score_threshold=score_threshold,
|
176
|
+
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
177
|
+
use_memmap=use_memmap,
|
178
|
+
)
|
179
|
+
state = score_analyzer.init_state()
|
180
|
+
state = score_analyzer(state, self.data, rotation_matrix=self.rotation_matrix)
|
181
|
+
|
182
|
+
data2 = self.data * 2
|
183
|
+
score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
|
184
|
+
scores, translation_offset, rotations, mapping = score_analyzer.result(state)
|
185
|
+
|
186
|
+
assert np.all(scores >= score_threshold)
|
187
|
+
max_scores = np.maximum(self.data, data2)
|
188
|
+
max_scores = np.maximum(max_scores, score_threshold)
|
189
|
+
assert np.allclose(scores, max_scores)
|
190
|
+
|
191
|
+
@pytest.mark.parametrize("use_memmap", [False, True])
|
192
|
+
@pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
|
193
|
+
def test_merge(self, use_memmap: bool, score_threshold: float):
|
194
|
+
score_analyzer = MaxScoreOverRotations(
|
195
|
+
shape=self.data.shape,
|
196
|
+
score_threshold=score_threshold,
|
197
|
+
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
198
|
+
use_memmap=use_memmap,
|
199
|
+
)
|
200
|
+
state1 = score_analyzer.init_state()
|
201
|
+
state1, score_analyzer(state1, self.data, rotation_matrix=self.rotation_matrix)
|
202
|
+
|
203
|
+
data2 = self.data * 2
|
204
|
+
score_analyzer2 = MaxScoreOverRotations(
|
205
|
+
shape=self.data.shape,
|
206
|
+
score_threshold=score_threshold,
|
207
|
+
translation_offset=np.zeros(self.data.ndim, dtype=int),
|
208
|
+
use_memmap=use_memmap,
|
209
|
+
)
|
210
|
+
state2 = score_analyzer2.init_state()
|
211
|
+
state2 = score_analyzer2(state2, data2, rotation_matrix=self.rotation_matrix)
|
212
|
+
states = [score_analyzer.result(state1), score_analyzer2.result(state2)]
|
213
|
+
|
214
|
+
ret = MaxScoreOverRotations.merge(
|
215
|
+
results=states, use_memmap=use_memmap, score_threshold=score_threshold
|
216
|
+
)
|
217
|
+
scores, translation, rotations, mapping = ret
|
218
|
+
assert np.all(scores >= score_threshold)
|
219
|
+
max_scores = np.maximum(self.data, data2)
|
220
|
+
max_scores = np.maximum(max_scores, score_threshold)
|
221
|
+
assert np.allclose(scores, max_scores)
|
tests/test_backends.py
CHANGED
tests/test_matching_cli.py
CHANGED
@@ -107,15 +107,15 @@ class TestMatchTemplate:
|
|
107
107
|
"-n": 1,
|
108
108
|
"-a": 60,
|
109
109
|
"-o": output_path,
|
110
|
-
"--
|
110
|
+
"--pad-edges": False,
|
111
111
|
"--backend": backend,
|
112
112
|
}
|
113
113
|
|
114
114
|
if use_template_mask:
|
115
|
-
argdict["--
|
115
|
+
argdict["--template-mask"] = template_mask_path
|
116
116
|
|
117
117
|
if use_target_mask:
|
118
|
-
argdict["--
|
118
|
+
argdict["--target-mask"] = target_mask_path
|
119
119
|
|
120
120
|
if test_rejection_sampling:
|
121
121
|
argdict["--orientations"] = self.orientations_path
|
@@ -123,8 +123,8 @@ class TestMatchTemplate:
|
|
123
123
|
if test_filter:
|
124
124
|
argdict["--lowpass"] = 30
|
125
125
|
argdict["--defocus"] = 3000
|
126
|
-
argdict["--
|
127
|
-
argdict["--
|
126
|
+
argdict["--tilt-angles"] = "40,40"
|
127
|
+
argdict["--wedge-axes"] = "2,0"
|
128
128
|
argdict["--whiten"] = True
|
129
129
|
|
130
130
|
if call_peaks:
|
@@ -207,28 +207,28 @@ class TestPostprocessing(TestMatchTemplate):
|
|
207
207
|
makedirs(self.tempdir, exist_ok=True)
|
208
208
|
|
209
209
|
argdict = {
|
210
|
-
"--
|
211
|
-
"--
|
212
|
-
"--
|
213
|
-
"--
|
214
|
-
"--
|
210
|
+
"--input-file": self.score_pickle,
|
211
|
+
"--output-format": "orientations",
|
212
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
213
|
+
"--peak-oversampling": peak_oversampling,
|
214
|
+
"--num-peaks": 3,
|
215
215
|
}
|
216
216
|
|
217
217
|
if score_cutoff is not None:
|
218
218
|
if len(score_cutoff) == 1:
|
219
|
-
argdict["--
|
219
|
+
argdict["--n-false-positives"] = 1
|
220
220
|
else:
|
221
221
|
min_score, max_score = score_cutoff
|
222
|
-
argdict["--
|
223
|
-
argdict["--
|
222
|
+
argdict["--min-score"] = min_score
|
223
|
+
argdict["--max-score"] = max_score
|
224
224
|
|
225
225
|
match distance_cutoff_strategy:
|
226
226
|
case 1:
|
227
|
-
argdict["--
|
227
|
+
argdict["--mask-edges"] = True
|
228
228
|
case 2:
|
229
|
-
argdict["--
|
229
|
+
argdict["--min-distance"] = 5
|
230
230
|
case 3:
|
231
|
-
argdict["--
|
231
|
+
argdict["--min-boundary-distance"] = 5
|
232
232
|
|
233
233
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
234
234
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
@@ -248,11 +248,11 @@ class TestPostprocessing(TestMatchTemplate):
|
|
248
248
|
input_file = self.peak_pickle
|
249
249
|
|
250
250
|
argdict = {
|
251
|
-
"--
|
252
|
-
"--
|
253
|
-
"--
|
254
|
-
"--
|
255
|
-
"--
|
251
|
+
"--input-file": input_file,
|
252
|
+
"--output-format": output_format,
|
253
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
254
|
+
"--num-peaks": 3,
|
255
|
+
"--peak-caller": "PeakCallerMaximumFilter",
|
256
256
|
}
|
257
257
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
258
258
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
@@ -278,11 +278,11 @@ class TestPostprocessing(TestMatchTemplate):
|
|
278
278
|
makedirs(self.tempdir, exist_ok=True)
|
279
279
|
|
280
280
|
argdict = {
|
281
|
-
"--
|
282
|
-
"--
|
283
|
-
"--
|
284
|
-
"--
|
285
|
-
"--
|
281
|
+
"--input-file": self.score_pickle,
|
282
|
+
"--output-format": "orientations",
|
283
|
+
"--output-prefix": f"{self.tempdir}/temp",
|
284
|
+
"--num-peaks": 1,
|
285
|
+
"--local-optimization": True,
|
286
286
|
}
|
287
287
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
288
288
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
@@ -302,7 +302,7 @@ class TestEstimateMemoryUsage(TestMatchTemplate):
|
|
302
302
|
"-m": self.target_path,
|
303
303
|
"-i": self.template_path,
|
304
304
|
"--ncores": ncores,
|
305
|
-
"--
|
305
|
+
"--pad-edges": pad_edges,
|
306
306
|
"--score": "FLCSphericalMask",
|
307
307
|
}
|
308
308
|
|
@@ -325,14 +325,14 @@ class TestPreprocess(TestMatchTemplate):
|
|
325
325
|
"-m": self.target_path,
|
326
326
|
"--backend": backend,
|
327
327
|
"--lowpass": 40,
|
328
|
-
"--
|
328
|
+
"--sampling-rate": 5,
|
329
329
|
"-o": f"{self.tempdir}/out.mrc",
|
330
330
|
}
|
331
331
|
if align_axis:
|
332
|
-
argdict["--
|
332
|
+
argdict["--align-axis"] = 2
|
333
333
|
|
334
334
|
if invert_contrast:
|
335
|
-
argdict["--
|
335
|
+
argdict["--invert-contrast"] = True
|
336
336
|
|
337
337
|
cmd = argdict_to_command(argdict, executable="preprocess.py")
|
338
338
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
tests/test_matching_data.py
CHANGED
@@ -7,7 +7,7 @@ from tme.backends import backend as be
|
|
7
7
|
from tme.matching_data import MatchingData
|
8
8
|
|
9
9
|
|
10
|
-
class
|
10
|
+
class TestMatchingData:
|
11
11
|
def setup_method(self):
|
12
12
|
target = np.zeros((50, 50, 50))
|
13
13
|
target[20:30, 30:40, 12:17] = 1
|
@@ -87,9 +87,9 @@ class TestDensity:
|
|
87
87
|
matching_data.target_mask = self.target
|
88
88
|
matching_data.template_mask = self.template
|
89
89
|
|
90
|
-
ret = matching_data.subset_by_slice()
|
90
|
+
ret, offset = matching_data.subset_by_slice()
|
91
91
|
|
92
|
-
assert
|
92
|
+
assert isinstance(ret, type(matching_data))
|
93
93
|
assert np.allclose(ret.target, matching_data.target)
|
94
94
|
assert np.allclose(ret.template, matching_data.template)
|
95
95
|
assert np.allclose(ret.target_mask, matching_data.target_mask)
|
@@ -107,10 +107,10 @@ class TestDensity:
|
|
107
107
|
template_slice = MatchingData._shape_to_slice(
|
108
108
|
shape=np.divide(self.template.shape, 2).astype(int)
|
109
109
|
)
|
110
|
-
ret = matching_data.subset_by_slice(
|
110
|
+
ret, offset = matching_data.subset_by_slice(
|
111
111
|
target_slice=target_slice, template_slice=template_slice
|
112
112
|
)
|
113
|
-
assert
|
113
|
+
assert isinstance(ret, type(matching_data))
|
114
114
|
|
115
115
|
assert np.allclose(
|
116
116
|
ret.target.shape, np.divide(self.target.shape, 2).astype(int)
|
tests/test_matching_utils.py
CHANGED
@@ -139,7 +139,7 @@ class TestMatchingUtils:
|
|
139
139
|
expected_size = np.subtract(
|
140
140
|
self.density.shape, self.structure_density.shape
|
141
141
|
)
|
142
|
-
expected_size +=
|
142
|
+
expected_size += 1
|
143
143
|
assert np.allclose(ret.shape, expected_size)
|
144
144
|
|
145
145
|
def test_apply_convolution_mode_error(self):
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.b0"
|
1
|
+
__version__ = "0.3.b0.post1"
|
tme/analyzer/__init__.py
CHANGED
tme/analyzer/_utils.py
CHANGED
@@ -47,10 +47,7 @@ def _convmode_to_shape(
|
|
47
47
|
if convolution_mode == "same":
|
48
48
|
output_shape = targetshape
|
49
49
|
elif convolution_mode == "valid":
|
50
|
-
output_shape = be.
|
51
|
-
be.subtract(targetshape, templateshape),
|
52
|
-
be.mod(templateshape, 2),
|
53
|
-
)
|
50
|
+
output_shape = be.subtract(targetshape, templateshape) + 1
|
54
51
|
return be.to_backend_array(output_shape)
|
55
52
|
|
56
53
|
|
tme/analyzer/aggregation.py
CHANGED
@@ -57,22 +57,23 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
57
57
|
The following achieves the minimal definition of a :py:class:`MaxScoreOverRotations`
|
58
58
|
instance
|
59
59
|
|
60
|
+
>>> import numpy as np
|
60
61
|
>>> from tme.analyzer import MaxScoreOverRotations
|
61
|
-
>>> analyzer = MaxScoreOverRotations(shape
|
62
|
+
>>> analyzer = MaxScoreOverRotations(shape=(50, 50))
|
62
63
|
|
63
64
|
The following simulates a template matching run by creating random data for a range
|
64
65
|
of rotations and sending it to ``analyzer`` via its __call__ method
|
65
66
|
|
66
|
-
|
67
|
+
>>> state = analyzer.init_state()
|
67
68
|
>>> for rotation_number in range(10):
|
68
69
|
>>> scores = np.random.rand(50,50)
|
69
70
|
>>> rotation = np.random.rand(scores.ndim, scores.ndim)
|
70
|
-
>>> state
|
71
|
+
>>> state = analyzer(state, scores=scores, rotation_matrix=rotation)
|
71
72
|
|
72
73
|
The aggregated scores can be extracted by invoking the result method of
|
73
74
|
``analyzer``
|
74
75
|
|
75
|
-
>>> results = analyzer.result()
|
76
|
+
>>> results = analyzer.result(state)
|
76
77
|
|
77
78
|
The ``results`` tuple contains (1) the maximum scores for each translation,
|
78
79
|
(2) an offset which is relevant when merging results from split template matching
|
@@ -100,6 +101,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
100
101
|
shm_handler: object = None,
|
101
102
|
use_memmap: bool = False,
|
102
103
|
inversion_mapping: bool = False,
|
104
|
+
jax_mode: bool = False,
|
103
105
|
**kwargs,
|
104
106
|
):
|
105
107
|
self._use_memmap = use_memmap
|
@@ -107,6 +109,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
107
109
|
self._shape = tuple(int(x) for x in shape)
|
108
110
|
self._inversion_mapping = inversion_mapping
|
109
111
|
|
112
|
+
self._jax_mode = jax_mode
|
113
|
+
if self._jax_mode:
|
114
|
+
self._inversion_mapping = False
|
115
|
+
|
110
116
|
if offset is None:
|
111
117
|
offset = be.zeros(len(self._shape), be._int_dtype)
|
112
118
|
self._offset = be.astype(be.to_backend_array(offset), int)
|
@@ -138,6 +144,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
138
144
|
state: Tuple,
|
139
145
|
scores: BackendArray,
|
140
146
|
rotation_matrix: BackendArray,
|
147
|
+
**kwargs,
|
141
148
|
) -> Tuple:
|
142
149
|
"""
|
143
150
|
Update the parameter store.
|
@@ -167,6 +174,8 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
167
174
|
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
168
175
|
if self._inversion_mapping:
|
169
176
|
rotation_mapping[rotation_index] = rotation_matrix
|
177
|
+
elif self._jax_mode:
|
178
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
170
179
|
else:
|
171
180
|
rotation = be.tobytes(rotation_matrix)
|
172
181
|
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
@@ -738,5 +747,5 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
738
747
|
cls._invert_rmap(master_rotation_mapping),
|
739
748
|
)
|
740
749
|
|
741
|
-
def
|
742
|
-
return
|
750
|
+
def result(self, state: Tuple, **kwargs) -> Tuple:
|
751
|
+
return state
|