pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
tests/data/.DS_Store
ADDED
Binary file
|
Binary file
|
Binary file
|
tests/data/Raw/.DS_Store
ADDED
Binary file
|
Binary file
|
@@ -51,6 +51,24 @@ class TestPreprocessUtils:
|
|
51
51
|
assert fgrid.shape == tuple(tilt_shape)
|
52
52
|
assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
|
53
53
|
|
54
|
+
@pytest.mark.parametrize("shape", ((15, 15, 15), (31, 31, 31), (64, 64, 64)))
|
55
|
+
@pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
|
56
|
+
@pytest.mark.parametrize("angle", (-5, 0, 5))
|
57
|
+
def test_freqgrid_comparison(self, shape, sampling_rate, angle):
|
58
|
+
grid = frequency_grid_at_angle(
|
59
|
+
shape=shape,
|
60
|
+
angle=angle,
|
61
|
+
sampling_rate=sampling_rate,
|
62
|
+
opening_axis=2,
|
63
|
+
tilt_axis=0,
|
64
|
+
)
|
65
|
+
grid2 = fftfreqn(
|
66
|
+
shape=shape[1:], sampling_rate=sampling_rate, compute_euclidean_norm=True
|
67
|
+
)
|
68
|
+
|
69
|
+
# These should be equal for cubical input shapes
|
70
|
+
assert np.allclose(grid, grid2)
|
71
|
+
|
54
72
|
@pytest.mark.parametrize("n", [10, 100, 1000])
|
55
73
|
@pytest.mark.parametrize("sampling_rate", range(1, 4))
|
56
74
|
def test_fftfreqn(self, n, sampling_rate):
|
tests/test_analyzer.py
CHANGED
@@ -165,7 +165,6 @@ class TestMaxScoreOverRotations:
|
|
165
165
|
assert res[0].dtype == be._float_dtype
|
166
166
|
assert res[1].size == self.data.ndim
|
167
167
|
assert np.allclose(res[2].shape, self.data.shape)
|
168
|
-
assert len(res) == 4
|
169
168
|
|
170
169
|
@pytest.mark.parametrize("use_memmap", [False, True])
|
171
170
|
@pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
|
@@ -181,7 +180,7 @@ class TestMaxScoreOverRotations:
|
|
181
180
|
|
182
181
|
data2 = self.data * 2
|
183
182
|
score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
|
184
|
-
scores,
|
183
|
+
scores, offset, rotations, mapping, *_ = score_analyzer.result(state)
|
185
184
|
|
186
185
|
assert np.all(scores >= score_threshold)
|
187
186
|
max_scores = np.maximum(self.data, data2)
|
@@ -214,7 +213,7 @@ class TestMaxScoreOverRotations:
|
|
214
213
|
ret = MaxScoreOverRotations.merge(
|
215
214
|
results=states, use_memmap=use_memmap, score_threshold=score_threshold
|
216
215
|
)
|
217
|
-
scores, translation, rotations, mapping = ret
|
216
|
+
scores, translation, rotations, mapping, *_ = ret
|
218
217
|
assert np.all(scores >= score_threshold)
|
219
218
|
max_scores = np.maximum(self.data, data2)
|
220
219
|
max_scores = np.maximum(max_scores, score_threshold)
|
tests/test_backends.py
CHANGED
@@ -292,16 +292,10 @@ class TestBackends:
|
|
292
292
|
|
293
293
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
294
294
|
@pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
|
295
|
-
def
|
295
|
+
def test_fft(self, backend, fast_shape):
|
296
296
|
_, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
297
297
|
fast_shape, (1 for _ in range(len(fast_shape)))
|
298
298
|
)
|
299
|
-
rfftn, irfftn = backend.build_fft(
|
300
|
-
fwd_shape=fast_shape,
|
301
|
-
inv_shape=fast_ft_shape,
|
302
|
-
real_dtype=backend._float_dtype,
|
303
|
-
cmpl_dtype=backend._complex_dtype,
|
304
|
-
)
|
305
299
|
arr = np.random.rand(*fast_shape)
|
306
300
|
out = np.zeros(fast_ft_shape)
|
307
301
|
|
@@ -310,11 +304,11 @@ class TestBackends:
|
|
310
304
|
backend.to_backend_array(out), backend._complex_dtype
|
311
305
|
)
|
312
306
|
|
313
|
-
rfftn(
|
307
|
+
backend.rfftn(
|
314
308
|
backend.astype(backend.to_backend_array(arr), backend._float_dtype),
|
315
309
|
complex_arr,
|
316
310
|
)
|
317
|
-
irfftn(complex_arr, real_arr)
|
311
|
+
backend.irfftn(complex_arr, real_arr)
|
318
312
|
assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
|
319
313
|
|
320
314
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
tests/test_density.py
CHANGED
@@ -98,7 +98,6 @@ class TestDensity:
|
|
98
98
|
assert np.allclose(density.data, self.density.data)
|
99
99
|
assert np.allclose(density.sampling_rate, self.density.sampling_rate)
|
100
100
|
assert np.allclose(density.origin, self.density.origin)
|
101
|
-
assert density.metadata == self.density.metadata
|
102
101
|
|
103
102
|
def test_from_file_baseline(self):
|
104
103
|
self.test_to_file(gzip=False)
|
tests/test_extensions.py
CHANGED
@@ -53,7 +53,6 @@ class TestExtensions:
|
|
53
53
|
@pytest.mark.parametrize("min_distance", [0, 5, 10])
|
54
54
|
def test_find_candidate_indices(self, dimension, dtype, min_distance):
|
55
55
|
coordinates = COORDINATES[dimension].astype(dtype)
|
56
|
-
print(coordinates.shape)
|
57
56
|
|
58
57
|
min_distance = np.array([min_distance]).astype(dtype)[0]
|
59
58
|
|
tests/test_matching_utils.py
CHANGED
@@ -10,9 +10,6 @@ from tme.backends import backend as be
|
|
10
10
|
from tme.memory import MATCHING_MEMORY_REGISTRY
|
11
11
|
from tme.matching_utils import (
|
12
12
|
compute_parallelization_schedule,
|
13
|
-
elliptical_mask,
|
14
|
-
box_mask,
|
15
|
-
tube_mask,
|
16
13
|
create_mask,
|
17
14
|
scramble_phases,
|
18
15
|
apply_convolution_mode,
|
@@ -50,73 +47,26 @@ class TestMatchingUtils:
|
|
50
47
|
max_splits=256,
|
51
48
|
)
|
52
49
|
|
53
|
-
|
50
|
+
@pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
|
51
|
+
def test_create_mask(self, mask_type: str):
|
54
52
|
create_mask(
|
55
|
-
mask_type=
|
53
|
+
mask_type=mask_type,
|
56
54
|
shape=self.density.shape,
|
57
55
|
radius=5,
|
58
56
|
center=np.divide(self.density.shape, 2),
|
57
|
+
height=np.max(self.density.shape) // 2,
|
58
|
+
size=np.divide(self.density.shape, 2).astype(int),
|
59
|
+
thickness=2,
|
60
|
+
separation=2,
|
61
|
+
symmetry_axis=1,
|
62
|
+
inner_radius=5,
|
63
|
+
outer_radius=10,
|
59
64
|
)
|
60
65
|
|
61
66
|
def test_create_mask_error(self):
|
62
67
|
with pytest.raises(ValueError):
|
63
68
|
create_mask(mask_type=None)
|
64
69
|
|
65
|
-
def test_elliptical_mask(self):
|
66
|
-
elliptical_mask(
|
67
|
-
shape=self.density.shape,
|
68
|
-
radius=5,
|
69
|
-
center=np.divide(self.density.shape, 2),
|
70
|
-
)
|
71
|
-
|
72
|
-
def test_box_mask(self):
|
73
|
-
box_mask(
|
74
|
-
shape=self.density.shape,
|
75
|
-
height=[5, 10, 20],
|
76
|
-
center=np.divide(self.density.shape, 2),
|
77
|
-
)
|
78
|
-
|
79
|
-
def test_tube_mask(self):
|
80
|
-
tube_mask(
|
81
|
-
shape=self.density.shape,
|
82
|
-
outer_radius=10,
|
83
|
-
inner_radius=5,
|
84
|
-
height=5,
|
85
|
-
base_center=np.divide(self.density.shape, 2),
|
86
|
-
symmetry_axis=1,
|
87
|
-
)
|
88
|
-
|
89
|
-
def test_tube_mask_error(self):
|
90
|
-
with pytest.raises(ValueError):
|
91
|
-
tube_mask(
|
92
|
-
shape=self.density.shape,
|
93
|
-
outer_radius=5,
|
94
|
-
inner_radius=10,
|
95
|
-
height=5,
|
96
|
-
base_center=np.divide(self.density.shape, 2),
|
97
|
-
symmetry_axis=1,
|
98
|
-
)
|
99
|
-
|
100
|
-
with pytest.raises(ValueError):
|
101
|
-
tube_mask(
|
102
|
-
shape=self.density.shape,
|
103
|
-
outer_radius=5,
|
104
|
-
inner_radius=10,
|
105
|
-
height=10 * np.max(self.density.shape),
|
106
|
-
base_center=np.divide(self.density.shape, 2),
|
107
|
-
symmetry_axis=1,
|
108
|
-
)
|
109
|
-
|
110
|
-
with pytest.raises(ValueError):
|
111
|
-
tube_mask(
|
112
|
-
shape=self.density.shape,
|
113
|
-
outer_radius=5,
|
114
|
-
inner_radius=10,
|
115
|
-
height=10 * np.max(self.density.shape),
|
116
|
-
base_center=np.divide(self.density.shape, 2),
|
117
|
-
symmetry_axis=len(self.density.shape) + 1,
|
118
|
-
)
|
119
|
-
|
120
70
|
def test_scramble_phases(self):
|
121
71
|
scramble_phases(arr=self.density.data, noise_proportion=0.5)
|
122
72
|
|
tests/test_orientations.py
CHANGED
@@ -95,18 +95,6 @@ class TestDensity:
|
|
95
95
|
self.orientations.rotations, orientations_new.rotations, atol=1e-3
|
96
96
|
)
|
97
97
|
|
98
|
-
@pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
|
99
|
-
@pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
|
100
|
-
def test_file_format_io(self, input_format: str, output_format: str):
|
101
|
-
_, output_file = mkstemp(suffix=f".{input_format}")
|
102
|
-
_, output_file2 = mkstemp(suffix=f".{output_format}")
|
103
|
-
|
104
|
-
self.orientations.to_file(output_file)
|
105
|
-
orientations_new = Orientations.from_file(output_file)
|
106
|
-
orientations_new.to_file(output_file2)
|
107
|
-
|
108
|
-
assert True
|
109
|
-
|
110
98
|
@pytest.mark.parametrize("drop_oob", (True, False))
|
111
99
|
@pytest.mark.parametrize("shape", (10, 40, 80))
|
112
100
|
@pytest.mark.parametrize("odd", (True, False))
|
tests/test_rotations.py
CHANGED
@@ -8,6 +8,7 @@ from tme import Density
|
|
8
8
|
from scipy.spatial.transform import Rotation
|
9
9
|
from scipy.signal import correlate
|
10
10
|
|
11
|
+
from tme.mask import elliptical_mask
|
11
12
|
from tme.rotations import (
|
12
13
|
euler_from_rotationmatrix,
|
13
14
|
euler_to_rotationmatrix,
|
@@ -16,7 +17,6 @@ from tme.rotations import (
|
|
16
17
|
get_rotation_matrices,
|
17
18
|
)
|
18
19
|
from tme.matching_utils import (
|
19
|
-
elliptical_mask,
|
20
20
|
split_shape,
|
21
21
|
compute_full_convolution_index,
|
22
22
|
)
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.
|
1
|
+
__version__ = "0.3.1"
|
tme/analyzer/_utils.py
CHANGED
@@ -93,22 +93,22 @@ def cart_to_score(
|
|
93
93
|
templateshape = be.to_backend_array(templateshape)
|
94
94
|
convolution_shape = be.to_backend_array(convolution_shape)
|
95
95
|
|
96
|
-
# Compute removed padding
|
97
96
|
output_shape = _convmode_to_shape(
|
98
97
|
convolution_mode=convolution_mode,
|
99
98
|
targetshape=targetshape,
|
100
99
|
templateshape=templateshape,
|
101
100
|
convolution_shape=convolution_shape,
|
102
101
|
)
|
103
|
-
valid_positions = be.multiply(positions >= 0, positions < output_shape)
|
104
|
-
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
105
102
|
|
103
|
+
# Offset from padding the target
|
106
104
|
starts = be.astype(
|
107
105
|
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
108
106
|
be._int_dtype,
|
109
107
|
)
|
110
|
-
|
111
108
|
positions = be.add(positions, starts)
|
109
|
+
|
110
|
+
valid_positions = be.multiply(positions >= 0, positions < fast_shape)
|
111
|
+
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
112
112
|
if fourier_shift is not None:
|
113
113
|
fourier_shift = be.to_backend_array(fourier_shift)
|
114
114
|
positions = be.subtract(positions, fourier_shift)
|
tme/analyzer/aggregation.py
CHANGED
@@ -132,12 +132,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
132
132
|
- scores : BackendArray of shape `self._shape` filled with `score_threshold`.
|
133
133
|
- rotations : BackendArray of shape `self._shape` filled with -1.
|
134
134
|
- rotation_mapping : dict, empty mapping from rotation bytes to indices.
|
135
|
+
- ssum : BackendArray, accumulator for sum of squared scores.
|
135
136
|
"""
|
136
137
|
scores = be.full(
|
137
138
|
shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
|
138
139
|
)
|
139
140
|
rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
|
140
|
-
|
141
|
+
ssum = be.full((1), dtype=be._float_dtype, fill_value=0)
|
142
|
+
return scores, rotations, {}, ssum
|
141
143
|
|
142
144
|
def __call__(
|
143
145
|
self,
|
@@ -156,6 +158,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
156
158
|
- scores : BackendArray, current maximum scores.
|
157
159
|
- rotations : BackendArray, current rotation indices.
|
158
160
|
- rotation_mapping : dict, mapping from rotation bytes to indices.
|
161
|
+
- ssum : BackendArray, accumulator for sum of squared scores.
|
159
162
|
scores : BackendArray
|
160
163
|
Array of new scores to update analyzer with.
|
161
164
|
rotation_matrix : BackendArray
|
@@ -168,7 +171,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
168
171
|
# be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
|
169
172
|
# If the analyzer is not shared and each rotation is unique, we can
|
170
173
|
# use index to rotation mapping and invert prior to merging.
|
171
|
-
prev_scores, rotations, rotation_mapping = state
|
174
|
+
prev_scores, rotations, rotation_mapping, ssum = state
|
172
175
|
|
173
176
|
rotation_index = len(rotation_mapping)
|
174
177
|
rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
|
@@ -180,13 +183,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
180
183
|
rotation = be.tobytes(rotation_matrix)
|
181
184
|
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
182
185
|
|
186
|
+
ssum = be.add(ssum, be.ssum(scores), out=ssum)
|
183
187
|
scores, rotations = be.max_score_over_rotations(
|
184
188
|
scores=scores,
|
185
189
|
max_scores=prev_scores,
|
186
190
|
rotations=rotations,
|
187
191
|
rotation_index=rotation_index,
|
188
192
|
)
|
189
|
-
return scores, rotations, rotation_mapping
|
193
|
+
return scores, rotations, rotation_mapping, ssum
|
190
194
|
|
191
195
|
@staticmethod
|
192
196
|
def _invert_rmap(rotation_mapping: dict) -> dict:
|
@@ -224,6 +228,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
224
228
|
- scores : BackendArray, current maximum scores.
|
225
229
|
- rotations : BackendArray, current rotation indices.
|
226
230
|
- rotation_mapping : dict, mapping from rotation indices to matrices.
|
231
|
+
- ssum : BackendArray, accumulator for sum of squared scores.
|
227
232
|
targetshape : Tuple[int], optional
|
228
233
|
Shape of the target for convolution mode correction.
|
229
234
|
templateshape : Tuple[int], optional
|
@@ -240,9 +245,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
240
245
|
Returns
|
241
246
|
-------
|
242
247
|
tuple
|
243
|
-
Final result tuple (scores, offset, rotations, rotation_mapping).
|
248
|
+
Final result tuple (scores, offset, rotations, rotation_mapping, ssum).
|
244
249
|
"""
|
245
|
-
scores, rotations, rotation_mapping = state
|
250
|
+
scores, rotations, rotation_mapping, ssum = state
|
246
251
|
|
247
252
|
# Apply postprocessing if parameters are provided
|
248
253
|
if fourier_shift is not None:
|
@@ -269,11 +274,13 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
269
274
|
if self._inversion_mapping:
|
270
275
|
rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
|
271
276
|
|
277
|
+
n_rotations = max(len(rotation_mapping), 1)
|
272
278
|
return (
|
273
279
|
scores,
|
274
280
|
be.to_numpy_array(self._offset),
|
275
281
|
rotations,
|
276
282
|
self._invert_rmap(rotation_mapping),
|
283
|
+
be.to_numpy_array(ssum) / (scores.size * n_rotations),
|
277
284
|
)
|
278
285
|
|
279
286
|
def _harmonize_states(states: List[Tuple]):
|
@@ -287,18 +294,18 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
287
294
|
if states[i] is None:
|
288
295
|
continue
|
289
296
|
|
290
|
-
scores, offset, rotations, rotation_mapping = states[i]
|
297
|
+
scores, offset, rotations, rotation_mapping, ssum = states[i]
|
291
298
|
if out_shape is None:
|
292
299
|
out_shape = np.zeros(scores.ndim, int)
|
293
300
|
out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
|
294
301
|
|
295
302
|
new_param = {}
|
296
303
|
for key, value in rotation_mapping.items():
|
297
|
-
rotation_bytes =
|
304
|
+
rotation_bytes = np.asarray(value).tobytes()
|
298
305
|
new_param[rotation_bytes] = key
|
299
306
|
if rotation_bytes not in new_rotation_mapping:
|
300
307
|
new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
|
301
|
-
states[i] = (scores, offset, rotations, new_param)
|
308
|
+
states[i] = (scores, offset, rotations, new_param, ssum)
|
302
309
|
out_shape = tuple(int(x) for x in out_shape)
|
303
310
|
return new_rotation_mapping, out_shape, states
|
304
311
|
|
@@ -329,11 +336,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
329
336
|
if len(results) == 1:
|
330
337
|
ret = results[0]
|
331
338
|
if use_memmap:
|
332
|
-
scores, offset, rotations, rotation_mapping = ret
|
339
|
+
scores, offset, rotations, rotation_mapping, ssum = ret
|
333
340
|
scores = array_to_memmap(scores)
|
334
341
|
rotations = array_to_memmap(rotations)
|
335
|
-
ret = (scores, offset, rotations, rotation_mapping)
|
336
|
-
|
342
|
+
ret = (scores, offset, rotations, rotation_mapping, ssum)
|
337
343
|
return ret
|
338
344
|
|
339
345
|
# Determine output array shape and create consistent rotation map
|
@@ -368,6 +374,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
368
374
|
)
|
369
375
|
rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
|
370
376
|
|
377
|
+
total_ssum = 0
|
371
378
|
for i in range(len(results)):
|
372
379
|
if results[i] is None:
|
373
380
|
continue
|
@@ -385,7 +392,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
385
392
|
shape=out_shape,
|
386
393
|
dtype=rotations_dtype,
|
387
394
|
)
|
388
|
-
scores, offset, rotations, rotation_mapping = results[i]
|
395
|
+
scores, offset, rotations, rotation_mapping, ssum = results[i]
|
396
|
+
|
397
|
+
total_ssum = np.add(total_ssum, ssum)
|
389
398
|
stops = np.add(offset, scores.shape).astype(int)
|
390
399
|
indices = tuple(slice(*pos) for pos in zip(offset, stops))
|
391
400
|
|
@@ -428,6 +437,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
428
437
|
np.zeros(scores_out.ndim, dtype=int),
|
429
438
|
rotations_out,
|
430
439
|
cls._invert_rmap(master_rotation_mapping),
|
440
|
+
total_ssum / len(results),
|
431
441
|
)
|
432
442
|
|
433
443
|
|
@@ -545,13 +555,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
545
555
|
)
|
546
556
|
|
547
557
|
def __call__(
|
548
|
-
self,
|
558
|
+
self,
|
559
|
+
state: Tuple,
|
560
|
+
scores: BackendArray,
|
561
|
+
rotation_matrix: BackendArray,
|
562
|
+
**kwargs,
|
549
563
|
) -> Tuple:
|
550
564
|
mask = self._get_constraint(rotation_matrix)
|
551
565
|
mask = self._get_score_mask(mask=mask, scores=scores)
|
552
566
|
|
553
567
|
scores = be.multiply(scores, mask, out=scores)
|
554
|
-
return super().__call__(
|
568
|
+
return super().__call__(
|
569
|
+
state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
570
|
+
)
|
555
571
|
|
556
572
|
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
557
573
|
"""
|
@@ -636,7 +652,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
636
652
|
return scores, rotations, {}
|
637
653
|
|
638
654
|
def __call__(
|
639
|
-
self,
|
655
|
+
self,
|
656
|
+
state,
|
657
|
+
scores: BackendArray,
|
658
|
+
rotation_matrix: BackendArray,
|
659
|
+
**kwargs,
|
640
660
|
) -> Tuple:
|
641
661
|
prev_scores, rotations, rotation_mapping = state
|
642
662
|
|
tme/analyzer/peaks.py
CHANGED
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
|
|
228
228
|
translations = be.full(
|
229
229
|
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
230
|
)
|
231
|
+
|
232
|
+
rdim = len(self.shape)
|
233
|
+
if self.batch_dims:
|
234
|
+
rdim -= len(self.batch_dims)
|
235
|
+
|
231
236
|
rotations = be.full(
|
232
|
-
(self.num_peaks,
|
237
|
+
(self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
|
233
238
|
)
|
234
|
-
for i in range(
|
239
|
+
for i in range(rdim):
|
235
240
|
rotations[:, i, i] = 1.0
|
236
241
|
|
237
242
|
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
750
755
|
Dictionary mapping values in rotations to Euler angles.
|
751
756
|
By default None
|
752
757
|
min_score : float
|
753
|
-
Minimum score value to consider.
|
754
|
-
by :py:attr:`PeakCaller.num_peaks`.
|
758
|
+
Minimum score value to consider.
|
755
759
|
|
756
760
|
Returns
|
757
761
|
-------
|
@@ -774,10 +778,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
774
778
|
mask = be.to_backend_array(mask)
|
775
779
|
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
776
780
|
|
777
|
-
|
778
|
-
if min_score is not None:
|
779
|
-
peak_limit = be.size(scores)
|
780
|
-
else:
|
781
|
+
if min_score is None:
|
781
782
|
min_score = be.min(scores) - 1
|
782
783
|
|
783
784
|
_scores = be.zeros(scores.shape, dtype=scores.dtype)
|
@@ -815,7 +816,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
815
816
|
score_mask = mask_buffer[tmpl_slice] <= 0.1
|
816
817
|
|
817
818
|
_scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
|
818
|
-
if len(peaks) >=
|
819
|
+
if len(peaks) >= self.num_peaks:
|
819
820
|
break
|
820
821
|
|
821
822
|
return be.to_backend_array(peaks), None
|
@@ -851,7 +852,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
851
852
|
|
852
853
|
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
853
854
|
|
854
|
-
#
|
855
|
+
# Old versions of rotation mapping contained Euler angles
|
855
856
|
if rotation.ndim != 2:
|
856
857
|
rotation = be.to_backend_array(
|
857
858
|
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
tme/backends/_jax_utils.py
CHANGED
@@ -10,14 +10,14 @@ from typing import Tuple
|
|
10
10
|
from functools import partial
|
11
11
|
|
12
12
|
import jax.numpy as jnp
|
13
|
-
from jax import pmap, lax, vmap
|
13
|
+
from jax import pmap, lax, vmap, jit
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import normalize_template as _normalize_template
|
18
18
|
|
19
19
|
|
20
|
-
__all__ = ["scan"]
|
20
|
+
__all__ = ["scan", "setup_scan"]
|
21
21
|
|
22
22
|
|
23
23
|
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
@@ -112,11 +112,56 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
|
112
112
|
return arr
|
113
113
|
|
114
114
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
)
|
115
|
+
def _mask_scores(arr, mask):
|
116
|
+
return arr.at[:].multiply(mask)
|
117
|
+
|
118
|
+
|
119
|
+
def _select_config(analyzer_kwargs, device_idx):
|
120
|
+
return analyzer_kwargs[device_idx]
|
121
|
+
|
122
|
+
|
123
|
+
def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
|
124
|
+
"""Create separate scan function with initialized analyzer for each device"""
|
125
|
+
device_scans = [
|
126
|
+
partial(
|
127
|
+
scan,
|
128
|
+
fast_shape=fast_shape,
|
129
|
+
rotate_mask=rotate_mask,
|
130
|
+
analyzer=callback_class(**device_config),
|
131
|
+
)
|
132
|
+
for device_config in analyzer_kwargs
|
133
|
+
]
|
134
|
+
|
135
|
+
@partial(
|
136
|
+
pmap,
|
137
|
+
in_axes=(0,) + (None,) * 6,
|
138
|
+
axis_name="batch",
|
139
|
+
)
|
140
|
+
def scan_combined(
|
141
|
+
target,
|
142
|
+
template,
|
143
|
+
template_mask,
|
144
|
+
rotations,
|
145
|
+
template_filter,
|
146
|
+
target_filter,
|
147
|
+
score_mask,
|
148
|
+
):
|
149
|
+
return lax.switch(
|
150
|
+
lax.axis_index("batch"),
|
151
|
+
device_scans,
|
152
|
+
target,
|
153
|
+
template,
|
154
|
+
template_mask,
|
155
|
+
rotations,
|
156
|
+
template_filter,
|
157
|
+
target_filter,
|
158
|
+
score_mask,
|
159
|
+
)
|
160
|
+
|
161
|
+
return scan_combined
|
162
|
+
|
163
|
+
|
164
|
+
@partial(jit, static_argnums=(7, 8, 9))
|
120
165
|
def scan(
|
121
166
|
target: BackendArray,
|
122
167
|
template: BackendArray,
|
@@ -124,8 +169,10 @@ def scan(
|
|
124
169
|
rotations: BackendArray,
|
125
170
|
template_filter: BackendArray,
|
126
171
|
target_filter: BackendArray,
|
172
|
+
score_mask: BackendArray,
|
127
173
|
fast_shape: Tuple[int],
|
128
174
|
rotate_mask: bool,
|
175
|
+
analyzer: object,
|
129
176
|
) -> Tuple[BackendArray, BackendArray]:
|
130
177
|
eps = jnp.finfo(template.dtype).resolution
|
131
178
|
|
@@ -150,8 +197,12 @@ def scan(
|
|
150
197
|
if template_filter.shape != ():
|
151
198
|
_template_filter_func = _apply_fourier_filter
|
152
199
|
|
200
|
+
_score_mask_func = _identity
|
201
|
+
if score_mask.shape != ():
|
202
|
+
_score_mask_func = _mask_scores
|
203
|
+
|
153
204
|
def _sample_transform(ret, rotation_matrix):
|
154
|
-
|
205
|
+
state, index = ret
|
155
206
|
template_rot, template_mask_rot = be.rigid_transform(
|
156
207
|
arr=template,
|
157
208
|
arr_mask=template_mask,
|
@@ -176,15 +227,10 @@ def scan(
|
|
176
227
|
n_observations=n_observations,
|
177
228
|
eps=eps,
|
178
229
|
)
|
179
|
-
|
180
|
-
scores, max_scores, rotations, index
|
181
|
-
)
|
182
|
-
return (max_scores, rotations, index + 1), None
|
230
|
+
scores = _score_mask_func(scores, score_mask)
|
183
231
|
|
184
|
-
|
185
|
-
|
186
|
-
(score_space, rotation_space, _), _ = lax.scan(
|
187
|
-
_sample_transform, (score_space, rotation_space, 0), rotations
|
188
|
-
)
|
232
|
+
state = analyzer(state, scores, rotation_matrix, rotation_index=index)
|
233
|
+
return (state, index + 1), None
|
189
234
|
|
190
|
-
|
235
|
+
(state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
|
236
|
+
return state
|