mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,3581 @@
1
+ /*
2
+ This file is part of pocketfft.
3
+
4
+ Copyright (C) 2010-2022 Max-Planck-Society
5
+ Copyright (C) 2019-2020 Peter Bell
6
+
7
+ For the odd-sized DCT-IV transforms:
8
+ Copyright (C) 2003, 2007-14 Matteo Frigo
9
+ Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology
10
+
11
+ Authors: Martin Reinecke, Peter Bell
12
+
13
+ All rights reserved.
14
+
15
+ Redistribution and use in source and binary forms, with or without modification,
16
+ are permitted provided that the following conditions are met:
17
+
18
+ * Redistributions of source code must retain the above copyright notice, this
19
+ list of conditions and the following disclaimer.
20
+ * Redistributions in binary form must reproduce the above copyright notice, this
21
+ list of conditions and the following disclaimer in the documentation and/or
22
+ other materials provided with the distribution.
23
+ * Neither the name of the copyright holder nor the names of its contributors may
24
+ be used to endorse or promote products derived from this software without
25
+ specific prior written permission.
26
+
27
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
28
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
29
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
31
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
32
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
33
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
34
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
35
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
36
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37
+ */
38
+
39
+ #ifndef POCKETFFT_HDRONLY_H
40
+ #define POCKETFFT_HDRONLY_H
41
+
42
+ #ifndef __cplusplus
43
+ #error This file is C++ and requires a C++ compiler.
44
+ #endif
45
+
46
+ #if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L)
47
+ #error This file requires at least C++11 support.
48
+ #endif
49
+
50
+ #ifndef POCKETFFT_CACHE_SIZE
51
+ #define POCKETFFT_CACHE_SIZE 0
52
+ #endif
53
+
54
+ #include <cmath>
55
+ #include <cstdlib>
56
+ #include <stdexcept>
57
+ #include <memory>
58
+ #include <vector>
59
+ #include <complex>
60
+ #include <algorithm>
61
+ #if POCKETFFT_CACHE_SIZE!=0
62
+ #include <array>
63
+ #include <mutex>
64
+ #endif
65
+
66
+ #ifndef POCKETFFT_NO_MULTITHREADING
67
+ #include <mutex>
68
+ #include <condition_variable>
69
+ #include <thread>
70
+ #include <queue>
71
+ #include <atomic>
72
+ #include <functional>
73
+ #include <new>
74
+
75
+ #ifdef POCKETFFT_PTHREADS
76
+ # include <pthread.h>
77
+ #endif
78
+ #endif
79
+
80
+ #if defined(__GNUC__)
81
+ #define POCKETFFT_NOINLINE __attribute__((noinline))
82
+ #define POCKETFFT_RESTRICT __restrict__
83
+ #elif defined(_MSC_VER)
84
+ #define POCKETFFT_NOINLINE __declspec(noinline)
85
+ #define POCKETFFT_RESTRICT __restrict
86
+ #else
87
+ #define POCKETFFT_NOINLINE
88
+ #define POCKETFFT_RESTRICT
89
+ #endif
90
+
91
+ namespace pocketfft {
92
+
93
+ namespace detail {
94
+ using std::size_t;
95
+ using std::ptrdiff_t;
96
+
97
+ // Always use std:: for <cmath> functions
98
+ template <typename T> T cos(T) = delete;
99
+ template <typename T> T sin(T) = delete;
100
+ template <typename T> T sqrt(T) = delete;
101
+
102
+ using shape_t = std::vector<size_t>;
103
+ using stride_t = std::vector<ptrdiff_t>;
104
+
105
+ constexpr bool FORWARD = true,
106
+ BACKWARD = false;
107
+
108
+ // only enable vector support for gcc>=5.0 and clang>=5.0
109
+ #ifndef POCKETFFT_NO_VECTORS
110
+ #define POCKETFFT_NO_VECTORS
111
+ #if defined(__INTEL_COMPILER)
112
+ // do nothing. This is necessary because this compiler also sets __GNUC__.
113
+ #elif defined(__clang__)
114
+ // AppleClang has their own version numbering
115
+ #ifdef __apple_build_version__
116
+ # if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1)
117
+ # undef POCKETFFT_NO_VECTORS
118
+ # endif
119
+ #elif __clang_major__ >= 5
120
+ # undef POCKETFFT_NO_VECTORS
121
+ #endif
122
+ #elif defined(__GNUC__)
123
+ #if __GNUC__>=5
124
+ #undef POCKETFFT_NO_VECTORS
125
+ #endif
126
+ #endif
127
+ #endif
128
+
129
+ template<typename T> struct VLEN { static constexpr size_t val=1; };
130
+
131
+ #ifndef POCKETFFT_NO_VECTORS
132
+ #if (defined(__AVX512F__))
133
+ template<> struct VLEN<float> { static constexpr size_t val=16; };
134
+ template<> struct VLEN<double> { static constexpr size_t val=8; };
135
+ #elif (defined(__AVX__))
136
+ template<> struct VLEN<float> { static constexpr size_t val=8; };
137
+ template<> struct VLEN<double> { static constexpr size_t val=4; };
138
+ #elif (defined(__SSE2__))
139
+ template<> struct VLEN<float> { static constexpr size_t val=4; };
140
+ template<> struct VLEN<double> { static constexpr size_t val=2; };
141
+ #elif (defined(__VSX__))
142
+ template<> struct VLEN<float> { static constexpr size_t val=4; };
143
+ template<> struct VLEN<double> { static constexpr size_t val=2; };
144
+ #elif (defined(__ARM_NEON__) || defined(__ARM_NEON))
145
+ template<> struct VLEN<float> { static constexpr size_t val=4; };
146
+ template<> struct VLEN<double> { static constexpr size_t val=2; };
147
+ #else
148
+ #define POCKETFFT_NO_VECTORS
149
+ #endif
150
+ #endif
151
+
152
+ // the __MINGW32__ part in the conditional below works around the problem that
153
+ // the standard C++ library on Windows does not provide aligned_alloc() even
154
+ // though the MinGW compiler and MSVC may advertise C++17 compliance.
155
+ #if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER))
156
+ inline void *aligned_alloc(size_t align, size_t size)
157
+ {
158
+ // aligned_alloc() requires that the requested size is a multiple of "align"
159
+ void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1)));
160
+ if (!ptr) throw std::bad_alloc();
161
+ return ptr;
162
+ }
163
+ inline void aligned_dealloc(void *ptr)
164
+ { free(ptr); }
165
+ #else // portable emulation
166
+ inline void *aligned_alloc(size_t align, size_t size)
167
+ {
168
+ align = std::max(align, alignof(max_align_t));
169
+ void *ptr = malloc(size+align);
170
+ if (!ptr) throw std::bad_alloc();
171
+ void *res = reinterpret_cast<void *>
172
+ ((reinterpret_cast<uintptr_t>(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align));
173
+ (reinterpret_cast<void**>(res))[-1] = ptr;
174
+ return res;
175
+ }
176
+ inline void aligned_dealloc(void *ptr)
177
+ { if (ptr) free((reinterpret_cast<void**>(ptr))[-1]); }
178
+ #endif
179
+
180
+ template<typename T> class arr
181
+ {
182
+ private:
183
+ T *p;
184
+ size_t sz;
185
+
186
+ #if defined(POCKETFFT_NO_VECTORS)
187
+ static T *ralloc(size_t num)
188
+ {
189
+ if (num==0) return nullptr;
190
+ void *res = malloc(num*sizeof(T));
191
+ if (!res) throw std::bad_alloc();
192
+ return reinterpret_cast<T *>(res);
193
+ }
194
+ static void dealloc(T *ptr)
195
+ { free(ptr); }
196
+ #else
197
+ static T *ralloc(size_t num)
198
+ {
199
+ if (num==0) return nullptr;
200
+ void *ptr = aligned_alloc(64, num*sizeof(T));
201
+ return static_cast<T*>(ptr);
202
+ }
203
+ static void dealloc(T *ptr)
204
+ { aligned_dealloc(ptr); }
205
+ #endif
206
+
207
+ public:
208
+ arr() : p(0), sz(0) {}
209
+ arr(size_t n) : p(ralloc(n)), sz(n) {}
210
+ arr(arr &&other)
211
+ : p(other.p), sz(other.sz)
212
+ { other.p=nullptr; other.sz=0; }
213
+ ~arr() { dealloc(p); }
214
+
215
+ void resize(size_t n)
216
+ {
217
+ if (n==sz) return;
218
+ dealloc(p);
219
+ p = ralloc(n);
220
+ sz = n;
221
+ }
222
+
223
+ T &operator[](size_t idx) { return p[idx]; }
224
+ const T &operator[](size_t idx) const { return p[idx]; }
225
+
226
+ T *data() { return p; }
227
+ const T *data() const { return p; }
228
+
229
+ size_t size() const { return sz; }
230
+ };
231
+
232
+ template<typename T> struct cmplx {
233
+ T r, i;
234
+ cmplx() {}
235
+ cmplx(T r_, T i_) : r(r_), i(i_) {}
236
+ void Set(T r_, T i_) { r=r_; i=i_; }
237
+ void Set(T r_) { r=r_; i=T(0); }
238
+ cmplx &operator+= (const cmplx &other)
239
+ { r+=other.r; i+=other.i; return *this; }
240
+ template<typename T2>cmplx &operator*= (T2 other)
241
+ { r*=other; i*=other; return *this; }
242
+ template<typename T2>cmplx &operator*= (const cmplx<T2> &other)
243
+ {
244
+ T tmp = r*other.r - i*other.i;
245
+ i = r*other.i + i*other.r;
246
+ r = tmp;
247
+ return *this;
248
+ }
249
+ template<typename T2>cmplx &operator+= (const cmplx<T2> &other)
250
+ { r+=other.r; i+=other.i; return *this; }
251
+ template<typename T2>cmplx &operator-= (const cmplx<T2> &other)
252
+ { r-=other.r; i-=other.i; return *this; }
253
+ template<typename T2> auto operator* (const T2 &other) const
254
+ -> cmplx<decltype(r*other)>
255
+ { return {r*other, i*other}; }
256
+ template<typename T2> auto operator+ (const cmplx<T2> &other) const
257
+ -> cmplx<decltype(r+other.r)>
258
+ { return {r+other.r, i+other.i}; }
259
+ template<typename T2> auto operator- (const cmplx<T2> &other) const
260
+ -> cmplx<decltype(r+other.r)>
261
+ { return {r-other.r, i-other.i}; }
262
+ template<typename T2> auto operator* (const cmplx<T2> &other) const
263
+ -> cmplx<decltype(r+other.r)>
264
+ { return {r*other.r-i*other.i, r*other.i + i*other.r}; }
265
+ template<bool fwd, typename T2> auto special_mul (const cmplx<T2> &other) const
266
+ -> cmplx<decltype(r+other.r)>
267
+ {
268
+ using Tres = cmplx<decltype(r+other.r)>;
269
+ return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i)
270
+ : Tres(r*other.r-i*other.i, r*other.i+i*other.r);
271
+ }
272
+ };
273
+ template<typename T> inline void PM(T &a, T &b, T c, T d)
274
+ { a=c+d; b=c-d; }
275
+ template<typename T> inline void PMINPLACE(T &a, T &b)
276
+ { T t = a; a+=b; b=t-b; }
277
+ template<typename T> inline void MPINPLACE(T &a, T &b)
278
+ { T t = a; a-=b; b=t+b; }
279
+ template<typename T> cmplx<T> conj(const cmplx<T> &a)
280
+ { return {a.r, -a.i}; }
281
+ template<bool fwd, typename T, typename T2> void special_mul (const cmplx<T> &v1, const cmplx<T2> &v2, cmplx<T> &res)
282
+ {
283
+ res = fwd ? cmplx<T>(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i)
284
+ : cmplx<T>(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r);
285
+ }
286
+
287
+ template<typename T> void ROT90(cmplx<T> &a)
288
+ { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; }
289
+ template<bool fwd, typename T> void ROTX90(cmplx<T> &a)
290
+ { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; }
291
+
292
+ //
293
+ // twiddle factor section
294
+ //
295
+ template<typename T> class sincos_2pibyn
296
+ {
297
+ private:
298
+ using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type;
299
+ size_t N, mask, shift;
300
+ arr<cmplx<Thigh>> v1, v2;
301
+
302
+ static cmplx<Thigh> calc(size_t x, size_t n, Thigh ang)
303
+ {
304
+ x<<=3;
305
+ if (x<4*n) // first half
306
+ {
307
+ if (x<2*n) // first quadrant
308
+ {
309
+ if (x<n) return cmplx<Thigh>(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang));
310
+ return cmplx<Thigh>(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang));
311
+ }
312
+ else // second quadrant
313
+ {
314
+ x-=2*n;
315
+ if (x<n) return cmplx<Thigh>(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang));
316
+ return cmplx<Thigh>(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang));
317
+ }
318
+ }
319
+ else
320
+ {
321
+ x=8*n-x;
322
+ if (x<2*n) // third quadrant
323
+ {
324
+ if (x<n) return cmplx<Thigh>(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang));
325
+ return cmplx<Thigh>(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang));
326
+ }
327
+ else // fourth quadrant
328
+ {
329
+ x-=2*n;
330
+ if (x<n) return cmplx<Thigh>(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang));
331
+ return cmplx<Thigh>(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang));
332
+ }
333
+ }
334
+ }
335
+
336
+ public:
337
+ POCKETFFT_NOINLINE sincos_2pibyn(size_t n)
338
+ : N(n)
339
+ {
340
+ constexpr auto pi = 3.141592653589793238462643383279502884197L;
341
+ Thigh ang = Thigh(0.25L*pi/n);
342
+ size_t nval = (n+2)/2;
343
+ shift = 1;
344
+ while((size_t(1)<<shift)*(size_t(1)<<shift) < nval) ++shift;
345
+ mask = (size_t(1)<<shift)-1;
346
+ v1.resize(mask+1);
347
+ v1[0].Set(Thigh(1), Thigh(0));
348
+ for (size_t i=1; i<v1.size(); ++i)
349
+ v1[i]=calc(i,n,ang);
350
+ v2.resize((nval+mask)/(mask+1));
351
+ v2[0].Set(Thigh(1), Thigh(0));
352
+ for (size_t i=1; i<v2.size(); ++i)
353
+ v2[i]=calc(i*(mask+1),n,ang);
354
+ }
355
+
356
+ cmplx<T> operator[](size_t idx) const
357
+ {
358
+ if (2*idx<=N)
359
+ {
360
+ auto x1=v1[idx&mask], x2=v2[idx>>shift];
361
+ return cmplx<T>(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r));
362
+ }
363
+ idx = N-idx;
364
+ auto x1=v1[idx&mask], x2=v2[idx>>shift];
365
+ return cmplx<T>(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r));
366
+ }
367
+ };
368
+
369
+ struct util // hack to avoid duplicate symbols
370
+ {
371
+ static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n)
372
+ {
373
+ size_t res=1;
374
+ while ((n&1)==0)
375
+ { res=2; n>>=1; }
376
+ for (size_t x=3; x*x<=n; x+=2)
377
+ while ((n%x)==0)
378
+ { res=x; n/=x; }
379
+ if (n>1) res=n;
380
+ return res;
381
+ }
382
+
383
+ static POCKETFFT_NOINLINE double cost_guess (size_t n)
384
+ {
385
+ constexpr double lfp=1.1; // penalty for non-hardcoded larger factors
386
+ size_t ni=n;
387
+ double result=0.;
388
+ while ((n&1)==0)
389
+ { result+=2; n>>=1; }
390
+ for (size_t x=3; x*x<=n; x+=2)
391
+ while ((n%x)==0)
392
+ {
393
+ result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors
394
+ n/=x;
395
+ }
396
+ if (n>1) result+=(n<=5) ? double(n) : lfp*double(n);
397
+ return result*double(ni);
398
+ }
399
+
400
+ /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */
401
+ static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n)
402
+ {
403
+ if (n<=12) return n;
404
+
405
+ size_t bestfac=2*n;
406
+ for (size_t f11=1; f11<bestfac; f11*=11)
407
+ for (size_t f117=f11; f117<bestfac; f117*=7)
408
+ for (size_t f1175=f117; f1175<bestfac; f1175*=5)
409
+ {
410
+ size_t x=f1175;
411
+ while (x<n) x*=2;
412
+ for (;;)
413
+ {
414
+ if (x<n)
415
+ x*=3;
416
+ else if (x>n)
417
+ {
418
+ if (x<bestfac) bestfac=x;
419
+ if (x&1) break;
420
+ x>>=1;
421
+ }
422
+ else
423
+ return n;
424
+ }
425
+ }
426
+ return bestfac;
427
+ }
428
+
429
+ /* returns the smallest composite of 2, 3, 5 which is >= n */
430
+ static POCKETFFT_NOINLINE size_t good_size_real(size_t n)
431
+ {
432
+ if (n<=6) return n;
433
+
434
+ size_t bestfac=2*n;
435
+ for (size_t f5=1; f5<bestfac; f5*=5)
436
+ {
437
+ size_t x = f5;
438
+ while (x<n) x *= 2;
439
+ for (;;)
440
+ {
441
+ if (x<n)
442
+ x*=3;
443
+ else if (x>n)
444
+ {
445
+ if (x<bestfac) bestfac=x;
446
+ if (x&1) break;
447
+ x>>=1;
448
+ }
449
+ else
450
+ return n;
451
+ }
452
+ }
453
+ return bestfac;
454
+ }
455
+
456
+ static size_t prod(const shape_t &shape)
457
+ {
458
+ size_t res=1;
459
+ for (auto sz: shape)
460
+ res*=sz;
461
+ return res;
462
+ }
463
+
464
+ static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
465
+ const stride_t &stride_in, const stride_t &stride_out, bool inplace)
466
+ {
467
+ auto ndim = shape.size();
468
+ if (ndim<1) throw std::runtime_error("ndim must be >= 1");
469
+ if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim))
470
+ throw std::runtime_error("stride dimension mismatch");
471
+ if (inplace && (stride_in!=stride_out))
472
+ throw std::runtime_error("stride mismatch");
473
+ }
474
+
475
+ static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
476
+ const stride_t &stride_in, const stride_t &stride_out, bool inplace,
477
+ const shape_t &axes)
478
+ {
479
+ sanity_check(shape, stride_in, stride_out, inplace);
480
+ auto ndim = shape.size();
481
+ shape_t tmp(ndim,0);
482
+ for (auto ax : axes)
483
+ {
484
+ if (ax>=ndim) throw std::invalid_argument("bad axis number");
485
+ if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly");
486
+ }
487
+ }
488
+
489
+ static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape,
490
+ const stride_t &stride_in, const stride_t &stride_out, bool inplace,
491
+ size_t axis)
492
+ {
493
+ sanity_check(shape, stride_in, stride_out, inplace);
494
+ if (axis>=shape.size()) throw std::invalid_argument("bad axis number");
495
+ }
496
+
497
+ #ifdef POCKETFFT_NO_MULTITHREADING
498
+ static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/,
499
+ size_t /*axis*/, size_t /*vlen*/)
500
+ { return 1; }
501
+ #else
502
+ static size_t thread_count (size_t nthreads, const shape_t &shape,
503
+ size_t axis, size_t vlen)
504
+ {
505
+ if (nthreads==1) return 1;
506
+ size_t size = prod(shape);
507
+ size_t parallel = size / (shape[axis] * vlen);
508
+ if (shape[axis] < 1000)
509
+ parallel /= 4;
510
+ size_t max_threads = nthreads == 0 ?
511
+ std::thread::hardware_concurrency() : nthreads;
512
+ return std::max(size_t(1), std::min(parallel, max_threads));
513
+ }
514
+ #endif
515
+ };
516
+
517
+ namespace threading {
518
+
519
+ #ifdef POCKETFFT_NO_MULTITHREADING
520
+
521
+ constexpr inline size_t thread_id() { return 0; }
522
+ constexpr inline size_t num_threads() { return 1; }
523
+
524
+ template <typename Func>
525
+ void thread_map(size_t /* nthreads */, Func f)
526
+ { f(); }
527
+
528
+ #else
529
+
530
+ inline size_t &thread_id()
531
+ {
532
+ static thread_local size_t thread_id_=0;
533
+ return thread_id_;
534
+ }
535
+ inline size_t &num_threads()
536
+ {
537
+ static thread_local size_t num_threads_=1;
538
+ return num_threads_;
539
+ }
540
+ static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency());
541
+
542
+ class latch
543
+ {
544
+ std::atomic<size_t> num_left_;
545
+ std::mutex mut_;
546
+ std::condition_variable completed_;
547
+ using lock_t = std::unique_lock<std::mutex>;
548
+
549
+ public:
550
+ latch(size_t n): num_left_(n) {}
551
+
552
+ void count_down()
553
+ {
554
+ lock_t lock(mut_);
555
+ if (--num_left_)
556
+ return;
557
+ completed_.notify_all();
558
+ }
559
+
560
+ void wait()
561
+ {
562
+ lock_t lock(mut_);
563
+ completed_.wait(lock, [this]{ return is_ready(); });
564
+ }
565
+ bool is_ready() { return num_left_ == 0; }
566
+ };
567
+
568
+ template <typename T> class concurrent_queue
569
+ {
570
+ std::queue<T> q_;
571
+ std::mutex mut_;
572
+ std::atomic<size_t> size_;
573
+ using lock_t = std::lock_guard<std::mutex>;
574
+
575
+ public:
576
+
577
+ void push(T val)
578
+ {
579
+ lock_t lock(mut_);
580
+ ++size_;
581
+ q_.push(std::move(val));
582
+ }
583
+
584
+ bool try_pop(T &val)
585
+ {
586
+ if (size_ == 0) return false;
587
+ lock_t lock(mut_);
588
+ // Queue might have been emptied while we acquired the lock
589
+ if (q_.empty()) return false;
590
+
591
+ val = std::move(q_.front());
592
+ --size_;
593
+ q_.pop();
594
+ return true;
595
+ }
596
+
597
+ bool empty() const { return size_==0; }
598
+ };
599
+
600
+ // C++ allocator with support for over-aligned types
601
+ template <typename T> struct aligned_allocator
602
+ {
603
+ using value_type = T;
604
+ template <class U>
605
+ aligned_allocator(const aligned_allocator<U>&) {}
606
+ aligned_allocator() = default;
607
+
608
+ T *allocate(size_t n)
609
+ {
610
+ void* mem = aligned_alloc(alignof(T), n*sizeof(T));
611
+ return static_cast<T*>(mem);
612
+ }
613
+
614
+ void deallocate(T *p, size_t /*n*/)
615
+ { aligned_dealloc(p); }
616
+ };
617
+
618
+ class thread_pool
619
+ {
620
+ // A reasonable guess, probably close enough for most hardware
621
+ static constexpr size_t cache_line_size = 64;
622
+ struct alignas(cache_line_size) worker
623
+ {
624
+ std::thread thread;
625
+ std::condition_variable work_ready;
626
+ std::mutex mut;
627
+ std::atomic_flag busy_flag = ATOMIC_FLAG_INIT;
628
+ std::function<void()> work;
629
+
630
+ void worker_main(
631
+ std::atomic<bool> &shutdown_flag,
632
+ std::atomic<size_t> &unscheduled_tasks,
633
+ concurrent_queue<std::function<void()>> &overflow_work)
634
+ {
635
+ using lock_t = std::unique_lock<std::mutex>;
636
+ bool expect_work = true;
637
+ while (!shutdown_flag || expect_work)
638
+ {
639
+ std::function<void()> local_work;
640
+ if (expect_work || unscheduled_tasks == 0)
641
+ {
642
+ lock_t lock(mut);
643
+ // Wait until there is work to be executed
644
+ work_ready.wait(lock, [&]{ return (work || shutdown_flag); });
645
+ local_work.swap(work);
646
+ expect_work = false;
647
+ }
648
+
649
+ bool marked_busy = false;
650
+ if (local_work)
651
+ {
652
+ marked_busy = true;
653
+ local_work();
654
+ }
655
+
656
+ if (!overflow_work.empty())
657
+ {
658
+ if (!marked_busy && busy_flag.test_and_set())
659
+ {
660
+ expect_work = true;
661
+ continue;
662
+ }
663
+ marked_busy = true;
664
+
665
+ while (overflow_work.try_pop(local_work))
666
+ {
667
+ --unscheduled_tasks;
668
+ local_work();
669
+ }
670
+ }
671
+
672
+ if (marked_busy) busy_flag.clear();
673
+ }
674
+ }
675
+ };
676
+
677
+ concurrent_queue<std::function<void()>> overflow_work_;
678
+ std::mutex mut_;
679
+ std::vector<worker, aligned_allocator<worker>> workers_;
680
+ std::atomic<bool> shutdown_;
681
+ std::atomic<size_t> unscheduled_tasks_;
682
+ using lock_t = std::lock_guard<std::mutex>;
683
+
684
+ void create_threads()
685
+ {
686
+ lock_t lock(mut_);
687
+ size_t nthreads=workers_.size();
688
+ for (size_t i=0; i<nthreads; ++i)
689
+ {
690
+ try
691
+ {
692
+ auto *worker = &workers_[i];
693
+ worker->busy_flag.clear();
694
+ worker->work = nullptr;
695
+ worker->thread = std::thread([worker, this]
696
+ {
697
+ worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_);
698
+ });
699
+ }
700
+ catch (...)
701
+ {
702
+ shutdown_locked();
703
+ throw;
704
+ }
705
+ }
706
+ }
707
+
708
+ void shutdown_locked()
709
+ {
710
+ shutdown_ = true;
711
+ for (auto &worker : workers_)
712
+ worker.work_ready.notify_all();
713
+
714
+ for (auto &worker : workers_)
715
+ if (worker.thread.joinable())
716
+ worker.thread.join();
717
+ }
718
+
719
+ public:
720
+ explicit thread_pool(size_t nthreads):
721
+ workers_(nthreads)
722
+ { create_threads(); }
723
+
724
+ thread_pool(): thread_pool(max_threads) {}
725
+
726
+ ~thread_pool() { shutdown(); }
727
+
728
+ void submit(std::function<void()> work)
729
+ {
730
+ lock_t lock(mut_);
731
+ if (shutdown_)
732
+ throw std::runtime_error("Work item submitted after shutdown");
733
+
734
+ ++unscheduled_tasks_;
735
+
736
+ // First check for any idle workers and wake those
737
+ for (auto &worker : workers_)
738
+ if (!worker.busy_flag.test_and_set())
739
+ {
740
+ --unscheduled_tasks_;
741
+ {
742
+ lock_t lock(worker.mut);
743
+ worker.work = std::move(work);
744
+ }
745
+ worker.work_ready.notify_one();
746
+ return;
747
+ }
748
+
749
+ // If no workers were idle, push onto the overflow queue for later
750
+ overflow_work_.push(std::move(work));
751
+ }
752
+
753
+ void shutdown()
754
+ {
755
+ lock_t lock(mut_);
756
+ shutdown_locked();
757
+ }
758
+
759
+ void restart()
760
+ {
761
+ shutdown_ = false;
762
+ create_threads();
763
+ }
764
+ };
765
+
766
+ inline thread_pool & get_pool()
767
+ {
768
+ static thread_pool pool;
769
+ #ifdef POCKETFFT_PTHREADS
770
+ static std::once_flag f;
771
+ std::call_once(f,
772
+ []{
773
+ pthread_atfork(
774
+ +[]{ get_pool().shutdown(); }, // prepare
775
+ +[]{ get_pool().restart(); }, // parent
776
+ +[]{ get_pool().restart(); } // child
777
+ );
778
+ });
779
+ #endif
780
+
781
+ return pool;
782
+ }
783
+
784
+ /** Map a function f over nthreads */
785
+ template <typename Func>
786
+ void thread_map(size_t nthreads, Func f)
787
+ {
788
+ if (nthreads == 0)
789
+ nthreads = max_threads;
790
+
791
+ if (nthreads == 1)
792
+ { f(); return; }
793
+
794
+ auto & pool = get_pool();
795
+ latch counter(nthreads);
796
+ std::exception_ptr ex;
797
+ std::mutex ex_mut;
798
+ for (size_t i=0; i<nthreads; ++i)
799
+ {
800
+ pool.submit(
801
+ [&f, &counter, &ex, &ex_mut, i, nthreads] {
802
+ thread_id() = i;
803
+ num_threads() = nthreads;
804
+ try { f(); }
805
+ catch (...)
806
+ {
807
+ std::lock_guard<std::mutex> lock(ex_mut);
808
+ ex = std::current_exception();
809
+ }
810
+ counter.count_down();
811
+ });
812
+ }
813
+ counter.wait();
814
+ if (ex)
815
+ std::rethrow_exception(ex);
816
+ }
817
+
818
+ #endif
819
+
820
+ }
821
+
822
+ //
823
+ // complex FFTPACK transforms
824
+ //
825
+
826
+ template<typename T0> class cfftp
827
+ {
828
+ private:
829
+ struct fctdata
830
+ {
831
+ size_t fct;
832
+ cmplx<T0> *tw, *tws;
833
+ };
834
+
835
+ size_t length;
836
+ arr<cmplx<T0>> mem;
837
+ std::vector<fctdata> fact;
838
+
839
+ void add_factor(size_t factor)
840
+ { fact.push_back({factor, nullptr, nullptr}); }
841
+
842
+ template<bool fwd, typename T> void pass2 (size_t ido, size_t l1,
843
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
844
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
845
+ {
846
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
847
+ { return ch[a+ido*(b+l1*c)]; };
848
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
849
+ { return cc[a+ido*(b+2*c)]; };
850
+ auto WA = [wa, ido](size_t x, size_t i)
851
+ { return wa[i-1+x*(ido-1)]; };
852
+
853
+ if (ido==1)
854
+ for (size_t k=0; k<l1; ++k)
855
+ {
856
+ CH(0,k,0) = CC(0,0,k)+CC(0,1,k);
857
+ CH(0,k,1) = CC(0,0,k)-CC(0,1,k);
858
+ }
859
+ else
860
+ for (size_t k=0; k<l1; ++k)
861
+ {
862
+ CH(0,k,0) = CC(0,0,k)+CC(0,1,k);
863
+ CH(0,k,1) = CC(0,0,k)-CC(0,1,k);
864
+ for (size_t i=1; i<ido; ++i)
865
+ {
866
+ CH(i,k,0) = CC(i,0,k)+CC(i,1,k);
867
+ special_mul<fwd>(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1));
868
+ }
869
+ }
870
+ }
871
+
872
+ #define POCKETFFT_PREP3(idx) \
873
+ T t0 = CC(idx,0,k), t1, t2; \
874
+ PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \
875
+ CH(idx,k,0)=t0+t1;
876
+ #define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \
877
+ { \
878
+ T ca=t0+t1*twr; \
879
+ T cb{-t2.i*twi, t2.r*twi}; \
880
+ PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\
881
+ }
882
+ #define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \
883
+ { \
884
+ T ca=t0+t1*twr; \
885
+ T cb{-t2.i*twi, t2.r*twi}; \
886
+ special_mul<fwd>(ca+cb,WA(u1-1,i),CH(i,k,u1)); \
887
+ special_mul<fwd>(ca-cb,WA(u2-1,i),CH(i,k,u2)); \
888
+ }
889
+ template<bool fwd, typename T> void pass3 (size_t ido, size_t l1,
890
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
891
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
892
+ {
893
+ constexpr T0 tw1r=-0.5,
894
+ tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L);
895
+
896
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
897
+ { return ch[a+ido*(b+l1*c)]; };
898
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
899
+ { return cc[a+ido*(b+3*c)]; };
900
+ auto WA = [wa, ido](size_t x, size_t i)
901
+ { return wa[i-1+x*(ido-1)]; };
902
+
903
+ if (ido==1)
904
+ for (size_t k=0; k<l1; ++k)
905
+ {
906
+ POCKETFFT_PREP3(0)
907
+ POCKETFFT_PARTSTEP3a(1,2,tw1r,tw1i)
908
+ }
909
+ else
910
+ for (size_t k=0; k<l1; ++k)
911
+ {
912
+ {
913
+ POCKETFFT_PREP3(0)
914
+ POCKETFFT_PARTSTEP3a(1,2,tw1r,tw1i)
915
+ }
916
+ for (size_t i=1; i<ido; ++i)
917
+ {
918
+ POCKETFFT_PREP3(i)
919
+ POCKETFFT_PARTSTEP3b(1,2,tw1r,tw1i)
920
+ }
921
+ }
922
+ }
923
+
924
+ #undef POCKETFFT_PARTSTEP3b
925
+ #undef POCKETFFT_PARTSTEP3a
926
+ #undef POCKETFFT_PREP3
927
+
928
+ template<bool fwd, typename T> void pass4 (size_t ido, size_t l1,
929
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
930
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
931
+ {
932
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
933
+ { return ch[a+ido*(b+l1*c)]; };
934
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
935
+ { return cc[a+ido*(b+4*c)]; };
936
+ auto WA = [wa, ido](size_t x, size_t i)
937
+ { return wa[i-1+x*(ido-1)]; };
938
+
939
+ if (ido==1)
940
+ for (size_t k=0; k<l1; ++k)
941
+ {
942
+ T t1, t2, t3, t4;
943
+ PM(t2,t1,CC(0,0,k),CC(0,2,k));
944
+ PM(t3,t4,CC(0,1,k),CC(0,3,k));
945
+ ROTX90<fwd>(t4);
946
+ PM(CH(0,k,0),CH(0,k,2),t2,t3);
947
+ PM(CH(0,k,1),CH(0,k,3),t1,t4);
948
+ }
949
+ else
950
+ for (size_t k=0; k<l1; ++k)
951
+ {
952
+ {
953
+ T t1, t2, t3, t4;
954
+ PM(t2,t1,CC(0,0,k),CC(0,2,k));
955
+ PM(t3,t4,CC(0,1,k),CC(0,3,k));
956
+ ROTX90<fwd>(t4);
957
+ PM(CH(0,k,0),CH(0,k,2),t2,t3);
958
+ PM(CH(0,k,1),CH(0,k,3),t1,t4);
959
+ }
960
+ for (size_t i=1; i<ido; ++i)
961
+ {
962
+ T t1, t2, t3, t4;
963
+ T cc0=CC(i,0,k), cc1=CC(i,1,k),cc2=CC(i,2,k),cc3=CC(i,3,k);
964
+ PM(t2,t1,cc0,cc2);
965
+ PM(t3,t4,cc1,cc3);
966
+ ROTX90<fwd>(t4);
967
+ CH(i,k,0) = t2+t3;
968
+ special_mul<fwd>(t1+t4,WA(0,i),CH(i,k,1));
969
+ special_mul<fwd>(t2-t3,WA(1,i),CH(i,k,2));
970
+ special_mul<fwd>(t1-t4,WA(2,i),CH(i,k,3));
971
+ }
972
+ }
973
+ }
974
+
975
+ #define POCKETFFT_PREP5(idx) \
976
+ T t0 = CC(idx,0,k), t1, t2, t3, t4; \
977
+ PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \
978
+ PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \
979
+ CH(idx,k,0).r=t0.r+t1.r+t2.r; \
980
+ CH(idx,k,0).i=t0.i+t1.i+t2.i;
981
+
982
+ #define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \
983
+ { \
984
+ T ca,cb; \
985
+ ca.r=t0.r+twar*t1.r+twbr*t2.r; \
986
+ ca.i=t0.i+twar*t1.i+twbr*t2.i; \
987
+ cb.i=twai*t4.r twbi*t3.r; \
988
+ cb.r=-(twai*t4.i twbi*t3.i); \
989
+ PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \
990
+ }
991
+
992
+ #define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \
993
+ { \
994
+ T ca,cb,da,db; \
995
+ ca.r=t0.r+twar*t1.r+twbr*t2.r; \
996
+ ca.i=t0.i+twar*t1.i+twbr*t2.i; \
997
+ cb.i=twai*t4.r twbi*t3.r; \
998
+ cb.r=-(twai*t4.i twbi*t3.i); \
999
+ special_mul<fwd>(ca+cb,WA(u1-1,i),CH(i,k,u1)); \
1000
+ special_mul<fwd>(ca-cb,WA(u2-1,i),CH(i,k,u2)); \
1001
+ }
1002
+ template<bool fwd, typename T> void pass5 (size_t ido, size_t l1,
1003
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1004
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
1005
+ {
1006
+ constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L),
1007
+ tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L),
1008
+ tw2r= T0(-0.8090169943749474241022934171828191L),
1009
+ tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L);
1010
+
1011
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1012
+ { return ch[a+ido*(b+l1*c)]; };
1013
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1014
+ { return cc[a+ido*(b+5*c)]; };
1015
+ auto WA = [wa, ido](size_t x, size_t i)
1016
+ { return wa[i-1+x*(ido-1)]; };
1017
+
1018
+ if (ido==1)
1019
+ for (size_t k=0; k<l1; ++k)
1020
+ {
1021
+ POCKETFFT_PREP5(0)
1022
+ POCKETFFT_PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)
1023
+ POCKETFFT_PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)
1024
+ }
1025
+ else
1026
+ for (size_t k=0; k<l1; ++k)
1027
+ {
1028
+ {
1029
+ POCKETFFT_PREP5(0)
1030
+ POCKETFFT_PARTSTEP5a(1,4,tw1r,tw2r,+tw1i,+tw2i)
1031
+ POCKETFFT_PARTSTEP5a(2,3,tw2r,tw1r,+tw2i,-tw1i)
1032
+ }
1033
+ for (size_t i=1; i<ido; ++i)
1034
+ {
1035
+ POCKETFFT_PREP5(i)
1036
+ POCKETFFT_PARTSTEP5b(1,4,tw1r,tw2r,+tw1i,+tw2i)
1037
+ POCKETFFT_PARTSTEP5b(2,3,tw2r,tw1r,+tw2i,-tw1i)
1038
+ }
1039
+ }
1040
+ }
1041
+
1042
+ #undef POCKETFFT_PARTSTEP5b
1043
+ #undef POCKETFFT_PARTSTEP5a
1044
+ #undef POCKETFFT_PREP5
1045
+
1046
+ #define POCKETFFT_PREP7(idx) \
1047
+ T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7; \
1048
+ PM (t2,t7,CC(idx,1,k),CC(idx,6,k)); \
1049
+ PM (t3,t6,CC(idx,2,k),CC(idx,5,k)); \
1050
+ PM (t4,t5,CC(idx,3,k),CC(idx,4,k)); \
1051
+ CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r; \
1052
+ CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i;
1053
+
1054
+ #define POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,out1,out2) \
1055
+ { \
1056
+ T ca,cb; \
1057
+ ca.r=t1.r+x1*t2.r+x2*t3.r+x3*t4.r; \
1058
+ ca.i=t1.i+x1*t2.i+x2*t3.i+x3*t4.i; \
1059
+ cb.i=y1*t7.r y2*t6.r y3*t5.r; \
1060
+ cb.r=-(y1*t7.i y2*t6.i y3*t5.i); \
1061
+ PM(out1,out2,ca,cb); \
1062
+ }
1063
+ #define POCKETFFT_PARTSTEP7a(u1,u2,x1,x2,x3,y1,y2,y3) \
1064
+ POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,CH(0,k,u1),CH(0,k,u2))
1065
+ #define POCKETFFT_PARTSTEP7(u1,u2,x1,x2,x3,y1,y2,y3) \
1066
+ { \
1067
+ T da,db; \
1068
+ POCKETFFT_PARTSTEP7a0(u1,u2,x1,x2,x3,y1,y2,y3,da,db) \
1069
+ special_mul<fwd>(da,WA(u1-1,i),CH(i,k,u1)); \
1070
+ special_mul<fwd>(db,WA(u2-1,i),CH(i,k,u2)); \
1071
+ }
1072
+
1073
+ template<bool fwd, typename T> void pass7(size_t ido, size_t l1,
1074
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1075
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
1076
+ {
1077
+ constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L),
1078
+ tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L),
1079
+ tw2r= T0(-0.2225209339563144042889025644967948L),
1080
+ tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L),
1081
+ tw3r= T0(-0.9009688679024191262361023195074451L),
1082
+ tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L);
1083
+
1084
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1085
+ { return ch[a+ido*(b+l1*c)]; };
1086
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1087
+ { return cc[a+ido*(b+7*c)]; };
1088
+ auto WA = [wa, ido](size_t x, size_t i)
1089
+ { return wa[i-1+x*(ido-1)]; };
1090
+
1091
+ if (ido==1)
1092
+ for (size_t k=0; k<l1; ++k)
1093
+ {
1094
+ POCKETFFT_PREP7(0)
1095
+ POCKETFFT_PARTSTEP7a(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)
1096
+ POCKETFFT_PARTSTEP7a(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)
1097
+ POCKETFFT_PARTSTEP7a(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)
1098
+ }
1099
+ else
1100
+ for (size_t k=0; k<l1; ++k)
1101
+ {
1102
+ {
1103
+ POCKETFFT_PREP7(0)
1104
+ POCKETFFT_PARTSTEP7a(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)
1105
+ POCKETFFT_PARTSTEP7a(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)
1106
+ POCKETFFT_PARTSTEP7a(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)
1107
+ }
1108
+ for (size_t i=1; i<ido; ++i)
1109
+ {
1110
+ POCKETFFT_PREP7(i)
1111
+ POCKETFFT_PARTSTEP7(1,6,tw1r,tw2r,tw3r,+tw1i,+tw2i,+tw3i)
1112
+ POCKETFFT_PARTSTEP7(2,5,tw2r,tw3r,tw1r,+tw2i,-tw3i,-tw1i)
1113
+ POCKETFFT_PARTSTEP7(3,4,tw3r,tw1r,tw2r,+tw3i,-tw1i,+tw2i)
1114
+ }
1115
+ }
1116
+ }
1117
+
1118
+ #undef POCKETFFT_PARTSTEP7
1119
+ #undef POCKETFFT_PARTSTEP7a0
1120
+ #undef POCKETFFT_PARTSTEP7a
1121
+ #undef POCKETFFT_PREP7
1122
+
1123
+ template <bool fwd, typename T> void ROTX45(T &a) const
1124
+ {
1125
+ constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);
1126
+ if (fwd)
1127
+ { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); }
1128
+ else
1129
+ { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); }
1130
+ }
1131
+ template <bool fwd, typename T> void ROTX135(T &a) const
1132
+ {
1133
+ constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);
1134
+ if (fwd)
1135
+ { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); }
1136
+ else
1137
+ { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); }
1138
+ }
1139
+
1140
+ template<bool fwd, typename T> void pass8 (size_t ido, size_t l1,
1141
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1142
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
1143
+ {
1144
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1145
+ { return ch[a+ido*(b+l1*c)]; };
1146
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1147
+ { return cc[a+ido*(b+8*c)]; };
1148
+ auto WA = [wa, ido](size_t x, size_t i)
1149
+ { return wa[i-1+x*(ido-1)]; };
1150
+
1151
+ if (ido==1)
1152
+ for (size_t k=0; k<l1; ++k)
1153
+ {
1154
+ T a0, a1, a2, a3, a4, a5, a6, a7;
1155
+ PM(a1,a5,CC(0,1,k),CC(0,5,k));
1156
+ PM(a3,a7,CC(0,3,k),CC(0,7,k));
1157
+ PMINPLACE(a1,a3);
1158
+ ROTX90<fwd>(a3);
1159
+
1160
+ ROTX90<fwd>(a7);
1161
+ PMINPLACE(a5,a7);
1162
+ ROTX45<fwd>(a5);
1163
+ ROTX135<fwd>(a7);
1164
+
1165
+ PM(a0,a4,CC(0,0,k),CC(0,4,k));
1166
+ PM(a2,a6,CC(0,2,k),CC(0,6,k));
1167
+ PM(CH(0,k,0),CH(0,k,4),a0+a2,a1);
1168
+ PM(CH(0,k,2),CH(0,k,6),a0-a2,a3);
1169
+ ROTX90<fwd>(a6);
1170
+ PM(CH(0,k,1),CH(0,k,5),a4+a6,a5);
1171
+ PM(CH(0,k,3),CH(0,k,7),a4-a6,a7);
1172
+ }
1173
+ else
1174
+ for (size_t k=0; k<l1; ++k)
1175
+ {
1176
+ {
1177
+ T a0, a1, a2, a3, a4, a5, a6, a7;
1178
+ PM(a1,a5,CC(0,1,k),CC(0,5,k));
1179
+ PM(a3,a7,CC(0,3,k),CC(0,7,k));
1180
+ PMINPLACE(a1,a3);
1181
+ ROTX90<fwd>(a3);
1182
+
1183
+ ROTX90<fwd>(a7);
1184
+ PMINPLACE(a5,a7);
1185
+ ROTX45<fwd>(a5);
1186
+ ROTX135<fwd>(a7);
1187
+
1188
+ PM(a0,a4,CC(0,0,k),CC(0,4,k));
1189
+ PM(a2,a6,CC(0,2,k),CC(0,6,k));
1190
+ PM(CH(0,k,0),CH(0,k,4),a0+a2,a1);
1191
+ PM(CH(0,k,2),CH(0,k,6),a0-a2,a3);
1192
+ ROTX90<fwd>(a6);
1193
+ PM(CH(0,k,1),CH(0,k,5),a4+a6,a5);
1194
+ PM(CH(0,k,3),CH(0,k,7),a4-a6,a7);
1195
+ }
1196
+ for (size_t i=1; i<ido; ++i)
1197
+ {
1198
+ T a0, a1, a2, a3, a4, a5, a6, a7;
1199
+ PM(a1,a5,CC(i,1,k),CC(i,5,k));
1200
+ PM(a3,a7,CC(i,3,k),CC(i,7,k));
1201
+ ROTX90<fwd>(a7);
1202
+ PMINPLACE(a1,a3);
1203
+ ROTX90<fwd>(a3);
1204
+ PMINPLACE(a5,a7);
1205
+ ROTX45<fwd>(a5);
1206
+ ROTX135<fwd>(a7);
1207
+ PM(a0,a4,CC(i,0,k),CC(i,4,k));
1208
+ PM(a2,a6,CC(i,2,k),CC(i,6,k));
1209
+ PMINPLACE(a0,a2);
1210
+ CH(i,k,0) = a0+a1;
1211
+ special_mul<fwd>(a0-a1,WA(3,i),CH(i,k,4));
1212
+ special_mul<fwd>(a2+a3,WA(1,i),CH(i,k,2));
1213
+ special_mul<fwd>(a2-a3,WA(5,i),CH(i,k,6));
1214
+ ROTX90<fwd>(a6);
1215
+ PMINPLACE(a4,a6);
1216
+ special_mul<fwd>(a4+a5,WA(0,i),CH(i,k,1));
1217
+ special_mul<fwd>(a4-a5,WA(4,i),CH(i,k,5));
1218
+ special_mul<fwd>(a6+a7,WA(2,i),CH(i,k,3));
1219
+ special_mul<fwd>(a6-a7,WA(6,i),CH(i,k,7));
1220
+ }
1221
+ }
1222
+ }
1223
+
1224
+
1225
+ #define POCKETFFT_PREP11(idx) \
1226
+ T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \
1227
+ PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \
1228
+ PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \
1229
+ PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \
1230
+ PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \
1231
+ PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \
1232
+ CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \
1233
+ CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i;
1234
+
1235
+ #define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \
1236
+ { \
1237
+ T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \
1238
+ cb; \
1239
+ cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \
1240
+ cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \
1241
+ PM(out1,out2,ca,cb); \
1242
+ }
1243
+ #define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \
1244
+ POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2))
1245
+ #define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \
1246
+ { \
1247
+ T da,db; \
1248
+ POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \
1249
+ special_mul<fwd>(da,WA(u1-1,i),CH(i,k,u1)); \
1250
+ special_mul<fwd>(db,WA(u2-1,i),CH(i,k,u2)); \
1251
+ }
1252
+
1253
+ template<bool fwd, typename T> void pass11 (size_t ido, size_t l1,
1254
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1255
+ const cmplx<T0> * POCKETFFT_RESTRICT wa) const
1256
+ {
1257
+ constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L),
1258
+ tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L),
1259
+ tw2r= T0(0.4154150130018864255292741492296232L),
1260
+ tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L),
1261
+ tw3r= T0(-0.1423148382732851404437926686163697L),
1262
+ tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L),
1263
+ tw4r= T0(-0.6548607339452850640569250724662936L),
1264
+ tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L),
1265
+ tw5r= T0(-0.9594929736144973898903680570663277L),
1266
+ tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L);
1267
+
1268
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1269
+ { return ch[a+ido*(b+l1*c)]; };
1270
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1271
+ { return cc[a+ido*(b+11*c)]; };
1272
+ auto WA = [wa, ido](size_t x, size_t i)
1273
+ { return wa[i-1+x*(ido-1)]; };
1274
+
1275
+ if (ido==1)
1276
+ for (size_t k=0; k<l1; ++k)
1277
+ {
1278
+ POCKETFFT_PREP11(0)
1279
+ POCKETFFT_PARTSTEP11a(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)
1280
+ POCKETFFT_PARTSTEP11a(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)
1281
+ POCKETFFT_PARTSTEP11a(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)
1282
+ POCKETFFT_PARTSTEP11a(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)
1283
+ POCKETFFT_PARTSTEP11a(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)
1284
+ }
1285
+ else
1286
+ for (size_t k=0; k<l1; ++k)
1287
+ {
1288
+ {
1289
+ POCKETFFT_PREP11(0)
1290
+ POCKETFFT_PARTSTEP11a(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)
1291
+ POCKETFFT_PARTSTEP11a(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)
1292
+ POCKETFFT_PARTSTEP11a(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)
1293
+ POCKETFFT_PARTSTEP11a(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)
1294
+ POCKETFFT_PARTSTEP11a(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)
1295
+ }
1296
+ for (size_t i=1; i<ido; ++i)
1297
+ {
1298
+ POCKETFFT_PREP11(i)
1299
+ POCKETFFT_PARTSTEP11(1,10,tw1r,tw2r,tw3r,tw4r,tw5r,+tw1i,+tw2i,+tw3i,+tw4i,+tw5i)
1300
+ POCKETFFT_PARTSTEP11(2, 9,tw2r,tw4r,tw5r,tw3r,tw1r,+tw2i,+tw4i,-tw5i,-tw3i,-tw1i)
1301
+ POCKETFFT_PARTSTEP11(3, 8,tw3r,tw5r,tw2r,tw1r,tw4r,+tw3i,-tw5i,-tw2i,+tw1i,+tw4i)
1302
+ POCKETFFT_PARTSTEP11(4, 7,tw4r,tw3r,tw1r,tw5r,tw2r,+tw4i,-tw3i,+tw1i,+tw5i,-tw2i)
1303
+ POCKETFFT_PARTSTEP11(5, 6,tw5r,tw1r,tw4r,tw2r,tw3r,+tw5i,-tw1i,+tw4i,-tw2i,+tw3i)
1304
+ }
1305
+ }
1306
+ }
1307
+
1308
+ #undef POCKETFFT_PARTSTEP11
1309
+ #undef POCKETFFT_PARTSTEP11a0
1310
+ #undef POCKETFFT_PARTSTEP11a
1311
+ #undef POCKETFFT_PREP11
1312
+
1313
+ template<bool fwd, typename T> void passg (size_t ido, size_t ip,
1314
+ size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1315
+ const cmplx<T0> * POCKETFFT_RESTRICT wa,
1316
+ const cmplx<T0> * POCKETFFT_RESTRICT csarr) const
1317
+ {
1318
+ const size_t cdim=ip;
1319
+ size_t ipph = (ip+1)/2;
1320
+ size_t idl1 = ido*l1;
1321
+
1322
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1323
+ { return ch[a+ido*(b+l1*c)]; };
1324
+ auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T&
1325
+ { return cc[a+ido*(b+cdim*c)]; };
1326
+ auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T&
1327
+ { return cc[a+ido*(b+l1*c)]; };
1328
+ auto CX2 = [cc, idl1](size_t a, size_t b) -> T&
1329
+ { return cc[a+idl1*b]; };
1330
+ auto CH2 = [ch, idl1](size_t a, size_t b) -> const T&
1331
+ { return ch[a+idl1*b]; };
1332
+
1333
+ arr<cmplx<T0>> wal(ip);
1334
+ wal[0] = cmplx<T0>(1., 0.);
1335
+ for (size_t i=1; i<ip; ++i)
1336
+ wal[i]=cmplx<T0>(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i);
1337
+
1338
+ for (size_t k=0; k<l1; ++k)
1339
+ for (size_t i=0; i<ido; ++i)
1340
+ CH(i,k,0) = CC(i,0,k);
1341
+ for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)
1342
+ for (size_t k=0; k<l1; ++k)
1343
+ for (size_t i=0; i<ido; ++i)
1344
+ PM(CH(i,k,j),CH(i,k,jc),CC(i,j,k),CC(i,jc,k));
1345
+ for (size_t k=0; k<l1; ++k)
1346
+ for (size_t i=0; i<ido; ++i)
1347
+ {
1348
+ T tmp = CH(i,k,0);
1349
+ for (size_t j=1; j<ipph; ++j)
1350
+ tmp+=CH(i,k,j);
1351
+ CX(i,k,0) = tmp;
1352
+ }
1353
+ for (size_t l=1, lc=ip-1; l<ipph; ++l, --lc)
1354
+ {
1355
+ // j=0
1356
+ for (size_t ik=0; ik<idl1; ++ik)
1357
+ {
1358
+ CX2(ik,l).r = CH2(ik,0).r+wal[l].r*CH2(ik,1).r+wal[2*l].r*CH2(ik,2).r;
1359
+ CX2(ik,l).i = CH2(ik,0).i+wal[l].r*CH2(ik,1).i+wal[2*l].r*CH2(ik,2).i;
1360
+ CX2(ik,lc).r=-wal[l].i*CH2(ik,ip-1).i-wal[2*l].i*CH2(ik,ip-2).i;
1361
+ CX2(ik,lc).i=wal[l].i*CH2(ik,ip-1).r+wal[2*l].i*CH2(ik,ip-2).r;
1362
+ }
1363
+
1364
+ size_t iwal=2*l;
1365
+ size_t j=3, jc=ip-3;
1366
+ for (; j<ipph-1; j+=2, jc-=2)
1367
+ {
1368
+ iwal+=l; if (iwal>ip) iwal-=ip;
1369
+ cmplx<T0> xwal=wal[iwal];
1370
+ iwal+=l; if (iwal>ip) iwal-=ip;
1371
+ cmplx<T0> xwal2=wal[iwal];
1372
+ for (size_t ik=0; ik<idl1; ++ik)
1373
+ {
1374
+ CX2(ik,l).r += CH2(ik,j).r*xwal.r+CH2(ik,j+1).r*xwal2.r;
1375
+ CX2(ik,l).i += CH2(ik,j).i*xwal.r+CH2(ik,j+1).i*xwal2.r;
1376
+ CX2(ik,lc).r -= CH2(ik,jc).i*xwal.i+CH2(ik,jc-1).i*xwal2.i;
1377
+ CX2(ik,lc).i += CH2(ik,jc).r*xwal.i+CH2(ik,jc-1).r*xwal2.i;
1378
+ }
1379
+ }
1380
+ for (; j<ipph; ++j, --jc)
1381
+ {
1382
+ iwal+=l; if (iwal>ip) iwal-=ip;
1383
+ cmplx<T0> xwal=wal[iwal];
1384
+ for (size_t ik=0; ik<idl1; ++ik)
1385
+ {
1386
+ CX2(ik,l).r += CH2(ik,j).r*xwal.r;
1387
+ CX2(ik,l).i += CH2(ik,j).i*xwal.r;
1388
+ CX2(ik,lc).r -= CH2(ik,jc).i*xwal.i;
1389
+ CX2(ik,lc).i += CH2(ik,jc).r*xwal.i;
1390
+ }
1391
+ }
1392
+ }
1393
+
1394
+ // shuffling and twiddling
1395
+ if (ido==1)
1396
+ for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc)
1397
+ for (size_t ik=0; ik<idl1; ++ik)
1398
+ {
1399
+ T t1=CX2(ik,j), t2=CX2(ik,jc);
1400
+ PM(CX2(ik,j),CX2(ik,jc),t1,t2);
1401
+ }
1402
+ else
1403
+ {
1404
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc)
1405
+ for (size_t k=0; k<l1; ++k)
1406
+ {
1407
+ T t1=CX(0,k,j), t2=CX(0,k,jc);
1408
+ PM(CX(0,k,j),CX(0,k,jc),t1,t2);
1409
+ for (size_t i=1; i<ido; ++i)
1410
+ {
1411
+ T x1, x2;
1412
+ PM(x1,x2,CX(i,k,j),CX(i,k,jc));
1413
+ size_t idij=(j-1)*(ido-1)+i-1;
1414
+ special_mul<fwd>(x1,wa[idij],CX(i,k,j));
1415
+ idij=(jc-1)*(ido-1)+i-1;
1416
+ special_mul<fwd>(x2,wa[idij],CX(i,k,jc));
1417
+ }
1418
+ }
1419
+ }
1420
+ }
1421
+
1422
+ template<bool fwd, typename T> void pass_all(T c[], T0 fct) const
1423
+ {
1424
+ if (length==1) { c[0]*=fct; return; }
1425
+ size_t l1=1;
1426
+ arr<T> ch(length);
1427
+ T *p1=c, *p2=ch.data();
1428
+
1429
+ for(size_t k1=0; k1<fact.size(); k1++)
1430
+ {
1431
+ size_t ip=fact[k1].fct;
1432
+ size_t l2=ip*l1;
1433
+ size_t ido = length/l2;
1434
+ if (ip==4)
1435
+ pass4<fwd> (ido, l1, p1, p2, fact[k1].tw);
1436
+ else if(ip==8)
1437
+ pass8<fwd>(ido, l1, p1, p2, fact[k1].tw);
1438
+ else if(ip==2)
1439
+ pass2<fwd>(ido, l1, p1, p2, fact[k1].tw);
1440
+ else if(ip==3)
1441
+ pass3<fwd> (ido, l1, p1, p2, fact[k1].tw);
1442
+ else if(ip==5)
1443
+ pass5<fwd> (ido, l1, p1, p2, fact[k1].tw);
1444
+ else if(ip==7)
1445
+ pass7<fwd> (ido, l1, p1, p2, fact[k1].tw);
1446
+ else if(ip==11)
1447
+ pass11<fwd> (ido, l1, p1, p2, fact[k1].tw);
1448
+ else
1449
+ {
1450
+ passg<fwd>(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws);
1451
+ std::swap(p1,p2);
1452
+ }
1453
+ std::swap(p1,p2);
1454
+ l1=l2;
1455
+ }
1456
+ if (p1!=c)
1457
+ {
1458
+ if (fct!=1.)
1459
+ for (size_t i=0; i<length; ++i)
1460
+ c[i] = ch[i]*fct;
1461
+ else
1462
+ std::copy_n (p1, length, c);
1463
+ }
1464
+ else
1465
+ if (fct!=1.)
1466
+ for (size_t i=0; i<length; ++i)
1467
+ c[i] *= fct;
1468
+ }
1469
+
1470
+ public:
1471
+ template<typename T> void exec(T c[], T0 fct, bool fwd) const
1472
+ { fwd ? pass_all<true>(c, fct) : pass_all<false>(c, fct); }
1473
+
1474
+ private:
1475
+ POCKETFFT_NOINLINE void factorize()
1476
+ {
1477
+ size_t len=length;
1478
+ while ((len&7)==0)
1479
+ { add_factor(8); len>>=3; }
1480
+ while ((len&3)==0)
1481
+ { add_factor(4); len>>=2; }
1482
+ if ((len&1)==0)
1483
+ {
1484
+ len>>=1;
1485
+ // factor 2 should be at the front of the factor list
1486
+ add_factor(2);
1487
+ std::swap(fact[0].fct, fact.back().fct);
1488
+ }
1489
+ for (size_t divisor=3; divisor*divisor<=len; divisor+=2)
1490
+ while ((len%divisor)==0)
1491
+ {
1492
+ add_factor(divisor);
1493
+ len/=divisor;
1494
+ }
1495
+ if (len>1) add_factor(len);
1496
+ }
1497
+
1498
+ size_t twsize() const
1499
+ {
1500
+ size_t twsize=0, l1=1;
1501
+ for (size_t k=0; k<fact.size(); ++k)
1502
+ {
1503
+ size_t ip=fact[k].fct, ido= length/(l1*ip);
1504
+ twsize+=(ip-1)*(ido-1);
1505
+ if (ip>11)
1506
+ twsize+=ip;
1507
+ l1*=ip;
1508
+ }
1509
+ return twsize;
1510
+ }
1511
+
1512
+ void comp_twiddle()
1513
+ {
1514
+ sincos_2pibyn<T0> twiddle(length);
1515
+ size_t l1=1;
1516
+ size_t memofs=0;
1517
+ for (size_t k=0; k<fact.size(); ++k)
1518
+ {
1519
+ size_t ip=fact[k].fct, ido=length/(l1*ip);
1520
+ fact[k].tw=mem.data()+memofs;
1521
+ memofs+=(ip-1)*(ido-1);
1522
+ for (size_t j=1; j<ip; ++j)
1523
+ for (size_t i=1; i<ido; ++i)
1524
+ fact[k].tw[(j-1)*(ido-1)+i-1] = twiddle[j*l1*i];
1525
+ if (ip>11)
1526
+ {
1527
+ fact[k].tws=mem.data()+memofs;
1528
+ memofs+=ip;
1529
+ for (size_t j=0; j<ip; ++j)
1530
+ fact[k].tws[j] = twiddle[j*l1*ido];
1531
+ }
1532
+ l1*=ip;
1533
+ }
1534
+ }
1535
+
1536
+ public:
1537
+ POCKETFFT_NOINLINE cfftp(size_t length_)
1538
+ : length(length_)
1539
+ {
1540
+ if (length==0) throw std::runtime_error("zero-length FFT requested");
1541
+ if (length==1) return;
1542
+ factorize();
1543
+ mem.resize(twsize());
1544
+ comp_twiddle();
1545
+ }
1546
+ };
1547
+
1548
+ //
1549
+ // real-valued FFTPACK transforms
1550
+ //
1551
+
1552
+ template<typename T0> class rfftp
1553
+ {
1554
+ private:
1555
+ struct fctdata
1556
+ {
1557
+ size_t fct;
1558
+ T0 *tw, *tws;
1559
+ };
1560
+
1561
+ size_t length;
1562
+ arr<T0> mem;
1563
+ std::vector<fctdata> fact;
1564
+
1565
+ void add_factor(size_t factor)
1566
+ { fact.push_back({factor, nullptr, nullptr}); }
1567
+
1568
+ /* (a+ib) = conj(c+id) * (e+if) */
1569
+ template<typename T1, typename T2, typename T3> inline void MULPM
1570
+ (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const
1571
+ { a=c*e+d*f; b=c*f-d*e; }
1572
+
1573
+ template<typename T> void radf2 (size_t ido, size_t l1,
1574
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1575
+ const T0 * POCKETFFT_RESTRICT wa) const
1576
+ {
1577
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1578
+ auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&
1579
+ { return cc[a+ido*(b+l1*c)]; };
1580
+ auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&
1581
+ { return ch[a+ido*(b+2*c)]; };
1582
+
1583
+ for (size_t k=0; k<l1; k++)
1584
+ PM (CH(0,0,k),CH(ido-1,1,k),CC(0,k,0),CC(0,k,1));
1585
+ if ((ido&1)==0)
1586
+ for (size_t k=0; k<l1; k++)
1587
+ {
1588
+ CH( 0,1,k) = -CC(ido-1,k,1);
1589
+ CH(ido-1,0,k) = CC(ido-1,k,0);
1590
+ }
1591
+ if (ido<=2) return;
1592
+ for (size_t k=0; k<l1; k++)
1593
+ for (size_t i=2; i<ido; i+=2)
1594
+ {
1595
+ size_t ic=ido-i;
1596
+ T tr2, ti2;
1597
+ MULPM (tr2,ti2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));
1598
+ PM (CH(i-1,0,k),CH(ic-1,1,k),CC(i-1,k,0),tr2);
1599
+ PM (CH(i ,0,k),CH(ic ,1,k),ti2,CC(i ,k,0));
1600
+ }
1601
+ }
1602
+
1603
+ // a2=a+b; b2=i*(b-a);
1604
+ #define POCKETFFT_REARRANGE(rx, ix, ry, iy) \
1605
+ {\
1606
+ auto t1=rx+ry, t2=ry-rx, t3=ix+iy, t4=ix-iy; \
1607
+ rx=t1; ix=t3; ry=t4; iy=t2; \
1608
+ }
1609
+
1610
+ template<typename T> void radf3(size_t ido, size_t l1,
1611
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1612
+ const T0 * POCKETFFT_RESTRICT wa) const
1613
+ {
1614
+ constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L);
1615
+
1616
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1617
+ auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&
1618
+ { return cc[a+ido*(b+l1*c)]; };
1619
+ auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&
1620
+ { return ch[a+ido*(b+3*c)]; };
1621
+
1622
+ for (size_t k=0; k<l1; k++)
1623
+ {
1624
+ T cr2=CC(0,k,1)+CC(0,k,2);
1625
+ CH(0,0,k) = CC(0,k,0)+cr2;
1626
+ CH(0,2,k) = taui*(CC(0,k,2)-CC(0,k,1));
1627
+ CH(ido-1,1,k) = CC(0,k,0)+taur*cr2;
1628
+ }
1629
+ if (ido==1) return;
1630
+ for (size_t k=0; k<l1; k++)
1631
+ for (size_t i=2; i<ido; i+=2)
1632
+ {
1633
+ size_t ic=ido-i;
1634
+ T di2, di3, dr2, dr3;
1635
+ MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1)); // d2=conj(WA0)*CC1
1636
+ MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2)); // d3=conj(WA1)*CC2
1637
+ POCKETFFT_REARRANGE(dr2, di2, dr3, di3);
1638
+ CH(i-1,0,k) = CC(i-1,k,0)+dr2; // c add
1639
+ CH(i ,0,k) = CC(i ,k,0)+di2;
1640
+ T tr2 = CC(i-1,k,0)+taur*dr2; // c add
1641
+ T ti2 = CC(i ,k,0)+taur*di2;
1642
+ T tr3 = taui*dr3; // t3 = taui*i*(d3-d2)?
1643
+ T ti3 = taui*di3;
1644
+ PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr3); // PM(i) = t2+t3
1645
+ PM(CH(i ,2,k),CH(ic ,1,k),ti3,ti2); // PM(ic) = conj(t2-t3)
1646
+ }
1647
+ }
1648
+
1649
+ template<typename T> void radf4(size_t ido, size_t l1,
1650
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1651
+ const T0 * POCKETFFT_RESTRICT wa) const
1652
+ {
1653
+ constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L);
1654
+
1655
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1656
+ auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&
1657
+ { return cc[a+ido*(b+l1*c)]; };
1658
+ auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&
1659
+ { return ch[a+ido*(b+4*c)]; };
1660
+
1661
+ for (size_t k=0; k<l1; k++)
1662
+ {
1663
+ T tr1,tr2;
1664
+ PM (tr1,CH(0,2,k),CC(0,k,3),CC(0,k,1));
1665
+ PM (tr2,CH(ido-1,1,k),CC(0,k,0),CC(0,k,2));
1666
+ PM (CH(0,0,k),CH(ido-1,3,k),tr2,tr1);
1667
+ }
1668
+ if ((ido&1)==0)
1669
+ for (size_t k=0; k<l1; k++)
1670
+ {
1671
+ T ti1=-hsqt2*(CC(ido-1,k,1)+CC(ido-1,k,3));
1672
+ T tr1= hsqt2*(CC(ido-1,k,1)-CC(ido-1,k,3));
1673
+ PM (CH(ido-1,0,k),CH(ido-1,2,k),CC(ido-1,k,0),tr1);
1674
+ PM (CH( 0,3,k),CH( 0,1,k),ti1,CC(ido-1,k,2));
1675
+ }
1676
+ if (ido<=2) return;
1677
+ for (size_t k=0; k<l1; k++)
1678
+ for (size_t i=2; i<ido; i+=2)
1679
+ {
1680
+ size_t ic=ido-i;
1681
+ T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;
1682
+ MULPM(cr2,ci2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));
1683
+ MULPM(cr3,ci3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2));
1684
+ MULPM(cr4,ci4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3));
1685
+ PM(tr1,tr4,cr4,cr2);
1686
+ PM(ti1,ti4,ci2,ci4);
1687
+ PM(tr2,tr3,CC(i-1,k,0),cr3);
1688
+ PM(ti2,ti3,CC(i ,k,0),ci3);
1689
+ PM(CH(i-1,0,k),CH(ic-1,3,k),tr2,tr1);
1690
+ PM(CH(i ,0,k),CH(ic ,3,k),ti1,ti2);
1691
+ PM(CH(i-1,2,k),CH(ic-1,1,k),tr3,ti4);
1692
+ PM(CH(i ,2,k),CH(ic ,1,k),tr4,ti3);
1693
+ }
1694
+ }
1695
+
1696
+ template<typename T> void radf5(size_t ido, size_t l1,
1697
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1698
+ const T0 * POCKETFFT_RESTRICT wa) const
1699
+ {
1700
+ constexpr T0 tr11= T0(0.3090169943749474241022934171828191L),
1701
+ ti11= T0(0.9510565162951535721164393333793821L),
1702
+ tr12= T0(-0.8090169943749474241022934171828191L),
1703
+ ti12= T0(0.5877852522924731291687059546390728L);
1704
+
1705
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1706
+ auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&
1707
+ { return cc[a+ido*(b+l1*c)]; };
1708
+ auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T&
1709
+ { return ch[a+ido*(b+5*c)]; };
1710
+
1711
+ for (size_t k=0; k<l1; k++)
1712
+ {
1713
+ T cr2, cr3, ci4, ci5;
1714
+ PM (cr2,ci5,CC(0,k,4),CC(0,k,1));
1715
+ PM (cr3,ci4,CC(0,k,3),CC(0,k,2));
1716
+ CH(0,0,k)=CC(0,k,0)+cr2+cr3;
1717
+ CH(ido-1,1,k)=CC(0,k,0)+tr11*cr2+tr12*cr3;
1718
+ CH(0,2,k)=ti11*ci5+ti12*ci4;
1719
+ CH(ido-1,3,k)=CC(0,k,0)+tr12*cr2+tr11*cr3;
1720
+ CH(0,4,k)=ti12*ci5-ti11*ci4;
1721
+ }
1722
+ if (ido==1) return;
1723
+ for (size_t k=0; k<l1;++k)
1724
+ for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)
1725
+ {
1726
+ T di2, di3, di4, di5, dr2, dr3, dr4, dr5;
1727
+ MULPM (dr2,di2,WA(0,i-2),WA(0,i-1),CC(i-1,k,1),CC(i,k,1));
1728
+ MULPM (dr3,di3,WA(1,i-2),WA(1,i-1),CC(i-1,k,2),CC(i,k,2));
1729
+ MULPM (dr4,di4,WA(2,i-2),WA(2,i-1),CC(i-1,k,3),CC(i,k,3));
1730
+ MULPM (dr5,di5,WA(3,i-2),WA(3,i-1),CC(i-1,k,4),CC(i,k,4));
1731
+ POCKETFFT_REARRANGE(dr2, di2, dr5, di5);
1732
+ POCKETFFT_REARRANGE(dr3, di3, dr4, di4);
1733
+ CH(i-1,0,k)=CC(i-1,k,0)+dr2+dr3;
1734
+ CH(i ,0,k)=CC(i ,k,0)+di2+di3;
1735
+ T tr2=CC(i-1,k,0)+tr11*dr2+tr12*dr3;
1736
+ T ti2=CC(i ,k,0)+tr11*di2+tr12*di3;
1737
+ T tr3=CC(i-1,k,0)+tr12*dr2+tr11*dr3;
1738
+ T ti3=CC(i ,k,0)+tr12*di2+tr11*di3;
1739
+ T tr5 = ti11*dr5 + ti12*dr4;
1740
+ T ti5 = ti11*di5 + ti12*di4;
1741
+ T tr4 = ti12*dr5 - ti11*dr4;
1742
+ T ti4 = ti12*di5 - ti11*di4;
1743
+ PM(CH(i-1,2,k),CH(ic-1,1,k),tr2,tr5);
1744
+ PM(CH(i ,2,k),CH(ic ,1,k),ti5,ti2);
1745
+ PM(CH(i-1,4,k),CH(ic-1,3,k),tr3,tr4);
1746
+ PM(CH(i ,4,k),CH(ic ,3,k),ti4,ti3);
1747
+ }
1748
+ }
1749
+
1750
+ #undef POCKETFFT_REARRANGE
1751
+
1752
+ template<typename T> void radfg(size_t ido, size_t ip, size_t l1,
1753
+ T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1754
+ const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const
1755
+ {
1756
+ const size_t cdim=ip;
1757
+ size_t ipph=(ip+1)/2;
1758
+ size_t idl1 = ido*l1;
1759
+
1760
+ auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T&
1761
+ { return cc[a+ido*(b+cdim*c)]; };
1762
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T&
1763
+ { return ch[a+ido*(b+l1*c)]; };
1764
+ auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T&
1765
+ { return cc[a+ido*(b+l1*c)]; };
1766
+ auto C2 = [cc,idl1] (size_t a, size_t b) -> T&
1767
+ { return cc[a+idl1*b]; };
1768
+ auto CH2 = [ch,idl1] (size_t a, size_t b) -> T&
1769
+ { return ch[a+idl1*b]; };
1770
+
1771
+ if (ido>1)
1772
+ {
1773
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 114
1774
+ {
1775
+ size_t is=(j-1)*(ido-1),
1776
+ is2=(jc-1)*(ido-1);
1777
+ for (size_t k=0; k<l1; ++k) // 113
1778
+ {
1779
+ size_t idij=is;
1780
+ size_t idij2=is2;
1781
+ for (size_t i=1; i<=ido-2; i+=2) // 112
1782
+ {
1783
+ T t1=C1(i,k,j ), t2=C1(i+1,k,j ),
1784
+ t3=C1(i,k,jc), t4=C1(i+1,k,jc);
1785
+ T x1=wa[idij]*t1 + wa[idij+1]*t2,
1786
+ x2=wa[idij]*t2 - wa[idij+1]*t1,
1787
+ x3=wa[idij2]*t3 + wa[idij2+1]*t4,
1788
+ x4=wa[idij2]*t4 - wa[idij2+1]*t3;
1789
+ PM(C1(i,k,j),C1(i+1,k,jc),x3,x1);
1790
+ PM(C1(i+1,k,j),C1(i,k,jc),x2,x4);
1791
+ idij+=2;
1792
+ idij2+=2;
1793
+ }
1794
+ }
1795
+ }
1796
+ }
1797
+
1798
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 123
1799
+ for (size_t k=0; k<l1; ++k) // 122
1800
+ MPINPLACE(C1(0,k,jc), C1(0,k,j));
1801
+
1802
+ //everything in C
1803
+ //memset(ch,0,ip*l1*ido*sizeof(double));
1804
+
1805
+ for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc) // 127
1806
+ {
1807
+ for (size_t ik=0; ik<idl1; ++ik) // 124
1808
+ {
1809
+ CH2(ik,l ) = C2(ik,0)+csarr[2*l]*C2(ik,1)+csarr[4*l]*C2(ik,2);
1810
+ CH2(ik,lc) = csarr[2*l+1]*C2(ik,ip-1)+csarr[4*l+1]*C2(ik,ip-2);
1811
+ }
1812
+ size_t iang = 2*l;
1813
+ size_t j=3, jc=ip-3;
1814
+ for (; j<ipph-3; j+=4,jc-=4) // 126
1815
+ {
1816
+ iang+=l; if (iang>=ip) iang-=ip;
1817
+ T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
1818
+ iang+=l; if (iang>=ip) iang-=ip;
1819
+ T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
1820
+ iang+=l; if (iang>=ip) iang-=ip;
1821
+ T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];
1822
+ iang+=l; if (iang>=ip) iang-=ip;
1823
+ T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];
1824
+ for (size_t ik=0; ik<idl1; ++ik) // 125
1825
+ {
1826
+ CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1)
1827
+ +ar3*C2(ik,j +2)+ar4*C2(ik,j +3);
1828
+ CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1)
1829
+ +ai3*C2(ik,jc-2)+ai4*C2(ik,jc-3);
1830
+ }
1831
+ }
1832
+ for (; j<ipph-1; j+=2,jc-=2) // 126
1833
+ {
1834
+ iang+=l; if (iang>=ip) iang-=ip;
1835
+ T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
1836
+ iang+=l; if (iang>=ip) iang-=ip;
1837
+ T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
1838
+ for (size_t ik=0; ik<idl1; ++ik) // 125
1839
+ {
1840
+ CH2(ik,l ) += ar1*C2(ik,j )+ar2*C2(ik,j +1);
1841
+ CH2(ik,lc) += ai1*C2(ik,jc)+ai2*C2(ik,jc-1);
1842
+ }
1843
+ }
1844
+ for (; j<ipph; ++j,--jc) // 126
1845
+ {
1846
+ iang+=l; if (iang>=ip) iang-=ip;
1847
+ T0 ar=csarr[2*iang], ai=csarr[2*iang+1];
1848
+ for (size_t ik=0; ik<idl1; ++ik) // 125
1849
+ {
1850
+ CH2(ik,l ) += ar*C2(ik,j );
1851
+ CH2(ik,lc) += ai*C2(ik,jc);
1852
+ }
1853
+ }
1854
+ }
1855
+ for (size_t ik=0; ik<idl1; ++ik) // 101
1856
+ CH2(ik,0) = C2(ik,0);
1857
+ for (size_t j=1; j<ipph; ++j) // 129
1858
+ for (size_t ik=0; ik<idl1; ++ik) // 128
1859
+ CH2(ik,0) += C2(ik,j);
1860
+
1861
+ // everything in CH at this point!
1862
+ //memset(cc,0,ip*l1*ido*sizeof(double));
1863
+
1864
+ for (size_t k=0; k<l1; ++k) // 131
1865
+ for (size_t i=0; i<ido; ++i) // 130
1866
+ CC(i,0,k) = CH(i,k,0);
1867
+
1868
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 137
1869
+ {
1870
+ size_t j2=2*j-1;
1871
+ for (size_t k=0; k<l1; ++k) // 136
1872
+ {
1873
+ CC(ido-1,j2,k) = CH(0,k,j);
1874
+ CC(0,j2+1,k) = CH(0,k,jc);
1875
+ }
1876
+ }
1877
+
1878
+ if (ido==1) return;
1879
+
1880
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 140
1881
+ {
1882
+ size_t j2=2*j-1;
1883
+ for(size_t k=0; k<l1; ++k) // 139
1884
+ for(size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2) // 138
1885
+ {
1886
+ CC(i ,j2+1,k) = CH(i ,k,j )+CH(i ,k,jc);
1887
+ CC(ic ,j2 ,k) = CH(i ,k,j )-CH(i ,k,jc);
1888
+ CC(i+1 ,j2+1,k) = CH(i+1,k,j )+CH(i+1,k,jc);
1889
+ CC(ic+1,j2 ,k) = CH(i+1,k,jc)-CH(i+1,k,j );
1890
+ }
1891
+ }
1892
+ }
1893
+
1894
+ template<typename T> void radb2(size_t ido, size_t l1,
1895
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1896
+ const T0 * POCKETFFT_RESTRICT wa) const
1897
+ {
1898
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1899
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1900
+ { return cc[a+ido*(b+2*c)]; };
1901
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1902
+ { return ch[a+ido*(b+l1*c)]; };
1903
+
1904
+ for (size_t k=0; k<l1; k++)
1905
+ PM (CH(0,k,0),CH(0,k,1),CC(0,0,k),CC(ido-1,1,k));
1906
+ if ((ido&1)==0)
1907
+ for (size_t k=0; k<l1; k++)
1908
+ {
1909
+ CH(ido-1,k,0) = 2*CC(ido-1,0,k);
1910
+ CH(ido-1,k,1) =-2*CC(0 ,1,k);
1911
+ }
1912
+ if (ido<=2) return;
1913
+ for (size_t k=0; k<l1;++k)
1914
+ for (size_t i=2; i<ido; i+=2)
1915
+ {
1916
+ size_t ic=ido-i;
1917
+ T ti2, tr2;
1918
+ PM (CH(i-1,k,0),tr2,CC(i-1,0,k),CC(ic-1,1,k));
1919
+ PM (ti2,CH(i ,k,0),CC(i ,0,k),CC(ic ,1,k));
1920
+ MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ti2,tr2);
1921
+ }
1922
+ }
1923
+
1924
+ template<typename T> void radb3(size_t ido, size_t l1,
1925
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1926
+ const T0 * POCKETFFT_RESTRICT wa) const
1927
+ {
1928
+ constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L);
1929
+
1930
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1931
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1932
+ { return cc[a+ido*(b+3*c)]; };
1933
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1934
+ { return ch[a+ido*(b+l1*c)]; };
1935
+
1936
+ for (size_t k=0; k<l1; k++)
1937
+ {
1938
+ T tr2=2*CC(ido-1,1,k);
1939
+ T cr2=CC(0,0,k)+taur*tr2;
1940
+ CH(0,k,0)=CC(0,0,k)+tr2;
1941
+ T ci3=2*taui*CC(0,2,k);
1942
+ PM (CH(0,k,2),CH(0,k,1),cr2,ci3);
1943
+ }
1944
+ if (ido==1) return;
1945
+ for (size_t k=0; k<l1; k++)
1946
+ for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)
1947
+ {
1948
+ T tr2=CC(i-1,2,k)+CC(ic-1,1,k); // t2=CC(I) + conj(CC(ic))
1949
+ T ti2=CC(i ,2,k)-CC(ic ,1,k);
1950
+ T cr2=CC(i-1,0,k)+taur*tr2; // c2=CC +taur*t2
1951
+ T ci2=CC(i ,0,k)+taur*ti2;
1952
+ CH(i-1,k,0)=CC(i-1,0,k)+tr2; // CH=CC+t2
1953
+ CH(i ,k,0)=CC(i ,0,k)+ti2;
1954
+ T cr3=taui*(CC(i-1,2,k)-CC(ic-1,1,k));// c3=taui*(CC(i)-conj(CC(ic)))
1955
+ T ci3=taui*(CC(i ,2,k)+CC(ic ,1,k));
1956
+ T di2, di3, dr2, dr3;
1957
+ PM(dr3,dr2,cr2,ci3); // d2= (cr2-ci3, ci2+cr3) = c2+i*c3
1958
+ PM(di2,di3,ci2,cr3); // d3= (cr2+ci3, ci2-cr3) = c2-i*c3
1959
+ MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2); // ch = WA*d2
1960
+ MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3);
1961
+ }
1962
+ }
1963
+
1964
+ template<typename T> void radb4(size_t ido, size_t l1,
1965
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
1966
+ const T0 * POCKETFFT_RESTRICT wa) const
1967
+ {
1968
+ constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
1969
+
1970
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
1971
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
1972
+ { return cc[a+ido*(b+4*c)]; };
1973
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
1974
+ { return ch[a+ido*(b+l1*c)]; };
1975
+
1976
+ for (size_t k=0; k<l1; k++)
1977
+ {
1978
+ T tr1, tr2;
1979
+ PM (tr2,tr1,CC(0,0,k),CC(ido-1,3,k));
1980
+ T tr3=2*CC(ido-1,1,k);
1981
+ T tr4=2*CC(0,2,k);
1982
+ PM (CH(0,k,0),CH(0,k,2),tr2,tr3);
1983
+ PM (CH(0,k,3),CH(0,k,1),tr1,tr4);
1984
+ }
1985
+ if ((ido&1)==0)
1986
+ for (size_t k=0; k<l1; k++)
1987
+ {
1988
+ T tr1,tr2,ti1,ti2;
1989
+ PM (ti1,ti2,CC(0 ,3,k),CC(0 ,1,k));
1990
+ PM (tr2,tr1,CC(ido-1,0,k),CC(ido-1,2,k));
1991
+ CH(ido-1,k,0)=tr2+tr2;
1992
+ CH(ido-1,k,1)=sqrt2*(tr1-ti1);
1993
+ CH(ido-1,k,2)=ti2+ti2;
1994
+ CH(ido-1,k,3)=-sqrt2*(tr1+ti1);
1995
+ }
1996
+ if (ido<=2) return;
1997
+ for (size_t k=0; k<l1;++k)
1998
+ for (size_t i=2; i<ido; i+=2)
1999
+ {
2000
+ T ci2, ci3, ci4, cr2, cr3, cr4, ti1, ti2, ti3, ti4, tr1, tr2, tr3, tr4;
2001
+ size_t ic=ido-i;
2002
+ PM (tr2,tr1,CC(i-1,0,k),CC(ic-1,3,k));
2003
+ PM (ti1,ti2,CC(i ,0,k),CC(ic ,3,k));
2004
+ PM (tr4,ti3,CC(i ,2,k),CC(ic ,1,k));
2005
+ PM (tr3,ti4,CC(i-1,2,k),CC(ic-1,1,k));
2006
+ PM (CH(i-1,k,0),cr3,tr2,tr3);
2007
+ PM (CH(i ,k,0),ci3,ti2,ti3);
2008
+ PM (cr4,cr2,tr1,tr4);
2009
+ PM (ci2,ci4,ti1,ti4);
2010
+ MULPM (CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),ci2,cr2);
2011
+ MULPM (CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),ci3,cr3);
2012
+ MULPM (CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),ci4,cr4);
2013
+ }
2014
+ }
2015
+
2016
+ template<typename T> void radb5(size_t ido, size_t l1,
2017
+ const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
2018
+ const T0 * POCKETFFT_RESTRICT wa) const
2019
+ {
2020
+ constexpr T0 tr11= T0(0.3090169943749474241022934171828191L),
2021
+ ti11= T0(0.9510565162951535721164393333793821L),
2022
+ tr12= T0(-0.8090169943749474241022934171828191L),
2023
+ ti12= T0(0.5877852522924731291687059546390728L);
2024
+
2025
+ auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; };
2026
+ auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T&
2027
+ { return cc[a+ido*(b+5*c)]; };
2028
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
2029
+ { return ch[a+ido*(b+l1*c)]; };
2030
+
2031
+ for (size_t k=0; k<l1; k++)
2032
+ {
2033
+ T ti5=CC(0,2,k)+CC(0,2,k);
2034
+ T ti4=CC(0,4,k)+CC(0,4,k);
2035
+ T tr2=CC(ido-1,1,k)+CC(ido-1,1,k);
2036
+ T tr3=CC(ido-1,3,k)+CC(ido-1,3,k);
2037
+ CH(0,k,0)=CC(0,0,k)+tr2+tr3;
2038
+ T cr2=CC(0,0,k)+tr11*tr2+tr12*tr3;
2039
+ T cr3=CC(0,0,k)+tr12*tr2+tr11*tr3;
2040
+ T ci4, ci5;
2041
+ MULPM(ci5,ci4,ti5,ti4,ti11,ti12);
2042
+ PM(CH(0,k,4),CH(0,k,1),cr2,ci5);
2043
+ PM(CH(0,k,3),CH(0,k,2),cr3,ci4);
2044
+ }
2045
+ if (ido==1) return;
2046
+ for (size_t k=0; k<l1;++k)
2047
+ for (size_t i=2, ic=ido-2; i<ido; i+=2, ic-=2)
2048
+ {
2049
+ T tr2, tr3, tr4, tr5, ti2, ti3, ti4, ti5;
2050
+ PM(tr2,tr5,CC(i-1,2,k),CC(ic-1,1,k));
2051
+ PM(ti5,ti2,CC(i ,2,k),CC(ic ,1,k));
2052
+ PM(tr3,tr4,CC(i-1,4,k),CC(ic-1,3,k));
2053
+ PM(ti4,ti3,CC(i ,4,k),CC(ic ,3,k));
2054
+ CH(i-1,k,0)=CC(i-1,0,k)+tr2+tr3;
2055
+ CH(i ,k,0)=CC(i ,0,k)+ti2+ti3;
2056
+ T cr2=CC(i-1,0,k)+tr11*tr2+tr12*tr3;
2057
+ T ci2=CC(i ,0,k)+tr11*ti2+tr12*ti3;
2058
+ T cr3=CC(i-1,0,k)+tr12*tr2+tr11*tr3;
2059
+ T ci3=CC(i ,0,k)+tr12*ti2+tr11*ti3;
2060
+ T ci4, ci5, cr5, cr4;
2061
+ MULPM(cr5,cr4,tr5,tr4,ti11,ti12);
2062
+ MULPM(ci5,ci4,ti5,ti4,ti11,ti12);
2063
+ T dr2, dr3, dr4, dr5, di2, di3, di4, di5;
2064
+ PM(dr4,dr3,cr3,ci4);
2065
+ PM(di3,di4,ci3,cr4);
2066
+ PM(dr5,dr2,cr2,ci5);
2067
+ PM(di2,di5,ci2,cr5);
2068
+ MULPM(CH(i,k,1),CH(i-1,k,1),WA(0,i-2),WA(0,i-1),di2,dr2);
2069
+ MULPM(CH(i,k,2),CH(i-1,k,2),WA(1,i-2),WA(1,i-1),di3,dr3);
2070
+ MULPM(CH(i,k,3),CH(i-1,k,3),WA(2,i-2),WA(2,i-1),di4,dr4);
2071
+ MULPM(CH(i,k,4),CH(i-1,k,4),WA(3,i-2),WA(3,i-1),di5,dr5);
2072
+ }
2073
+ }
2074
+
2075
+ template<typename T> void radbg(size_t ido, size_t ip, size_t l1,
2076
+ T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch,
2077
+ const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const
2078
+ {
2079
+ const size_t cdim=ip;
2080
+ size_t ipph=(ip+1)/ 2;
2081
+ size_t idl1 = ido*l1;
2082
+
2083
+ auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T&
2084
+ { return cc[a+ido*(b+cdim*c)]; };
2085
+ auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T&
2086
+ { return ch[a+ido*(b+l1*c)]; };
2087
+ auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T&
2088
+ { return cc[a+ido*(b+l1*c)]; };
2089
+ auto C2 = [cc,idl1](size_t a, size_t b) -> T&
2090
+ { return cc[a+idl1*b]; };
2091
+ auto CH2 = [ch,idl1](size_t a, size_t b) -> T&
2092
+ { return ch[a+idl1*b]; };
2093
+
2094
+ for (size_t k=0; k<l1; ++k) // 102
2095
+ for (size_t i=0; i<ido; ++i) // 101
2096
+ CH(i,k,0) = CC(i,0,k);
2097
+ for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc) // 108
2098
+ {
2099
+ size_t j2=2*j-1;
2100
+ for (size_t k=0; k<l1; ++k)
2101
+ {
2102
+ CH(0,k,j ) = 2*CC(ido-1,j2,k);
2103
+ CH(0,k,jc) = 2*CC(0,j2+1,k);
2104
+ }
2105
+ }
2106
+
2107
+ if (ido!=1)
2108
+ {
2109
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 111
2110
+ {
2111
+ size_t j2=2*j-1;
2112
+ for (size_t k=0; k<l1; ++k)
2113
+ for (size_t i=1, ic=ido-i-2; i<=ido-2; i+=2, ic-=2) // 109
2114
+ {
2115
+ CH(i ,k,j ) = CC(i ,j2+1,k)+CC(ic ,j2,k);
2116
+ CH(i ,k,jc) = CC(i ,j2+1,k)-CC(ic ,j2,k);
2117
+ CH(i+1,k,j ) = CC(i+1,j2+1,k)-CC(ic+1,j2,k);
2118
+ CH(i+1,k,jc) = CC(i+1,j2+1,k)+CC(ic+1,j2,k);
2119
+ }
2120
+ }
2121
+ }
2122
+ for (size_t l=1,lc=ip-1; l<ipph; ++l,--lc)
2123
+ {
2124
+ for (size_t ik=0; ik<idl1; ++ik)
2125
+ {
2126
+ C2(ik,l ) = CH2(ik,0)+csarr[2*l]*CH2(ik,1)+csarr[4*l]*CH2(ik,2);
2127
+ C2(ik,lc) = csarr[2*l+1]*CH2(ik,ip-1)+csarr[4*l+1]*CH2(ik,ip-2);
2128
+ }
2129
+ size_t iang=2*l;
2130
+ size_t j=3,jc=ip-3;
2131
+ for(; j<ipph-3; j+=4,jc-=4)
2132
+ {
2133
+ iang+=l; if(iang>ip) iang-=ip;
2134
+ T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
2135
+ iang+=l; if(iang>ip) iang-=ip;
2136
+ T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
2137
+ iang+=l; if(iang>ip) iang-=ip;
2138
+ T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1];
2139
+ iang+=l; if(iang>ip) iang-=ip;
2140
+ T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1];
2141
+ for (size_t ik=0; ik<idl1; ++ik)
2142
+ {
2143
+ C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1)
2144
+ +ar3*CH2(ik,j +2)+ar4*CH2(ik,j +3);
2145
+ C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1)
2146
+ +ai3*CH2(ik,jc-2)+ai4*CH2(ik,jc-3);
2147
+ }
2148
+ }
2149
+ for(; j<ipph-1; j+=2,jc-=2)
2150
+ {
2151
+ iang+=l; if(iang>ip) iang-=ip;
2152
+ T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1];
2153
+ iang+=l; if(iang>ip) iang-=ip;
2154
+ T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1];
2155
+ for (size_t ik=0; ik<idl1; ++ik)
2156
+ {
2157
+ C2(ik,l ) += ar1*CH2(ik,j )+ar2*CH2(ik,j +1);
2158
+ C2(ik,lc) += ai1*CH2(ik,jc)+ai2*CH2(ik,jc-1);
2159
+ }
2160
+ }
2161
+ for(; j<ipph; ++j,--jc)
2162
+ {
2163
+ iang+=l; if(iang>ip) iang-=ip;
2164
+ T0 war=csarr[2*iang], wai=csarr[2*iang+1];
2165
+ for (size_t ik=0; ik<idl1; ++ik)
2166
+ {
2167
+ C2(ik,l ) += war*CH2(ik,j );
2168
+ C2(ik,lc) += wai*CH2(ik,jc);
2169
+ }
2170
+ }
2171
+ }
2172
+ for (size_t j=1; j<ipph; ++j)
2173
+ for (size_t ik=0; ik<idl1; ++ik)
2174
+ CH2(ik,0) += CH2(ik,j);
2175
+ for (size_t j=1, jc=ip-1; j<ipph; ++j,--jc) // 124
2176
+ for (size_t k=0; k<l1; ++k)
2177
+ PM(CH(0,k,jc),CH(0,k,j),C1(0,k,j),C1(0,k,jc));
2178
+
2179
+ if (ido==1) return;
2180
+
2181
+ for (size_t j=1, jc=ip-1; j<ipph; ++j, --jc) // 127
2182
+ for (size_t k=0; k<l1; ++k)
2183
+ for (size_t i=1; i<=ido-2; i+=2)
2184
+ {
2185
+ CH(i ,k,j ) = C1(i ,k,j)-C1(i+1,k,jc);
2186
+ CH(i ,k,jc) = C1(i ,k,j)+C1(i+1,k,jc);
2187
+ CH(i+1,k,j ) = C1(i+1,k,j)+C1(i ,k,jc);
2188
+ CH(i+1,k,jc) = C1(i+1,k,j)-C1(i ,k,jc);
2189
+ }
2190
+
2191
+ // All in CH
2192
+
2193
+ for (size_t j=1; j<ip; ++j)
2194
+ {
2195
+ size_t is = (j-1)*(ido-1);
2196
+ for (size_t k=0; k<l1; ++k)
2197
+ {
2198
+ size_t idij = is;
2199
+ for (size_t i=1; i<=ido-2; i+=2)
2200
+ {
2201
+ T t1=CH(i,k,j), t2=CH(i+1,k,j);
2202
+ CH(i ,k,j) = wa[idij]*t1-wa[idij+1]*t2;
2203
+ CH(i+1,k,j) = wa[idij]*t2+wa[idij+1]*t1;
2204
+ idij+=2;
2205
+ }
2206
+ }
2207
+ }
2208
+ }
2209
+
2210
+ template<typename T> void copy_and_norm(T *c, T *p1, T0 fct) const
2211
+ {
2212
+ if (p1!=c)
2213
+ {
2214
+ if (fct!=1.)
2215
+ for (size_t i=0; i<length; ++i)
2216
+ c[i] = fct*p1[i];
2217
+ else
2218
+ std::copy_n (p1, length, c);
2219
+ }
2220
+ else
2221
+ if (fct!=1.)
2222
+ for (size_t i=0; i<length; ++i)
2223
+ c[i] *= fct;
2224
+ }
2225
+
2226
+ public:
2227
+ template<typename T> void exec(T c[], T0 fct, bool r2hc) const
2228
+ {
2229
+ if (length==1) { c[0]*=fct; return; }
2230
+ size_t nf=fact.size();
2231
+ arr<T> ch(length);
2232
+ T *p1=c, *p2=ch.data();
2233
+
2234
+ if (r2hc)
2235
+ for(size_t k1=0, l1=length; k1<nf;++k1)
2236
+ {
2237
+ size_t k=nf-k1-1;
2238
+ size_t ip=fact[k].fct;
2239
+ size_t ido=length / l1;
2240
+ l1 /= ip;
2241
+ if(ip==4)
2242
+ radf4(ido, l1, p1, p2, fact[k].tw);
2243
+ else if(ip==2)
2244
+ radf2(ido, l1, p1, p2, fact[k].tw);
2245
+ else if(ip==3)
2246
+ radf3(ido, l1, p1, p2, fact[k].tw);
2247
+ else if(ip==5)
2248
+ radf5(ido, l1, p1, p2, fact[k].tw);
2249
+ else
2250
+ { radfg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws); std::swap (p1,p2); }
2251
+ std::swap (p1,p2);
2252
+ }
2253
+ else
2254
+ for(size_t k=0, l1=1; k<nf; k++)
2255
+ {
2256
+ size_t ip = fact[k].fct,
2257
+ ido= length/(ip*l1);
2258
+ if(ip==4)
2259
+ radb4(ido, l1, p1, p2, fact[k].tw);
2260
+ else if(ip==2)
2261
+ radb2(ido, l1, p1, p2, fact[k].tw);
2262
+ else if(ip==3)
2263
+ radb3(ido, l1, p1, p2, fact[k].tw);
2264
+ else if(ip==5)
2265
+ radb5(ido, l1, p1, p2, fact[k].tw);
2266
+ else
2267
+ radbg(ido, ip, l1, p1, p2, fact[k].tw, fact[k].tws);
2268
+ std::swap (p1,p2);
2269
+ l1*=ip;
2270
+ }
2271
+
2272
+ copy_and_norm(c,p1,fct);
2273
+ }
2274
+
2275
+ private:
2276
+ void factorize()
2277
+ {
2278
+ size_t len=length;
2279
+ while ((len%4)==0)
2280
+ { add_factor(4); len>>=2; }
2281
+ if ((len%2)==0)
2282
+ {
2283
+ len>>=1;
2284
+ // factor 2 should be at the front of the factor list
2285
+ add_factor(2);
2286
+ std::swap(fact[0].fct, fact.back().fct);
2287
+ }
2288
+ for (size_t divisor=3; divisor*divisor<=len; divisor+=2)
2289
+ while ((len%divisor)==0)
2290
+ {
2291
+ add_factor(divisor);
2292
+ len/=divisor;
2293
+ }
2294
+ if (len>1) add_factor(len);
2295
+ }
2296
+
2297
+ size_t twsize() const
2298
+ {
2299
+ size_t twsz=0, l1=1;
2300
+ for (size_t k=0; k<fact.size(); ++k)
2301
+ {
2302
+ size_t ip=fact[k].fct, ido=length/(l1*ip);
2303
+ twsz+=(ip-1)*(ido-1);
2304
+ if (ip>5) twsz+=2*ip;
2305
+ l1*=ip;
2306
+ }
2307
+ return twsz;
2308
+ }
2309
+
2310
+ void comp_twiddle()
2311
+ {
2312
+ sincos_2pibyn<T0> twid(length);
2313
+ size_t l1=1;
2314
+ T0 *ptr=mem.data();
2315
+ for (size_t k=0; k<fact.size(); ++k)
2316
+ {
2317
+ size_t ip=fact[k].fct, ido=length/(l1*ip);
2318
+ if (k<fact.size()-1) // last factor doesn't need twiddles
2319
+ {
2320
+ fact[k].tw=ptr; ptr+=(ip-1)*(ido-1);
2321
+ for (size_t j=1; j<ip; ++j)
2322
+ for (size_t i=1; i<=(ido-1)/2; ++i)
2323
+ {
2324
+ fact[k].tw[(j-1)*(ido-1)+2*i-2] = twid[j*l1*i].r;
2325
+ fact[k].tw[(j-1)*(ido-1)+2*i-1] = twid[j*l1*i].i;
2326
+ }
2327
+ }
2328
+ if (ip>5) // special factors required by *g functions
2329
+ {
2330
+ fact[k].tws=ptr; ptr+=2*ip;
2331
+ fact[k].tws[0] = 1.;
2332
+ fact[k].tws[1] = 0.;
2333
+ for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2)
2334
+ {
2335
+ fact[k].tws[i ] = twid[i/2*(length/ip)].r;
2336
+ fact[k].tws[i+1] = twid[i/2*(length/ip)].i;
2337
+ fact[k].tws[ic] = twid[i/2*(length/ip)].r;
2338
+ fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i;
2339
+ }
2340
+ }
2341
+ l1*=ip;
2342
+ }
2343
+ }
2344
+
2345
+ public:
2346
+ POCKETFFT_NOINLINE rfftp(size_t length_)
2347
+ : length(length_)
2348
+ {
2349
+ if (length==0) throw std::runtime_error("zero-length FFT requested");
2350
+ if (length==1) return;
2351
+ factorize();
2352
+ mem.resize(twsize());
2353
+ comp_twiddle();
2354
+ }
2355
+ };
2356
+
2357
+ //
2358
+ // complex Bluestein transforms
2359
+ //
2360
+
2361
+ template<typename T0> class fftblue
2362
+ {
2363
+ private:
2364
+ size_t n, n2;
2365
+ cfftp<T0> plan;
2366
+ arr<cmplx<T0>> mem;
2367
+ cmplx<T0> *bk, *bkf;
2368
+
2369
+ template<bool fwd, typename T> void fft(cmplx<T> c[], T0 fct) const
2370
+ {
2371
+ arr<cmplx<T>> akf(n2);
2372
+
2373
+ /* initialize a_k and FFT it */
2374
+ for (size_t m=0; m<n; ++m)
2375
+ special_mul<fwd>(c[m],bk[m],akf[m]);
2376
+ auto zero = akf[0]*T0(0);
2377
+ for (size_t m=n; m<n2; ++m)
2378
+ akf[m]=zero;
2379
+
2380
+ plan.exec (akf.data(),1.,true);
2381
+
2382
+ /* do the convolution */
2383
+ akf[0] = akf[0].template special_mul<!fwd>(bkf[0]);
2384
+ for (size_t m=1; m<(n2+1)/2; ++m)
2385
+ {
2386
+ akf[m] = akf[m].template special_mul<!fwd>(bkf[m]);
2387
+ akf[n2-m] = akf[n2-m].template special_mul<!fwd>(bkf[m]);
2388
+ }
2389
+ if ((n2&1)==0)
2390
+ akf[n2/2] = akf[n2/2].template special_mul<!fwd>(bkf[n2/2]);
2391
+
2392
+ /* inverse FFT */
2393
+ plan.exec (akf.data(),1.,false);
2394
+
2395
+ /* multiply by b_k */
2396
+ for (size_t m=0; m<n; ++m)
2397
+ c[m] = akf[m].template special_mul<fwd>(bk[m])*fct;
2398
+ }
2399
+
2400
+ public:
2401
+ POCKETFFT_NOINLINE fftblue(size_t length)
2402
+ : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1),
2403
+ bk(mem.data()), bkf(mem.data()+n)
2404
+ {
2405
+ /* initialize b_k */
2406
+ sincos_2pibyn<T0> tmp(2*n);
2407
+ bk[0].Set(1, 0);
2408
+
2409
+ size_t coeff=0;
2410
+ for (size_t m=1; m<n; ++m)
2411
+ {
2412
+ coeff+=2*m-1;
2413
+ if (coeff>=2*n) coeff-=2*n;
2414
+ bk[m] = tmp[coeff];
2415
+ }
2416
+
2417
+ /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */
2418
+ arr<cmplx<T0>> tbkf(n2);
2419
+ T0 xn2 = T0(1)/T0(n2);
2420
+ tbkf[0] = bk[0]*xn2;
2421
+ for (size_t m=1; m<n; ++m)
2422
+ tbkf[m] = tbkf[n2-m] = bk[m]*xn2;
2423
+ for (size_t m=n;m<=(n2-n);++m)
2424
+ tbkf[m].Set(0.,0.);
2425
+ plan.exec(tbkf.data(),1.,true);
2426
+ for (size_t i=0; i<n2/2+1; ++i)
2427
+ bkf[i] = tbkf[i];
2428
+ }
2429
+
2430
+ template<typename T> void exec(cmplx<T> c[], T0 fct, bool fwd) const
2431
+ { fwd ? fft<true>(c,fct) : fft<false>(c,fct); }
2432
+
2433
+ template<typename T> void exec_r(T c[], T0 fct, bool fwd)
2434
+ {
2435
+ arr<cmplx<T>> tmp(n);
2436
+ if (fwd)
2437
+ {
2438
+ auto zero = T0(0)*c[0];
2439
+ for (size_t m=0; m<n; ++m)
2440
+ tmp[m].Set(c[m], zero);
2441
+ fft<true>(tmp.data(),fct);
2442
+ c[0] = tmp[0].r;
2443
+ std::copy_n (&tmp[1].r, n-1, &c[1]);
2444
+ }
2445
+ else
2446
+ {
2447
+ tmp[0].Set(c[0],c[0]*0);
2448
+ std::copy_n (c+1, n-1, &tmp[1].r);
2449
+ if ((n&1)==0) tmp[n/2].i=T0(0)*c[0];
2450
+ for (size_t m=1; 2*m<n; ++m)
2451
+ tmp[n-m].Set(tmp[m].r, -tmp[m].i);
2452
+ fft<false>(tmp.data(),fct);
2453
+ for (size_t m=0; m<n; ++m)
2454
+ c[m] = tmp[m].r;
2455
+ }
2456
+ }
2457
+ };
2458
+
2459
+ //
2460
+ // flexible (FFTPACK/Bluestein) complex 1D transform
2461
+ //
2462
+
2463
+ template<typename T0> class pocketfft_c
2464
+ {
2465
+ private:
2466
+ std::unique_ptr<cfftp<T0>> packplan;
2467
+ std::unique_ptr<fftblue<T0>> blueplan;
2468
+ size_t len;
2469
+
2470
+ public:
2471
+ POCKETFFT_NOINLINE pocketfft_c(size_t length)
2472
+ : len(length)
2473
+ {
2474
+ if (length==0) throw std::runtime_error("zero-length FFT requested");
2475
+ size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length);
2476
+ if (tmp*tmp <= length)
2477
+ {
2478
+ packplan=std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));
2479
+ return;
2480
+ }
2481
+ double comp1 = util::cost_guess(length);
2482
+ double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1));
2483
+ comp2*=1.5; /* fudge factor that appears to give good overall performance */
2484
+ if (comp2<comp1) // use Bluestein
2485
+ blueplan=std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));
2486
+ else
2487
+ packplan=std::unique_ptr<cfftp<T0>>(new cfftp<T0>(length));
2488
+ }
2489
+
2490
+ template<typename T> POCKETFFT_NOINLINE void exec(cmplx<T> c[], T0 fct, bool fwd) const
2491
+ { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); }
2492
+
2493
+ size_t length() const { return len; }
2494
+ };
2495
+
2496
+ //
2497
+ // flexible (FFTPACK/Bluestein) real-valued 1D transform
2498
+ //
2499
+
2500
+ template<typename T0> class pocketfft_r
2501
+ {
2502
+ private:
2503
+ std::unique_ptr<rfftp<T0>> packplan;
2504
+ std::unique_ptr<fftblue<T0>> blueplan;
2505
+ size_t len;
2506
+
2507
+ public:
2508
+ POCKETFFT_NOINLINE pocketfft_r(size_t length)
2509
+ : len(length)
2510
+ {
2511
+ if (length==0) throw std::runtime_error("zero-length FFT requested");
2512
+ size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length);
2513
+ if (tmp*tmp <= length)
2514
+ {
2515
+ packplan=std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
2516
+ return;
2517
+ }
2518
+ double comp1 = 0.5*util::cost_guess(length);
2519
+ double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1));
2520
+ comp2*=1.5; /* fudge factor that appears to give good overall performance */
2521
+ if (comp2<comp1) // use Bluestein
2522
+ blueplan=std::unique_ptr<fftblue<T0>>(new fftblue<T0>(length));
2523
+ else
2524
+ packplan=std::unique_ptr<rfftp<T0>>(new rfftp<T0>(length));
2525
+ }
2526
+
2527
+ template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const
2528
+ { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); }
2529
+
2530
+ size_t length() const { return len; }
2531
+ };
2532
+
2533
+
2534
+ //
2535
+ // sine/cosine transforms
2536
+ //
2537
+
2538
+ template<typename T0> class T_dct1
2539
+ {
2540
+ private:
2541
+ pocketfft_r<T0> fftplan;
2542
+
2543
+ public:
2544
+ POCKETFFT_NOINLINE T_dct1(size_t length)
2545
+ : fftplan(2*(length-1)) {}
2546
+
2547
+ template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
2548
+ int /*type*/, bool /*cosine*/) const
2549
+ {
2550
+ constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
2551
+ size_t N=fftplan.length(), n=N/2+1;
2552
+ if (ortho)
2553
+ { c[0]*=sqrt2; c[n-1]*=sqrt2; }
2554
+ arr<T> tmp(N);
2555
+ tmp[0] = c[0];
2556
+ for (size_t i=1; i<n; ++i)
2557
+ tmp[i] = tmp[N-i] = c[i];
2558
+ fftplan.exec(tmp.data(), fct, true);
2559
+ c[0] = tmp[0];
2560
+ for (size_t i=1; i<n; ++i)
2561
+ c[i] = tmp[2*i-1];
2562
+ if (ortho)
2563
+ { c[0]*=sqrt2*T0(0.5); c[n-1]*=sqrt2*T0(0.5); }
2564
+ }
2565
+
2566
+ size_t length() const { return fftplan.length()/2+1; }
2567
+ };
2568
+
2569
+ template<typename T0> class T_dst1
2570
+ {
2571
+ private:
2572
+ pocketfft_r<T0> fftplan;
2573
+
2574
+ public:
2575
+ POCKETFFT_NOINLINE T_dst1(size_t length)
2576
+ : fftplan(2*(length+1)) {}
2577
+
2578
+ template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
2579
+ bool /*ortho*/, int /*type*/, bool /*cosine*/) const
2580
+ {
2581
+ size_t N=fftplan.length(), n=N/2-1;
2582
+ arr<T> tmp(N);
2583
+ tmp[0] = tmp[n+1] = c[0]*0;
2584
+ for (size_t i=0; i<n; ++i)
2585
+ { tmp[i+1]=c[i]; tmp[N-1-i]=-c[i]; }
2586
+ fftplan.exec(tmp.data(), fct, true);
2587
+ for (size_t i=0; i<n; ++i)
2588
+ c[i] = -tmp[2*i+2];
2589
+ }
2590
+
2591
+ size_t length() const { return fftplan.length()/2-1; }
2592
+ };
2593
+
2594
+ template<typename T0> class T_dcst23
2595
+ {
2596
+ private:
2597
+ pocketfft_r<T0> fftplan;
2598
+ std::vector<T0> twiddle;
2599
+
2600
+ public:
2601
+ POCKETFFT_NOINLINE T_dcst23(size_t length)
2602
+ : fftplan(length), twiddle(length)
2603
+ {
2604
+ sincos_2pibyn<T0> tw(4*length);
2605
+ for (size_t i=0; i<length; ++i)
2606
+ twiddle[i] = tw[i+1].r;
2607
+ }
2608
+
2609
+ template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho,
2610
+ int type, bool cosine) const
2611
+ {
2612
+ constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
2613
+ size_t N=length();
2614
+ size_t NS2 = (N+1)/2;
2615
+ if (type==2)
2616
+ {
2617
+ if (!cosine)
2618
+ for (size_t k=1; k<N; k+=2)
2619
+ c[k] = -c[k];
2620
+ c[0] *= 2;
2621
+ if ((N&1)==0) c[N-1]*=2;
2622
+ for (size_t k=1; k<N-1; k+=2)
2623
+ MPINPLACE(c[k+1], c[k]);
2624
+ fftplan.exec(c, fct, false);
2625
+ for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
2626
+ {
2627
+ T t1 = twiddle[k-1]*c[kc]+twiddle[kc-1]*c[k];
2628
+ T t2 = twiddle[k-1]*c[k]-twiddle[kc-1]*c[kc];
2629
+ c[k] = T0(0.5)*(t1+t2); c[kc]=T0(0.5)*(t1-t2);
2630
+ }
2631
+ if ((N&1)==0)
2632
+ c[NS2] *= twiddle[NS2-1];
2633
+ if (!cosine)
2634
+ for (size_t k=0, kc=N-1; k<kc; ++k, --kc)
2635
+ std::swap(c[k], c[kc]);
2636
+ if (ortho) c[0]*=sqrt2*T0(0.5);
2637
+ }
2638
+ else
2639
+ {
2640
+ if (ortho) c[0]*=sqrt2;
2641
+ if (!cosine)
2642
+ for (size_t k=0, kc=N-1; k<NS2; ++k, --kc)
2643
+ std::swap(c[k], c[kc]);
2644
+ for (size_t k=1, kc=N-1; k<NS2; ++k, --kc)
2645
+ {
2646
+ T t1=c[k]+c[kc], t2=c[k]-c[kc];
2647
+ c[k] = twiddle[k-1]*t2+twiddle[kc-1]*t1;
2648
+ c[kc]= twiddle[k-1]*t1-twiddle[kc-1]*t2;
2649
+ }
2650
+ if ((N&1)==0)
2651
+ c[NS2] *= 2*twiddle[NS2-1];
2652
+ fftplan.exec(c, fct, true);
2653
+ for (size_t k=1; k<N-1; k+=2)
2654
+ MPINPLACE(c[k], c[k+1]);
2655
+ if (!cosine)
2656
+ for (size_t k=1; k<N; k+=2)
2657
+ c[k] = -c[k];
2658
+ }
2659
+ }
2660
+
2661
+ size_t length() const { return fftplan.length(); }
2662
+ };
2663
+
2664
+ template<typename T0> class T_dcst4
2665
+ {
2666
+ private:
2667
+ size_t N;
2668
+ std::unique_ptr<pocketfft_c<T0>> fft;
2669
+ std::unique_ptr<pocketfft_r<T0>> rfft;
2670
+ arr<cmplx<T0>> C2;
2671
+
2672
+ public:
2673
+ POCKETFFT_NOINLINE T_dcst4(size_t length)
2674
+ : N(length),
2675
+ fft((N&1) ? nullptr : new pocketfft_c<T0>(N/2)),
2676
+ rfft((N&1)? new pocketfft_r<T0>(N) : nullptr),
2677
+ C2((N&1) ? 0 : N/2)
2678
+ {
2679
+ if ((N&1)==0)
2680
+ {
2681
+ sincos_2pibyn<T0> tw(16*N);
2682
+ for (size_t i=0; i<N/2; ++i)
2683
+ C2[i] = conj(tw[8*i+1]);
2684
+ }
2685
+ }
2686
+
2687
+ template<typename T> POCKETFFT_NOINLINE void exec(T c[], T0 fct,
2688
+ bool /*ortho*/, int /*type*/, bool cosine) const
2689
+ {
2690
+ size_t n2 = N/2;
2691
+ if (!cosine)
2692
+ for (size_t k=0, kc=N-1; k<n2; ++k, --kc)
2693
+ std::swap(c[k], c[kc]);
2694
+ if (N&1)
2695
+ {
2696
+ // The following code is derived from the FFTW3 function apply_re11()
2697
+ // and is released under the 3-clause BSD license with friendly
2698
+ // permission of Matteo Frigo and Steven G. Johnson.
2699
+
2700
+ arr<T> y(N);
2701
+ {
2702
+ size_t i=0, m=n2;
2703
+ for (; m<N; ++i, m+=4)
2704
+ y[i] = c[m];
2705
+ for (; m<2*N; ++i, m+=4)
2706
+ y[i] = -c[2*N-m-1];
2707
+ for (; m<3*N; ++i, m+=4)
2708
+ y[i] = -c[m-2*N];
2709
+ for (; m<4*N; ++i, m+=4)
2710
+ y[i] = c[4*N-m-1];
2711
+ for (; i<N; ++i, m+=4)
2712
+ y[i] = c[m-4*N];
2713
+ }
2714
+ rfft->exec(y.data(), fct, true);
2715
+ {
2716
+ auto SGN = [](size_t i)
2717
+ {
2718
+ constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L);
2719
+ return (i&2) ? -sqrt2 : sqrt2;
2720
+ };
2721
+ c[n2] = y[0]*SGN(n2+1);
2722
+ size_t i=0, i1=1, k=1;
2723
+ for (; k<n2; ++i, ++i1, k+=2)
2724
+ {
2725
+ c[i ] = y[2*k-1]*SGN(i1) + y[2*k ]*SGN(i);
2726
+ c[N -i1] = y[2*k-1]*SGN(N -i) - y[2*k ]*SGN(N -i1);
2727
+ c[n2-i1] = y[2*k+1]*SGN(n2-i) - y[2*k+2]*SGN(n2-i1);
2728
+ c[n2+i1] = y[2*k+1]*SGN(n2+i+2) + y[2*k+2]*SGN(n2+i1);
2729
+ }
2730
+ if (k == n2)
2731
+ {
2732
+ c[i ] = y[2*k-1]*SGN(i+1) + y[2*k]*SGN(i);
2733
+ c[N-i1] = y[2*k-1]*SGN(i+2) + y[2*k]*SGN(i1);
2734
+ }
2735
+ }
2736
+
2737
+ // FFTW-derived code ends here
2738
+ }
2739
+ else
2740
+ {
2741
+ // even length algorithm from
2742
+ // https://www.appletonaudio.com/blog/2013/derivation-of-fast-dct-4-algorithm-based-on-dft/
2743
+ arr<cmplx<T>> y(n2);
2744
+ for(size_t i=0; i<n2; ++i)
2745
+ {
2746
+ y[i].Set(c[2*i],c[N-1-2*i]);
2747
+ y[i] *= C2[i];
2748
+ }
2749
+ fft->exec(y.data(), fct, true);
2750
+ for(size_t i=0, ic=n2-1; i<n2; ++i, --ic)
2751
+ {
2752
+ c[2*i ] = 2*(y[i ].r*C2[i ].r-y[i ].i*C2[i ].i);
2753
+ c[2*i+1] = -2*(y[ic].i*C2[ic].r+y[ic].r*C2[ic].i);
2754
+ }
2755
+ }
2756
+ if (!cosine)
2757
+ for (size_t k=1; k<N; k+=2)
2758
+ c[k] = -c[k];
2759
+ }
2760
+
2761
+ size_t length() const { return N; }
2762
+ };
2763
+
2764
+
2765
+ //
2766
+ // multi-D infrastructure
2767
+ //
2768
+
2769
+ template<typename T> std::shared_ptr<T> get_plan(size_t length)
2770
+ {
2771
+ #if POCKETFFT_CACHE_SIZE==0
2772
+ return std::make_shared<T>(length);
2773
+ #else
2774
+ constexpr size_t nmax=POCKETFFT_CACHE_SIZE;
2775
+ static std::array<std::shared_ptr<T>, nmax> cache;
2776
+ static std::array<size_t, nmax> last_access{{0}};
2777
+ static size_t access_counter = 0;
2778
+ static std::mutex mut;
2779
+
2780
+ auto find_in_cache = [&]() -> std::shared_ptr<T>
2781
+ {
2782
+ for (size_t i=0; i<nmax; ++i)
2783
+ if (cache[i] && (cache[i]->length()==length))
2784
+ {
2785
+ // no need to update if this is already the most recent entry
2786
+ if (last_access[i]!=access_counter)
2787
+ {
2788
+ last_access[i] = ++access_counter;
2789
+ // Guard against overflow
2790
+ if (access_counter == 0)
2791
+ last_access.fill(0);
2792
+ }
2793
+ return cache[i];
2794
+ }
2795
+
2796
+ return nullptr;
2797
+ };
2798
+
2799
+ {
2800
+ std::lock_guard<std::mutex> lock(mut);
2801
+ auto p = find_in_cache();
2802
+ if (p) return p;
2803
+ }
2804
+ auto plan = std::make_shared<T>(length);
2805
+ {
2806
+ std::lock_guard<std::mutex> lock(mut);
2807
+ auto p = find_in_cache();
2808
+ if (p) return p;
2809
+
2810
+ size_t lru = 0;
2811
+ for (size_t i=1; i<nmax; ++i)
2812
+ if (last_access[i] < last_access[lru])
2813
+ lru = i;
2814
+
2815
+ cache[lru] = plan;
2816
+ last_access[lru] = ++access_counter;
2817
+ }
2818
+ return plan;
2819
+ #endif
2820
+ }
2821
+
2822
+ class arr_info
2823
+ {
2824
+ protected:
2825
+ shape_t shp;
2826
+ stride_t str;
2827
+
2828
+ public:
2829
+ arr_info(const shape_t &shape_, const stride_t &stride_)
2830
+ : shp(shape_), str(stride_) {}
2831
+ size_t ndim() const { return shp.size(); }
2832
+ size_t size() const { return util::prod(shp); }
2833
+ const shape_t &shape() const { return shp; }
2834
+ size_t shape(size_t i) const { return shp[i]; }
2835
+ const stride_t &stride() const { return str; }
2836
+ const ptrdiff_t &stride(size_t i) const { return str[i]; }
2837
+ };
2838
+
2839
+ template<typename T> class cndarr: public arr_info
2840
+ {
2841
+ protected:
2842
+ const char *d;
2843
+
2844
+ public:
2845
+ cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_)
2846
+ : arr_info(shape_, stride_),
2847
+ d(reinterpret_cast<const char *>(data_)) {}
2848
+ const T &operator[](ptrdiff_t ofs) const
2849
+ { return *reinterpret_cast<const T *>(d+ofs); }
2850
+ };
2851
+
2852
+ template<typename T> class ndarr: public cndarr<T>
2853
+ {
2854
+ public:
2855
+ ndarr(void *data_, const shape_t &shape_, const stride_t &stride_)
2856
+ : cndarr<T>::cndarr(const_cast<const void *>(data_), shape_, stride_)
2857
+ {}
2858
+ T &operator[](ptrdiff_t ofs)
2859
+ { return *reinterpret_cast<T *>(const_cast<char *>(cndarr<T>::d+ofs)); }
2860
+ };
2861
+
2862
+ template<size_t N> class multi_iter
2863
+ {
2864
+ private:
2865
+ shape_t pos;
2866
+ const arr_info &iarr, &oarr;
2867
+ ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o;
2868
+ size_t idim, rem;
2869
+
2870
+ void advance_i()
2871
+ {
2872
+ for (int i_=int(pos.size())-1; i_>=0; --i_)
2873
+ {
2874
+ auto i = size_t(i_);
2875
+ if (i==idim) continue;
2876
+ p_ii += iarr.stride(i);
2877
+ p_oi += oarr.stride(i);
2878
+ if (++pos[i] < iarr.shape(i))
2879
+ return;
2880
+ pos[i] = 0;
2881
+ p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i);
2882
+ p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i);
2883
+ }
2884
+ }
2885
+
2886
+ public:
2887
+ multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
2888
+ : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
2889
+ str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
2890
+ idim(idim_), rem(iarr.size()/iarr.shape(idim))
2891
+ {
2892
+ auto nshares = threading::num_threads();
2893
+ if (nshares==1) return;
2894
+ if (nshares==0) throw std::runtime_error("can't run with zero threads");
2895
+ auto myshare = threading::thread_id();
2896
+ if (myshare>=nshares) throw std::runtime_error("impossible share requested");
2897
+ size_t nbase = rem/nshares;
2898
+ size_t additional = rem%nshares;
2899
+ size_t lo = myshare*nbase + ((myshare<additional) ? myshare : additional);
2900
+ size_t hi = lo+nbase+(myshare<additional);
2901
+ size_t todo = hi-lo;
2902
+
2903
+ size_t chunk = rem;
2904
+ for (size_t i=0; i<pos.size(); ++i)
2905
+ {
2906
+ if (i==idim) continue;
2907
+ chunk /= iarr.shape(i);
2908
+ size_t n_advance = lo/chunk;
2909
+ pos[i] += n_advance;
2910
+ p_ii += ptrdiff_t(n_advance)*iarr.stride(i);
2911
+ p_oi += ptrdiff_t(n_advance)*oarr.stride(i);
2912
+ lo -= n_advance*chunk;
2913
+ }
2914
+ rem = todo;
2915
+ }
2916
+ void advance(size_t n)
2917
+ {
2918
+ if (rem<n) throw std::runtime_error("underrun");
2919
+ for (size_t i=0; i<n; ++i)
2920
+ {
2921
+ p_i[i] = p_ii;
2922
+ p_o[i] = p_oi;
2923
+ advance_i();
2924
+ }
2925
+ rem -= n;
2926
+ }
2927
+ ptrdiff_t iofs(size_t i) const { return p_i[0] + ptrdiff_t(i)*str_i; }
2928
+ ptrdiff_t iofs(size_t j, size_t i) const { return p_i[j] + ptrdiff_t(i)*str_i; }
2929
+ ptrdiff_t oofs(size_t i) const { return p_o[0] + ptrdiff_t(i)*str_o; }
2930
+ ptrdiff_t oofs(size_t j, size_t i) const { return p_o[j] + ptrdiff_t(i)*str_o; }
2931
+ size_t length_in() const { return iarr.shape(idim); }
2932
+ size_t length_out() const { return oarr.shape(idim); }
2933
+ ptrdiff_t stride_in() const { return str_i; }
2934
+ ptrdiff_t stride_out() const { return str_o; }
2935
+ size_t remaining() const { return rem; }
2936
+ };
2937
+
2938
+ class simple_iter
2939
+ {
2940
+ private:
2941
+ shape_t pos;
2942
+ const arr_info &arr;
2943
+ ptrdiff_t p;
2944
+ size_t rem;
2945
+
2946
+ public:
2947
+ simple_iter(const arr_info &arr_)
2948
+ : pos(arr_.ndim(), 0), arr(arr_), p(0), rem(arr_.size()) {}
2949
+ void advance()
2950
+ {
2951
+ --rem;
2952
+ for (int i_=int(pos.size())-1; i_>=0; --i_)
2953
+ {
2954
+ auto i = size_t(i_);
2955
+ p += arr.stride(i);
2956
+ if (++pos[i] < arr.shape(i))
2957
+ return;
2958
+ pos[i] = 0;
2959
+ p -= ptrdiff_t(arr.shape(i))*arr.stride(i);
2960
+ }
2961
+ }
2962
+ ptrdiff_t ofs() const { return p; }
2963
+ size_t remaining() const { return rem; }
2964
+ };
2965
+
2966
+ class rev_iter
2967
+ {
2968
+ private:
2969
+ shape_t pos;
2970
+ const arr_info &arr;
2971
+ std::vector<char> rev_axis;
2972
+ std::vector<char> rev_jump;
2973
+ size_t last_axis, last_size;
2974
+ shape_t shp;
2975
+ ptrdiff_t p, rp;
2976
+ size_t rem;
2977
+
2978
+ public:
2979
+ rev_iter(const arr_info &arr_, const shape_t &axes)
2980
+ : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0),
2981
+ rev_jump(arr_.ndim(), 1), p(0), rp(0)
2982
+ {
2983
+ for (auto ax: axes)
2984
+ rev_axis[ax]=1;
2985
+ last_axis = axes.back();
2986
+ last_size = arr.shape(last_axis)/2 + 1;
2987
+ shp = arr.shape();
2988
+ shp[last_axis] = last_size;
2989
+ rem=1;
2990
+ for (auto i: shp)
2991
+ rem *= i;
2992
+ }
2993
+ void advance()
2994
+ {
2995
+ --rem;
2996
+ for (int i_=int(pos.size())-1; i_>=0; --i_)
2997
+ {
2998
+ auto i = size_t(i_);
2999
+ p += arr.stride(i);
3000
+ if (!rev_axis[i])
3001
+ rp += arr.stride(i);
3002
+ else
3003
+ {
3004
+ rp -= arr.stride(i);
3005
+ if (rev_jump[i])
3006
+ {
3007
+ rp += ptrdiff_t(arr.shape(i))*arr.stride(i);
3008
+ rev_jump[i] = 0;
3009
+ }
3010
+ }
3011
+ if (++pos[i] < shp[i])
3012
+ return;
3013
+ pos[i] = 0;
3014
+ p -= ptrdiff_t(shp[i])*arr.stride(i);
3015
+ if (rev_axis[i])
3016
+ {
3017
+ rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i);
3018
+ rev_jump[i] = 1;
3019
+ }
3020
+ else
3021
+ rp -= ptrdiff_t(shp[i])*arr.stride(i);
3022
+ }
3023
+ }
3024
+ ptrdiff_t ofs() const { return p; }
3025
+ ptrdiff_t rev_ofs() const { return rp; }
3026
+ size_t remaining() const { return rem; }
3027
+ };
3028
+
3029
+ template<typename T> struct VTYPE {};
3030
+ template <typename T> using vtype_t = typename VTYPE<T>::type;
3031
+
3032
+ #ifndef POCKETFFT_NO_VECTORS
3033
+ template<> struct VTYPE<float>
3034
+ {
3035
+ using type = float __attribute__ ((vector_size (VLEN<float>::val*sizeof(float))));
3036
+ };
3037
+ template<> struct VTYPE<double>
3038
+ {
3039
+ using type = double __attribute__ ((vector_size (VLEN<double>::val*sizeof(double))));
3040
+ };
3041
+ template<> struct VTYPE<long double>
3042
+ {
3043
+ using type = long double __attribute__ ((vector_size (VLEN<long double>::val*sizeof(long double))));
3044
+ };
3045
+ #endif
3046
+
3047
+ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
3048
+ size_t axsize, size_t elemsize)
3049
+ {
3050
+ auto othersize = util::prod(shape)/axsize;
3051
+ auto tmpsize = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
3052
+ return arr<char>(tmpsize*elemsize);
3053
+ }
3054
+ template<typename T> arr<char> alloc_tmp(const shape_t &shape,
3055
+ const shape_t &axes, size_t elemsize)
3056
+ {
3057
+ size_t fullsize=util::prod(shape);
3058
+ size_t tmpsize=0;
3059
+ for (size_t i=0; i<axes.size(); ++i)
3060
+ {
3061
+ auto axsize = shape[axes[i]];
3062
+ auto othersize = fullsize/axsize;
3063
+ auto sz = axsize*((othersize>=VLEN<T>::val) ? VLEN<T>::val : 1);
3064
+ if (sz>tmpsize) tmpsize=sz;
3065
+ }
3066
+ return arr<char>(tmpsize*elemsize);
3067
+ }
3068
+
3069
+ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
3070
+ const cndarr<cmplx<T>> &src, cmplx<vtype_t<T>> *POCKETFFT_RESTRICT dst)
3071
+ {
3072
+ for (size_t i=0; i<it.length_in(); ++i)
3073
+ for (size_t j=0; j<vlen; ++j)
3074
+ {
3075
+ dst[i].r[j] = src[it.iofs(j,i)].r;
3076
+ dst[i].i[j] = src[it.iofs(j,i)].i;
3077
+ }
3078
+ }
3079
+
3080
+ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
3081
+ const cndarr<T> &src, vtype_t<T> *POCKETFFT_RESTRICT dst)
3082
+ {
3083
+ for (size_t i=0; i<it.length_in(); ++i)
3084
+ for (size_t j=0; j<vlen; ++j)
3085
+ dst[i][j] = src[it.iofs(j,i)];
3086
+ }
3087
+
3088
+ template <typename T, size_t vlen> void copy_input(const multi_iter<vlen> &it,
3089
+ const cndarr<T> &src, T *POCKETFFT_RESTRICT dst)
3090
+ {
3091
+ if (dst == &src[it.iofs(0)]) return; // in-place
3092
+ for (size_t i=0; i<it.length_in(); ++i)
3093
+ dst[i] = src[it.iofs(i)];
3094
+ }
3095
+
3096
+ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
3097
+ const cmplx<vtype_t<T>> *POCKETFFT_RESTRICT src, ndarr<cmplx<T>> &dst)
3098
+ {
3099
+ for (size_t i=0; i<it.length_out(); ++i)
3100
+ for (size_t j=0; j<vlen; ++j)
3101
+ dst[it.oofs(j,i)].Set(src[i].r[j],src[i].i[j]);
3102
+ }
3103
+
3104
+ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
3105
+ const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3106
+ {
3107
+ for (size_t i=0; i<it.length_out(); ++i)
3108
+ for (size_t j=0; j<vlen; ++j)
3109
+ dst[it.oofs(j,i)] = src[i][j];
3110
+ }
3111
+
3112
+ template<typename T, size_t vlen> void copy_output(const multi_iter<vlen> &it,
3113
+ const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3114
+ {
3115
+ if (src == &dst[it.oofs(0)]) return; // in-place
3116
+ for (size_t i=0; i<it.length_out(); ++i)
3117
+ dst[it.oofs(i)] = src[i];
3118
+ }
3119
+
3120
+ template <typename T> struct add_vec { using type = vtype_t<T>; };
3121
+ template <typename T> struct add_vec<cmplx<T>>
3122
+ { using type = cmplx<vtype_t<T>>; };
3123
+ template <typename T> using add_vec_t = typename add_vec<T>::type;
3124
+
3125
+ template<typename Tplan, typename T, typename T0, typename Exec>
3126
+ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,
3127
+ const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec,
3128
+ const bool allow_inplace=true)
3129
+ {
3130
+ std::shared_ptr<Tplan> plan;
3131
+
3132
+ for (size_t iax=0; iax<axes.size(); ++iax)
3133
+ {
3134
+ size_t len=in.shape(axes[iax]);
3135
+ if ((!plan) || (len!=plan->length()))
3136
+ plan = get_plan<Tplan>(len);
3137
+
3138
+ threading::thread_map(
3139
+ util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
3140
+ [&] {
3141
+ constexpr auto vlen = VLEN<T0>::val;
3142
+ auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
3143
+ const auto &tin(iax==0? in : out);
3144
+ multi_iter<vlen> it(tin, out, axes[iax]);
3145
+ #ifndef POCKETFFT_NO_VECTORS
3146
+ if (vlen>1)
3147
+ while (it.remaining()>=vlen)
3148
+ {
3149
+ it.advance(vlen);
3150
+ auto tdatav = reinterpret_cast<add_vec_t<T> *>(storage.data());
3151
+ exec(it, tin, out, tdatav, *plan, fct);
3152
+ }
3153
+ #endif
3154
+ while (it.remaining()>0)
3155
+ {
3156
+ it.advance(1);
3157
+ auto buf = allow_inplace && it.stride_out() == sizeof(T) ?
3158
+ &out[it.oofs(0)] : reinterpret_cast<T *>(storage.data());
3159
+ exec(it, tin, out, buf, *plan, fct);
3160
+ }
3161
+ }); // end of parallel region
3162
+ fct = T0(1); // factor has been applied, use 1 for remaining axes
3163
+ }
3164
+ }
3165
+
3166
+ struct ExecC2C
3167
+ {
3168
+ bool forward;
3169
+
3170
+ template <typename T0, typename T, size_t vlen> void operator () (
3171
+ const multi_iter<vlen> &it, const cndarr<cmplx<T0>> &in,
3172
+ ndarr<cmplx<T0>> &out, T * buf, const pocketfft_c<T0> &plan, T0 fct) const
3173
+ {
3174
+ copy_input(it, in, buf);
3175
+ plan.exec(buf, fct, forward);
3176
+ copy_output(it, buf, out);
3177
+ }
3178
+ };
3179
+
3180
+ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
3181
+ const vtype_t<T> *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3182
+ {
3183
+ for (size_t j=0; j<vlen; ++j)
3184
+ dst[it.oofs(j,0)] = src[0][j];
3185
+ size_t i=1, i1=1, i2=it.length_out()-1;
3186
+ for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
3187
+ for (size_t j=0; j<vlen; ++j)
3188
+ {
3189
+ dst[it.oofs(j,i1)] = src[i][j]+src[i+1][j];
3190
+ dst[it.oofs(j,i2)] = src[i][j]-src[i+1][j];
3191
+ }
3192
+ if (i<it.length_out())
3193
+ for (size_t j=0; j<vlen; ++j)
3194
+ dst[it.oofs(j,i1)] = src[i][j];
3195
+ }
3196
+
3197
+ template <typename T, size_t vlen> void copy_hartley(const multi_iter<vlen> &it,
3198
+ const T *POCKETFFT_RESTRICT src, ndarr<T> &dst)
3199
+ {
3200
+ dst[it.oofs(0)] = src[0];
3201
+ size_t i=1, i1=1, i2=it.length_out()-1;
3202
+ for (i=1; i<it.length_out()-1; i+=2, ++i1, --i2)
3203
+ {
3204
+ dst[it.oofs(i1)] = src[i]+src[i+1];
3205
+ dst[it.oofs(i2)] = src[i]-src[i+1];
3206
+ }
3207
+ if (i<it.length_out())
3208
+ dst[it.oofs(i1)] = src[i];
3209
+ }
3210
+
3211
+ struct ExecHartley
3212
+ {
3213
+ template <typename T0, typename T, size_t vlen> void operator () (
3214
+ const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out,
3215
+ T * buf, const pocketfft_r<T0> &plan, T0 fct) const
3216
+ {
3217
+ copy_input(it, in, buf);
3218
+ plan.exec(buf, fct, true);
3219
+ copy_hartley(it, buf, out);
3220
+ }
3221
+ };
3222
+
3223
+ struct ExecDcst
3224
+ {
3225
+ bool ortho;
3226
+ int type;
3227
+ bool cosine;
3228
+
3229
+ template <typename T0, typename T, typename Tplan, size_t vlen>
3230
+ void operator () (const multi_iter<vlen> &it, const cndarr<T0> &in,
3231
+ ndarr<T0> &out, T * buf, const Tplan &plan, T0 fct) const
3232
+ {
3233
+ copy_input(it, in, buf);
3234
+ plan.exec(buf, fct, ortho, type, cosine);
3235
+ copy_output(it, buf, out);
3236
+ }
3237
+ };
3238
+
3239
+ template<typename T> POCKETFFT_NOINLINE void general_r2c(
3240
+ const cndarr<T> &in, ndarr<cmplx<T>> &out, size_t axis, bool forward, T fct,
3241
+ size_t nthreads)
3242
+ {
3243
+ auto plan = get_plan<pocketfft_r<T>>(in.shape(axis));
3244
+ size_t len=in.shape(axis);
3245
+ threading::thread_map(
3246
+ util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
3247
+ [&] {
3248
+ constexpr auto vlen = VLEN<T>::val;
3249
+ auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
3250
+ multi_iter<vlen> it(in, out, axis);
3251
+ #ifndef POCKETFFT_NO_VECTORS
3252
+ if (vlen>1)
3253
+ while (it.remaining()>=vlen)
3254
+ {
3255
+ it.advance(vlen);
3256
+ auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
3257
+ copy_input(it, in, tdatav);
3258
+ plan->exec(tdatav, fct, true);
3259
+ for (size_t j=0; j<vlen; ++j)
3260
+ out[it.oofs(j,0)].Set(tdatav[0][j]);
3261
+ size_t i=1, ii=1;
3262
+ if (forward)
3263
+ for (; i<len-1; i+=2, ++ii)
3264
+ for (size_t j=0; j<vlen; ++j)
3265
+ out[it.oofs(j,ii)].Set(tdatav[i][j], tdatav[i+1][j]);
3266
+ else
3267
+ for (; i<len-1; i+=2, ++ii)
3268
+ for (size_t j=0; j<vlen; ++j)
3269
+ out[it.oofs(j,ii)].Set(tdatav[i][j], -tdatav[i+1][j]);
3270
+ if (i<len)
3271
+ for (size_t j=0; j<vlen; ++j)
3272
+ out[it.oofs(j,ii)].Set(tdatav[i][j]);
3273
+ }
3274
+ #endif
3275
+ while (it.remaining()>0)
3276
+ {
3277
+ it.advance(1);
3278
+ auto tdata = reinterpret_cast<T *>(storage.data());
3279
+ copy_input(it, in, tdata);
3280
+ plan->exec(tdata, fct, true);
3281
+ out[it.oofs(0)].Set(tdata[0]);
3282
+ size_t i=1, ii=1;
3283
+ if (forward)
3284
+ for (; i<len-1; i+=2, ++ii)
3285
+ out[it.oofs(ii)].Set(tdata[i], tdata[i+1]);
3286
+ else
3287
+ for (; i<len-1; i+=2, ++ii)
3288
+ out[it.oofs(ii)].Set(tdata[i], -tdata[i+1]);
3289
+ if (i<len)
3290
+ out[it.oofs(ii)].Set(tdata[i]);
3291
+ }
3292
+ }); // end of parallel region
3293
+ }
3294
+ template<typename T> POCKETFFT_NOINLINE void general_c2r(
3295
+ const cndarr<cmplx<T>> &in, ndarr<T> &out, size_t axis, bool forward, T fct,
3296
+ size_t nthreads)
3297
+ {
3298
+ auto plan = get_plan<pocketfft_r<T>>(out.shape(axis));
3299
+ size_t len=out.shape(axis);
3300
+ threading::thread_map(
3301
+ util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
3302
+ [&] {
3303
+ constexpr auto vlen = VLEN<T>::val;
3304
+ auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
3305
+ multi_iter<vlen> it(in, out, axis);
3306
+ #ifndef POCKETFFT_NO_VECTORS
3307
+ if (vlen>1)
3308
+ while (it.remaining()>=vlen)
3309
+ {
3310
+ it.advance(vlen);
3311
+ auto tdatav = reinterpret_cast<vtype_t<T> *>(storage.data());
3312
+ for (size_t j=0; j<vlen; ++j)
3313
+ tdatav[0][j]=in[it.iofs(j,0)].r;
3314
+ {
3315
+ size_t i=1, ii=1;
3316
+ if (forward)
3317
+ for (; i<len-1; i+=2, ++ii)
3318
+ for (size_t j=0; j<vlen; ++j)
3319
+ {
3320
+ tdatav[i ][j] = in[it.iofs(j,ii)].r;
3321
+ tdatav[i+1][j] = -in[it.iofs(j,ii)].i;
3322
+ }
3323
+ else
3324
+ for (; i<len-1; i+=2, ++ii)
3325
+ for (size_t j=0; j<vlen; ++j)
3326
+ {
3327
+ tdatav[i ][j] = in[it.iofs(j,ii)].r;
3328
+ tdatav[i+1][j] = in[it.iofs(j,ii)].i;
3329
+ }
3330
+ if (i<len)
3331
+ for (size_t j=0; j<vlen; ++j)
3332
+ tdatav[i][j] = in[it.iofs(j,ii)].r;
3333
+ }
3334
+ plan->exec(tdatav, fct, false);
3335
+ copy_output(it, tdatav, out);
3336
+ }
3337
+ #endif
3338
+ while (it.remaining()>0)
3339
+ {
3340
+ it.advance(1);
3341
+ auto tdata = reinterpret_cast<T *>(storage.data());
3342
+ tdata[0]=in[it.iofs(0)].r;
3343
+ {
3344
+ size_t i=1, ii=1;
3345
+ if (forward)
3346
+ for (; i<len-1; i+=2, ++ii)
3347
+ {
3348
+ tdata[i ] = in[it.iofs(ii)].r;
3349
+ tdata[i+1] = -in[it.iofs(ii)].i;
3350
+ }
3351
+ else
3352
+ for (; i<len-1; i+=2, ++ii)
3353
+ {
3354
+ tdata[i ] = in[it.iofs(ii)].r;
3355
+ tdata[i+1] = in[it.iofs(ii)].i;
3356
+ }
3357
+ if (i<len)
3358
+ tdata[i] = in[it.iofs(ii)].r;
3359
+ }
3360
+ plan->exec(tdata, fct, false);
3361
+ copy_output(it, tdata, out);
3362
+ }
3363
+ }); // end of parallel region
3364
+ }
3365
+
3366
+ struct ExecR2R
3367
+ {
3368
+ bool r2h, forward;
3369
+
3370
+ template <typename T0, typename T, size_t vlen> void operator () (
3371
+ const multi_iter<vlen> &it, const cndarr<T0> &in, ndarr<T0> &out, T * buf,
3372
+ const pocketfft_r<T0> &plan, T0 fct) const
3373
+ {
3374
+ copy_input(it, in, buf);
3375
+ if ((!r2h) && forward)
3376
+ for (size_t i=2; i<it.length_out(); i+=2)
3377
+ buf[i] = -buf[i];
3378
+ plan.exec(buf, fct, r2h);
3379
+ if (r2h && (!forward))
3380
+ for (size_t i=2; i<it.length_out(); i+=2)
3381
+ buf[i] = -buf[i];
3382
+ copy_output(it, buf, out);
3383
+ }
3384
+ };
3385
+
3386
+ template<typename T> void c2c(const shape_t &shape, const stride_t &stride_in,
3387
+ const stride_t &stride_out, const shape_t &axes, bool forward,
3388
+ const std::complex<T> *data_in, std::complex<T> *data_out, T fct,
3389
+ size_t nthreads=1)
3390
+ {
3391
+ if (util::prod(shape)==0) return;
3392
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3393
+ cndarr<cmplx<T>> ain(data_in, shape, stride_in);
3394
+ ndarr<cmplx<T>> aout(data_out, shape, stride_out);
3395
+ general_nd<pocketfft_c<T>>(ain, aout, axes, fct, nthreads, ExecC2C{forward});
3396
+ }
3397
+
3398
+ template<typename T> void dct(const shape_t &shape,
3399
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3400
+ int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
3401
+ {
3402
+ if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type");
3403
+ if (util::prod(shape)==0) return;
3404
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3405
+ cndarr<T> ain(data_in, shape, stride_in);
3406
+ ndarr<T> aout(data_out, shape, stride_out);
3407
+ const ExecDcst exec{ortho, type, true};
3408
+ if (type==1)
3409
+ general_nd<T_dct1<T>>(ain, aout, axes, fct, nthreads, exec);
3410
+ else if (type==4)
3411
+ general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
3412
+ else
3413
+ general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
3414
+ }
3415
+
3416
+ template<typename T> void dst(const shape_t &shape,
3417
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3418
+ int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1)
3419
+ {
3420
+ if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type");
3421
+ if (util::prod(shape)==0) return;
3422
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3423
+ cndarr<T> ain(data_in, shape, stride_in);
3424
+ ndarr<T> aout(data_out, shape, stride_out);
3425
+ const ExecDcst exec{ortho, type, false};
3426
+ if (type==1)
3427
+ general_nd<T_dst1<T>>(ain, aout, axes, fct, nthreads, exec);
3428
+ else if (type==4)
3429
+ general_nd<T_dcst4<T>>(ain, aout, axes, fct, nthreads, exec);
3430
+ else
3431
+ general_nd<T_dcst23<T>>(ain, aout, axes, fct, nthreads, exec);
3432
+ }
3433
+
3434
+ template<typename T> void r2c(const shape_t &shape_in,
3435
+ const stride_t &stride_in, const stride_t &stride_out, size_t axis,
3436
+ bool forward, const T *data_in, std::complex<T> *data_out, T fct,
3437
+ size_t nthreads=1)
3438
+ {
3439
+ if (util::prod(shape_in)==0) return;
3440
+ util::sanity_check(shape_in, stride_in, stride_out, false, axis);
3441
+ cndarr<T> ain(data_in, shape_in, stride_in);
3442
+ shape_t shape_out(shape_in);
3443
+ shape_out[axis] = shape_in[axis]/2 + 1;
3444
+ ndarr<cmplx<T>> aout(data_out, shape_out, stride_out);
3445
+ general_r2c(ain, aout, axis, forward, fct, nthreads);
3446
+ }
3447
+
3448
+ template<typename T> void r2c(const shape_t &shape_in,
3449
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3450
+ bool forward, const T *data_in, std::complex<T> *data_out, T fct,
3451
+ size_t nthreads=1)
3452
+ {
3453
+ if (util::prod(shape_in)==0) return;
3454
+ util::sanity_check(shape_in, stride_in, stride_out, false, axes);
3455
+ r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out,
3456
+ fct, nthreads);
3457
+ if (axes.size()==1) return;
3458
+
3459
+ shape_t shape_out(shape_in);
3460
+ shape_out[axes.back()] = shape_in[axes.back()]/2 + 1;
3461
+ auto newaxes = shape_t{axes.begin(), --axes.end()};
3462
+ c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out,
3463
+ T(1), nthreads);
3464
+ }
3465
+
3466
+ template<typename T> void c2r(const shape_t &shape_out,
3467
+ const stride_t &stride_in, const stride_t &stride_out, size_t axis,
3468
+ bool forward, const std::complex<T> *data_in, T *data_out, T fct,
3469
+ size_t nthreads=1)
3470
+ {
3471
+ if (util::prod(shape_out)==0) return;
3472
+ util::sanity_check(shape_out, stride_in, stride_out, false, axis);
3473
+ shape_t shape_in(shape_out);
3474
+ shape_in[axis] = shape_out[axis]/2 + 1;
3475
+ cndarr<cmplx<T>> ain(data_in, shape_in, stride_in);
3476
+ ndarr<T> aout(data_out, shape_out, stride_out);
3477
+ general_c2r(ain, aout, axis, forward, fct, nthreads);
3478
+ }
3479
+
3480
+ template<typename T> void c2r(const shape_t &shape_out,
3481
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3482
+ bool forward, const std::complex<T> *data_in, T *data_out, T fct,
3483
+ size_t nthreads=1)
3484
+ {
3485
+ if (util::prod(shape_out)==0) return;
3486
+ if (axes.size()==1)
3487
+ return c2r(shape_out, stride_in, stride_out, axes[0], forward,
3488
+ data_in, data_out, fct, nthreads);
3489
+ util::sanity_check(shape_out, stride_in, stride_out, false, axes);
3490
+ auto shape_in = shape_out;
3491
+ shape_in[axes.back()] = shape_out[axes.back()]/2 + 1;
3492
+ auto nval = util::prod(shape_in);
3493
+ stride_t stride_inter(shape_in.size());
3494
+ stride_inter.back() = sizeof(cmplx<T>);
3495
+ for (int i=int(shape_in.size())-2; i>=0; --i)
3496
+ stride_inter[size_t(i)] =
3497
+ stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]);
3498
+ arr<std::complex<T>> tmp(nval);
3499
+ auto newaxes = shape_t{axes.begin(), --axes.end()};
3500
+ c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(),
3501
+ T(1), nthreads);
3502
+ c2r(shape_out, stride_inter, stride_out, axes.back(), forward,
3503
+ tmp.data(), data_out, fct, nthreads);
3504
+ }
3505
+
3506
+ template<typename T> void r2r_fftpack(const shape_t &shape,
3507
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3508
+ bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct,
3509
+ size_t nthreads=1)
3510
+ {
3511
+ if (util::prod(shape)==0) return;
3512
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3513
+ cndarr<T> ain(data_in, shape, stride_in);
3514
+ ndarr<T> aout(data_out, shape, stride_out);
3515
+ general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads,
3516
+ ExecR2R{real2hermitian, forward});
3517
+ }
3518
+
3519
+ template<typename T> void r2r_separable_hartley(const shape_t &shape,
3520
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3521
+ const T *data_in, T *data_out, T fct, size_t nthreads=1)
3522
+ {
3523
+ if (util::prod(shape)==0) return;
3524
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3525
+ cndarr<T> ain(data_in, shape, stride_in);
3526
+ ndarr<T> aout(data_out, shape, stride_out);
3527
+ general_nd<pocketfft_r<T>>(ain, aout, axes, fct, nthreads, ExecHartley{},
3528
+ false);
3529
+ }
3530
+
3531
+ template<typename T> void r2r_genuine_hartley(const shape_t &shape,
3532
+ const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes,
3533
+ const T *data_in, T *data_out, T fct, size_t nthreads=1)
3534
+ {
3535
+ if (util::prod(shape)==0) return;
3536
+ if (axes.size()==1)
3537
+ return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in,
3538
+ data_out, fct, nthreads);
3539
+ util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes);
3540
+ shape_t tshp(shape);
3541
+ tshp[axes.back()] = tshp[axes.back()]/2+1;
3542
+ arr<std::complex<T>> tdata(util::prod(tshp));
3543
+ stride_t tstride(shape.size());
3544
+ tstride.back()=sizeof(std::complex<T>);
3545
+ for (size_t i=tstride.size()-1; i>0; --i)
3546
+ tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]);
3547
+ r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads);
3548
+ cndarr<cmplx<T>> atmp(tdata.data(), tshp, tstride);
3549
+ ndarr<T> aout(data_out, shape, stride_out);
3550
+ simple_iter iin(atmp);
3551
+ rev_iter iout(aout, axes);
3552
+ while(iin.remaining()>0)
3553
+ {
3554
+ auto v = atmp[iin.ofs()];
3555
+ aout[iout.ofs()] = v.r+v.i;
3556
+ aout[iout.rev_ofs()] = v.r-v.i;
3557
+ iin.advance(); iout.advance();
3558
+ }
3559
+ }
3560
+
3561
+ } // namespace detail
3562
+
3563
+ using detail::FORWARD;
3564
+ using detail::BACKWARD;
3565
+ using detail::shape_t;
3566
+ using detail::stride_t;
3567
+ using detail::c2c;
3568
+ using detail::c2r;
3569
+ using detail::r2c;
3570
+ using detail::r2r_fftpack;
3571
+ using detail::r2r_separable_hartley;
3572
+ using detail::r2r_genuine_hartley;
3573
+ using detail::dct;
3574
+ using detail::dst;
3575
+
3576
+ } // namespace pocketfft
3577
+
3578
+ #undef POCKETFFT_NOINLINE
3579
+ #undef POCKETFFT_RESTRICT
3580
+
3581
+ #endif // POCKETFFT_HDRONLY_H