pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
- pytme-0.3.1.dist-info/RECORD +133 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
- pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +177 -226
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +69 -47
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +98 -28
- scripts/pytme_runner.py +1223 -0
- scripts/refine_matches.py +156 -387
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +4 -9
- tests/test_density.py +0 -1
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +11 -61
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +5 -8
- tme/analyzer/aggregation.py +28 -9
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +49 -122
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +31 -28
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +11 -54
- tme/backends/jax_backend.py +72 -48
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +95 -90
- tme/backends/pytorch_backend.py +5 -26
- tme/density.py +7 -10
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +138 -87
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +98 -112
- tme/filters/whitening.py +1 -6
- tme/mask.py +341 -0
- tme/matching_data.py +20 -44
- tme/matching_exhaustive.py +46 -56
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +216 -412
- tme/matching_utils.py +82 -424
- tme/memory.py +1 -1
- tme/orientations.py +16 -8
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- tme/rotations.py +1 -1
- pytme-0.3b0.dist-info/RECORD +0 -122
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/backends/_jax_utils.py
CHANGED
@@ -10,16 +10,19 @@ from typing import Tuple
|
|
10
10
|
from functools import partial
|
11
11
|
|
12
12
|
import jax.numpy as jnp
|
13
|
-
from jax import pmap, lax
|
13
|
+
from jax import pmap, lax, vmap
|
14
14
|
|
15
15
|
from ..types import BackendArray
|
16
16
|
from ..backends import backend as be
|
17
17
|
from ..matching_utils import normalize_template as _normalize_template
|
18
18
|
|
19
19
|
|
20
|
+
__all__ = ["scan"]
|
21
|
+
|
22
|
+
|
20
23
|
def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
21
24
|
"""
|
22
|
-
Computes :py:meth:`tme.
|
25
|
+
Computes :py:meth:`tme.matching_scores.cc_setup`.
|
23
26
|
"""
|
24
27
|
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
25
28
|
template_ft = template_ft.at[:].multiply(ft_target)
|
@@ -28,18 +31,17 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
|
28
31
|
|
29
32
|
|
30
33
|
def _flc_scoring(
|
31
|
-
template: BackendArray,
|
32
|
-
template_mask: BackendArray,
|
33
34
|
ft_target: BackendArray,
|
34
35
|
ft_target2: BackendArray,
|
36
|
+
template: BackendArray,
|
37
|
+
template_mask: BackendArray,
|
35
38
|
n_observations: BackendArray,
|
36
39
|
eps: float,
|
37
40
|
**kwargs,
|
38
41
|
) -> BackendArray:
|
39
42
|
"""
|
40
|
-
Computes :py:meth:`tme.
|
43
|
+
Computes :py:meth:`tme.matching_scores.flc_scoring`.
|
41
44
|
"""
|
42
|
-
correlation = _correlate(template=template, ft_target=ft_target)
|
43
45
|
inv_denominator = _reciprocal_target_std(
|
44
46
|
ft_target=ft_target,
|
45
47
|
ft_target2=ft_target2,
|
@@ -47,18 +49,17 @@ def _flc_scoring(
|
|
47
49
|
eps=eps,
|
48
50
|
n_observations=n_observations,
|
49
51
|
)
|
50
|
-
|
51
|
-
return correlation
|
52
|
+
return _flcSphere_scoring(ft_target, template, inv_denominator)
|
52
53
|
|
53
54
|
|
54
55
|
def _flcSphere_scoring(
|
55
|
-
template: BackendArray,
|
56
56
|
ft_target: BackendArray,
|
57
|
+
template: BackendArray,
|
57
58
|
inv_denominator: BackendArray,
|
58
59
|
**kwargs,
|
59
60
|
) -> BackendArray:
|
60
61
|
"""
|
61
|
-
Computes :py:meth:`tme.
|
62
|
+
Computes :py:meth:`tme.matching_scores.corr_scoring`.
|
62
63
|
"""
|
63
64
|
correlation = _correlate(template=template, ft_target=ft_target)
|
64
65
|
correlation = correlation.at[:].multiply(inv_denominator)
|
@@ -77,7 +78,7 @@ def _reciprocal_target_std(
|
|
77
78
|
|
78
79
|
See Also
|
79
80
|
--------
|
80
|
-
:py:meth:`tme.
|
81
|
+
:py:meth:`tme.matching_scores.flc_scoring`.
|
81
82
|
"""
|
82
83
|
ft_shape = template_mask.shape
|
83
84
|
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
@@ -114,7 +115,8 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
|
114
115
|
@partial(
|
115
116
|
pmap,
|
116
117
|
in_axes=(0,) + (None,) * 6,
|
117
|
-
static_broadcasted_argnums=[6, 7],
|
118
|
+
static_broadcasted_argnums=[6, 7, 8, 9],
|
119
|
+
axis_name="batch",
|
118
120
|
)
|
119
121
|
def scan(
|
120
122
|
target: BackendArray,
|
@@ -125,9 +127,17 @@ def scan(
|
|
125
127
|
target_filter: BackendArray,
|
126
128
|
fast_shape: Tuple[int],
|
127
129
|
rotate_mask: bool,
|
130
|
+
analyzer_class: object,
|
131
|
+
analyzer_kwargs: Tuple[Tuple],
|
128
132
|
) -> Tuple[BackendArray, BackendArray]:
|
129
133
|
eps = jnp.finfo(template.dtype).resolution
|
130
134
|
|
135
|
+
kwargs = lax.switch(
|
136
|
+
lax.axis_index("batch"),
|
137
|
+
[lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
|
138
|
+
)
|
139
|
+
analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
|
140
|
+
|
131
141
|
if hasattr(target_filter, "shape"):
|
132
142
|
target = _apply_fourier_filter(target, target_filter)
|
133
143
|
|
@@ -150,7 +160,7 @@ def scan(
|
|
150
160
|
_template_filter_func = _apply_fourier_filter
|
151
161
|
|
152
162
|
def _sample_transform(ret, rotation_matrix):
|
153
|
-
|
163
|
+
state, index = ret
|
154
164
|
template_rot, template_mask_rot = be.rigid_transform(
|
155
165
|
arr=template,
|
156
166
|
arr_mask=template_mask,
|
@@ -163,27 +173,20 @@ def scan(
|
|
163
173
|
template_rot = _normalize_template(
|
164
174
|
template_rot, template_mask_rot, n_observations
|
165
175
|
)
|
166
|
-
|
167
|
-
|
176
|
+
rot_pad = be.topleft_pad(template_rot, fast_shape)
|
177
|
+
mask_rot_pad = be.topleft_pad(template_mask_rot, fast_shape)
|
168
178
|
|
169
179
|
scores = scoring_func(
|
170
|
-
template=
|
171
|
-
template_mask=
|
180
|
+
template=rot_pad,
|
181
|
+
template_mask=mask_rot_pad,
|
172
182
|
ft_target=ft_target,
|
173
183
|
ft_target2=ft_target2,
|
174
184
|
inv_denominator=inv_denominator,
|
175
185
|
n_observations=n_observations,
|
176
186
|
eps=eps,
|
177
187
|
)
|
178
|
-
|
179
|
-
|
180
|
-
)
|
181
|
-
return (max_scores, rotations, index + 1), None
|
182
|
-
|
183
|
-
score_space = jnp.zeros(fast_shape)
|
184
|
-
rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
|
185
|
-
(score_space, rotation_space, _), _ = lax.scan(
|
186
|
-
_sample_transform, (score_space, rotation_space, 0), rotations
|
187
|
-
)
|
188
|
+
state = analyzer(state, scores, rotation_matrix, rotation_index=index)
|
189
|
+
return (state, index + 1), None
|
188
190
|
|
189
|
-
|
191
|
+
(state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
|
192
|
+
return state
|
@@ -0,0 +1,270 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
#
|
3
|
+
# Henry Gomersall
|
4
|
+
# heng@kedevelopments.co.uk
|
5
|
+
#
|
6
|
+
# All rights reserved.
|
7
|
+
#
|
8
|
+
# Redistribution and use in source and binary forms, with or without
|
9
|
+
# modification, are permitted provided that the following conditions are met:
|
10
|
+
#
|
11
|
+
# * Redistributions of source code must retain the above copyright notice, this
|
12
|
+
# list of conditions and the following disclaimer.
|
13
|
+
#
|
14
|
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
15
|
+
# this list of conditions and the following disclaimer in the documentation
|
16
|
+
# and/or other materials provided with the distribution.
|
17
|
+
#
|
18
|
+
# * Neither the name of the copyright holder nor the names of its contributors
|
19
|
+
# may be used to endorse or promote products derived from this software without
|
20
|
+
# specific prior written permission.
|
21
|
+
#
|
22
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
25
|
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
26
|
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
27
|
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
28
|
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
29
|
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
30
|
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
31
|
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
32
|
+
# POSSIBILITY OF SUCH DAMAGE.
|
33
|
+
#
|
34
|
+
|
35
|
+
# This code has been adapted to add support for the out argument in rfftn, irfftn
|
36
|
+
# to allow for reusing existing array buffers
|
37
|
+
# Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
38
|
+
|
39
|
+
import threading
|
40
|
+
|
41
|
+
import pyfftw
|
42
|
+
import numpy as np
|
43
|
+
import pyfftw.builders as builders
|
44
|
+
from pyfftw.interfaces import cache
|
45
|
+
from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
|
46
|
+
|
47
|
+
|
48
|
+
def _Xfftn(
|
49
|
+
a,
|
50
|
+
s,
|
51
|
+
axes,
|
52
|
+
overwrite_input,
|
53
|
+
planner_effort,
|
54
|
+
threads,
|
55
|
+
auto_align_input,
|
56
|
+
auto_contiguous,
|
57
|
+
calling_func,
|
58
|
+
normalise_idft=True,
|
59
|
+
ortho=False,
|
60
|
+
real_direction_flag=None,
|
61
|
+
output_array=None,
|
62
|
+
):
|
63
|
+
|
64
|
+
work_with_copy = False
|
65
|
+
|
66
|
+
a = np.asanyarray(a)
|
67
|
+
|
68
|
+
try:
|
69
|
+
s = tuple(s)
|
70
|
+
except TypeError:
|
71
|
+
pass
|
72
|
+
|
73
|
+
try:
|
74
|
+
axes = tuple(axes)
|
75
|
+
except TypeError:
|
76
|
+
pass
|
77
|
+
|
78
|
+
if calling_func in ("dct", "dst"):
|
79
|
+
# real-to-real transforms require passing an additional flag argument
|
80
|
+
avoid_copy = False
|
81
|
+
args = (
|
82
|
+
overwrite_input,
|
83
|
+
planner_effort,
|
84
|
+
threads,
|
85
|
+
auto_align_input,
|
86
|
+
auto_contiguous,
|
87
|
+
avoid_copy,
|
88
|
+
real_direction_flag,
|
89
|
+
)
|
90
|
+
elif calling_func in ("irfft2", "irfftn"):
|
91
|
+
# overwrite_input is not an argument to irfft2 or irfftn
|
92
|
+
args = (planner_effort, threads, auto_align_input, auto_contiguous)
|
93
|
+
|
94
|
+
if not overwrite_input:
|
95
|
+
# Only irfft2 and irfftn have overwriting the input
|
96
|
+
# as the default (and so require the input array to
|
97
|
+
# be reloaded).
|
98
|
+
work_with_copy = True
|
99
|
+
else:
|
100
|
+
args = (
|
101
|
+
overwrite_input,
|
102
|
+
planner_effort,
|
103
|
+
threads,
|
104
|
+
auto_align_input,
|
105
|
+
auto_contiguous,
|
106
|
+
)
|
107
|
+
|
108
|
+
if not a.flags.writeable:
|
109
|
+
# Special case of a locked array - always work with a
|
110
|
+
# copy. See issue #92.
|
111
|
+
work_with_copy = True
|
112
|
+
|
113
|
+
if overwrite_input:
|
114
|
+
raise ValueError(
|
115
|
+
"overwrite_input cannot be True when the "
|
116
|
+
+ "input array flags.writeable is False"
|
117
|
+
)
|
118
|
+
|
119
|
+
if work_with_copy:
|
120
|
+
# We make the copy before registering the key so that the
|
121
|
+
# copy's stride information will be cached since this will be
|
122
|
+
# used for planning. Make sure the copy is byte aligned to
|
123
|
+
# prevent further copying
|
124
|
+
a_original = a
|
125
|
+
a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
|
126
|
+
a[...] = a_original
|
127
|
+
|
128
|
+
if cache.is_enabled():
|
129
|
+
alignment = a.ctypes.data % pyfftw.simd_alignment
|
130
|
+
|
131
|
+
key = (
|
132
|
+
calling_func,
|
133
|
+
a.shape,
|
134
|
+
a.strides,
|
135
|
+
a.dtype,
|
136
|
+
s.__hash__(),
|
137
|
+
axes.__hash__(),
|
138
|
+
alignment,
|
139
|
+
args,
|
140
|
+
threading.get_ident(),
|
141
|
+
)
|
142
|
+
|
143
|
+
try:
|
144
|
+
if key in cache._fftw_cache:
|
145
|
+
FFTW_object = cache._fftw_cache.lookup(key)
|
146
|
+
else:
|
147
|
+
FFTW_object = None
|
148
|
+
|
149
|
+
except KeyError:
|
150
|
+
# This occurs if the object has fallen out of the cache between
|
151
|
+
# the check and the lookup
|
152
|
+
FFTW_object = None
|
153
|
+
|
154
|
+
if not cache.is_enabled() or FFTW_object is None:
|
155
|
+
|
156
|
+
# If we're going to create a new FFTW object and are not
|
157
|
+
# working with a copy, then we need to copy the input array to
|
158
|
+
# preserve it, otherwise we can't actually take the transform
|
159
|
+
# of the input array! (in general, we have to assume that the
|
160
|
+
# input array will be destroyed during planning).
|
161
|
+
if not work_with_copy:
|
162
|
+
a_copy = a.copy()
|
163
|
+
|
164
|
+
planner_args = (a, s, axes) + args
|
165
|
+
|
166
|
+
FFTW_object = getattr(builders, calling_func)(*planner_args)
|
167
|
+
|
168
|
+
# Only copy if the input array is what was actually used
|
169
|
+
# (otherwise it shouldn't be overwritten)
|
170
|
+
if not work_with_copy and FFTW_object.input_array is a:
|
171
|
+
a[:] = a_copy
|
172
|
+
|
173
|
+
if cache.is_enabled():
|
174
|
+
cache._fftw_cache.insert(FFTW_object, key)
|
175
|
+
|
176
|
+
output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
|
177
|
+
|
178
|
+
else:
|
179
|
+
orig_output_array = FFTW_object.output_array
|
180
|
+
output_shape = orig_output_array.shape
|
181
|
+
output_dtype = orig_output_array.dtype
|
182
|
+
output_alignment = FFTW_object.output_alignment
|
183
|
+
|
184
|
+
if output_array is None:
|
185
|
+
output_array = pyfftw.empty_aligned(
|
186
|
+
output_shape, output_dtype, n=output_alignment
|
187
|
+
)
|
188
|
+
|
189
|
+
FFTW_object(
|
190
|
+
input_array=a,
|
191
|
+
output_array=output_array,
|
192
|
+
normalise_idft=normalise_idft,
|
193
|
+
ortho=ortho,
|
194
|
+
)
|
195
|
+
|
196
|
+
return output_array
|
197
|
+
|
198
|
+
|
199
|
+
def rfftn(
|
200
|
+
a,
|
201
|
+
s=None,
|
202
|
+
axes=None,
|
203
|
+
norm=None,
|
204
|
+
overwrite_input=False,
|
205
|
+
planner_effort=None,
|
206
|
+
threads=None,
|
207
|
+
auto_align_input=True,
|
208
|
+
auto_contiguous=True,
|
209
|
+
out=None,
|
210
|
+
):
|
211
|
+
"""Perform an n-D real FFT.
|
212
|
+
|
213
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
214
|
+
the rest of the arguments are documented
|
215
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
216
|
+
"""
|
217
|
+
calling_func = "rfftn"
|
218
|
+
planner_effort = _default_effort(planner_effort)
|
219
|
+
threads = _default_threads(threads)
|
220
|
+
|
221
|
+
return _Xfftn(
|
222
|
+
a,
|
223
|
+
s,
|
224
|
+
axes,
|
225
|
+
overwrite_input,
|
226
|
+
planner_effort,
|
227
|
+
threads,
|
228
|
+
auto_align_input,
|
229
|
+
auto_contiguous,
|
230
|
+
calling_func,
|
231
|
+
**_norm_args(norm),
|
232
|
+
output_array=out,
|
233
|
+
)
|
234
|
+
|
235
|
+
|
236
|
+
def irfftn(
|
237
|
+
a,
|
238
|
+
s=None,
|
239
|
+
axes=None,
|
240
|
+
norm=None,
|
241
|
+
overwrite_input=False,
|
242
|
+
planner_effort=None,
|
243
|
+
threads=None,
|
244
|
+
auto_align_input=True,
|
245
|
+
auto_contiguous=True,
|
246
|
+
out=None,
|
247
|
+
):
|
248
|
+
"""Perform an n-D real inverse FFT.
|
249
|
+
|
250
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
251
|
+
the rest of the arguments are documented
|
252
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
253
|
+
"""
|
254
|
+
calling_func = "irfftn"
|
255
|
+
planner_effort = _default_effort(planner_effort)
|
256
|
+
threads = _default_threads(threads)
|
257
|
+
|
258
|
+
return _Xfftn(
|
259
|
+
a,
|
260
|
+
s,
|
261
|
+
axes,
|
262
|
+
overwrite_input,
|
263
|
+
planner_effort,
|
264
|
+
threads,
|
265
|
+
auto_align_input,
|
266
|
+
auto_contiguous,
|
267
|
+
calling_func,
|
268
|
+
**_norm_args(norm),
|
269
|
+
output_array=out,
|
270
|
+
)
|
tme/backends/cupy_backend.py
CHANGED
@@ -6,12 +6,9 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
import
|
9
|
+
from typing import Tuple, List
|
10
10
|
from importlib.util import find_spec
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from typing import Tuple, Callable, List
|
13
|
-
|
14
|
-
import numpy as np
|
15
12
|
|
16
13
|
from .npfftw_backend import NumpyFFTWBackend
|
17
14
|
from ..types import CupyArray, NDArray, shm_type
|
@@ -114,54 +111,14 @@ class CupyBackend(NumpyFFTWBackend):
|
|
114
111
|
def unravel_index(self, indices, shape):
|
115
112
|
return self._array_backend.unravel_index(indices=indices, dims=shape)
|
116
113
|
|
117
|
-
def
|
118
|
-
self
|
119
|
-
fwd_shape: Tuple[int],
|
120
|
-
inv_shape: Tuple[int],
|
121
|
-
inv_output_shape: Tuple[int] = None,
|
122
|
-
fwd_axes: Tuple[int] = None,
|
123
|
-
inv_axes: Tuple[int] = None,
|
124
|
-
**kwargs,
|
125
|
-
) -> Tuple[Callable, Callable]:
|
126
|
-
cache = self._array_backend.fft.config.get_plan_cache()
|
127
|
-
current_device = self._array_backend.cuda.device.get_device_id()
|
128
|
-
|
129
|
-
previous_transform = [fwd_shape, inv_shape]
|
130
|
-
if current_device in PLAN_CACHE:
|
131
|
-
previous_transform = PLAN_CACHE[current_device]
|
132
|
-
|
133
|
-
real_diff, cmplx_diff = True, True
|
134
|
-
if len(fwd_shape) == len(previous_transform[0]):
|
135
|
-
real_diff = fwd_shape == previous_transform[0]
|
136
|
-
if len(inv_shape) == len(previous_transform[1]):
|
137
|
-
cmplx_diff = inv_shape == previous_transform[1]
|
138
|
-
|
139
|
-
if real_diff or cmplx_diff:
|
140
|
-
cache.clear()
|
141
|
-
|
142
|
-
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
143
|
-
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
144
|
-
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
145
|
-
|
146
|
-
def rfftn(
|
147
|
-
arr: CupyArray, out: CupyArray = None, s=rfft_shape, axes=fwd_axes
|
148
|
-
) -> CupyArray:
|
149
|
-
return self.rfftn(arr, s=s, axes=fwd_axes)
|
150
|
-
|
151
|
-
def irfftn(
|
152
|
-
arr: CupyArray, out: CupyArray = None, s=irfft_shape, axes=inv_axes
|
153
|
-
) -> CupyArray:
|
154
|
-
return self.irfftn(arr, s=s, axes=inv_axes)
|
155
|
-
|
156
|
-
PLAN_CACHE[current_device] = [fwd_shape, inv_shape]
|
157
|
-
|
158
|
-
return rfftn, irfftn
|
114
|
+
def free_cache(self):
|
115
|
+
self._array_backend.fft.config.get_plan_cache().clear()
|
159
116
|
|
160
117
|
def rfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
161
118
|
return self._cufft.rfftn(arr, **kwargs)
|
162
119
|
|
163
120
|
def irfftn(self, arr: CupyArray, out: CupyArray = None, **kwargs) -> CupyArray:
|
164
|
-
return self._cufft.irfftn(arr, **kwargs)
|
121
|
+
return self._cufft.irfftn(arr, **kwargs).astype(self._float_dtype)
|
165
122
|
|
166
123
|
def compute_convolution_shapes(
|
167
124
|
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
@@ -239,13 +196,13 @@ class CupyBackend(NumpyFFTWBackend):
|
|
239
196
|
)
|
240
197
|
return None
|
241
198
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
199
|
+
if data.ndim == 3 and cache and self.texture_available and not batched:
|
200
|
+
# Device memory pool (should) come to rescue performance
|
201
|
+
temp = self.zeros(data.shape, data.dtype)
|
202
|
+
texture = self._get_texture(data, order=order, prefilter=prefilter)
|
203
|
+
texture.affine(transform_m=matrix, profile=False, output=temp)
|
204
|
+
output[out_slice] = temp
|
205
|
+
return None
|
249
206
|
|
250
207
|
self.affine_transform(
|
251
208
|
input=data,
|