pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tests/data/.DS_Store ADDED
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -51,6 +51,24 @@ class TestPreprocessUtils:
51
51
  assert fgrid.shape == tuple(tilt_shape)
52
52
  assert fgrid.max() <= np.sqrt(1 / sampling_rate * len(shape))
53
53
 
54
+ @pytest.mark.parametrize("shape", ((15, 15, 15), (31, 31, 31), (64, 64, 64)))
55
+ @pytest.mark.parametrize("sampling_rate", (0.5, 1, 2))
56
+ @pytest.mark.parametrize("angle", (-5, 0, 5))
57
+ def test_freqgrid_comparison(self, shape, sampling_rate, angle):
58
+ grid = frequency_grid_at_angle(
59
+ shape=shape,
60
+ angle=angle,
61
+ sampling_rate=sampling_rate,
62
+ opening_axis=2,
63
+ tilt_axis=0,
64
+ )
65
+ grid2 = fftfreqn(
66
+ shape=shape[1:], sampling_rate=sampling_rate, compute_euclidean_norm=True
67
+ )
68
+
69
+ # These should be equal for cubical input shapes
70
+ assert np.allclose(grid, grid2)
71
+
54
72
  @pytest.mark.parametrize("n", [10, 100, 1000])
55
73
  @pytest.mark.parametrize("sampling_rate", range(1, 4))
56
74
  def test_fftfreqn(self, n, sampling_rate):
tests/test_backends.py CHANGED
@@ -292,16 +292,10 @@ class TestBackends:
292
292
 
293
293
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
294
294
  @pytest.mark.parametrize("fast_shape", ((10, 15, 100), (55, 23, 17)))
295
- def test_build_fft(self, backend, fast_shape):
295
+ def test_fft(self, backend, fast_shape):
296
296
  _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
297
297
  fast_shape, (1 for _ in range(len(fast_shape)))
298
298
  )
299
- rfftn, irfftn = backend.build_fft(
300
- fwd_shape=fast_shape,
301
- inv_shape=fast_ft_shape,
302
- real_dtype=backend._float_dtype,
303
- cmpl_dtype=backend._complex_dtype,
304
- )
305
299
  arr = np.random.rand(*fast_shape)
306
300
  out = np.zeros(fast_ft_shape)
307
301
 
@@ -310,11 +304,11 @@ class TestBackends:
310
304
  backend.to_backend_array(out), backend._complex_dtype
311
305
  )
312
306
 
313
- rfftn(
307
+ backend.rfftn(
314
308
  backend.astype(backend.to_backend_array(arr), backend._float_dtype),
315
309
  complex_arr,
316
310
  )
317
- irfftn(complex_arr, real_arr)
311
+ backend.irfftn(complex_arr, real_arr)
318
312
  assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
319
313
 
320
314
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
tests/test_density.py CHANGED
@@ -98,7 +98,6 @@ class TestDensity:
98
98
  assert np.allclose(density.data, self.density.data)
99
99
  assert np.allclose(density.sampling_rate, self.density.sampling_rate)
100
100
  assert np.allclose(density.origin, self.density.origin)
101
- assert density.metadata == self.density.metadata
102
101
 
103
102
  def test_from_file_baseline(self):
104
103
  self.test_to_file(gzip=False)
@@ -10,9 +10,6 @@ from tme.backends import backend as be
10
10
  from tme.memory import MATCHING_MEMORY_REGISTRY
11
11
  from tme.matching_utils import (
12
12
  compute_parallelization_schedule,
13
- elliptical_mask,
14
- box_mask,
15
- tube_mask,
16
13
  create_mask,
17
14
  scramble_phases,
18
15
  apply_convolution_mode,
@@ -50,73 +47,26 @@ class TestMatchingUtils:
50
47
  max_splits=256,
51
48
  )
52
49
 
53
- def test_create_mask(self):
50
+ @pytest.mark.parametrize("mask_type", ["ellipse", "box", "tube", "membrane"])
51
+ def test_create_mask(self, mask_type: str):
54
52
  create_mask(
55
- mask_type="ellipse",
53
+ mask_type=mask_type,
56
54
  shape=self.density.shape,
57
55
  radius=5,
58
56
  center=np.divide(self.density.shape, 2),
57
+ height=np.max(self.density.shape) // 2,
58
+ size=np.divide(self.density.shape, 2).astype(int),
59
+ thickness=2,
60
+ separation=2,
61
+ symmetry_axis=1,
62
+ inner_radius=5,
63
+ outer_radius=10,
59
64
  )
60
65
 
61
66
  def test_create_mask_error(self):
62
67
  with pytest.raises(ValueError):
63
68
  create_mask(mask_type=None)
64
69
 
65
- def test_elliptical_mask(self):
66
- elliptical_mask(
67
- shape=self.density.shape,
68
- radius=5,
69
- center=np.divide(self.density.shape, 2),
70
- )
71
-
72
- def test_box_mask(self):
73
- box_mask(
74
- shape=self.density.shape,
75
- height=[5, 10, 20],
76
- center=np.divide(self.density.shape, 2),
77
- )
78
-
79
- def test_tube_mask(self):
80
- tube_mask(
81
- shape=self.density.shape,
82
- outer_radius=10,
83
- inner_radius=5,
84
- height=5,
85
- base_center=np.divide(self.density.shape, 2),
86
- symmetry_axis=1,
87
- )
88
-
89
- def test_tube_mask_error(self):
90
- with pytest.raises(ValueError):
91
- tube_mask(
92
- shape=self.density.shape,
93
- outer_radius=5,
94
- inner_radius=10,
95
- height=5,
96
- base_center=np.divide(self.density.shape, 2),
97
- symmetry_axis=1,
98
- )
99
-
100
- with pytest.raises(ValueError):
101
- tube_mask(
102
- shape=self.density.shape,
103
- outer_radius=5,
104
- inner_radius=10,
105
- height=10 * np.max(self.density.shape),
106
- base_center=np.divide(self.density.shape, 2),
107
- symmetry_axis=1,
108
- )
109
-
110
- with pytest.raises(ValueError):
111
- tube_mask(
112
- shape=self.density.shape,
113
- outer_radius=5,
114
- inner_radius=10,
115
- height=10 * np.max(self.density.shape),
116
- base_center=np.divide(self.density.shape, 2),
117
- symmetry_axis=len(self.density.shape) + 1,
118
- )
119
-
120
70
  def test_scramble_phases(self):
121
71
  scramble_phases(arr=self.density.data, noise_proportion=0.5)
122
72
 
tests/test_rotations.py CHANGED
@@ -8,6 +8,7 @@ from tme import Density
8
8
  from scipy.spatial.transform import Rotation
9
9
  from scipy.signal import correlate
10
10
 
11
+ from tme.mask import elliptical_mask
11
12
  from tme.rotations import (
12
13
  euler_from_rotationmatrix,
13
14
  euler_to_rotationmatrix,
@@ -16,7 +17,6 @@ from tme.rotations import (
16
17
  get_rotation_matrices,
17
18
  )
18
19
  from tme.matching_utils import (
19
- elliptical_mask,
20
20
  split_shape,
21
21
  compute_full_convolution_index,
22
22
  )
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.b0.post1"
1
+ __version__ = "0.3.1"
tme/analyzer/_utils.py CHANGED
@@ -93,22 +93,22 @@ def cart_to_score(
93
93
  templateshape = be.to_backend_array(templateshape)
94
94
  convolution_shape = be.to_backend_array(convolution_shape)
95
95
 
96
- # Compute removed padding
97
96
  output_shape = _convmode_to_shape(
98
97
  convolution_mode=convolution_mode,
99
98
  targetshape=targetshape,
100
99
  templateshape=templateshape,
101
100
  convolution_shape=convolution_shape,
102
101
  )
103
- valid_positions = be.multiply(positions >= 0, positions < output_shape)
104
- valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
105
102
 
103
+ # Offset from padding the target
106
104
  starts = be.astype(
107
105
  be.divide(be.subtract(convolution_shape, output_shape), 2),
108
106
  be._int_dtype,
109
107
  )
110
-
111
108
  positions = be.add(positions, starts)
109
+
110
+ valid_positions = be.multiply(positions >= 0, positions < fast_shape)
111
+ valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
112
112
  if fourier_shift is not None:
113
113
  fourier_shift = be.to_backend_array(fourier_shift)
114
114
  positions = be.subtract(positions, fourier_shift)
@@ -545,13 +545,19 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
545
545
  )
546
546
 
547
547
  def __call__(
548
- self, state: Tuple, scores: BackendArray, rotation_matrix: BackendArray
548
+ self,
549
+ state: Tuple,
550
+ scores: BackendArray,
551
+ rotation_matrix: BackendArray,
552
+ **kwargs,
549
553
  ) -> Tuple:
550
554
  mask = self._get_constraint(rotation_matrix)
551
555
  mask = self._get_score_mask(mask=mask, scores=scores)
552
556
 
553
557
  scores = be.multiply(scores, mask, out=scores)
554
- return super().__call__(state, scores=scores, rotation_matrix=rotation_matrix)
558
+ return super().__call__(
559
+ state, scores=scores, rotation_matrix=rotation_matrix, **kwargs
560
+ )
555
561
 
556
562
  def _get_constraint(self, rotation_matrix: BackendArray) -> BackendArray:
557
563
  """
@@ -636,7 +642,11 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
636
642
  return scores, rotations, {}
637
643
 
638
644
  def __call__(
639
- self, state, scores: BackendArray, rotation_matrix: BackendArray
645
+ self,
646
+ state,
647
+ scores: BackendArray,
648
+ rotation_matrix: BackendArray,
649
+ **kwargs,
640
650
  ) -> Tuple:
641
651
  prev_scores, rotations, rotation_mapping = state
642
652
 
tme/analyzer/peaks.py CHANGED
@@ -228,10 +228,15 @@ class PeakCaller(AbstractAnalyzer):
228
228
  translations = be.full(
229
229
  (self.num_peaks, ndim), fill_value=-1, dtype=be._int_dtype
230
230
  )
231
+
232
+ rdim = len(self.shape)
233
+ if self.batch_dims:
234
+ rdim -= len(self.batch_dims)
235
+
231
236
  rotations = be.full(
232
- (self.num_peaks, ndim, ndim), fill_value=0, dtype=be._float_dtype
237
+ (self.num_peaks, rdim, rdim), fill_value=0, dtype=be._float_dtype
233
238
  )
234
- for i in range(ndim):
239
+ for i in range(rdim):
235
240
  rotations[:, i, i] = 1.0
236
241
 
237
242
  scores = be.full((self.num_peaks,), fill_value=-1, dtype=be._float_dtype)
@@ -750,8 +755,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
750
755
  Dictionary mapping values in rotations to Euler angles.
751
756
  By default None
752
757
  min_score : float
753
- Minimum score value to consider. If provided, superseeds limit given
754
- by :py:attr:`PeakCaller.num_peaks`.
758
+ Minimum score value to consider.
755
759
 
756
760
  Returns
757
761
  -------
@@ -774,10 +778,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
774
778
  mask = be.to_backend_array(mask)
775
779
  mask_buffer = be.zeros(mask.shape, dtype=mask.dtype)
776
780
 
777
- peak_limit = self.num_peaks
778
- if min_score is not None:
779
- peak_limit = be.size(scores)
780
- else:
781
+ if min_score is None:
781
782
  min_score = be.min(scores) - 1
782
783
 
783
784
  _scores = be.zeros(scores.shape, dtype=scores.dtype)
@@ -815,7 +816,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
815
816
  score_mask = mask_buffer[tmpl_slice] <= 0.1
816
817
 
817
818
  _scores[score_slice] = be.multiply(_scores[score_slice], score_mask)
818
- if len(peaks) >= peak_limit:
819
+ if len(peaks) >= self.num_peaks:
819
820
  break
820
821
 
821
822
  return be.to_backend_array(peaks), None
@@ -851,7 +852,7 @@ class PeakCallerRecursiveMasking(PeakCaller):
851
852
 
852
853
  rotation = rotation_mapping[rotation_space[tuple(peak)]]
853
854
 
854
- # TODO: Newer versions of rotation mapping contain rotation matrices not angles
855
+ # Old versions of rotation mapping contained Euler angles
855
856
  if rotation.ndim != 2:
856
857
  rotation = be.to_backend_array(
857
858
  euler_to_rotationmatrix(be.to_numpy_array(rotation))
@@ -115,7 +115,8 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
115
115
  @partial(
116
116
  pmap,
117
117
  in_axes=(0,) + (None,) * 6,
118
- static_broadcasted_argnums=[6, 7],
118
+ static_broadcasted_argnums=[6, 7, 8, 9],
119
+ axis_name="batch",
119
120
  )
120
121
  def scan(
121
122
  target: BackendArray,
@@ -126,9 +127,17 @@ def scan(
126
127
  target_filter: BackendArray,
127
128
  fast_shape: Tuple[int],
128
129
  rotate_mask: bool,
130
+ analyzer_class: object,
131
+ analyzer_kwargs: Tuple[Tuple],
129
132
  ) -> Tuple[BackendArray, BackendArray]:
130
133
  eps = jnp.finfo(template.dtype).resolution
131
134
 
135
+ kwargs = lax.switch(
136
+ lax.axis_index("batch"),
137
+ [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
138
+ )
139
+ analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
140
+
132
141
  if hasattr(target_filter, "shape"):
133
142
  target = _apply_fourier_filter(target, target_filter)
134
143
 
@@ -151,7 +160,7 @@ def scan(
151
160
  _template_filter_func = _apply_fourier_filter
152
161
 
153
162
  def _sample_transform(ret, rotation_matrix):
154
- max_scores, rotations, index = ret
163
+ state, index = ret
155
164
  template_rot, template_mask_rot = be.rigid_transform(
156
165
  arr=template,
157
166
  arr_mask=template_mask,
@@ -176,15 +185,8 @@ def scan(
176
185
  n_observations=n_observations,
177
186
  eps=eps,
178
187
  )
179
- max_scores, rotations = be.max_score_over_rotations(
180
- scores, max_scores, rotations, index
181
- )
182
- return (max_scores, rotations, index + 1), None
183
-
184
- score_space = jnp.zeros(fast_shape)
185
- rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
186
- (score_space, rotation_space, _), _ = lax.scan(
187
- _sample_transform, (score_space, rotation_space, 0), rotations
188
- )
188
+ state = analyzer(state, scores, rotation_matrix, rotation_index=index)
189
+ return (state, index + 1), None
189
190
 
190
- return score_space, rotation_space
191
+ (state, _), _ = lax.scan(_sample_transform, (analyzer.init_state(), 0), rotations)
192
+ return state
@@ -0,0 +1,270 @@
1
+ #!/usr/bin/env python
2
+ #
3
+ # Henry Gomersall
4
+ # heng@kedevelopments.co.uk
5
+ #
6
+ # All rights reserved.
7
+ #
8
+ # Redistribution and use in source and binary forms, with or without
9
+ # modification, are permitted provided that the following conditions are met:
10
+ #
11
+ # * Redistributions of source code must retain the above copyright notice, this
12
+ # list of conditions and the following disclaimer.
13
+ #
14
+ # * Redistributions in binary form must reproduce the above copyright notice,
15
+ # this list of conditions and the following disclaimer in the documentation
16
+ # and/or other materials provided with the distribution.
17
+ #
18
+ # * Neither the name of the copyright holder nor the names of its contributors
19
+ # may be used to endorse or promote products derived from this software without
20
+ # specific prior written permission.
21
+ #
22
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
23
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
26
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
29
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
30
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
31
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32
+ # POSSIBILITY OF SUCH DAMAGE.
33
+ #
34
+
35
+ # This code has been adapted to add support for the out argument in rfftn, irfftn
36
+ # to allow for reusing existing array buffers
37
+ # Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
38
+
39
+ import threading
40
+
41
+ import pyfftw
42
+ import numpy as np
43
+ import pyfftw.builders as builders
44
+ from pyfftw.interfaces import cache
45
+ from pyfftw.builders._utils import _norm_args, _default_effort, _default_threads
46
+
47
+
48
+ def _Xfftn(
49
+ a,
50
+ s,
51
+ axes,
52
+ overwrite_input,
53
+ planner_effort,
54
+ threads,
55
+ auto_align_input,
56
+ auto_contiguous,
57
+ calling_func,
58
+ normalise_idft=True,
59
+ ortho=False,
60
+ real_direction_flag=None,
61
+ output_array=None,
62
+ ):
63
+
64
+ work_with_copy = False
65
+
66
+ a = np.asanyarray(a)
67
+
68
+ try:
69
+ s = tuple(s)
70
+ except TypeError:
71
+ pass
72
+
73
+ try:
74
+ axes = tuple(axes)
75
+ except TypeError:
76
+ pass
77
+
78
+ if calling_func in ("dct", "dst"):
79
+ # real-to-real transforms require passing an additional flag argument
80
+ avoid_copy = False
81
+ args = (
82
+ overwrite_input,
83
+ planner_effort,
84
+ threads,
85
+ auto_align_input,
86
+ auto_contiguous,
87
+ avoid_copy,
88
+ real_direction_flag,
89
+ )
90
+ elif calling_func in ("irfft2", "irfftn"):
91
+ # overwrite_input is not an argument to irfft2 or irfftn
92
+ args = (planner_effort, threads, auto_align_input, auto_contiguous)
93
+
94
+ if not overwrite_input:
95
+ # Only irfft2 and irfftn have overwriting the input
96
+ # as the default (and so require the input array to
97
+ # be reloaded).
98
+ work_with_copy = True
99
+ else:
100
+ args = (
101
+ overwrite_input,
102
+ planner_effort,
103
+ threads,
104
+ auto_align_input,
105
+ auto_contiguous,
106
+ )
107
+
108
+ if not a.flags.writeable:
109
+ # Special case of a locked array - always work with a
110
+ # copy. See issue #92.
111
+ work_with_copy = True
112
+
113
+ if overwrite_input:
114
+ raise ValueError(
115
+ "overwrite_input cannot be True when the "
116
+ + "input array flags.writeable is False"
117
+ )
118
+
119
+ if work_with_copy:
120
+ # We make the copy before registering the key so that the
121
+ # copy's stride information will be cached since this will be
122
+ # used for planning. Make sure the copy is byte aligned to
123
+ # prevent further copying
124
+ a_original = a
125
+ a = pyfftw.empty_aligned(shape=a.shape, dtype=a.dtype)
126
+ a[...] = a_original
127
+
128
+ if cache.is_enabled():
129
+ alignment = a.ctypes.data % pyfftw.simd_alignment
130
+
131
+ key = (
132
+ calling_func,
133
+ a.shape,
134
+ a.strides,
135
+ a.dtype,
136
+ s.__hash__(),
137
+ axes.__hash__(),
138
+ alignment,
139
+ args,
140
+ threading.get_ident(),
141
+ )
142
+
143
+ try:
144
+ if key in cache._fftw_cache:
145
+ FFTW_object = cache._fftw_cache.lookup(key)
146
+ else:
147
+ FFTW_object = None
148
+
149
+ except KeyError:
150
+ # This occurs if the object has fallen out of the cache between
151
+ # the check and the lookup
152
+ FFTW_object = None
153
+
154
+ if not cache.is_enabled() or FFTW_object is None:
155
+
156
+ # If we're going to create a new FFTW object and are not
157
+ # working with a copy, then we need to copy the input array to
158
+ # preserve it, otherwise we can't actually take the transform
159
+ # of the input array! (in general, we have to assume that the
160
+ # input array will be destroyed during planning).
161
+ if not work_with_copy:
162
+ a_copy = a.copy()
163
+
164
+ planner_args = (a, s, axes) + args
165
+
166
+ FFTW_object = getattr(builders, calling_func)(*planner_args)
167
+
168
+ # Only copy if the input array is what was actually used
169
+ # (otherwise it shouldn't be overwritten)
170
+ if not work_with_copy and FFTW_object.input_array is a:
171
+ a[:] = a_copy
172
+
173
+ if cache.is_enabled():
174
+ cache._fftw_cache.insert(FFTW_object, key)
175
+
176
+ output_array = FFTW_object(normalise_idft=normalise_idft, ortho=ortho)
177
+
178
+ else:
179
+ orig_output_array = FFTW_object.output_array
180
+ output_shape = orig_output_array.shape
181
+ output_dtype = orig_output_array.dtype
182
+ output_alignment = FFTW_object.output_alignment
183
+
184
+ if output_array is None:
185
+ output_array = pyfftw.empty_aligned(
186
+ output_shape, output_dtype, n=output_alignment
187
+ )
188
+
189
+ FFTW_object(
190
+ input_array=a,
191
+ output_array=output_array,
192
+ normalise_idft=normalise_idft,
193
+ ortho=ortho,
194
+ )
195
+
196
+ return output_array
197
+
198
+
199
+ def rfftn(
200
+ a,
201
+ s=None,
202
+ axes=None,
203
+ norm=None,
204
+ overwrite_input=False,
205
+ planner_effort=None,
206
+ threads=None,
207
+ auto_align_input=True,
208
+ auto_contiguous=True,
209
+ out=None,
210
+ ):
211
+ """Perform an n-D real FFT.
212
+
213
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
214
+ the rest of the arguments are documented
215
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
216
+ """
217
+ calling_func = "rfftn"
218
+ planner_effort = _default_effort(planner_effort)
219
+ threads = _default_threads(threads)
220
+
221
+ return _Xfftn(
222
+ a,
223
+ s,
224
+ axes,
225
+ overwrite_input,
226
+ planner_effort,
227
+ threads,
228
+ auto_align_input,
229
+ auto_contiguous,
230
+ calling_func,
231
+ **_norm_args(norm),
232
+ output_array=out,
233
+ )
234
+
235
+
236
+ def irfftn(
237
+ a,
238
+ s=None,
239
+ axes=None,
240
+ norm=None,
241
+ overwrite_input=False,
242
+ planner_effort=None,
243
+ threads=None,
244
+ auto_align_input=True,
245
+ auto_contiguous=True,
246
+ out=None,
247
+ ):
248
+ """Perform an n-D real inverse FFT.
249
+
250
+ The first four arguments are as per :func:`numpy.fft.rfftn`;
251
+ the rest of the arguments are documented
252
+ in the :ref:`additional arguments docs<interfaces_additional_args>`.
253
+ """
254
+ calling_func = "irfftn"
255
+ planner_effort = _default_effort(planner_effort)
256
+ threads = _default_threads(threads)
257
+
258
+ return _Xfftn(
259
+ a,
260
+ s,
261
+ axes,
262
+ overwrite_input,
263
+ planner_effort,
264
+ threads,
265
+ auto_align_input,
266
+ auto_contiguous,
267
+ calling_func,
268
+ **_norm_args(norm),
269
+ output_array=out,
270
+ )