pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
#
|
3
|
+
# Henry Gomersall
|
4
|
+
# heng@kedevelopments.co.uk
|
5
|
+
#
|
6
|
+
# All rights reserved.
|
7
|
+
#
|
8
|
+
# Redistribution and use in source and binary forms, with or without
|
9
|
+
# modification, are permitted provided that the following conditions are met:
|
10
|
+
#
|
11
|
+
# * Redistributions of source code must retain the above copyright notice, this
|
12
|
+
# list of conditions and the following disclaimer.
|
13
|
+
#
|
14
|
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
15
|
+
# this list of conditions and the following disclaimer in the documentation
|
16
|
+
# and/or other materials provided with the distribution.
|
17
|
+
#
|
18
|
+
# * Neither the name of the copyright holder nor the names of its contributors
|
19
|
+
# may be used to endorse or promote products derived from this software without
|
20
|
+
# specific prior written permission.
|
21
|
+
#
|
22
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
25
|
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
26
|
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
27
|
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
28
|
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
29
|
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
30
|
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
31
|
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
32
|
+
# POSSIBILITY OF SUCH DAMAGE.
|
33
|
+
#
|
34
|
+
|
35
|
+
# This code has been adapted to add support for the out argument in rfftn, irfftn
|
36
|
+
# to allow for reusing existing array buffers
|
37
|
+
# Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
38
|
+
|
39
|
+
import threading
|
40
|
+
|
41
|
+
import pyfftw
|
42
|
+
import numpy as np
|
43
|
+
import pyfftw.builders as builders
|
44
|
+
from pyfftw.interfaces import cache
|
45
|
+
from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
|
46
|
+
|
47
|
+
|
48
|
+
def _Xfftn(
|
49
|
+
a,
|
50
|
+
s,
|
51
|
+
axes,
|
52
|
+
overwrite_input,
|
53
|
+
planner_effort,
|
54
|
+
threads,
|
55
|
+
auto_align_input,
|
56
|
+
auto_contiguous,
|
57
|
+
calling_func,
|
58
|
+
normalise_idft=True,
|
59
|
+
ortho=False,
|
60
|
+
real_direction_flag=None,
|
61
|
+
output_array=None,
|
62
|
+
):
|
63
|
+
|
64
|
+
work_with_copy = False
|
65
|
+
|
66
|
+
a = np.asanyarray(a)
|
67
|
+
|
68
|
+
try:
|
69
|
+
s = tuple(s)
|
70
|
+
except TypeError:
|
71
|
+
pass
|
72
|
+
|
73
|
+
try:
|
74
|
+
axes = tuple(axes)
|
75
|
+
except TypeError:
|
76
|
+
pass
|
77
|
+
|
78
|
+
if calling_func in ("dct", "dst"):
|
79
|
+
# real-to-real transforms require passing an additional flag argument
|
80
|
+
avoid_copy = False
|
81
|
+
args = (
|
82
|
+
overwrite_input,
|
83
|
+
planner_effort,
|
84
|
+
threads,
|
85
|
+
auto_align_input,
|
86
|
+
auto_contiguous,
|
87
|
+
avoid_copy,
|
88
|
+
real_direction_flag,
|
89
|
+
)
|
90
|
+
elif calling_func in ("irfft2", "irfftn"):
|
91
|
+
# overwrite_input is not an argument to irfft2 or irfftn
|
92
|
+
args = (planner_effort, threads, auto_align_input, auto_contiguous)
|
93
|
+
|
94
|
+
if not overwrite_input:
|
95
|
+
# Only irfft2 and irfftn have overwriting the input
|
96
|
+
# as the default (and so require the input array to
|
97
|
+
# be reloaded).
|
98
|
+
work_with_copy = True
|
99
|
+
else:
|
100
|
+
args = (
|
101
|
+
overwrite_input,
|
102
|
+
planner_effort,
|
103
|
+
threads,
|
104
|
+
auto_align_input,
|
105
|
+
auto_contiguous,
|
106
|
+
)
|
107
|
+
|
108
|
+
if not a.flags.writeable:
|
109
|
+
# Special case of a locked array - always work with a
|
110
|
+
# copy. See issue #92.
|
111
|
+
work_with_copy = True
|
112
|
+
|
113
|
+
if overwrite_input:
|
114
|
+
raise ValueError(
|
115
|
+
"overwrite_input cannot be True when the "
|
116
|
+
+ "input array flags.writeable is False"
|
117
|
+
)
|
118
|
+
|
119
|
+
if work_with_copy:
|
120
|
+
# We make the copy before registering the key so that the
|
121
|
+
# copy's stride information will be cached since this will be
|
122
|
+
# used for planning. Make sure the copy is byte aligned to
|
123
|
+
# prevent further copying
|
124
|
+
a_original = a
|
125
|
+
a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
|
126
|
+
a[...] = a_original
|
127
|
+
|
128
|
+
if cache.is_enabled():
|
129
|
+
alignment = a.ctypes.data % pyfftw.simd_alignment
|
130
|
+
|
131
|
+
key = (
|
132
|
+
calling_func,
|
133
|
+
a.shape,
|
134
|
+
a.strides,
|
135
|
+
a.dtype,
|
136
|
+
s.__hash__(),
|
137
|
+
axes.__hash__(),
|
138
|
+
alignment,
|
139
|
+
args,
|
140
|
+
threading.get_ident(),
|
141
|
+
)
|
142
|
+
|
143
|
+
try:
|
144
|
+
if key in cache._fftw_cache:
|
145
|
+
FFTW_object = cache._fftw_cache.lookup(key)
|
146
|
+
else:
|
147
|
+
FFTW_object = None
|
148
|
+
|
149
|
+
except KeyError:
|
150
|
+
# This occurs if the object has fallen out of the cache between
|
151
|
+
# the check and the lookup
|
152
|
+
FFTW_object = None
|
153
|
+
|
154
|
+
if not cache.is_enabled() or FFTW_object is None:
|
155
|
+
|
156
|
+
# If we're going to create a new FFTW object and are not
|
157
|
+
# working with a copy, then we need to copy the input array to
|
158
|
+
# preserve it, otherwise we can't actually take the transform
|
159
|
+
# of the input array! (in general, we have to assume that the
|
160
|
+
# input array will be destroyed during planning).
|
161
|
+
if not work_with_copy:
|
162
|
+
a_copy = a.copy()
|
163
|
+
|
164
|
+
planner_args = (a, s, axes) + args
|
165
|
+
|
166
|
+
FFTW_object = getattr(builders, calling_func)(*planner_args)
|
167
|
+
|
168
|
+
# Only copy if the input array is what was actually used
|
169
|
+
# (otherwise it shouldn't be overwritten)
|
170
|
+
if not work_with_copy and FFTW_object.input_array is a:
|
171
|
+
a[:] = a_copy
|
172
|
+
|
173
|
+
if cache.is_enabled():
|
174
|
+
cache._fftw_cache.insert(FFTW_object, key)
|
175
|
+
|
176
|
+
output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
|
177
|
+
|
178
|
+
else:
|
179
|
+
orig_output_array = FFTW_object.output_array
|
180
|
+
output_shape = orig_output_array.shape
|
181
|
+
output_dtype = orig_output_array.dtype
|
182
|
+
output_alignment = FFTW_object.output_alignment
|
183
|
+
|
184
|
+
if output_array is None:
|
185
|
+
output_array = pyfftw.empty_aligned(
|
186
|
+
output_shape, output_dtype, n=output_alignment
|
187
|
+
)
|
188
|
+
|
189
|
+
FFTW_object(
|
190
|
+
input_array=a,
|
191
|
+
output_array=output_array,
|
192
|
+
normalise_idft=normalise_idft,
|
193
|
+
ortho=ortho,
|
194
|
+
)
|
195
|
+
|
196
|
+
return output_array
|
197
|
+
|
198
|
+
|
199
|
+
def rfftn(
|
200
|
+
a,
|
201
|
+
s=None,
|
202
|
+
axes=None,
|
203
|
+
norm=None,
|
204
|
+
overwrite_input=False,
|
205
|
+
planner_effort=None,
|
206
|
+
threads=None,
|
207
|
+
auto_align_input=True,
|
208
|
+
auto_contiguous=True,
|
209
|
+
out=None,
|
210
|
+
):
|
211
|
+
"""Perform an n-D real FFT.
|
212
|
+
|
213
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
214
|
+
the rest of the arguments are documented
|
215
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
216
|
+
"""
|
217
|
+
calling_func = "rfftn"
|
218
|
+
planner_effort = _default_effort(planner_effort)
|
219
|
+
threads = _default_threads(threads)
|
220
|
+
|
221
|
+
return _Xfftn(
|
222
|
+
a,
|
223
|
+
s,
|
224
|
+
axes,
|
225
|
+
overwrite_input,
|
226
|
+
planner_effort,
|
227
|
+
threads,
|
228
|
+
auto_align_input,
|
229
|
+
auto_contiguous,
|
230
|
+
calling_func,
|
231
|
+
**_norm_args(norm),
|
232
|
+
output_array=out,
|
233
|
+
)
|
234
|
+
|
235
|
+
|
236
|
+
def irfftn(
|
237
|
+
a,
|
238
|
+
s=None,
|
239
|
+
axes=None,
|
240
|
+
norm=None,
|
241
|
+
overwrite_input=False,
|
242
|
+
planner_effort=None,
|
243
|
+
threads=None,
|
244
|
+
auto_align_input=True,
|
245
|
+
auto_contiguous=True,
|
246
|
+
out=None,
|
247
|
+
):
|
248
|
+
"""Perform an n-D real inverse FFT.
|
249
|
+
|
250
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
251
|
+
the rest of the arguments are documented
|
252
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
253
|
+
"""
|
254
|
+
calling_func = "irfftn"
|
255
|
+
planner_effort = _default_effort(planner_effort)
|
256
|
+
threads = _default_threads(threads)
|
257
|
+
|
258
|
+
return _Xfftn(
|
259
|
+
a,
|
260
|
+
s,
|
261
|
+
axes,
|
262
|
+
overwrite_input,
|
263
|
+
planner_effort,
|
264
|
+
threads,
|
265
|
+
auto_align_input,
|
266
|
+
auto_contiguous,
|
267
|
+
calling_func,
|
268
|
+
**_norm_args(norm),
|
269
|
+
output_array=out,
|
270
|
+
)
|
tme/backends/cupy_backend.py
CHANGED
@@ -6,9 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
+
from typing import Tuple, List
|
9
10
|
from importlib.util import find_spec
|
10
11
|
from contextlib import contextmanager
|
11
|
-
from typing import Tuple, Callable, List
|
12
12
|
|
13
13
|
from .npfftw_backend import NumpyFFTWBackend
|
14
14
|
from ..types import CupyArray, NDArray, shm_type
|
@@ -81,6 +81,17 @@ class CupyBackend(NumpyFFTWBackend):
|
|
81
81
|
""",
|
82
82
|
"norm_scores",
|
83
83
|
)
|
84
|
+
|
85
|
+
# Sum of square computation similar to the demeaned variance in pytom
|
86
|
+
self.ssum = cp.ReductionKernel(
|
87
|
+
f"{ftype} arr",
|
88
|
+
f"{ftype} ret",
|
89
|
+
"arr * arr",
|
90
|
+
"a + b",
|
91
|
+
"ret = a",
|
92
|
+
"0",
|
93
|
+
f"ssum_{ftype}",
|
94
|
+
)
|
84
95
|
self.texture_available = find_spec("voltools") is not None
|
85
96
|
|
86
97
|
def to_backend_array(self, arr: NDArray) -> CupyArray:
|
@@ -111,53 +122,14 @@ class CupyBackend(NumpyFFTWBackend):
|
|
111
122
|
def unravel_index(self, indices, shape):
|
112
123
|
return self._array_backend.unravel_index(indices=indices, dims=shape)
|
113
124
|
|
114
|
-
def
|
115
|
-
self
|
116
|
-
fwd_shape: Tuple[int],
|
117
|
-
inv_shape: Tuple[int],
|
118
|
-
inv_output_shape: Tuple[int] = None,
|
119
|
-
fwd_axes: Tuple[int] = None,
|
120
|
-
inv_axes: Tuple[int] = None,
|
121
|
-
**kwargs,
|
122
|
-
) -> Tuple[Callable, Callable]:
|
123
|
-
cache = self._array_backend.fft.config.get_plan_cache()
|
124
|
-
current_device = self._array_backend.cuda.device.get_device_id()
|
125
|
-
|
126
|
-
previous_transform = [fwd_shape, inv_shape]
|
127
|
-
if current_device in PLAN_CACHE:
|
128
|
-
previous_transform = PLAN_CACHE[current_device]
|
129
|
-
|
130
|
-
real_diff, cmplx_diff = True, True
|
131
|
-
if len(fwd_shape) == len(previous_transform[0]):
|
132
|
-
real_diff = fwd_shape == previous_transform[0]
|
133
|
-
if len(inv_shape) == len(previous_transform[1]):
|
134
|
-
cmplx_diff = inv_shape == previous_transform[1]
|
135
|
-
|
136
|
-
if real_diff or cmplx_diff:
|
137
|
-
cache.clear()
|
138
|
-
|
139
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
140
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
141
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
142
|
-
|
143
|
-
def rfftn(
|
144
|
-
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
145
|
-
) -> CupyArray:
|
146
|
-
return self.rfftn(arr, s=s, axes=fwd_axes, overwrite_x=True)
|
147
|
-
|
148
|
-
def irfftn(
|
149
|
-
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
150
|
-
) -> CupyArray:
|
151
|
-
return self.irfftn(arr, s=s, axes=inv_axes, overwrite_x=True)
|
152
|
-
|
153
|
-
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
154
|
-
return rfftn, irfftn
|
125
|
+
def free_cache(self):
|
126
|
+
self._array_backend.fft.config.get_plan_cache().clear()
|
155
127
|
|
156
128
|
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
157
129
|
return self._cufft.rfftn(arr, **kwargs)
|
158
130
|
|
159
131
|
def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
160
|
-
return self._cufft.irfftn(arr, **kwargs)
|
132
|
+
return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
|
161
133
|
|
162
134
|
def compute_convolution_shapes(
|
163
135
|
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
@@ -178,17 +150,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
178
150
|
peaks = self._array_backend.array(self._array_backend.nonzero(max_filter)).T
|
179
151
|
return peaks
|
180
152
|
|
181
|
-
# The default methods in Cupy were oddly slow
|
182
|
-
def var(self, a, *args, **kwargs):
|
183
|
-
out = a - self._array_backend.mean(a, *args, **kwargs)
|
184
|
-
self._array_backend.square(out, out)
|
185
|
-
out = self._array_backend.mean(out, *args, **kwargs)
|
186
|
-
return out
|
187
|
-
|
188
|
-
def std(self, a, *args, **kwargs):
|
189
|
-
out = self.var(a, *args, **kwargs)
|
190
|
-
return self._array_backend.sqrt(out)
|
191
|
-
|
192
153
|
def _get_texture(self, arr: CupyArray, order: int = 3, prefilter: bool = False):
|
193
154
|
key = id(arr)
|
194
155
|
if key in TEXTURE_CACHE:
|
@@ -235,7 +196,7 @@ class CupyBackend(NumpyFFTWBackend):
|
|
235
196
|
)
|
236
197
|
return None
|
237
198
|
|
238
|
-
if data.ndim == 3 and cache and self.texture_available:
|
199
|
+
if data.ndim == 3 and cache and self.texture_available and not batched:
|
239
200
|
# Device memory pool (should) come to rescue performance
|
240
201
|
temp = self.zeros(data.shape, data.dtype)
|
241
202
|
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
tme/backends/jax_backend.py
CHANGED
@@ -7,7 +7,9 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import wraps
|
10
|
-
from typing import Tuple, List,
|
10
|
+
from typing import Tuple, List, Dict, Any
|
11
|
+
|
12
|
+
import numpy as np
|
11
13
|
|
12
14
|
from ..types import BackendArray
|
13
15
|
from .npfftw_backend import NumpyFFTWBackend, shm_type
|
@@ -64,6 +66,10 @@ class JaxBackend(NumpyFFTWBackend):
|
|
64
66
|
arr = arr.at[idx].set(value)
|
65
67
|
return arr
|
66
68
|
|
69
|
+
def addat(self, arr, indices, values):
|
70
|
+
arr = arr.at[indices].add(values)
|
71
|
+
return arr
|
72
|
+
|
67
73
|
def topleft_pad(
|
68
74
|
self, arr: BackendArray, shape: Tuple[int], padval: int = 0
|
69
75
|
) -> BackendArray:
|
@@ -88,6 +94,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
88
94
|
"sqrt",
|
89
95
|
"maximum",
|
90
96
|
"exp",
|
97
|
+
"mod",
|
91
98
|
]
|
92
99
|
for ufunc in ufuncs:
|
93
100
|
backend_method = emulate_out(getattr(self._array_backend, ufunc))
|
@@ -103,27 +110,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
103
110
|
shape=arr.shape, dtype=arr.dtype, fill_value=value
|
104
111
|
)
|
105
112
|
|
106
|
-
def build_fft(
|
107
|
-
self,
|
108
|
-
fwd_shape: Tuple[int],
|
109
|
-
inv_shape: Tuple[int] = None,
|
110
|
-
inv_output_shape: Tuple[int] = None,
|
111
|
-
fwd_axes: Tuple[int] = None,
|
112
|
-
inv_axes: Tuple[int] = None,
|
113
|
-
**kwargs,
|
114
|
-
) -> Tuple[Callable, Callable]:
|
115
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
116
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
117
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
118
|
-
|
119
|
-
def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
|
120
|
-
return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
|
121
|
-
|
122
|
-
def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
|
123
|
-
return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
|
124
|
-
|
125
|
-
return rfftn, irfftn
|
126
|
-
|
127
113
|
def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
|
128
114
|
return self._array_backend.fft.rfftn(arr, **kwargs)
|
129
115
|
|
@@ -194,12 +180,37 @@ class JaxBackend(NumpyFFTWBackend):
|
|
194
180
|
|
195
181
|
return convolution_shape, fast_shape, fast_ft_shape
|
196
182
|
|
183
|
+
def _to_hashable(self, obj: Any) -> Tuple[str, Tuple]:
|
184
|
+
if isinstance(obj, np.ndarray):
|
185
|
+
return ("array", (tuple(obj.flatten().tolist()), obj.shape))
|
186
|
+
return ("other", obj)
|
187
|
+
|
188
|
+
def _from_hashable(self, type_info: str, data: Any) -> Any:
|
189
|
+
if type_info == "array":
|
190
|
+
data, shape = data
|
191
|
+
return self.array(data).reshape(shape)
|
192
|
+
return data
|
193
|
+
|
194
|
+
def _dict_to_tuple(self, data: Dict) -> Tuple:
|
195
|
+
return tuple((k, self._to_hashable(v)) for k, v in data.items())
|
196
|
+
|
197
|
+
def _tuple_to_dict(self, data: Tuple) -> Dict:
|
198
|
+
return {x[0]: self._from_hashable(*x[1]) for x in data}
|
199
|
+
|
200
|
+
def _unbatch(self, data, target_ndim, index):
|
201
|
+
if not isinstance(data, type(self.zeros(1))):
|
202
|
+
return data
|
203
|
+
elif data.ndim <= target_ndim:
|
204
|
+
return data
|
205
|
+
return data[index]
|
206
|
+
|
197
207
|
def scan(
|
198
208
|
self,
|
199
209
|
matching_data: type,
|
200
210
|
splits: Tuple[Tuple[slice, slice]],
|
201
211
|
n_jobs: int,
|
202
|
-
callback_class,
|
212
|
+
callback_class: object,
|
213
|
+
callback_class_args: Dict,
|
203
214
|
rotate_mask: bool = False,
|
204
215
|
**kwargs,
|
205
216
|
) -> List:
|
@@ -207,12 +218,14 @@ class JaxBackend(NumpyFFTWBackend):
|
|
207
218
|
Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
|
208
219
|
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
209
220
|
"""
|
210
|
-
from ._jax_utils import
|
221
|
+
from ._jax_utils import setup_scan
|
222
|
+
from ..analyzer import MaxScoreOverRotations
|
211
223
|
|
212
224
|
pad_target = True if len(splits) > 1 else False
|
213
225
|
convolution_mode = "valid" if pad_target else "same"
|
214
226
|
target_pad = matching_data.target_padding(pad_target=pad_target)
|
215
227
|
|
228
|
+
score_mask = self.full((1,), fill_value=1, dtype=bool)
|
216
229
|
target_shape = tuple(
|
217
230
|
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
218
231
|
)
|
@@ -220,16 +233,20 @@ class JaxBackend(NumpyFFTWBackend):
|
|
220
233
|
target_shape=self.to_numpy_array(target_shape),
|
221
234
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
222
235
|
batch_mask=self.to_numpy_array(matching_data._batch_mask),
|
223
|
-
pad_target=pad_target,
|
224
236
|
)
|
225
237
|
analyzer_args = {
|
226
|
-
"
|
238
|
+
"shape": fast_shape,
|
227
239
|
"fourier_shift": shift,
|
240
|
+
"fast_shape": fast_shape,
|
228
241
|
"targetshape": target_shape,
|
229
242
|
"templateshape": matching_data.template.shape,
|
230
243
|
"convolution_shape": conv_shape,
|
244
|
+
"convolution_mode": convolution_mode,
|
245
|
+
"thread_safe": False,
|
246
|
+
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
247
|
+
"n_rotations": matching_data.rotations.shape[0],
|
248
|
+
"jax_mode": True,
|
231
249
|
}
|
232
|
-
|
233
250
|
create_target_filter = matching_data.target_filter is not None
|
234
251
|
create_template_filter = matching_data.template_filter is not None
|
235
252
|
create_filter = create_target_filter or create_template_filter
|
@@ -245,6 +262,9 @@ class JaxBackend(NumpyFFTWBackend):
|
|
245
262
|
for i in range(matching_data.rotations.shape[0])
|
246
263
|
}
|
247
264
|
for split_start in range(0, len(splits), n_jobs):
|
265
|
+
|
266
|
+
analyzer_kwargs = []
|
267
|
+
|
248
268
|
split_subset = splits[split_start : (split_start + n_jobs)]
|
249
269
|
if not len(split_subset):
|
250
270
|
continue
|
@@ -256,8 +276,17 @@ class JaxBackend(NumpyFFTWBackend):
|
|
256
276
|
target_pad=target_pad,
|
257
277
|
template_slice=template_split,
|
258
278
|
)
|
279
|
+
cur_args = analyzer_args.copy()
|
280
|
+
cur_args["offset"] = translation_offset
|
281
|
+
cur_args.update(callback_class_args)
|
282
|
+
analyzer_kwargs.append(cur_args)
|
283
|
+
|
284
|
+
if pad_target:
|
285
|
+
score_mask = base._score_mask(fast_shape, shift)
|
286
|
+
|
287
|
+
_target = self.astype(base._target, self._float_dtype)
|
259
288
|
translation_offsets.append(translation_offset)
|
260
|
-
targets.append(self.topleft_pad(
|
289
|
+
targets.append(self.topleft_pad(_target, fast_shape))
|
261
290
|
|
262
291
|
if create_filter:
|
263
292
|
filter_args = {
|
@@ -279,24 +308,34 @@ class JaxBackend(NumpyFFTWBackend):
|
|
279
308
|
|
280
309
|
create_filter, create_template_filter, create_target_filter = (False,) * 3
|
281
310
|
base, targets = None, self._array_backend.stack(targets)
|
282
|
-
|
311
|
+
|
312
|
+
scan_inner = setup_scan(
|
313
|
+
analyzer_kwargs=analyzer_kwargs,
|
314
|
+
callback_class=callback_class,
|
315
|
+
fast_shape=fast_shape,
|
316
|
+
rotate_mask=rotate_mask,
|
317
|
+
)
|
318
|
+
|
319
|
+
states = scan_inner(
|
283
320
|
self.astype(targets, self._float_dtype),
|
284
|
-
matching_data.template,
|
285
|
-
matching_data.template_mask,
|
321
|
+
self.astype(matching_data.template, self._float_dtype),
|
322
|
+
self.astype(matching_data.template_mask, self._float_dtype),
|
286
323
|
matching_data.rotations,
|
287
324
|
template_filter,
|
288
325
|
target_filter,
|
289
|
-
|
290
|
-
rotate_mask,
|
326
|
+
score_mask,
|
291
327
|
)
|
292
328
|
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
)
|
298
|
-
|
299
|
-
|
329
|
+
ndim = targets.ndim - 1
|
330
|
+
for index in range(targets.shape[0]):
|
331
|
+
kwargs = analyzer_kwargs[index]
|
332
|
+
analyzer = callback_class(**kwargs)
|
333
|
+
state = [self._unbatch(x, ndim, index) for x in states]
|
334
|
+
|
335
|
+
if isinstance(analyzer, MaxScoreOverRotations):
|
336
|
+
state[2] = rotation_mapping
|
337
|
+
|
338
|
+
ret.append(analyzer.result(state, **kwargs))
|
300
339
|
return ret
|
301
340
|
|
302
341
|
def get_available_memory(self) -> int:
|
tme/backends/matching_backend.py
CHANGED
@@ -8,7 +8,7 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from multiprocessing import shared_memory
|
11
|
-
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
11
|
+
from typing import Tuple, Callable, List, Any, Union, Optional, Generator
|
12
12
|
|
13
13
|
from ..types import BackendArray, NDArray, Scalar, shm_type
|
14
14
|
|
@@ -863,6 +863,17 @@ class MatchingBackend(ABC):
|
|
863
863
|
Indices of ``k`` largest elements in ``arr``.
|
864
864
|
"""
|
865
865
|
|
866
|
+
@abstractmethod
|
867
|
+
def ssum(self, arr, *args, **kwargs) -> BackendArray:
|
868
|
+
"""
|
869
|
+
Compute the sum of squares of ``arr``.
|
870
|
+
|
871
|
+
Returns
|
872
|
+
-------
|
873
|
+
BackendArray
|
874
|
+
Sum of squares with shape ().
|
875
|
+
"""
|
876
|
+
|
866
877
|
def indices(self, *args, **kwargs) -> BackendArray:
|
867
878
|
"""
|
868
879
|
Creates an array representing the index grid of an input.
|
@@ -1087,57 +1098,12 @@ class MatchingBackend(ABC):
|
|
1087
1098
|
"""
|
1088
1099
|
|
1089
1100
|
@abstractmethod
|
1090
|
-
def
|
1091
|
-
|
1092
|
-
fwd_shape: Tuple[int],
|
1093
|
-
inv_shape: Tuple[int],
|
1094
|
-
real_dtype: type,
|
1095
|
-
cmpl_dtype: type,
|
1096
|
-
inv_output_shape: Tuple[int] = None,
|
1097
|
-
temp_fwd: NDArray = None,
|
1098
|
-
temp_inv: NDArray = None,
|
1099
|
-
fwd_axes: Tuple[int] = None,
|
1100
|
-
inv_axes: Tuple[int] = None,
|
1101
|
-
fftargs: Dict = {},
|
1102
|
-
) -> Tuple[Callable, Callable]:
|
1103
|
-
"""
|
1104
|
-
Build forward and inverse real fourier transform functions. The returned
|
1105
|
-
callables have two parameters ``arr`` and ``out`` which correspond to the
|
1106
|
-
input and output of the Fourier transform. The methods return the output
|
1107
|
-
of the respective function call, regardless of ``out`` being provided or not,
|
1108
|
-
analogous to most numpy functions.
|
1101
|
+
def rfftn(self, **kwargs):
|
1102
|
+
"""Perform an n-D real FFT."""
|
1109
1103
|
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
Input shape for the forward Fourier transform.
|
1114
|
-
(see `compute_convolution_shapes`).
|
1115
|
-
inv_shape : tuple
|
1116
|
-
Input shape for the inverse Fourier transform.
|
1117
|
-
real_dtype : dtype
|
1118
|
-
Data type of the forward Fourier transform.
|
1119
|
-
complex_dtype : dtype
|
1120
|
-
Data type of the inverse Fourier transform.
|
1121
|
-
inv_output_shape : tuple, optional
|
1122
|
-
Output shape of the inverse Fourier transform. By default fast_shape.
|
1123
|
-
fftargs : dict, optional
|
1124
|
-
Dictionary passed to pyFFTW builders.
|
1125
|
-
temp_fwd : NDArray, optional
|
1126
|
-
Temporary array to build the forward transform. Superseeds shape defined by
|
1127
|
-
fwd_shape if provided.
|
1128
|
-
temp_inv : NDArray, optional
|
1129
|
-
Temporary array to build the inverse transform. Superseeds shape defined by
|
1130
|
-
inv_shape if provided.
|
1131
|
-
fwd_axes : tuple of int
|
1132
|
-
Axes to perform the forward Fourier transform over.
|
1133
|
-
inv_axes : tuple of int
|
1134
|
-
Axes to perform the inverse Fourier transform over.
|
1135
|
-
|
1136
|
-
Returns
|
1137
|
-
-------
|
1138
|
-
tuple
|
1139
|
-
Tuple of callables for forward and inverse real Fourier transform.
|
1140
|
-
"""
|
1104
|
+
@abstractmethod
|
1105
|
+
def irfftn(self, **kwargs):
|
1106
|
+
"""Perform an n-D real inverse FFT."""
|
1141
1107
|
|
1142
1108
|
def extract_center(self, arr: BackendArray, newshape: Tuple[int]) -> BackendArray:
|
1143
1109
|
"""
|