pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
- pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
- {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
- pytme-0.3b0.post1.dist-info/RECORD +126 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
- pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +76 -0
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +224 -0
- scripts/match_template.py +341 -378
- pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
- scripts/postprocess.py +318 -189
- scripts/preprocess.py +21 -31
- scripts/preprocessor_gui.py +12 -12
- scripts/pytme_runner.py +769 -0
- scripts/refine_matches.py +625 -0
- tests/preprocessing/test_frequency_filters.py +28 -14
- tests/test_analyzer.py +41 -36
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +109 -54
- tests/test_matching_data.py +5 -5
- tests/test_matching_exhaustive.py +1 -2
- tests/test_matching_optimization.py +4 -9
- tests/test_matching_utils.py +1 -1
- tests/test_orientations.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +26 -21
- tme/analyzer/aggregation.py +395 -222
- tme/analyzer/base.py +127 -0
- tme/analyzer/peaks.py +189 -204
- tme/analyzer/proxy.py +123 -0
- tme/backends/__init__.py +4 -3
- tme/backends/_cupy_utils.py +25 -24
- tme/backends/_jax_utils.py +20 -18
- tme/backends/cupy_backend.py +13 -26
- tme/backends/jax_backend.py +24 -23
- tme/backends/matching_backend.py +4 -3
- tme/backends/mlx_backend.py +4 -3
- tme/backends/npfftw_backend.py +34 -30
- tme/backends/pytorch_backend.py +18 -4
- tme/cli.py +126 -0
- tme/density.py +9 -7
- tme/filters/__init__.py +3 -3
- tme/filters/_utils.py +36 -10
- tme/filters/bandpass.py +229 -188
- tme/filters/compose.py +5 -4
- tme/filters/ctf.py +516 -254
- tme/filters/reconstruction.py +91 -32
- tme/filters/wedge.py +196 -135
- tme/filters/whitening.py +37 -42
- tme/matching_data.py +28 -39
- tme/matching_exhaustive.py +31 -27
- tme/matching_optimization.py +5 -4
- tme/matching_scores.py +25 -15
- tme/matching_utils.py +54 -9
- tme/memory.py +4 -3
- tme/orientations.py +22 -9
- tme/parser.py +114 -33
- tme/preprocessor.py +6 -5
- tme/rotations.py +10 -7
- tme/structure.py +4 -3
- pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
- pytme-0.2.9.post1.dist-info/RECORD +0 -119
- pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
- scripts/estimate_ram_usage.py +0 -97
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
- {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/analyzer/proxy.py
ADDED
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Implements SharedAnalyzerProxy to managed shared memory of Analyzer instances
|
3
|
+
across different tasks.
|
4
|
+
|
5
|
+
This is primarily useful for CPU template matching, where parallelization can
|
6
|
+
be performed over rotations, rather than subsections of a large input volume.
|
7
|
+
|
8
|
+
Copyright (c) 2025 European Molecular Biology Laboratory
|
9
|
+
|
10
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
11
|
+
"""
|
12
|
+
|
13
|
+
from typing import Tuple
|
14
|
+
from multiprocessing import Manager
|
15
|
+
from multiprocessing.shared_memory import SharedMemory
|
16
|
+
|
17
|
+
from ..backends import backend as be
|
18
|
+
|
19
|
+
__all__ = ["StatelessSharedAnalyzerProxy", "SharedAnalyzerProxy"]
|
20
|
+
|
21
|
+
|
22
|
+
class StatelessSharedAnalyzerProxy:
|
23
|
+
"""
|
24
|
+
Proxy that wraps functional analyzers for concurrent access via shared memory.
|
25
|
+
|
26
|
+
Enables multiple processes/threads to safely update the same analyzer
|
27
|
+
while preserving the functional interface of the underlying analyzer.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, analyzer_class: type, analyzer_params: dict):
|
31
|
+
self._shared = False
|
32
|
+
self._process = self._direct_call
|
33
|
+
|
34
|
+
self._analyzer = analyzer_class(**analyzer_params)
|
35
|
+
|
36
|
+
def __call__(self, state, *args, **kwargs):
|
37
|
+
return self._process(state, *args, **kwargs)
|
38
|
+
|
39
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
40
|
+
state = self._analyzer.init_state()
|
41
|
+
if shm_handler is not None:
|
42
|
+
self._shared = True
|
43
|
+
state = self._to_shared(state, shm_handler)
|
44
|
+
|
45
|
+
self._lock = Manager().Lock()
|
46
|
+
self._process = self._thread_safe_call
|
47
|
+
return state
|
48
|
+
|
49
|
+
def _to_shared(self, state: Tuple, shm_handler):
|
50
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
51
|
+
|
52
|
+
ret = []
|
53
|
+
for v in state:
|
54
|
+
if isinstance(v, backend_arr):
|
55
|
+
v = be.to_sharedarr(v, shm_handler)
|
56
|
+
elif isinstance(v, dict):
|
57
|
+
v = Manager().dict(**v)
|
58
|
+
ret.append(v)
|
59
|
+
return tuple(ret)
|
60
|
+
|
61
|
+
def _shared_to_object(self, shared: type):
|
62
|
+
if not self._shared:
|
63
|
+
return shared
|
64
|
+
if isinstance(shared, tuple) and len(shared):
|
65
|
+
if isinstance(shared[0], SharedMemory):
|
66
|
+
return be.from_sharedarr(shared)
|
67
|
+
return shared
|
68
|
+
|
69
|
+
def _thread_safe_call(self, state, *args, **kwargs):
|
70
|
+
"""Thread-safe call to analyzer"""
|
71
|
+
with self._lock:
|
72
|
+
state = tuple(self._shared_to_object(x) for x in state)
|
73
|
+
return self._direct_call(state, *args, **kwargs)
|
74
|
+
|
75
|
+
def _direct_call(self, state, *args, **kwargs):
|
76
|
+
"""Direct call to analyzer without locking"""
|
77
|
+
return self._analyzer(state, *args, **kwargs)
|
78
|
+
|
79
|
+
def result(self, state, **kwargs):
|
80
|
+
"""Extract final result"""
|
81
|
+
final_state = state
|
82
|
+
if self._shared:
|
83
|
+
# Convert shared arrays back to regular arrays and copy to
|
84
|
+
# avoid array invalidation by shared memory handler
|
85
|
+
final_state = tuple(self._shared_to_object(x) for x in final_state)
|
86
|
+
return self._analyzer.result(final_state, **kwargs)
|
87
|
+
|
88
|
+
def merge(self, *args, **kwargs):
|
89
|
+
return self._analyzer.merge(*args, **kwargs)
|
90
|
+
|
91
|
+
|
92
|
+
class SharedAnalyzerProxy(StatelessSharedAnalyzerProxy):
|
93
|
+
"""
|
94
|
+
Child of :py:class:`StatelessSharedAnalyzerProxy` that is aware
|
95
|
+
of the current analyzer state to emulate the previous analyzer interface.
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
analyzer_class: type,
|
101
|
+
analyzer_params: dict,
|
102
|
+
shm_handler: type = None,
|
103
|
+
**kwargs,
|
104
|
+
):
|
105
|
+
super().__init__(
|
106
|
+
analyzer_class=analyzer_class,
|
107
|
+
analyzer_params=analyzer_params,
|
108
|
+
)
|
109
|
+
if not self._analyzer.shareable:
|
110
|
+
shm_handler = None
|
111
|
+
self.init_state(shm_handler)
|
112
|
+
|
113
|
+
def init_state(self, shm_handler=None, *args, **kwargs) -> Tuple:
|
114
|
+
self._state = super().init_state(shm_handler, *args, **kwargs)
|
115
|
+
|
116
|
+
def __call__(self, *args, **kwargs):
|
117
|
+
state = super().__call__(self._state, *args, **kwargs)
|
118
|
+
if not self._shared:
|
119
|
+
self._state = state
|
120
|
+
|
121
|
+
def result(self, **kwargs):
|
122
|
+
"""Extract final result"""
|
123
|
+
return super().result(self._state, **kwargs)
|
tme/backends/__init__.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
pyTME backend manager.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Dict, List
|
tme/backends/_cupy_utils.py
CHANGED
@@ -1,34 +1,35 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Utility functions for cupy backend.
|
2
3
|
|
3
|
-
|
4
|
-
|
5
|
-
|
4
|
+
The functions spline_filter, _prepad_for_spline_filter, _filter_input,
|
5
|
+
_get_coord_affine_batched and affine_transform are largely copied from
|
6
|
+
cupyx.scipy.ndimage which operates under the following license
|
6
7
|
|
7
|
-
|
8
|
-
|
8
|
+
Copyright (c) 2015 Preferred Infrastructure, Inc.
|
9
|
+
Copyright (c) 2015 Preferred Networks, Inc.
|
9
10
|
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
11
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
12
|
+
of this software and associated documentation files (the "Software"), to deal
|
13
|
+
in the Software without restriction, including without limitation the rights
|
14
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
15
|
+
copies of the Software, and to permit persons to whom the Software is
|
16
|
+
furnished to do so, subject to the following conditions:
|
16
17
|
|
17
|
-
|
18
|
-
|
18
|
+
The above copyright notice and this permission notice shall be included in
|
19
|
+
all copies or substantial portions of the Software.
|
19
20
|
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
21
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
22
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
23
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
24
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
25
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
26
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
27
|
+
THE SOFTWARE.
|
27
28
|
|
28
|
-
|
29
|
-
|
29
|
+
I have since extended the functionality of the cupyx.scipy.ndimage functions
|
30
|
+
in question to support batched inputs.
|
30
31
|
|
31
|
-
|
32
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
32
33
|
"""
|
33
34
|
|
34
35
|
import numpy
|
tme/backends/_jax_utils.py
CHANGED
@@ -1,24 +1,28 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Utility functions for jax backend.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Tuple
|
9
10
|
from functools import partial
|
10
11
|
|
11
12
|
import jax.numpy as jnp
|
12
|
-
from jax import pmap, lax
|
13
|
+
from jax import pmap, lax, vmap
|
13
14
|
|
14
15
|
from ..types import BackendArray
|
15
16
|
from ..backends import backend as be
|
16
17
|
from ..matching_utils import normalize_template as _normalize_template
|
17
18
|
|
18
19
|
|
20
|
+
__all__ = ["scan"]
|
21
|
+
|
22
|
+
|
19
23
|
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
20
24
|
"""
|
21
|
-
Computes :py:meth:`tme.
|
25
|
+
Computes :py:meth:`tme.matching_scores.cc_setup`.
|
22
26
|
"""
|
23
27
|
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
24
28
|
template_ft = template_ft.at[:].multiply(ft_target)
|
@@ -27,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
|
27
31
|
|
28
32
|
|
29
33
|
def _flc_scoring(
|
30
|
-
template: BackendArray,
|
31
|
-
template_mask: BackendArray,
|
32
34
|
ft_target: BackendArray,
|
33
35
|
ft_target2: BackendArray,
|
36
|
+
template: BackendArray,
|
37
|
+
template_mask: BackendArray,
|
34
38
|
n_observations: BackendArray,
|
35
39
|
eps: float,
|
36
40
|
**kwargs,
|
37
41
|
) -> BackendArray:
|
38
42
|
"""
|
39
|
-
Computes :py:meth:`tme.
|
43
|
+
Computes :py:meth:`tme.matching_scores.flc_scoring`.
|
40
44
|
"""
|
41
|
-
correlation = _correlate(template=template, ft_target=ft_target)
|
42
45
|
inv_denominator = _reciprocal_target_std(
|
43
46
|
ft_target=ft_target,
|
44
47
|
ft_target2=ft_target2,
|
@@ -46,18 +49,17 @@ def _flc_scoring(
|
|
46
49
|
eps=eps,
|
47
50
|
n_observations=n_observations,
|
48
51
|
)
|
49
|
-
|
50
|
-
return correlation
|
52
|
+
return _flcSphere_scoring(ft_target, template, inv_denominator)
|
51
53
|
|
52
54
|
|
53
55
|
def _flcSphere_scoring(
|
54
|
-
template: BackendArray,
|
55
56
|
ft_target: BackendArray,
|
57
|
+
template: BackendArray,
|
56
58
|
inv_denominator: BackendArray,
|
57
59
|
**kwargs,
|
58
60
|
) -> BackendArray:
|
59
61
|
"""
|
60
|
-
Computes :py:meth:`tme.
|
62
|
+
Computes :py:meth:`tme.matching_scores.corr_scoring`.
|
61
63
|
"""
|
62
64
|
correlation = _correlate(template=template, ft_target=ft_target)
|
63
65
|
correlation = correlation.at[:].multiply(inv_denominator)
|
@@ -76,7 +78,7 @@ def _reciprocal_target_std(
|
|
76
78
|
|
77
79
|
See Also
|
78
80
|
--------
|
79
|
-
:py:meth:`tme.
|
81
|
+
:py:meth:`tme.matching_scores.flc_scoring`.
|
80
82
|
"""
|
81
83
|
ft_shape = template_mask.shape
|
82
84
|
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
@@ -162,12 +164,12 @@ def scan(
|
|
162
164
|
template_rot = _normalize_template(
|
163
165
|
template_rot, template_mask_rot, n_observations
|
164
166
|
)
|
165
|
-
|
166
|
-
|
167
|
+
rot_pad = be.topleft_pad(template_rot, fast_shape)
|
168
|
+
mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
|
167
169
|
|
168
170
|
scores = scoring_func(
|
169
|
-
template=
|
170
|
-
template_mask=
|
171
|
+
template=rot_pad,
|
172
|
+
template_mask=mask_rot_pad,
|
171
173
|
ft_target=ft_target,
|
172
174
|
ft_target2=ft_target2,
|
173
175
|
inv_denominator=inv_denominator,
|
tme/backends/cupy_backend.py
CHANGED
@@ -1,17 +1,15 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Backend using cupy for template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
|
-
import warnings
|
9
9
|
from importlib.util import find_spec
|
10
10
|
from contextlib import contextmanager
|
11
11
|
from typing import Tuple, Callable, List
|
12
12
|
|
13
|
-
import numpy as np
|
14
|
-
|
15
13
|
from .npfftw_backend import NumpyFFTWBackend
|
16
14
|
from ..types import CupyArray, NDArray, shm_type
|
17
15
|
|
@@ -113,16 +111,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
113
111
|
def unravel_index(self, indices, shape):
|
114
112
|
return self._array_backend.unravel_index(indices=indices, dims=shape)
|
115
113
|
|
116
|
-
def unique(self, ar, axis=None, *args, **kwargs):
|
117
|
-
if axis is None:
|
118
|
-
return self._array_backend.unique(ar=ar, axis=axis, *args, **kwargs)
|
119
|
-
|
120
|
-
warnings.warn("Axis argument not yet supported in CupY, falling back to NumPy.")
|
121
|
-
ret = np.unique(ar=self.to_numpy_array(ar), axis=axis, *args, **kwargs)
|
122
|
-
if not isinstance(ret, tuple):
|
123
|
-
return self.to_backend_array(ret)
|
124
|
-
return tuple(self.to_backend_array(k) for k in ret)
|
125
|
-
|
126
114
|
def build_fft(
|
127
115
|
self,
|
128
116
|
fwd_shape: Tuple[int],
|
@@ -155,15 +143,14 @@ class CupyBackend(NumpyFFTWBackend):
|
|
155
143
|
def rfftn(
|
156
144
|
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
157
145
|
) -> CupyArray:
|
158
|
-
return self.rfftn(arr, s=s, axes=fwd_axes)
|
146
|
+
return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
|
159
147
|
|
160
148
|
def irfftn(
|
161
149
|
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
162
150
|
) -> CupyArray:
|
163
|
-
return self.irfftn(arr, s=s, axes=inv_axes)
|
151
|
+
return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
|
164
152
|
|
165
153
|
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
166
|
-
|
167
154
|
return rfftn, irfftn
|
168
155
|
|
169
156
|
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
@@ -248,13 +235,13 @@ class CupyBackend(NumpyFFTWBackend):
|
|
248
235
|
)
|
249
236
|
return None
|
250
237
|
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
238
|
+
if data.ndim == 3 and cache and self.texture_available:
|
239
|
+
# Device memory pool (should) come to rescue performance
|
240
|
+
temp = self.zeros(data.shape, data.dtype)
|
241
|
+
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
242
|
+
texture.affine(transform_m=matrix, profile=False, output=temp)
|
243
|
+
output[out_slice] = temp
|
244
|
+
return None
|
258
245
|
|
259
246
|
self.affine_transform(
|
260
247
|
input=data,
|
tme/backends/jax_backend.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Backend using jax for template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023-2024 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from functools import wraps
|
9
|
-
from typing import Tuple, List, Callable
|
10
|
+
from typing import Tuple, List, Callable, Dict
|
10
11
|
|
11
12
|
from ..types import BackendArray
|
12
13
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
@@ -50,12 +51,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
50
51
|
)
|
51
52
|
self.scipy = jsp
|
52
53
|
self._create_ufuncs()
|
53
|
-
try:
|
54
|
-
from ._jax_utils import scan as _
|
55
|
-
|
56
|
-
self.scan = self._scan
|
57
|
-
except Exception:
|
58
|
-
pass
|
59
54
|
|
60
55
|
def from_sharedarr(self, arr: BackendArray) -> BackendArray:
|
61
56
|
return arr
|
@@ -188,7 +183,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
188
183
|
rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
|
189
184
|
return max_scores, rotations
|
190
185
|
|
191
|
-
def
|
186
|
+
def compute_convolution_shapes(
|
187
|
+
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
188
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
189
|
+
from scipy.fft import next_fast_len
|
190
|
+
|
191
|
+
convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
|
192
|
+
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
193
|
+
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
194
|
+
|
195
|
+
return convolution_shape, fast_shape, fast_ft_shape
|
196
|
+
|
197
|
+
def scan(
|
192
198
|
self,
|
193
199
|
matching_data: type,
|
194
200
|
splits: Tuple[Tuple[slice, slice]],
|
@@ -213,9 +219,9 @@ class JaxBackend(NumpyFFTWBackend):
|
|
213
219
|
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
214
220
|
target_shape=self.to_numpy_array(target_shape),
|
215
221
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
216
|
-
|
222
|
+
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
223
|
+
pad_target=pad_target,
|
217
224
|
)
|
218
|
-
|
219
225
|
analyzer_args = {
|
220
226
|
"convolution_mode": convolution_mode,
|
221
227
|
"fourier_shift": shift,
|
@@ -245,19 +251,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
245
251
|
|
246
252
|
targets, translation_offsets = [], []
|
247
253
|
for target_split, template_split in split_subset:
|
248
|
-
base = matching_data.subset_by_slice(
|
254
|
+
base, translation_offset = matching_data.subset_by_slice(
|
249
255
|
target_slice=target_split,
|
250
256
|
target_pad=target_pad,
|
251
257
|
template_slice=template_split,
|
252
258
|
)
|
253
|
-
translation_offsets.append(
|
259
|
+
translation_offsets.append(translation_offset)
|
254
260
|
targets.append(self.topleft_pad(base._target, fast_shape))
|
255
261
|
|
256
262
|
if create_filter:
|
257
263
|
filter_args = {
|
258
264
|
"data_rfft": self.fft.rfftn(targets[0]),
|
259
265
|
"return_real_fourier": True,
|
260
|
-
"shape_is_real_fourier": False,
|
261
266
|
}
|
262
267
|
|
263
268
|
if create_template_filter:
|
@@ -287,15 +292,11 @@ class JaxBackend(NumpyFFTWBackend):
|
|
287
292
|
|
288
293
|
for index in range(scores.shape[0]):
|
289
294
|
temp = callback_class(
|
290
|
-
shape=scores.shape,
|
291
|
-
scores=scores[index],
|
292
|
-
rotations=rotations[index],
|
293
|
-
thread_safe=False,
|
295
|
+
shape=scores[index].shape,
|
294
296
|
offset=translation_offsets[index],
|
295
297
|
)
|
296
|
-
|
297
|
-
ret.append(
|
298
|
-
|
298
|
+
state = (scores[index], rotations[index], rotation_mapping)
|
299
|
+
ret.append(temp.result(state, **analyzer_args))
|
299
300
|
return ret
|
300
301
|
|
301
302
|
def get_available_memory(self) -> int:
|
tme/backends/matching_backend.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Strategy pattern to allow for flexible array / FFT backends.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from abc import ABC, abstractmethod
|
tme/backends/mlx_backend.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Backend using Apple's MLX library for template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
from typing import Tuple, List, Callable
|
tme/backends/npfftw_backend.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Backend using numpy and pyFFTW for template matching.
|
2
3
|
|
3
|
-
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
5
|
|
5
|
-
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
7
|
"""
|
7
8
|
|
8
9
|
import os
|
@@ -159,8 +160,9 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
159
160
|
temp = self._array_backend.zeros(1, dtype=dtype)
|
160
161
|
return temp.nbytes
|
161
162
|
|
162
|
-
|
163
|
-
|
163
|
+
def astype(self, arr, dtype: Type) -> NDArray:
|
164
|
+
if self._array_backend.iscomplexobj(arr):
|
165
|
+
arr = arr.real
|
164
166
|
return arr.astype(dtype)
|
165
167
|
|
166
168
|
@staticmethod
|
@@ -396,33 +398,33 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
396
398
|
out_mask: NDArray = None,
|
397
399
|
order: int = 3,
|
398
400
|
cache: bool = False,
|
401
|
+
batched: bool = False,
|
399
402
|
) -> Tuple[NDArray, NDArray]:
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
translation=
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
|
403
|
+
if out is None:
|
404
|
+
out = self.zeros_like(arr)
|
405
|
+
|
406
|
+
# Check whether rotation_matrix is already a rigid transform matrix
|
407
|
+
matrix = rotation_matrix
|
408
|
+
if matrix.shape[-1] == (arr.ndim - int(batched)):
|
409
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
410
|
+
if not use_geometric_center:
|
411
|
+
center = self.center_of_mass(arr, cutoff=0)
|
412
|
+
|
413
|
+
offset = int(arr.ndim - rotation_matrix.shape[0])
|
414
|
+
center = center[offset:]
|
415
|
+
translation = (
|
416
|
+
self.zeros(center.size) if translation is None else translation
|
417
|
+
)
|
418
|
+
matrix = self._rigid_transform_matrix(
|
419
|
+
rotation_matrix=rotation_matrix,
|
420
|
+
translation=translation,
|
421
|
+
center=center,
|
420
422
|
)
|
421
423
|
|
422
424
|
self._rigid_transform(
|
423
|
-
data=arr
|
425
|
+
data=arr,
|
424
426
|
matrix=matrix,
|
425
|
-
output=out
|
427
|
+
output=out,
|
426
428
|
order=order,
|
427
429
|
prefilter=True,
|
428
430
|
cache=cache,
|
@@ -431,11 +433,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
431
433
|
|
432
434
|
# Applying the prefilter leads to artifacts in the mask.
|
433
435
|
if arr_mask is not None:
|
434
|
-
|
436
|
+
if out_mask is None:
|
437
|
+
out_mask = self.zeros_like(arr_mask)
|
438
|
+
|
435
439
|
self._rigid_transform(
|
436
|
-
data=arr_mask
|
440
|
+
data=arr_mask,
|
437
441
|
matrix=matrix,
|
438
|
-
output=out_mask
|
442
|
+
output=out_mask,
|
439
443
|
order=order,
|
440
444
|
prefilter=False,
|
441
445
|
cache=cache,
|
tme/backends/pytorch_backend.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
-
"""
|
2
|
-
|
1
|
+
"""
|
2
|
+
Backend using pytorch and optionally GPU acceleration for
|
3
|
+
template matching.
|
3
4
|
|
4
|
-
|
5
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
6
|
|
6
|
-
|
7
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
8
|
"""
|
8
9
|
|
9
10
|
from typing import Tuple, Callable
|
@@ -134,6 +135,15 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
134
135
|
def astype(self, arr: TorchTensor, dtype: type) -> TorchTensor:
|
135
136
|
return arr.to(dtype)
|
136
137
|
|
138
|
+
@staticmethod
|
139
|
+
def at(arr, idx, value) -> NDArray:
|
140
|
+
arr[idx] = value
|
141
|
+
return arr
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def addat(arr, indices, *args, **kwargs) -> NDArray:
|
145
|
+
return arr.index_put_(indices, *args, accumulate=True, **kwargs)
|
146
|
+
|
137
147
|
def flip(self, a, axis, **kwargs):
|
138
148
|
return self._array_backend.flip(input=a, dims=axis, **kwargs)
|
139
149
|
|
@@ -296,6 +306,9 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
296
306
|
kwargs["dim"] = kwargs.pop("axes", None)
|
297
307
|
return self._array_backend.fft.irfftn(arr, **kwargs)
|
298
308
|
|
309
|
+
def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
|
310
|
+
return rotation_matrix
|
311
|
+
|
299
312
|
def rigid_transform(
|
300
313
|
self,
|
301
314
|
arr: TorchTensor,
|
@@ -307,6 +320,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
307
320
|
out_mask: TorchTensor = None,
|
308
321
|
order: int = 1,
|
309
322
|
cache: bool = False,
|
323
|
+
**kwargs,
|
310
324
|
):
|
311
325
|
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
312
326
|
mode = _mode_mapping.get(order, None)
|