pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +23 -10
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +13 -3
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +15 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +5 -44
- tme/backends/jax_backend.py +58 -37
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +68 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +7 -19
- tme/matching_exhaustive.py +34 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +206 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +4 -6
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tests/data/.DS_Store
ADDED
Binary file
|
Binary file
|
Binary file
|
tests/data/Raw/.DS_Store
ADDED
Binary file
|
Binary file
|
@@ -51,6 +51,24 @@ class TestPreprocessUtils:
|
|
51
51
|
assert fgrid.shape == tuple(tilt_shape)
|
52
52
|
assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
|
53
53
|
|
54
|
+
@pytest.mark.parametrize("shape", ((15, 15, 15), (31, 31, 31), (64, 64, 64)))
|
55
|
+
@pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
|
56
|
+
@pytest.mark.parametrize("angle", (-5, 0, 5))
|
57
|
+
def test_freqgrid_comparison(self, shape, sampling_rate, angle):
|
58
|
+
grid = frequency_grid_at_angle(
|
59
|
+
shape=shape,
|
60
|
+
angle=angle,
|
61
|
+
sampling_rate=sampling_rate,
|
62
|
+
opening_axis=2,
|
63
|
+
tilt_axis=0,
|
64
|
+
)
|
65
|
+
grid2 = fftfreqn(
|
66
|
+
shape=shape[1:], sampling_rate=sampling_rate, compute_euclidean_norm=True
|
67
|
+
)
|
68
|
+
|
69
|
+
# These should be equal for cubical input shapes
|
70
|
+
assert np.allclose(grid, grid2)
|
71
|
+
|
54
72
|
@pytest.mark.parametrize("n", [10, 100, 1000])
|
55
73
|
@pytest.mark.parametrize("sampling_rate", range(1, 4))
|
56
74
|
def test_fftfreqn(self, n, sampling_rate):
|
tests/test_backends.py
CHANGED
@@ -292,16 +292,10 @@ class TestBackends:
|
|
292
292
|
|
293
293
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
294
294
|
@pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
|
295
|
-
def
|
295
|
+
def test_fft(self, backend, fast_shape):
|
296
296
|
_, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
297
297
|
fast_shape, (1 for _ in range(len(fast_shape)))
|
298
298
|
)
|
299
|
-
rfftn, irfftn = backend.build_fft(
|
300
|
-
fwd_shape=fast_shape,
|
301
|
-
inv_shape=fast_ft_shape,
|
302
|
-
real_dtype=backend._float_dtype,
|
303
|
-
cmpl_dtype=backend._complex_dtype,
|
304
|
-
)
|
305
299
|
arr = np.random.rand(*fast_shape)
|
306
300
|
out = np.zeros(fast_ft_shape)
|
307
301
|
|
@@ -310,11 +304,11 @@ class TestBackends:
|
|
310
304
|
backend.to_backend_array(out), backend._complex_dtype
|
311
305
|
)
|
312
306
|
|
313
|
-
rfftn(
|
307
|
+
backend.rfftn(
|
314
308
|
backend.astype(backend.to_backend_array(arr), backend._float_dtype),
|
315
309
|
complex_arr,
|
316
310
|
)
|
317
|
-
irfftn(complex_arr, real_arr)
|
311
|
+
backend.irfftn(complex_arr, real_arr)
|
318
312
|
assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
|
319
313
|
|
320
314
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
tests/test_density.py
CHANGED
@@ -98,7 +98,6 @@ class TestDensity:
|
|
98
98
|
assert np.allclose(density.data, self.density.data)
|
99
99
|
assert np.allclose(density.sampling_rate, self.density.sampling_rate)
|
100
100
|
assert np.allclose(density.origin, self.density.origin)
|
101
|
-
assert density.metadata == self.density.metadata
|
102
101
|
|
103
102
|
def test_from_file_baseline(self):
|
104
103
|
self.test_to_file(gzip=False)
|
tests/test_matching_utils.py
CHANGED
@@ -10,9 +10,6 @@ from tme.backends import backend as be
|
|
10
10
|
from tme.memory import MATCHING_MEMORY_REGISTRY
|
11
11
|
from tme.matching_utils import (
|
12
12
|
compute_parallelization_schedule,
|
13
|
-
elliptical_mask,
|
14
|
-
box_mask,
|
15
|
-
tube_mask,
|
16
13
|
create_mask,
|
17
14
|
scramble_phases,
|
18
15
|
apply_convolution_mode,
|
@@ -50,73 +47,26 @@ class TestMatchingUtils:
|
|
50
47
|
max_splits=256,
|
51
48
|
)
|
52
49
|
|
53
|
-
|
50
|
+
@pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
|
51
|
+
def test_create_mask(self, mask_type: str):
|
54
52
|
create_mask(
|
55
|
-
mask_type=
|
53
|
+
mask_type=mask_type,
|
56
54
|
shape=self.density.shape,
|
57
55
|
radius=5,
|
58
56
|
center=np.divide(self.density.shape, 2),
|
57
|
+
height=np.max(self.density.shape) // 2,
|
58
|
+
size=np.divide(self.density.shape, 2).astype(int),
|
59
|
+
thickness=2,
|
60
|
+
separation=2,
|
61
|
+
symmetry_axis=1,
|
62
|
+
inner_radius=5,
|
63
|
+
outer_radius=10,
|
59
64
|
)
|
60
65
|
|
61
66
|
def test_create_mask_error(self):
|
62
67
|
with pytest.raises(ValueError):
|
63
68
|
create_mask(mask_type=None)
|
64
69
|
|
65
|
-
def test_elliptical_mask(self):
|
66
|
-
elliptical_mask(
|
67
|
-
shape=self.density.shape,
|
68
|
-
radius=5,
|
69
|
-
center=np.divide(self.density.shape, 2),
|
70
|
-
)
|
71
|
-
|
72
|
-
def test_box_mask(self):
|
73
|
-
box_mask(
|
74
|
-
shape=self.density.shape,
|
75
|
-
height=[5, 10, 20],
|
76
|
-
center=np.divide(self.density.shape, 2),
|
77
|
-
)
|
78
|
-
|
79
|
-
def test_tube_mask(self):
|
80
|
-
tube_mask(
|
81
|
-
shape=self.density.shape,
|
82
|
-
outer_radius=10,
|
83
|
-
inner_radius=5,
|
84
|
-
height=5,
|
85
|
-
base_center=np.divide(self.density.shape, 2),
|
86
|
-
symmetry_axis=1,
|
87
|
-
)
|
88
|
-
|
89
|
-
def test_tube_mask_error(self):
|
90
|
-
with pytest.raises(ValueError):
|
91
|
-
tube_mask(
|
92
|
-
shape=self.density.shape,
|
93
|
-
outer_radius=5,
|
94
|
-
inner_radius=10,
|
95
|
-
height=5,
|
96
|
-
base_center=np.divide(self.density.shape, 2),
|
97
|
-
symmetry_axis=1,
|
98
|
-
)
|
99
|
-
|
100
|
-
with pytest.raises(ValueError):
|
101
|
-
tube_mask(
|
102
|
-
shape=self.density.shape,
|
103
|
-
outer_radius=5,
|
104
|
-
inner_radius=10,
|
105
|
-
height=10 * np.max(self.density.shape),
|
106
|
-
base_center=np.divide(self.density.shape, 2),
|
107
|
-
symmetry_axis=1,
|
108
|
-
)
|
109
|
-
|
110
|
-
with pytest.raises(ValueError):
|
111
|
-
tube_mask(
|
112
|
-
shape=self.density.shape,
|
113
|
-
outer_radius=5,
|
114
|
-
inner_radius=10,
|
115
|
-
height=10 * np.max(self.density.shape),
|
116
|
-
base_center=np.divide(self.density.shape, 2),
|
117
|
-
symmetry_axis=len(self.density.shape) + 1,
|
118
|
-
)
|
119
|
-
|
120
70
|
def test_scramble_phases(self):
|
121
71
|
scramble_phases(arr=self.density.data, noise_proportion=0.5)
|
122
72
|
|
tests/test_rotations.py
CHANGED
@@ -8,6 +8,7 @@ from tme import Density
|
|
8
8
|
from scipy.spatial.transform import Rotation
|
9
9
|
from scipy.signal import correlate
|
10
10
|
|
11
|
+
from tme.mask import elliptical_mask
|
11
12
|
from tme.rotations import (
|
12
13
|
euler_from_rotationmatrix,
|
13
14
|
euler_to_rotationmatrix,
|
@@ -16,7 +17,6 @@ from tme.rotations import (
|
|
16
17
|
get_rotation_matrices,
|
17
18
|
)
|
18
19
|
from tme.matching_utils import (
|
19
|
-
elliptical_mask,
|
20
20
|
split_shape,
|
21
21
|
compute_full_convolution_index,
|
22
22
|
)
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.
|
1
|
+
__version__ = "0.3.1"
|
tme/analyzer/_utils.py
CHANGED
@@ -93,22 +93,22 @@ def cart_to_score(
|
|
93
93
|
templateshape = be.to_backend_array(templateshape)
|
94
94
|
convolution_shape = be.to_backend_array(convolution_shape)
|
95
95
|
|
96
|
-
# Compute removed padding
|
97
96
|
output_shape = _convmode_to_shape(
|
98
97
|
convolution_mode=convolution_mode,
|
99
98
|
targetshape=targetshape,
|
100
99
|
templateshape=templateshape,
|
101
100
|
convolution_shape=convolution_shape,
|
102
101
|
)
|
103
|
-
valid_positions = be.multiply(positions >= 0, positions < output_shape)
|
104
|
-
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
105
102
|
|
103
|
+
# Offset from padding the target
|
106
104
|
starts = be.astype(
|
107
105
|
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
108
106
|
be._int_dtype,
|
109
107
|
)
|
110
|
-
|
111
108
|
positions = be.add(positions, starts)
|
109
|
+
|
110
|
+
valid_positions = be.multiply(positions >= 0, positions < fast_shape)
|
111
|
+
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
112
112
|
if fourier_shift is not None:
|
113
113
|
fourier_shift = be.to_backend_array(fourier_shift)
|
114
114
|
positions = be.subtract(positions, fourier_shift)
|
tme/analyzer/aggregation.py
CHANGED
@@ -545,13 +545,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
545
545
|
)
|
546
546
|
|
547
547
|
def __call__(
|
548
|
-
self,
|
548
|
+
self,
|
549
|
+
state: Tuple,
|
550
|
+
scores: BackendArray,
|
551
|
+
rotation_matrix: BackendArray,
|
552
|
+
**kwargs,
|
549
553
|
) -> Tuple:
|
550
554
|
mask = self._get_constraint(rotation_matrix)
|
551
555
|
mask = self._get_score_mask(mask=mask, scores=scores)
|
552
556
|
|
553
557
|
scores = be.multiply(scores, mask, out=scores)
|
554
|
-
return super().__call__(
|
558
|
+
return super().__call__(
|
559
|
+
state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
|
560
|
+
)
|
555
561
|
|
556
562
|
def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
|
557
563
|
"""
|
@@ -636,7 +642,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
636
642
|
return scores, rotations, {}
|
637
643
|
|
638
644
|
def __call__(
|
639
|
-
self,
|
645
|
+
self,
|
646
|
+
state,
|
647
|
+
scores: BackendArray,
|
648
|
+
rotation_matrix: BackendArray,
|
649
|
+
**kwargs,
|
640
650
|
) -> Tuple:
|
641
651
|
prev_scores, rotations, rotation_mapping = state
|
642
652
|
|
tme/analyzer/peaks.py
CHANGED
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
|
|
228
228
|
translations = be.full(
|
229
229
|
(self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
|
230
230
|
)
|
231
|
+
|
232
|
+
rdim = len(self.shape)
|
233
|
+
if self.batch_dims:
|
234
|
+
rdim -= len(self.batch_dims)
|
235
|
+
|
231
236
|
rotations = be.full(
|
232
|
-
(self.num_peaks,
|
237
|
+
(self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
|
233
238
|
)
|
234
|
-
for i in range(
|
239
|
+
for i in range(rdim):
|
235
240
|
rotations[:, i, i] = 1.0
|
236
241
|
|
237
242
|
scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
|
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
750
755
|
Dictionary mapping values in rotations to Euler angles.
|
751
756
|
By default None
|
752
757
|
min_score : float
|
753
|
-
Minimum score value to consider.
|
754
|
-
by :py:attr:`PeakCaller.num_peaks`.
|
758
|
+
Minimum score value to consider.
|
755
759
|
|
756
760
|
Returns
|
757
761
|
-------
|
@@ -774,10 +778,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
774
778
|
mask = be.to_backend_array(mask)
|
775
779
|
mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
|
776
780
|
|
777
|
-
|
778
|
-
if min_score is not None:
|
779
|
-
peak_limit = be.size(scores)
|
780
|
-
else:
|
781
|
+
if min_score is None:
|
781
782
|
min_score = be.min(scores) - 1
|
782
783
|
|
783
784
|
_scores = be.zeros(scores.shape, dtype=scores.dtype)
|
@@ -815,7 +816,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
815
816
|
score_mask = mask_buffer[tmpl_slice] <= 0.1
|
816
817
|
|
817
818
|
_scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
|
818
|
-
if len(peaks) >=
|
819
|
+
if len(peaks) >= self.num_peaks:
|
819
820
|
break
|
820
821
|
|
821
822
|
return be.to_backend_array(peaks), None
|
@@ -851,7 +852,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
|
|
851
852
|
|
852
853
|
rotation = rotation_mapping[rotation_space[tuple(peak)]]
|
853
854
|
|
854
|
-
#
|
855
|
+
# Old versions of rotation mapping contained Euler angles
|
855
856
|
if rotation.ndim != 2:
|
856
857
|
rotation = be.to_backend_array(
|
857
858
|
euler_to_rotationmatrix(be.to_numpy_array(rotation))
|
tme/backends/_jax_utils.py
CHANGED
@@ -115,7 +115,8 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
|
115
115
|
@partial(
|
116
116
|
pmap,
|
117
117
|
in_axes=(0,) + (None,) * 6,
|
118
|
-
static_broadcasted_argnums=[6, 7],
|
118
|
+
static_broadcasted_argnums=[6, 7, 8, 9],
|
119
|
+
axis_name="batch",
|
119
120
|
)
|
120
121
|
def scan(
|
121
122
|
target: BackendArray,
|
@@ -126,9 +127,17 @@ def scan(
|
|
126
127
|
target_filter: BackendArray,
|
127
128
|
fast_shape: Tuple[int],
|
128
129
|
rotate_mask: bool,
|
130
|
+
analyzer_class: object,
|
131
|
+
analyzer_kwargs: Tuple[Tuple],
|
129
132
|
) -> Tuple[BackendArray, BackendArray]:
|
130
133
|
eps = jnp.finfo(template.dtype).resolution
|
131
134
|
|
135
|
+
kwargs = lax.switch(
|
136
|
+
lax.axis_index("batch"),
|
137
|
+
[lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
|
138
|
+
)
|
139
|
+
analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
|
140
|
+
|
132
141
|
if hasattr(target_filter, "shape"):
|
133
142
|
target = _apply_fourier_filter(target, target_filter)
|
134
143
|
|
@@ -151,7 +160,7 @@ def scan(
|
|
151
160
|
_template_filter_func = _apply_fourier_filter
|
152
161
|
|
153
162
|
def _sample_transform(ret, rotation_matrix):
|
154
|
-
|
163
|
+
state, index = ret
|
155
164
|
template_rot, template_mask_rot = be.rigid_transform(
|
156
165
|
arr=template,
|
157
166
|
arr_mask=template_mask,
|
@@ -176,15 +185,8 @@ def scan(
|
|
176
185
|
n_observations=n_observations,
|
177
186
|
eps=eps,
|
178
187
|
)
|
179
|
-
|
180
|
-
|
181
|
-
)
|
182
|
-
return (max_scores, rotations, index + 1), None
|
183
|
-
|
184
|
-
score_space = jnp.zeros(fast_shape)
|
185
|
-
rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
|
186
|
-
(score_space, rotation_space, _), _ = lax.scan(
|
187
|
-
_sample_transform, (score_space, rotation_space, 0), rotations
|
188
|
-
)
|
188
|
+
state = analyzer(state, scores, rotation_matrix, rotation_index=index)
|
189
|
+
return (state, index + 1), None
|
189
190
|
|
190
|
-
|
191
|
+
(state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
|
192
|
+
return state
|
@@ -0,0 +1,270 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
#
|
3
|
+
# Henry Gomersall
|
4
|
+
# heng@kedevelopments.co.uk
|
5
|
+
#
|
6
|
+
# All rights reserved.
|
7
|
+
#
|
8
|
+
# Redistribution and use in source and binary forms, with or without
|
9
|
+
# modification, are permitted provided that the following conditions are met:
|
10
|
+
#
|
11
|
+
# * Redistributions of source code must retain the above copyright notice, this
|
12
|
+
# list of conditions and the following disclaimer.
|
13
|
+
#
|
14
|
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
15
|
+
# this list of conditions and the following disclaimer in the documentation
|
16
|
+
# and/or other materials provided with the distribution.
|
17
|
+
#
|
18
|
+
# * Neither the name of the copyright holder nor the names of its contributors
|
19
|
+
# may be used to endorse or promote products derived from this software without
|
20
|
+
# specific prior written permission.
|
21
|
+
#
|
22
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
23
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
24
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
25
|
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
26
|
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
27
|
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
28
|
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
29
|
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
30
|
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
31
|
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
32
|
+
# POSSIBILITY OF SUCH DAMAGE.
|
33
|
+
#
|
34
|
+
|
35
|
+
# This code has been adapted to add support for the out argument in rfftn, irfftn
|
36
|
+
# to allow for reusing existing array buffers
|
37
|
+
# Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
38
|
+
|
39
|
+
import threading
|
40
|
+
|
41
|
+
import pyfftw
|
42
|
+
import numpy as np
|
43
|
+
import pyfftw.builders as builders
|
44
|
+
from pyfftw.interfaces import cache
|
45
|
+
from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
|
46
|
+
|
47
|
+
|
48
|
+
def _Xfftn(
|
49
|
+
a,
|
50
|
+
s,
|
51
|
+
axes,
|
52
|
+
overwrite_input,
|
53
|
+
planner_effort,
|
54
|
+
threads,
|
55
|
+
auto_align_input,
|
56
|
+
auto_contiguous,
|
57
|
+
calling_func,
|
58
|
+
normalise_idft=True,
|
59
|
+
ortho=False,
|
60
|
+
real_direction_flag=None,
|
61
|
+
output_array=None,
|
62
|
+
):
|
63
|
+
|
64
|
+
work_with_copy = False
|
65
|
+
|
66
|
+
a = np.asanyarray(a)
|
67
|
+
|
68
|
+
try:
|
69
|
+
s = tuple(s)
|
70
|
+
except TypeError:
|
71
|
+
pass
|
72
|
+
|
73
|
+
try:
|
74
|
+
axes = tuple(axes)
|
75
|
+
except TypeError:
|
76
|
+
pass
|
77
|
+
|
78
|
+
if calling_func in ("dct", "dst"):
|
79
|
+
# real-to-real transforms require passing an additional flag argument
|
80
|
+
avoid_copy = False
|
81
|
+
args = (
|
82
|
+
overwrite_input,
|
83
|
+
planner_effort,
|
84
|
+
threads,
|
85
|
+
auto_align_input,
|
86
|
+
auto_contiguous,
|
87
|
+
avoid_copy,
|
88
|
+
real_direction_flag,
|
89
|
+
)
|
90
|
+
elif calling_func in ("irfft2", "irfftn"):
|
91
|
+
# overwrite_input is not an argument to irfft2 or irfftn
|
92
|
+
args = (planner_effort, threads, auto_align_input, auto_contiguous)
|
93
|
+
|
94
|
+
if not overwrite_input:
|
95
|
+
# Only irfft2 and irfftn have overwriting the input
|
96
|
+
# as the default (and so require the input array to
|
97
|
+
# be reloaded).
|
98
|
+
work_with_copy = True
|
99
|
+
else:
|
100
|
+
args = (
|
101
|
+
overwrite_input,
|
102
|
+
planner_effort,
|
103
|
+
threads,
|
104
|
+
auto_align_input,
|
105
|
+
auto_contiguous,
|
106
|
+
)
|
107
|
+
|
108
|
+
if not a.flags.writeable:
|
109
|
+
# Special case of a locked array - always work with a
|
110
|
+
# copy. See issue #92.
|
111
|
+
work_with_copy = True
|
112
|
+
|
113
|
+
if overwrite_input:
|
114
|
+
raise ValueError(
|
115
|
+
"overwrite_input cannot be True when the "
|
116
|
+
+ "input array flags.writeable is False"
|
117
|
+
)
|
118
|
+
|
119
|
+
if work_with_copy:
|
120
|
+
# We make the copy before registering the key so that the
|
121
|
+
# copy's stride information will be cached since this will be
|
122
|
+
# used for planning. Make sure the copy is byte aligned to
|
123
|
+
# prevent further copying
|
124
|
+
a_original = a
|
125
|
+
a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
|
126
|
+
a[...] = a_original
|
127
|
+
|
128
|
+
if cache.is_enabled():
|
129
|
+
alignment = a.ctypes.data % pyfftw.simd_alignment
|
130
|
+
|
131
|
+
key = (
|
132
|
+
calling_func,
|
133
|
+
a.shape,
|
134
|
+
a.strides,
|
135
|
+
a.dtype,
|
136
|
+
s.__hash__(),
|
137
|
+
axes.__hash__(),
|
138
|
+
alignment,
|
139
|
+
args,
|
140
|
+
threading.get_ident(),
|
141
|
+
)
|
142
|
+
|
143
|
+
try:
|
144
|
+
if key in cache._fftw_cache:
|
145
|
+
FFTW_object = cache._fftw_cache.lookup(key)
|
146
|
+
else:
|
147
|
+
FFTW_object = None
|
148
|
+
|
149
|
+
except KeyError:
|
150
|
+
# This occurs if the object has fallen out of the cache between
|
151
|
+
# the check and the lookup
|
152
|
+
FFTW_object = None
|
153
|
+
|
154
|
+
if not cache.is_enabled() or FFTW_object is None:
|
155
|
+
|
156
|
+
# If we're going to create a new FFTW object and are not
|
157
|
+
# working with a copy, then we need to copy the input array to
|
158
|
+
# preserve it, otherwise we can't actually take the transform
|
159
|
+
# of the input array! (in general, we have to assume that the
|
160
|
+
# input array will be destroyed during planning).
|
161
|
+
if not work_with_copy:
|
162
|
+
a_copy = a.copy()
|
163
|
+
|
164
|
+
planner_args = (a, s, axes) + args
|
165
|
+
|
166
|
+
FFTW_object = getattr(builders, calling_func)(*planner_args)
|
167
|
+
|
168
|
+
# Only copy if the input array is what was actually used
|
169
|
+
# (otherwise it shouldn't be overwritten)
|
170
|
+
if not work_with_copy and FFTW_object.input_array is a:
|
171
|
+
a[:] = a_copy
|
172
|
+
|
173
|
+
if cache.is_enabled():
|
174
|
+
cache._fftw_cache.insert(FFTW_object, key)
|
175
|
+
|
176
|
+
output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
|
177
|
+
|
178
|
+
else:
|
179
|
+
orig_output_array = FFTW_object.output_array
|
180
|
+
output_shape = orig_output_array.shape
|
181
|
+
output_dtype = orig_output_array.dtype
|
182
|
+
output_alignment = FFTW_object.output_alignment
|
183
|
+
|
184
|
+
if output_array is None:
|
185
|
+
output_array = pyfftw.empty_aligned(
|
186
|
+
output_shape, output_dtype, n=output_alignment
|
187
|
+
)
|
188
|
+
|
189
|
+
FFTW_object(
|
190
|
+
input_array=a,
|
191
|
+
output_array=output_array,
|
192
|
+
normalise_idft=normalise_idft,
|
193
|
+
ortho=ortho,
|
194
|
+
)
|
195
|
+
|
196
|
+
return output_array
|
197
|
+
|
198
|
+
|
199
|
+
def rfftn(
|
200
|
+
a,
|
201
|
+
s=None,
|
202
|
+
axes=None,
|
203
|
+
norm=None,
|
204
|
+
overwrite_input=False,
|
205
|
+
planner_effort=None,
|
206
|
+
threads=None,
|
207
|
+
auto_align_input=True,
|
208
|
+
auto_contiguous=True,
|
209
|
+
out=None,
|
210
|
+
):
|
211
|
+
"""Perform an n-D real FFT.
|
212
|
+
|
213
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
214
|
+
the rest of the arguments are documented
|
215
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
216
|
+
"""
|
217
|
+
calling_func = "rfftn"
|
218
|
+
planner_effort = _default_effort(planner_effort)
|
219
|
+
threads = _default_threads(threads)
|
220
|
+
|
221
|
+
return _Xfftn(
|
222
|
+
a,
|
223
|
+
s,
|
224
|
+
axes,
|
225
|
+
overwrite_input,
|
226
|
+
planner_effort,
|
227
|
+
threads,
|
228
|
+
auto_align_input,
|
229
|
+
auto_contiguous,
|
230
|
+
calling_func,
|
231
|
+
**_norm_args(norm),
|
232
|
+
output_array=out,
|
233
|
+
)
|
234
|
+
|
235
|
+
|
236
|
+
def irfftn(
|
237
|
+
a,
|
238
|
+
s=None,
|
239
|
+
axes=None,
|
240
|
+
norm=None,
|
241
|
+
overwrite_input=False,
|
242
|
+
planner_effort=None,
|
243
|
+
threads=None,
|
244
|
+
auto_align_input=True,
|
245
|
+
auto_contiguous=True,
|
246
|
+
out=None,
|
247
|
+
):
|
248
|
+
"""Perform an n-D real inverse FFT.
|
249
|
+
|
250
|
+
The first four arguments are as per :func:`numpy.fft.rfftn`;
|
251
|
+
the rest of the arguments are documented
|
252
|
+
in the :ref:`additional arguments docs<interfaces_additional_args>`.
|
253
|
+
"""
|
254
|
+
calling_func = "irfftn"
|
255
|
+
planner_effort = _default_effort(planner_effort)
|
256
|
+
threads = _default_threads(threads)
|
257
|
+
|
258
|
+
return _Xfftn(
|
259
|
+
a,
|
260
|
+
s,
|
261
|
+
axes,
|
262
|
+
overwrite_input,
|
263
|
+
planner_effort,
|
264
|
+
threads,
|
265
|
+
auto_align_input,
|
266
|
+
auto_contiguous,
|
267
|
+
calling_func,
|
268
|
+
**_norm_args(norm),
|
269
|
+
output_array=out,
|
270
|
+
)
|