scipy 1.15.3__cp313-cp313-macosx_12_0_arm64.whl → 1.16.0rc2__cp313-cp313-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (629) hide show
  1. scipy/.dylibs/libscipy_openblas.dylib +0 -0
  2. scipy/__config__.py +8 -8
  3. scipy/__init__.py +3 -6
  4. scipy/_cyutility.cpython-313-darwin.so +0 -0
  5. scipy/_lib/_array_api.py +486 -161
  6. scipy/_lib/_array_api_compat_vendor.py +9 -0
  7. scipy/_lib/_bunch.py +4 -0
  8. scipy/_lib/_ccallback_c.cpython-313-darwin.so +0 -0
  9. scipy/_lib/_docscrape.py +1 -1
  10. scipy/_lib/_elementwise_iterative_method.py +15 -26
  11. scipy/_lib/_sparse.py +41 -0
  12. scipy/_lib/_test_deprecation_call.cpython-313-darwin.so +0 -0
  13. scipy/_lib/_test_deprecation_def.cpython-313-darwin.so +0 -0
  14. scipy/_lib/_testutils.py +6 -2
  15. scipy/_lib/_util.py +222 -125
  16. scipy/_lib/array_api_compat/__init__.py +4 -4
  17. scipy/_lib/array_api_compat/_internal.py +19 -6
  18. scipy/_lib/array_api_compat/common/__init__.py +1 -1
  19. scipy/_lib/array_api_compat/common/_aliases.py +365 -193
  20. scipy/_lib/array_api_compat/common/_fft.py +94 -64
  21. scipy/_lib/array_api_compat/common/_helpers.py +413 -180
  22. scipy/_lib/array_api_compat/common/_linalg.py +116 -40
  23. scipy/_lib/array_api_compat/common/_typing.py +179 -10
  24. scipy/_lib/array_api_compat/cupy/__init__.py +1 -4
  25. scipy/_lib/array_api_compat/cupy/_aliases.py +61 -41
  26. scipy/_lib/array_api_compat/cupy/_info.py +16 -6
  27. scipy/_lib/array_api_compat/cupy/_typing.py +24 -39
  28. scipy/_lib/array_api_compat/dask/array/__init__.py +6 -3
  29. scipy/_lib/array_api_compat/dask/array/_aliases.py +267 -108
  30. scipy/_lib/array_api_compat/dask/array/_info.py +105 -34
  31. scipy/_lib/array_api_compat/dask/array/fft.py +5 -8
  32. scipy/_lib/array_api_compat/dask/array/linalg.py +21 -22
  33. scipy/_lib/array_api_compat/numpy/__init__.py +13 -15
  34. scipy/_lib/array_api_compat/numpy/_aliases.py +98 -49
  35. scipy/_lib/array_api_compat/numpy/_info.py +36 -16
  36. scipy/_lib/array_api_compat/numpy/_typing.py +27 -43
  37. scipy/_lib/array_api_compat/numpy/fft.py +11 -5
  38. scipy/_lib/array_api_compat/numpy/linalg.py +75 -22
  39. scipy/_lib/array_api_compat/torch/__init__.py +3 -5
  40. scipy/_lib/array_api_compat/torch/_aliases.py +262 -159
  41. scipy/_lib/array_api_compat/torch/_info.py +27 -16
  42. scipy/_lib/array_api_compat/torch/_typing.py +3 -0
  43. scipy/_lib/array_api_compat/torch/fft.py +17 -18
  44. scipy/_lib/array_api_compat/torch/linalg.py +16 -16
  45. scipy/_lib/array_api_extra/__init__.py +26 -3
  46. scipy/_lib/array_api_extra/_delegation.py +171 -0
  47. scipy/_lib/array_api_extra/_lib/__init__.py +1 -0
  48. scipy/_lib/array_api_extra/_lib/_at.py +463 -0
  49. scipy/_lib/array_api_extra/_lib/_backends.py +46 -0
  50. scipy/_lib/array_api_extra/_lib/_funcs.py +937 -0
  51. scipy/_lib/array_api_extra/_lib/_lazy.py +357 -0
  52. scipy/_lib/array_api_extra/_lib/_testing.py +278 -0
  53. scipy/_lib/array_api_extra/_lib/_utils/__init__.py +1 -0
  54. scipy/_lib/array_api_extra/_lib/_utils/_compat.py +74 -0
  55. scipy/_lib/array_api_extra/_lib/_utils/_compat.pyi +45 -0
  56. scipy/_lib/array_api_extra/_lib/_utils/_helpers.py +559 -0
  57. scipy/_lib/array_api_extra/_lib/_utils/_typing.py +10 -0
  58. scipy/_lib/array_api_extra/_lib/_utils/_typing.pyi +105 -0
  59. scipy/_lib/array_api_extra/testing.py +359 -0
  60. scipy/_lib/decorator.py +2 -2
  61. scipy/_lib/doccer.py +1 -7
  62. scipy/_lib/messagestream.cpython-313-darwin.so +0 -0
  63. scipy/_lib/pyprima/__init__.py +212 -0
  64. scipy/_lib/pyprima/cobyla/__init__.py +0 -0
  65. scipy/_lib/pyprima/cobyla/cobyla.py +559 -0
  66. scipy/_lib/pyprima/cobyla/cobylb.py +714 -0
  67. scipy/_lib/pyprima/cobyla/geometry.py +226 -0
  68. scipy/_lib/pyprima/cobyla/initialize.py +215 -0
  69. scipy/_lib/pyprima/cobyla/trustregion.py +492 -0
  70. scipy/_lib/pyprima/cobyla/update.py +289 -0
  71. scipy/_lib/pyprima/common/__init__.py +0 -0
  72. scipy/_lib/pyprima/common/_bounds.py +34 -0
  73. scipy/_lib/pyprima/common/_linear_constraints.py +46 -0
  74. scipy/_lib/pyprima/common/_nonlinear_constraints.py +54 -0
  75. scipy/_lib/pyprima/common/_project.py +173 -0
  76. scipy/_lib/pyprima/common/checkbreak.py +93 -0
  77. scipy/_lib/pyprima/common/consts.py +47 -0
  78. scipy/_lib/pyprima/common/evaluate.py +99 -0
  79. scipy/_lib/pyprima/common/history.py +38 -0
  80. scipy/_lib/pyprima/common/infos.py +30 -0
  81. scipy/_lib/pyprima/common/linalg.py +435 -0
  82. scipy/_lib/pyprima/common/message.py +290 -0
  83. scipy/_lib/pyprima/common/powalg.py +131 -0
  84. scipy/_lib/pyprima/common/preproc.py +277 -0
  85. scipy/_lib/pyprima/common/present.py +5 -0
  86. scipy/_lib/pyprima/common/ratio.py +54 -0
  87. scipy/_lib/pyprima/common/redrho.py +47 -0
  88. scipy/_lib/pyprima/common/selectx.py +296 -0
  89. scipy/_lib/tests/test__util.py +105 -121
  90. scipy/_lib/tests/test_array_api.py +166 -35
  91. scipy/_lib/tests/test_bunch.py +7 -0
  92. scipy/_lib/tests/test_ccallback.py +2 -10
  93. scipy/_lib/tests/test_public_api.py +13 -0
  94. scipy/cluster/_hierarchy.cpython-313-darwin.so +0 -0
  95. scipy/cluster/_optimal_leaf_ordering.cpython-313-darwin.so +0 -0
  96. scipy/cluster/_vq.cpython-313-darwin.so +0 -0
  97. scipy/cluster/hierarchy.py +393 -223
  98. scipy/cluster/tests/test_hierarchy.py +273 -335
  99. scipy/cluster/tests/test_vq.py +45 -61
  100. scipy/cluster/vq.py +39 -35
  101. scipy/conftest.py +263 -157
  102. scipy/constants/_constants.py +4 -1
  103. scipy/constants/tests/test_codata.py +2 -2
  104. scipy/constants/tests/test_constants.py +11 -18
  105. scipy/datasets/_download_all.py +15 -1
  106. scipy/datasets/_fetchers.py +7 -1
  107. scipy/datasets/_utils.py +1 -1
  108. scipy/differentiate/_differentiate.py +25 -25
  109. scipy/differentiate/tests/test_differentiate.py +24 -25
  110. scipy/fft/_basic.py +20 -0
  111. scipy/fft/_helper.py +3 -34
  112. scipy/fft/_pocketfft/helper.py +29 -1
  113. scipy/fft/_pocketfft/tests/test_basic.py +2 -4
  114. scipy/fft/_pocketfft/tests/test_real_transforms.py +4 -4
  115. scipy/fft/_realtransforms.py +13 -0
  116. scipy/fft/tests/test_basic.py +27 -25
  117. scipy/fft/tests/test_fftlog.py +16 -7
  118. scipy/fft/tests/test_helper.py +18 -34
  119. scipy/fft/tests/test_real_transforms.py +8 -10
  120. scipy/fftpack/convolve.cpython-313-darwin.so +0 -0
  121. scipy/fftpack/tests/test_basic.py +2 -4
  122. scipy/fftpack/tests/test_real_transforms.py +8 -9
  123. scipy/integrate/_bvp.py +9 -3
  124. scipy/integrate/_cubature.py +3 -2
  125. scipy/integrate/_dop.cpython-313-darwin.so +0 -0
  126. scipy/integrate/_lsoda.cpython-313-darwin.so +0 -0
  127. scipy/integrate/_ode.py +9 -2
  128. scipy/integrate/_odepack.cpython-313-darwin.so +0 -0
  129. scipy/integrate/_quad_vec.py +21 -29
  130. scipy/integrate/_quadpack.cpython-313-darwin.so +0 -0
  131. scipy/integrate/_quadpack_py.py +11 -7
  132. scipy/integrate/_quadrature.py +3 -3
  133. scipy/integrate/_rules/_base.py +2 -2
  134. scipy/integrate/_tanhsinh.py +48 -47
  135. scipy/integrate/_test_odeint_banded.cpython-313-darwin.so +0 -0
  136. scipy/integrate/_vode.cpython-313-darwin.so +0 -0
  137. scipy/integrate/tests/test__quad_vec.py +0 -6
  138. scipy/integrate/tests/test_banded_ode_solvers.py +85 -0
  139. scipy/integrate/tests/test_cubature.py +21 -35
  140. scipy/integrate/tests/test_quadrature.py +6 -8
  141. scipy/integrate/tests/test_tanhsinh.py +56 -48
  142. scipy/interpolate/__init__.py +70 -58
  143. scipy/interpolate/_bary_rational.py +22 -22
  144. scipy/interpolate/_bsplines.py +119 -66
  145. scipy/interpolate/_cubic.py +65 -50
  146. scipy/interpolate/_dfitpack.cpython-313-darwin.so +0 -0
  147. scipy/interpolate/_dierckx.cpython-313-darwin.so +0 -0
  148. scipy/interpolate/_fitpack.cpython-313-darwin.so +0 -0
  149. scipy/interpolate/_fitpack2.py +9 -6
  150. scipy/interpolate/_fitpack_impl.py +32 -26
  151. scipy/interpolate/_fitpack_repro.py +23 -19
  152. scipy/interpolate/_interpnd.cpython-313-darwin.so +0 -0
  153. scipy/interpolate/_interpolate.py +30 -12
  154. scipy/interpolate/_ndbspline.py +13 -18
  155. scipy/interpolate/_ndgriddata.py +5 -8
  156. scipy/interpolate/_polyint.py +95 -31
  157. scipy/interpolate/_ppoly.cpython-313-darwin.so +0 -0
  158. scipy/interpolate/_rbf.py +2 -2
  159. scipy/interpolate/_rbfinterp.py +1 -1
  160. scipy/interpolate/_rbfinterp_pythran.cpython-313-darwin.so +0 -0
  161. scipy/interpolate/_rgi.py +31 -26
  162. scipy/interpolate/_rgi_cython.cpython-313-darwin.so +0 -0
  163. scipy/interpolate/dfitpack.py +0 -20
  164. scipy/interpolate/interpnd.py +1 -2
  165. scipy/interpolate/tests/test_bary_rational.py +2 -2
  166. scipy/interpolate/tests/test_bsplines.py +97 -1
  167. scipy/interpolate/tests/test_fitpack2.py +39 -1
  168. scipy/interpolate/tests/test_interpnd.py +32 -20
  169. scipy/interpolate/tests/test_interpolate.py +48 -4
  170. scipy/interpolate/tests/test_rgi.py +2 -1
  171. scipy/io/_fast_matrix_market/__init__.py +2 -0
  172. scipy/io/_harwell_boeing/_fortran_format_parser.py +19 -16
  173. scipy/io/_harwell_boeing/hb.py +7 -11
  174. scipy/io/_idl.py +5 -7
  175. scipy/io/_netcdf.py +15 -5
  176. scipy/io/_test_fortran.cpython-313-darwin.so +0 -0
  177. scipy/io/arff/tests/test_arffread.py +3 -3
  178. scipy/io/matlab/__init__.py +5 -3
  179. scipy/io/matlab/_mio.py +4 -1
  180. scipy/io/matlab/_mio5.py +19 -13
  181. scipy/io/matlab/_mio5_utils.cpython-313-darwin.so +0 -0
  182. scipy/io/matlab/_mio_utils.cpython-313-darwin.so +0 -0
  183. scipy/io/matlab/_miobase.py +4 -1
  184. scipy/io/matlab/_streams.cpython-313-darwin.so +0 -0
  185. scipy/io/matlab/tests/test_mio.py +46 -18
  186. scipy/io/matlab/tests/test_mio_funcs.py +1 -1
  187. scipy/io/tests/test_mmio.py +7 -1
  188. scipy/io/tests/test_wavfile.py +41 -0
  189. scipy/io/wavfile.py +57 -10
  190. scipy/linalg/_basic.py +113 -86
  191. scipy/linalg/_cythonized_array_utils.cpython-313-darwin.so +0 -0
  192. scipy/linalg/_decomp.py +22 -9
  193. scipy/linalg/_decomp_cholesky.py +28 -13
  194. scipy/linalg/_decomp_cossin.py +45 -30
  195. scipy/linalg/_decomp_interpolative.cpython-313-darwin.so +0 -0
  196. scipy/linalg/_decomp_ldl.py +4 -1
  197. scipy/linalg/_decomp_lu.py +18 -6
  198. scipy/linalg/_decomp_lu_cython.cpython-313-darwin.so +0 -0
  199. scipy/linalg/_decomp_polar.py +2 -0
  200. scipy/linalg/_decomp_qr.py +6 -2
  201. scipy/linalg/_decomp_qz.py +3 -0
  202. scipy/linalg/_decomp_schur.py +3 -1
  203. scipy/linalg/_decomp_svd.py +13 -2
  204. scipy/linalg/_decomp_update.cpython-313-darwin.so +0 -0
  205. scipy/linalg/_expm_frechet.py +4 -0
  206. scipy/linalg/_fblas.cpython-313-darwin.so +0 -0
  207. scipy/linalg/_flapack.cpython-313-darwin.so +0 -0
  208. scipy/linalg/_linalg_pythran.cpython-313-darwin.so +0 -0
  209. scipy/linalg/_matfuncs.py +187 -4
  210. scipy/linalg/_matfuncs_expm.cpython-313-darwin.so +0 -0
  211. scipy/linalg/_matfuncs_schur_sqrtm.cpython-313-darwin.so +0 -0
  212. scipy/linalg/_matfuncs_sqrtm.py +1 -99
  213. scipy/linalg/_matfuncs_sqrtm_triu.cpython-313-darwin.so +0 -0
  214. scipy/linalg/_procrustes.py +2 -0
  215. scipy/linalg/_sketches.py +17 -6
  216. scipy/linalg/_solve_toeplitz.cpython-313-darwin.so +0 -0
  217. scipy/linalg/_solvers.py +7 -2
  218. scipy/linalg/_special_matrices.py +26 -36
  219. scipy/linalg/cython_blas.cpython-313-darwin.so +0 -0
  220. scipy/linalg/cython_lapack.cpython-313-darwin.so +0 -0
  221. scipy/linalg/lapack.py +22 -2
  222. scipy/linalg/tests/_cython_examples/meson.build +7 -0
  223. scipy/linalg/tests/test_basic.py +31 -16
  224. scipy/linalg/tests/test_batch.py +588 -0
  225. scipy/linalg/tests/test_cythonized_array_utils.py +0 -2
  226. scipy/linalg/tests/test_decomp.py +40 -3
  227. scipy/linalg/tests/test_decomp_cossin.py +14 -0
  228. scipy/linalg/tests/test_decomp_ldl.py +1 -1
  229. scipy/linalg/tests/test_lapack.py +115 -7
  230. scipy/linalg/tests/test_matfuncs.py +157 -102
  231. scipy/linalg/tests/test_procrustes.py +0 -7
  232. scipy/linalg/tests/test_solve_toeplitz.py +1 -1
  233. scipy/linalg/tests/test_special_matrices.py +1 -5
  234. scipy/ndimage/__init__.py +1 -0
  235. scipy/ndimage/_cytest.cpython-313-darwin.so +0 -0
  236. scipy/ndimage/_delegators.py +8 -2
  237. scipy/ndimage/_filters.py +453 -5
  238. scipy/ndimage/_interpolation.py +36 -6
  239. scipy/ndimage/_measurements.py +4 -2
  240. scipy/ndimage/_morphology.py +5 -0
  241. scipy/ndimage/_nd_image.cpython-313-darwin.so +0 -0
  242. scipy/ndimage/_ni_docstrings.py +5 -1
  243. scipy/ndimage/_ni_label.cpython-313-darwin.so +0 -0
  244. scipy/ndimage/_ni_support.py +1 -5
  245. scipy/ndimage/_rank_filter_1d.cpython-313-darwin.so +0 -0
  246. scipy/ndimage/_support_alternative_backends.py +18 -6
  247. scipy/ndimage/tests/test_filters.py +370 -259
  248. scipy/ndimage/tests/test_fourier.py +7 -9
  249. scipy/ndimage/tests/test_interpolation.py +68 -61
  250. scipy/ndimage/tests/test_measurements.py +18 -35
  251. scipy/ndimage/tests/test_morphology.py +143 -131
  252. scipy/ndimage/tests/test_splines.py +1 -3
  253. scipy/odr/__odrpack.cpython-313-darwin.so +0 -0
  254. scipy/optimize/_basinhopping.py +13 -7
  255. scipy/optimize/_bglu_dense.cpython-313-darwin.so +0 -0
  256. scipy/optimize/_bracket.py +17 -24
  257. scipy/optimize/_chandrupatla.py +9 -10
  258. scipy/optimize/_cobyla_py.py +104 -123
  259. scipy/optimize/_constraints.py +14 -10
  260. scipy/optimize/_differentiable_functions.py +371 -230
  261. scipy/optimize/_differentialevolution.py +4 -3
  262. scipy/optimize/_direct.cpython-313-darwin.so +0 -0
  263. scipy/optimize/_dual_annealing.py +1 -1
  264. scipy/optimize/_elementwise.py +1 -4
  265. scipy/optimize/_group_columns.cpython-313-darwin.so +0 -0
  266. scipy/optimize/_lbfgsb.cpython-313-darwin.so +0 -0
  267. scipy/optimize/_lbfgsb_py.py +57 -16
  268. scipy/optimize/_linprog_doc.py +2 -2
  269. scipy/optimize/_linprog_highs.py +2 -2
  270. scipy/optimize/_linprog_ip.py +25 -10
  271. scipy/optimize/_linprog_util.py +14 -16
  272. scipy/optimize/_lsap.cpython-313-darwin.so +0 -0
  273. scipy/optimize/_lsq/common.py +3 -3
  274. scipy/optimize/_lsq/dogbox.py +16 -2
  275. scipy/optimize/_lsq/givens_elimination.cpython-313-darwin.so +0 -0
  276. scipy/optimize/_lsq/least_squares.py +198 -126
  277. scipy/optimize/_lsq/lsq_linear.py +6 -6
  278. scipy/optimize/_lsq/trf.py +35 -8
  279. scipy/optimize/_milp.py +3 -1
  280. scipy/optimize/_minimize.py +105 -36
  281. scipy/optimize/_minpack.cpython-313-darwin.so +0 -0
  282. scipy/optimize/_minpack_py.py +21 -14
  283. scipy/optimize/_moduleTNC.cpython-313-darwin.so +0 -0
  284. scipy/optimize/_nnls.py +20 -21
  285. scipy/optimize/_nonlin.py +34 -3
  286. scipy/optimize/_numdiff.py +288 -110
  287. scipy/optimize/_optimize.py +86 -48
  288. scipy/optimize/_pava_pybind.cpython-313-darwin.so +0 -0
  289. scipy/optimize/_remove_redundancy.py +5 -5
  290. scipy/optimize/_root_scalar.py +1 -1
  291. scipy/optimize/_shgo.py +6 -0
  292. scipy/optimize/_shgo_lib/_complex.py +1 -1
  293. scipy/optimize/_slsqp_py.py +216 -124
  294. scipy/optimize/_slsqplib.cpython-313-darwin.so +0 -0
  295. scipy/optimize/_spectral.py +1 -1
  296. scipy/optimize/_tnc.py +8 -1
  297. scipy/optimize/_trlib/_trlib.cpython-313-darwin.so +0 -0
  298. scipy/optimize/_trustregion.py +20 -6
  299. scipy/optimize/_trustregion_constr/canonical_constraint.py +7 -7
  300. scipy/optimize/_trustregion_constr/equality_constrained_sqp.py +1 -1
  301. scipy/optimize/_trustregion_constr/minimize_trustregion_constr.py +11 -3
  302. scipy/optimize/_trustregion_constr/projections.py +12 -8
  303. scipy/optimize/_trustregion_constr/qp_subproblem.py +9 -9
  304. scipy/optimize/_trustregion_constr/tests/test_projections.py +7 -7
  305. scipy/optimize/_trustregion_constr/tests/test_qp_subproblem.py +77 -77
  306. scipy/optimize/_trustregion_constr/tr_interior_point.py +5 -5
  307. scipy/optimize/_trustregion_exact.py +0 -1
  308. scipy/optimize/_zeros.cpython-313-darwin.so +0 -0
  309. scipy/optimize/_zeros_py.py +97 -17
  310. scipy/optimize/cython_optimize/_zeros.cpython-313-darwin.so +0 -0
  311. scipy/optimize/slsqp.py +0 -1
  312. scipy/optimize/tests/test__basinhopping.py +1 -1
  313. scipy/optimize/tests/test__differential_evolution.py +4 -4
  314. scipy/optimize/tests/test__linprog_clean_inputs.py +5 -3
  315. scipy/optimize/tests/test__numdiff.py +66 -22
  316. scipy/optimize/tests/test__remove_redundancy.py +2 -2
  317. scipy/optimize/tests/test__shgo.py +9 -1
  318. scipy/optimize/tests/test_bracket.py +36 -46
  319. scipy/optimize/tests/test_chandrupatla.py +133 -135
  320. scipy/optimize/tests/test_cobyla.py +74 -45
  321. scipy/optimize/tests/test_constraints.py +1 -1
  322. scipy/optimize/tests/test_differentiable_functions.py +226 -6
  323. scipy/optimize/tests/test_lbfgsb_hessinv.py +22 -0
  324. scipy/optimize/tests/test_least_squares.py +125 -13
  325. scipy/optimize/tests/test_linear_assignment.py +3 -3
  326. scipy/optimize/tests/test_linprog.py +3 -3
  327. scipy/optimize/tests/test_lsq_linear.py +6 -6
  328. scipy/optimize/tests/test_minimize_constrained.py +2 -2
  329. scipy/optimize/tests/test_minpack.py +4 -4
  330. scipy/optimize/tests/test_nnls.py +43 -3
  331. scipy/optimize/tests/test_nonlin.py +36 -0
  332. scipy/optimize/tests/test_optimize.py +95 -17
  333. scipy/optimize/tests/test_slsqp.py +36 -4
  334. scipy/optimize/tests/test_zeros.py +34 -1
  335. scipy/signal/__init__.py +12 -23
  336. scipy/signal/_delegators.py +568 -0
  337. scipy/signal/_filter_design.py +459 -241
  338. scipy/signal/_fir_filter_design.py +262 -90
  339. scipy/signal/_lti_conversion.py +3 -2
  340. scipy/signal/_ltisys.py +118 -91
  341. scipy/signal/_max_len_seq_inner.cpython-313-darwin.so +0 -0
  342. scipy/signal/_peak_finding_utils.cpython-313-darwin.so +0 -0
  343. scipy/signal/_polyutils.py +172 -0
  344. scipy/signal/_short_time_fft.py +519 -70
  345. scipy/signal/_signal_api.py +30 -0
  346. scipy/signal/_signaltools.py +719 -399
  347. scipy/signal/_sigtools.cpython-313-darwin.so +0 -0
  348. scipy/signal/_sosfilt.cpython-313-darwin.so +0 -0
  349. scipy/signal/_spectral_py.py +230 -50
  350. scipy/signal/_spline.cpython-313-darwin.so +0 -0
  351. scipy/signal/_spline_filters.py +108 -68
  352. scipy/signal/_support_alternative_backends.py +73 -0
  353. scipy/signal/_upfirdn.py +4 -1
  354. scipy/signal/_upfirdn_apply.cpython-313-darwin.so +0 -0
  355. scipy/signal/_waveforms.py +2 -11
  356. scipy/signal/_wavelets.py +1 -1
  357. scipy/signal/fir_filter_design.py +1 -0
  358. scipy/signal/spline.py +4 -11
  359. scipy/signal/tests/_scipy_spectral_test_shim.py +2 -171
  360. scipy/signal/tests/test_bsplines.py +114 -79
  361. scipy/signal/tests/test_cont2discrete.py +9 -2
  362. scipy/signal/tests/test_filter_design.py +721 -481
  363. scipy/signal/tests/test_fir_filter_design.py +332 -140
  364. scipy/signal/tests/test_savitzky_golay.py +4 -3
  365. scipy/signal/tests/test_short_time_fft.py +221 -3
  366. scipy/signal/tests/test_signaltools.py +2144 -1348
  367. scipy/signal/tests/test_spectral.py +50 -6
  368. scipy/signal/tests/test_splines.py +161 -96
  369. scipy/signal/tests/test_upfirdn.py +84 -50
  370. scipy/signal/tests/test_waveforms.py +20 -0
  371. scipy/signal/tests/test_windows.py +607 -466
  372. scipy/signal/windows/_windows.py +287 -148
  373. scipy/sparse/__init__.py +23 -4
  374. scipy/sparse/_base.py +270 -108
  375. scipy/sparse/_bsr.py +7 -4
  376. scipy/sparse/_compressed.py +59 -231
  377. scipy/sparse/_construct.py +90 -38
  378. scipy/sparse/_coo.py +115 -181
  379. scipy/sparse/_csc.py +4 -4
  380. scipy/sparse/_csparsetools.cpython-313-darwin.so +0 -0
  381. scipy/sparse/_csr.py +2 -2
  382. scipy/sparse/_data.py +48 -48
  383. scipy/sparse/_dia.py +105 -18
  384. scipy/sparse/_dok.py +0 -23
  385. scipy/sparse/_index.py +4 -4
  386. scipy/sparse/_matrix.py +23 -0
  387. scipy/sparse/_sparsetools.cpython-313-darwin.so +0 -0
  388. scipy/sparse/_sputils.py +37 -22
  389. scipy/sparse/base.py +0 -9
  390. scipy/sparse/bsr.py +0 -14
  391. scipy/sparse/compressed.py +0 -23
  392. scipy/sparse/construct.py +0 -6
  393. scipy/sparse/coo.py +0 -14
  394. scipy/sparse/csc.py +0 -3
  395. scipy/sparse/csgraph/_flow.cpython-313-darwin.so +0 -0
  396. scipy/sparse/csgraph/_matching.cpython-313-darwin.so +0 -0
  397. scipy/sparse/csgraph/_min_spanning_tree.cpython-313-darwin.so +0 -0
  398. scipy/sparse/csgraph/_reordering.cpython-313-darwin.so +0 -0
  399. scipy/sparse/csgraph/_shortest_path.cpython-313-darwin.so +0 -0
  400. scipy/sparse/csgraph/_tools.cpython-313-darwin.so +0 -0
  401. scipy/sparse/csgraph/_traversal.cpython-313-darwin.so +0 -0
  402. scipy/sparse/csgraph/tests/test_matching.py +14 -2
  403. scipy/sparse/csgraph/tests/test_pydata_sparse.py +4 -1
  404. scipy/sparse/csgraph/tests/test_shortest_path.py +83 -27
  405. scipy/sparse/csr.py +0 -5
  406. scipy/sparse/data.py +1 -6
  407. scipy/sparse/dia.py +0 -7
  408. scipy/sparse/dok.py +0 -10
  409. scipy/sparse/linalg/_dsolve/_superlu.cpython-313-darwin.so +0 -0
  410. scipy/sparse/linalg/_dsolve/linsolve.py +9 -0
  411. scipy/sparse/linalg/_dsolve/tests/test_linsolve.py +35 -28
  412. scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-313-darwin.so +0 -0
  413. scipy/sparse/linalg/_eigen/arpack/arpack.py +23 -17
  414. scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py +6 -6
  415. scipy/sparse/linalg/_interface.py +17 -18
  416. scipy/sparse/linalg/_isolve/_gcrotmk.py +4 -4
  417. scipy/sparse/linalg/_isolve/iterative.py +51 -45
  418. scipy/sparse/linalg/_isolve/lgmres.py +6 -6
  419. scipy/sparse/linalg/_isolve/minres.py +5 -5
  420. scipy/sparse/linalg/_isolve/tfqmr.py +7 -7
  421. scipy/sparse/linalg/_isolve/utils.py +2 -8
  422. scipy/sparse/linalg/_matfuncs.py +1 -1
  423. scipy/sparse/linalg/_norm.py +1 -1
  424. scipy/sparse/linalg/_propack/_cpropack.cpython-313-darwin.so +0 -0
  425. scipy/sparse/linalg/_propack/_dpropack.cpython-313-darwin.so +0 -0
  426. scipy/sparse/linalg/_propack/_spropack.cpython-313-darwin.so +0 -0
  427. scipy/sparse/linalg/_propack/_zpropack.cpython-313-darwin.so +0 -0
  428. scipy/sparse/linalg/_special_sparse_arrays.py +39 -38
  429. scipy/sparse/linalg/tests/test_pydata_sparse.py +14 -0
  430. scipy/sparse/tests/test_arithmetic1d.py +5 -2
  431. scipy/sparse/tests/test_base.py +214 -42
  432. scipy/sparse/tests/test_common1d.py +7 -7
  433. scipy/sparse/tests/test_construct.py +1 -1
  434. scipy/sparse/tests/test_coo.py +272 -4
  435. scipy/sparse/tests/test_sparsetools.py +5 -0
  436. scipy/sparse/tests/test_sputils.py +36 -7
  437. scipy/spatial/_ckdtree.cpython-313-darwin.so +0 -0
  438. scipy/spatial/_distance_pybind.cpython-313-darwin.so +0 -0
  439. scipy/spatial/_distance_wrap.cpython-313-darwin.so +0 -0
  440. scipy/spatial/_hausdorff.cpython-313-darwin.so +0 -0
  441. scipy/spatial/_qhull.cpython-313-darwin.so +0 -0
  442. scipy/spatial/_voronoi.cpython-313-darwin.so +0 -0
  443. scipy/spatial/distance.py +49 -42
  444. scipy/spatial/tests/test_distance.py +15 -1
  445. scipy/spatial/tests/test_kdtree.py +1 -0
  446. scipy/spatial/tests/test_qhull.py +7 -2
  447. scipy/spatial/transform/__init__.py +5 -3
  448. scipy/spatial/transform/_rigid_transform.cpython-313-darwin.so +0 -0
  449. scipy/spatial/transform/_rotation.cpython-313-darwin.so +0 -0
  450. scipy/spatial/transform/tests/test_rigid_transform.py +1221 -0
  451. scipy/spatial/transform/tests/test_rotation.py +1213 -832
  452. scipy/spatial/transform/tests/test_rotation_groups.py +3 -3
  453. scipy/spatial/transform/tests/test_rotation_spline.py +29 -8
  454. scipy/special/__init__.py +1 -47
  455. scipy/special/_add_newdocs.py +34 -772
  456. scipy/special/_basic.py +22 -25
  457. scipy/special/_comb.cpython-313-darwin.so +0 -0
  458. scipy/special/_ellip_harm_2.cpython-313-darwin.so +0 -0
  459. scipy/special/_gufuncs.cpython-313-darwin.so +0 -0
  460. scipy/special/_logsumexp.py +67 -58
  461. scipy/special/_orthogonal.pyi +1 -1
  462. scipy/special/_specfun.cpython-313-darwin.so +0 -0
  463. scipy/special/_special_ufuncs.cpython-313-darwin.so +0 -0
  464. scipy/special/_spherical_bessel.py +4 -4
  465. scipy/special/_support_alternative_backends.py +212 -119
  466. scipy/special/_test_internal.cpython-313-darwin.so +0 -0
  467. scipy/special/_testutils.py +4 -4
  468. scipy/special/_ufuncs.cpython-313-darwin.so +0 -0
  469. scipy/special/_ufuncs.pyi +1 -0
  470. scipy/special/_ufuncs.pyx +215 -1400
  471. scipy/special/_ufuncs_cxx.cpython-313-darwin.so +0 -0
  472. scipy/special/_ufuncs_cxx.pxd +2 -15
  473. scipy/special/_ufuncs_cxx.pyx +5 -44
  474. scipy/special/_ufuncs_cxx_defs.h +2 -16
  475. scipy/special/_ufuncs_defs.h +0 -8
  476. scipy/special/cython_special.cpython-313-darwin.so +0 -0
  477. scipy/special/cython_special.pxd +1 -1
  478. scipy/special/tests/_cython_examples/meson.build +10 -1
  479. scipy/special/tests/test_basic.py +153 -20
  480. scipy/special/tests/test_boost_ufuncs.py +3 -0
  481. scipy/special/tests/test_cdflib.py +35 -11
  482. scipy/special/tests/test_gammainc.py +16 -0
  483. scipy/special/tests/test_hyp2f1.py +2 -2
  484. scipy/special/tests/test_log1mexp.py +85 -0
  485. scipy/special/tests/test_logsumexp.py +206 -64
  486. scipy/special/tests/test_mpmath.py +1 -0
  487. scipy/special/tests/test_nan_inputs.py +1 -1
  488. scipy/special/tests/test_orthogonal.py +17 -18
  489. scipy/special/tests/test_sf_error.py +3 -2
  490. scipy/special/tests/test_sph_harm.py +6 -7
  491. scipy/special/tests/test_support_alternative_backends.py +211 -76
  492. scipy/stats/__init__.py +4 -1
  493. scipy/stats/_ansari_swilk_statistics.cpython-313-darwin.so +0 -0
  494. scipy/stats/_axis_nan_policy.py +5 -12
  495. scipy/stats/_biasedurn.cpython-313-darwin.so +0 -0
  496. scipy/stats/_continued_fraction.py +387 -0
  497. scipy/stats/_continuous_distns.py +277 -310
  498. scipy/stats/_correlation.py +1 -1
  499. scipy/stats/_covariance.py +6 -3
  500. scipy/stats/_discrete_distns.py +39 -32
  501. scipy/stats/_distn_infrastructure.py +39 -12
  502. scipy/stats/_distribution_infrastructure.py +900 -238
  503. scipy/stats/_entropy.py +9 -10
  504. scipy/{_lib → stats}/_finite_differences.py +1 -1
  505. scipy/stats/_hypotests.py +83 -50
  506. scipy/stats/_kde.py +53 -49
  507. scipy/stats/_ksstats.py +1 -1
  508. scipy/stats/_levy_stable/__init__.py +7 -15
  509. scipy/stats/_levy_stable/levyst.cpython-313-darwin.so +0 -0
  510. scipy/stats/_morestats.py +118 -73
  511. scipy/stats/_mstats_basic.py +13 -17
  512. scipy/stats/_mstats_extras.py +8 -8
  513. scipy/stats/_multivariate.py +89 -113
  514. scipy/stats/_new_distributions.py +97 -20
  515. scipy/stats/_page_trend_test.py +12 -5
  516. scipy/stats/_probability_distribution.py +265 -43
  517. scipy/stats/_qmc.py +14 -9
  518. scipy/stats/_qmc_cy.cpython-313-darwin.so +0 -0
  519. scipy/stats/_qmvnt.py +16 -95
  520. scipy/stats/_qmvnt_cy.cpython-313-darwin.so +0 -0
  521. scipy/stats/_quantile.py +335 -0
  522. scipy/stats/_rcont/rcont.cpython-313-darwin.so +0 -0
  523. scipy/stats/_resampling.py +4 -29
  524. scipy/stats/_sampling.py +1 -1
  525. scipy/stats/_sobol.cpython-313-darwin.so +0 -0
  526. scipy/stats/_stats.cpython-313-darwin.so +0 -0
  527. scipy/stats/_stats_mstats_common.py +21 -2
  528. scipy/stats/_stats_py.py +550 -476
  529. scipy/stats/_stats_pythran.cpython-313-darwin.so +0 -0
  530. scipy/stats/_unuran/unuran_wrapper.cpython-313-darwin.so +0 -0
  531. scipy/stats/_unuran/unuran_wrapper.pyi +2 -1
  532. scipy/stats/_variation.py +6 -8
  533. scipy/stats/_wilcoxon.py +13 -7
  534. scipy/stats/tests/common_tests.py +6 -4
  535. scipy/stats/tests/test_axis_nan_policy.py +62 -24
  536. scipy/stats/tests/test_continued_fraction.py +173 -0
  537. scipy/stats/tests/test_continuous.py +379 -60
  538. scipy/stats/tests/test_continuous_basic.py +18 -12
  539. scipy/stats/tests/test_discrete_basic.py +14 -8
  540. scipy/stats/tests/test_discrete_distns.py +16 -16
  541. scipy/stats/tests/test_distributions.py +95 -75
  542. scipy/stats/tests/test_entropy.py +40 -48
  543. scipy/stats/tests/test_fit.py +4 -3
  544. scipy/stats/tests/test_hypotests.py +153 -24
  545. scipy/stats/tests/test_kdeoth.py +109 -41
  546. scipy/stats/tests/test_marray.py +289 -0
  547. scipy/stats/tests/test_morestats.py +79 -47
  548. scipy/stats/tests/test_mstats_basic.py +3 -3
  549. scipy/stats/tests/test_multivariate.py +434 -83
  550. scipy/stats/tests/test_qmc.py +13 -10
  551. scipy/stats/tests/test_quantile.py +199 -0
  552. scipy/stats/tests/test_rank.py +119 -112
  553. scipy/stats/tests/test_resampling.py +47 -56
  554. scipy/stats/tests/test_sampling.py +9 -4
  555. scipy/stats/tests/test_stats.py +799 -939
  556. scipy/stats/tests/test_variation.py +8 -6
  557. scipy/version.py +2 -2
  558. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/LICENSE.txt +4 -4
  559. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/METADATA +11 -11
  560. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/RECORD +561 -568
  561. scipy-1.16.0rc2.dist-info/WHEEL +6 -0
  562. scipy/_lib/array_api_extra/_funcs.py +0 -484
  563. scipy/_lib/array_api_extra/_typing.py +0 -8
  564. scipy/interpolate/_bspl.cpython-313-darwin.so +0 -0
  565. scipy/optimize/_cobyla.cpython-313-darwin.so +0 -0
  566. scipy/optimize/_cython_nnls.cpython-313-darwin.so +0 -0
  567. scipy/optimize/_slsqp.cpython-313-darwin.so +0 -0
  568. scipy/spatial/qhull_src/COPYING.txt +0 -38
  569. scipy/special/libsf_error_state.dylib +0 -0
  570. scipy/special/tests/test_log_softmax.py +0 -109
  571. scipy/special/tests/test_xsf_cuda.py +0 -114
  572. scipy/special/xsf/binom.h +0 -89
  573. scipy/special/xsf/cdflib.h +0 -100
  574. scipy/special/xsf/cephes/airy.h +0 -307
  575. scipy/special/xsf/cephes/besselpoly.h +0 -51
  576. scipy/special/xsf/cephes/beta.h +0 -257
  577. scipy/special/xsf/cephes/cbrt.h +0 -131
  578. scipy/special/xsf/cephes/chbevl.h +0 -85
  579. scipy/special/xsf/cephes/chdtr.h +0 -193
  580. scipy/special/xsf/cephes/const.h +0 -87
  581. scipy/special/xsf/cephes/ellie.h +0 -293
  582. scipy/special/xsf/cephes/ellik.h +0 -251
  583. scipy/special/xsf/cephes/ellpe.h +0 -107
  584. scipy/special/xsf/cephes/ellpk.h +0 -117
  585. scipy/special/xsf/cephes/expn.h +0 -260
  586. scipy/special/xsf/cephes/gamma.h +0 -398
  587. scipy/special/xsf/cephes/hyp2f1.h +0 -596
  588. scipy/special/xsf/cephes/hyperg.h +0 -361
  589. scipy/special/xsf/cephes/i0.h +0 -149
  590. scipy/special/xsf/cephes/i1.h +0 -158
  591. scipy/special/xsf/cephes/igam.h +0 -421
  592. scipy/special/xsf/cephes/igam_asymp_coeff.h +0 -195
  593. scipy/special/xsf/cephes/igami.h +0 -313
  594. scipy/special/xsf/cephes/j0.h +0 -225
  595. scipy/special/xsf/cephes/j1.h +0 -198
  596. scipy/special/xsf/cephes/jv.h +0 -715
  597. scipy/special/xsf/cephes/k0.h +0 -164
  598. scipy/special/xsf/cephes/k1.h +0 -163
  599. scipy/special/xsf/cephes/kn.h +0 -243
  600. scipy/special/xsf/cephes/lanczos.h +0 -112
  601. scipy/special/xsf/cephes/ndtr.h +0 -275
  602. scipy/special/xsf/cephes/poch.h +0 -85
  603. scipy/special/xsf/cephes/polevl.h +0 -167
  604. scipy/special/xsf/cephes/psi.h +0 -194
  605. scipy/special/xsf/cephes/rgamma.h +0 -111
  606. scipy/special/xsf/cephes/scipy_iv.h +0 -811
  607. scipy/special/xsf/cephes/shichi.h +0 -248
  608. scipy/special/xsf/cephes/sici.h +0 -224
  609. scipy/special/xsf/cephes/sindg.h +0 -221
  610. scipy/special/xsf/cephes/tandg.h +0 -139
  611. scipy/special/xsf/cephes/trig.h +0 -58
  612. scipy/special/xsf/cephes/unity.h +0 -186
  613. scipy/special/xsf/cephes/zeta.h +0 -172
  614. scipy/special/xsf/config.h +0 -304
  615. scipy/special/xsf/digamma.h +0 -205
  616. scipy/special/xsf/error.h +0 -57
  617. scipy/special/xsf/evalpoly.h +0 -47
  618. scipy/special/xsf/expint.h +0 -266
  619. scipy/special/xsf/hyp2f1.h +0 -694
  620. scipy/special/xsf/iv_ratio.h +0 -173
  621. scipy/special/xsf/lambertw.h +0 -150
  622. scipy/special/xsf/loggamma.h +0 -163
  623. scipy/special/xsf/sici.h +0 -200
  624. scipy/special/xsf/tools.h +0 -427
  625. scipy/special/xsf/trig.h +0 -164
  626. scipy/special/xsf/wright_bessel.h +0 -843
  627. scipy/special/xsf/zlog1.h +0 -35
  628. scipy/stats/_mvn.cpython-313-darwin.so +0 -0
  629. scipy-1.15.3.dist-info/WHEEL +0 -4
@@ -26,6 +26,14 @@ from ._filter_design import cheby1, _validate_sos, zpk2sos
26
26
  from ._fir_filter_design import firwin
27
27
  from ._sosfilt import _sosfilt
28
28
 
29
+ from scipy._lib._array_api import (
30
+ array_namespace, is_torch, is_numpy, xp_copy, xp_size, xp_default_dtype
31
+
32
+ )
33
+ from scipy._lib.array_api_compat import is_array_api_obj
34
+ import scipy._lib.array_api_compat.numpy as np_compat
35
+ import scipy._lib.array_api_extra as xpx
36
+
29
37
 
30
38
  __all__ = ['correlate', 'correlation_lags', 'correlate2d',
31
39
  'convolve', 'convolve2d', 'fftconvolve', 'oaconvolve',
@@ -168,24 +176,23 @@ def correlate(in1, in2, mode='full', method='auto'):
168
176
 
169
177
  z[...,k,...] = sum[..., i_l, ...] x[..., i_l,...] * conj(y[..., i_l - k,...])
170
178
 
171
- This way, if x and y are 1-D arrays and ``z = correlate(x, y, 'full')``
179
+ This way, if ``x`` and ``y`` are 1-D arrays and ``z = correlate(x, y, 'full')``
172
180
  then
173
181
 
174
182
  .. math::
175
183
 
176
- z[k] = (x * y)(k - N + 1)
177
- = \sum_{l=0}^{||x||-1}x_l y_{l-k+N-1}^{*}
178
-
179
- for :math:`k = 0, 1, ..., ||x|| + ||y|| - 2`
180
-
181
- where :math:`||x||` is the length of ``x``, :math:`N = \max(||x||,||y||)`,
182
- and :math:`y_m` is 0 when m is outside the range of y.
184
+ z[k] = \sum_{l=0}^{N-1} x_l \, y_{l-k}^{*}
183
185
 
186
+ for :math:`k = -(M-1), \dots, (N-1)`, where :math:`N` is the length of ``x``,
187
+ :math:`M` is the length of ``y``, and :math:`y_m = 0` when :math:`m` is outside the
188
+ valid range :math:`[0, M-1]`. The size of :math:`z` is :math:`N + M - 1` and
189
+ :math:`y^*` denotes the complex conjugate of :math:`y`.
190
+
184
191
  ``method='fft'`` only works for numerical arrays as it relies on
185
192
  `fftconvolve`. In certain cases (i.e., arrays of objects or when
186
193
  rounding integers can lose precision), ``method='direct'`` is always used.
187
194
 
188
- When using "same" mode with even-length inputs, the outputs of `correlate`
195
+ When using ``mode='same'`` with even-length inputs, the outputs of `correlate`
189
196
  and `correlate2d` differ: There is a 1-index offset between them.
190
197
 
191
198
  Examples
@@ -243,13 +250,24 @@ def correlate(in1, in2, mode='full', method='auto'):
243
250
  >>> plt.show()
244
251
 
245
252
  """
246
- in1 = np.asarray(in1)
247
- in2 = np.asarray(in2)
248
- _reject_objects(in1, 'correlate')
249
- _reject_objects(in2, 'correlate')
253
+ try:
254
+ xp = array_namespace(in1, in2)
255
+ except TypeError:
256
+ # either in1 or in2 are object arrays
257
+ xp = np_compat
258
+
259
+ if is_numpy(xp):
260
+ _reject_objects(in1, 'correlate')
261
+ _reject_objects(in2, 'correlate')
262
+
263
+ in1 = xp.asarray(in1)
264
+ in2 = xp.asarray(in2)
250
265
 
251
266
  if in1.ndim == in2.ndim == 0:
252
- return in1 * in2.conj()
267
+ in2_conj = (xp.conj(in2)
268
+ if xp.isdtype(in2.dtype, 'complex floating')
269
+ else in2)
270
+ return in1 * in2_conj
253
271
  elif in1.ndim != in2.ndim:
254
272
  raise ValueError("in1 and in2 should have the same dimensionality")
255
273
 
@@ -262,47 +280,56 @@ def correlate(in1, in2, mode='full', method='auto'):
262
280
 
263
281
  # this either calls fftconvolve or this function with method=='direct'
264
282
  if method in ('fft', 'auto'):
265
- return convolve(in1, _reverse_and_conj(in2), mode, method)
283
+ return convolve(in1, _reverse_and_conj(in2, xp), mode, method)
266
284
 
267
285
  elif method == 'direct':
268
286
  # fastpath to faster numpy.correlate for 1d inputs when possible
269
- if _np_conv_ok(in1, in2, mode):
270
- return np.correlate(in1, in2, mode)
287
+ if _np_conv_ok(in1, in2, mode, xp):
288
+ a_in1 = np.asarray(in1)
289
+ a_in2 = np.asarray(in2)
290
+ out = np.correlate(a_in1, a_in2, mode)
291
+ return xp.asarray(out)
271
292
 
272
293
  # _correlateND is far slower when in2.size > in1.size, so swap them
273
294
  # and then undo the effect afterward if mode == 'full'. Also, it fails
274
295
  # with 'valid' mode if in2 is larger than in1, so swap those, too.
275
296
  # Don't swap inputs for 'same' mode, since shape of in1 matters.
276
- swapped_inputs = ((mode == 'full') and (in2.size > in1.size) or
297
+ swapped_inputs = ((mode == 'full') and (xp_size(in2) > xp_size(in1)) or
277
298
  _inputs_swap_needed(mode, in1.shape, in2.shape))
278
299
 
279
300
  if swapped_inputs:
280
301
  in1, in2 = in2, in1
281
302
 
303
+ # convert to numpy & back for _sigtools._correlateND
304
+ a_in1 = np.asarray(in1)
305
+ a_in2 = np.asarray(in2)
306
+
282
307
  if mode == 'valid':
283
308
  ps = [i - j + 1 for i, j in zip(in1.shape, in2.shape)]
284
- out = np.empty(ps, in1.dtype)
309
+ out = np.empty(ps, a_in1.dtype)
285
310
 
286
- z = _sigtools._correlateND(in1, in2, out, val)
311
+ z = _sigtools._correlateND(a_in1, a_in2, out, val)
287
312
 
288
313
  else:
289
314
  ps = [i + j - 1 for i, j in zip(in1.shape, in2.shape)]
290
315
 
291
316
  # zero pad input
292
- in1zpadded = np.zeros(ps, in1.dtype)
317
+ in1zpadded = np.zeros(ps, a_in1.dtype)
293
318
  sc = tuple(slice(0, i) for i in in1.shape)
294
- in1zpadded[sc] = in1.copy()
319
+ in1zpadded[sc] = a_in1.copy()
295
320
 
296
321
  if mode == 'full':
297
- out = np.empty(ps, in1.dtype)
322
+ out = np.empty(ps, a_in1.dtype)
298
323
  elif mode == 'same':
299
- out = np.empty(in1.shape, in1.dtype)
324
+ out = np.empty(in1.shape, a_in1.dtype)
325
+
326
+ z = _sigtools._correlateND(in1zpadded, a_in2, out, val)
300
327
 
301
- z = _sigtools._correlateND(in1zpadded, in2, out, val)
328
+ z = xp.asarray(z)
302
329
 
303
330
  if swapped_inputs:
304
331
  # Reverse and conjugate to undo the effect of swapping inputs
305
- z = _reverse_and_conj(z)
332
+ z = _reverse_and_conj(z, xp)
306
333
 
307
334
  return z
308
335
 
@@ -481,7 +508,7 @@ def _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False):
481
508
  return in1, in2, axes
482
509
 
483
510
 
484
- def _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=False):
511
+ def _freq_domain_conv(xp, in1, in2, axes, shape, calc_fast_len=False):
485
512
  """Convolve two arrays in the frequency domain.
486
513
 
487
514
  This function implements only base the FFT-related operations.
@@ -515,7 +542,8 @@ def _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=False):
515
542
  if not len(axes):
516
543
  return in1 * in2
517
544
 
518
- complex_result = (in1.dtype.kind == 'c' or in2.dtype.kind == 'c')
545
+ complex_result = (xp.isdtype(in1.dtype, 'complex floating') or
546
+ xp.isdtype(in2.dtype, 'complex floating'))
519
547
 
520
548
  if calc_fast_len:
521
549
  # Speed up FFT by padding to optimal size.
@@ -529,6 +557,11 @@ def _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=False):
529
557
  else:
530
558
  fft, ifft = sp_fft.fftn, sp_fft.ifftn
531
559
 
560
+ if xp.isdtype(in1.dtype, 'integral'):
561
+ in1 = xp.astype(in1, xp.float64)
562
+ if xp.isdtype(in2.dtype, 'integral'):
563
+ in2 = xp.astype(in2, xp.float64)
564
+
532
565
  sp1 = fft(in1, fshape, axes=axes)
533
566
  sp2 = fft(in2, fshape, axes=axes)
534
567
 
@@ -541,7 +574,7 @@ def _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=False):
541
574
  return ret
542
575
 
543
576
 
544
- def _apply_conv_mode(ret, s1, s2, mode, axes):
577
+ def _apply_conv_mode(ret, s1, s2, mode, axes, xp):
545
578
  """Calculate the convolution result shape based on the `mode` argument.
546
579
 
547
580
  Returns the result sliced to the correct size for the given mode.
@@ -567,13 +600,13 @@ def _apply_conv_mode(ret, s1, s2, mode, axes):
567
600
 
568
601
  """
569
602
  if mode == "full":
570
- return ret.copy()
603
+ return xp_copy(ret, xp=xp)
571
604
  elif mode == "same":
572
- return _centered(ret, s1).copy()
605
+ return xp_copy(_centered(ret, s1), xp=xp)
573
606
  elif mode == "valid":
574
607
  shape_valid = [ret.shape[a] if a not in axes else s1[a] - s2[a] + 1
575
608
  for a in range(ret.ndim)]
576
- return _centered(ret, shape_valid).copy()
609
+ return xp_copy(_centered(ret, shape_valid), xp=xp)
577
610
  else:
578
611
  raise ValueError("acceptable mode flags are 'valid',"
579
612
  " 'same', or 'full'")
@@ -673,15 +706,17 @@ def fftconvolve(in1, in2, mode="full", axes=None):
673
706
  >>> fig.show()
674
707
 
675
708
  """
676
- in1 = np.asarray(in1)
677
- in2 = np.asarray(in2)
709
+ xp = array_namespace(in1, in2)
710
+
711
+ in1 = xp.asarray(in1)
712
+ in2 = xp.asarray(in2)
678
713
 
679
714
  if in1.ndim == in2.ndim == 0: # scalar inputs
680
715
  return in1 * in2
681
716
  elif in1.ndim != in2.ndim:
682
717
  raise ValueError("in1 and in2 should have the same dimensionality")
683
- elif in1.size == 0 or in2.size == 0: # empty arrays
684
- return np.array([])
718
+ elif xp_size(in1) == 0 or xp_size(in2) == 0: # empty arrays
719
+ return xp.asarray([])
685
720
 
686
721
  in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes,
687
722
  sorted_axes=False)
@@ -692,9 +727,9 @@ def fftconvolve(in1, in2, mode="full", axes=None):
692
727
  shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1
693
728
  for i in range(in1.ndim)]
694
729
 
695
- ret = _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=True)
730
+ ret = _freq_domain_conv(xp, in1, in2, axes, shape, calc_fast_len=True)
696
731
 
697
- return _apply_conv_mode(ret, s1, s2, mode, axes)
732
+ return _apply_conv_mode(ret, s1, s2, mode, axes, xp=xp)
698
733
 
699
734
 
700
735
  def _calc_oa_lens(s1, s2):
@@ -808,6 +843,35 @@ def _calc_oa_lens(s1, s2):
808
843
  return block_size, overlap, in1_step, in2_step
809
844
 
810
845
 
846
+ def _swapaxes(x, ax1, ax2, xp):
847
+ """np.swapaxes"""
848
+ shp = list(range(x.ndim))
849
+ shp[ax1], shp[ax2] = shp[ax2], shp[ax1]
850
+ return xp.permute_dims(x, shp)
851
+
852
+
853
+ # may want to look at moving _swapaxes and this to array-api-extra,
854
+ # cross-ref https://github.com/data-apis/array-api-extra/issues/97
855
+ def _split(x, indices_or_sections, axis, xp):
856
+ """A simplified version of np.split, with `indices` being an list.
857
+ """
858
+ # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_shape_base_impl.py#L743
859
+ Ntotal = x.shape[axis]
860
+
861
+ # handle array case.
862
+ Nsections = len(indices_or_sections) + 1
863
+ div_points = [0] + list(indices_or_sections) + [Ntotal]
864
+
865
+ sub_arys = []
866
+ sary = _swapaxes(x, axis, 0, xp=xp)
867
+ for i in range(Nsections):
868
+ st = div_points[i]
869
+ end = div_points[i + 1]
870
+ sub_arys.append(_swapaxes(sary[st:end, ...], axis, 0, xp=xp))
871
+
872
+ return sub_arys
873
+
874
+
811
875
  def oaconvolve(in1, in2, mode="full", axes=None):
812
876
  """Convolve two N-dimensional arrays using the overlap-add method.
813
877
 
@@ -888,15 +952,17 @@ def oaconvolve(in1, in2, mode="full", axes=None):
888
952
  >>> fig.show()
889
953
 
890
954
  """
891
- in1 = np.asarray(in1)
892
- in2 = np.asarray(in2)
955
+ xp = array_namespace(in1, in2)
956
+
957
+ in1 = xp.asarray(in1)
958
+ in2 = xp.asarray(in2)
893
959
 
894
960
  if in1.ndim == in2.ndim == 0: # scalar inputs
895
961
  return in1 * in2
896
962
  elif in1.ndim != in2.ndim:
897
963
  raise ValueError("in1 and in2 should have the same dimensionality")
898
964
  elif in1.size == 0 or in2.size == 0: # empty arrays
899
- return np.array([])
965
+ return xp.asarray([])
900
966
  elif in1.shape == in2.shape: # Equivalent to fftconvolve
901
967
  return fftconvolve(in1, in2, mode=mode, axes=axes)
902
968
 
@@ -908,7 +974,7 @@ def oaconvolve(in1, in2, mode="full", axes=None):
908
974
 
909
975
  if not axes:
910
976
  ret = in1 * in2
911
- return _apply_conv_mode(ret, s1, s2, mode, axes)
977
+ return _apply_conv_mode(ret, s1, s2, mode, axes, xp)
912
978
 
913
979
  # Calculate this now since in1 is changed later
914
980
  shape_final = [None if i not in axes else
@@ -966,10 +1032,10 @@ def oaconvolve(in1, in2, mode="full", axes=None):
966
1032
  # Pad the array to a size that can be reshaped to the desired shape
967
1033
  # if necessary.
968
1034
  if not all(curpad == (0, 0) for curpad in pad_size1):
969
- in1 = np.pad(in1, pad_size1, mode='constant', constant_values=0)
1035
+ in1 = xpx.pad(in1, pad_size1, mode='constant', constant_values=0, xp=xp)
970
1036
 
971
1037
  if not all(curpad == (0, 0) for curpad in pad_size2):
972
- in2 = np.pad(in2, pad_size2, mode='constant', constant_values=0)
1038
+ in2 = xpx.pad(in2, pad_size2, mode='constant', constant_values=0, xp=xp)
973
1039
 
974
1040
  # Reshape the overlap-add parts to input block sizes.
975
1041
  split_axes = [iax+i for i, iax in enumerate(axes)]
@@ -983,12 +1049,12 @@ def oaconvolve(in1, in2, mode="full", axes=None):
983
1049
  reshape_size1.insert(iax, nsteps1[i])
984
1050
  reshape_size2.insert(iax, nsteps2[i])
985
1051
 
986
- in1 = in1.reshape(*reshape_size1)
987
- in2 = in2.reshape(*reshape_size2)
1052
+ in1 = xp.reshape(in1, tuple(reshape_size1))
1053
+ in2 = xp.reshape(in2, tuple(reshape_size2))
988
1054
 
989
1055
  # Do the convolution.
990
1056
  fft_shape = [block_size[i] for i in axes]
991
- ret = _freq_domain_conv(in1, in2, fft_axes, fft_shape, calc_fast_len=False)
1057
+ ret = _freq_domain_conv(xp, in1, in2, fft_axes, fft_shape, calc_fast_len=False)
992
1058
 
993
1059
  # Do the overlap-add.
994
1060
  for ax, ax_fft, ax_split in zip(axes, fft_axes, split_axes):
@@ -996,27 +1062,27 @@ def oaconvolve(in1, in2, mode="full", axes=None):
996
1062
  if overlap is None:
997
1063
  continue
998
1064
 
999
- ret, overpart = np.split(ret, [-overlap], ax_fft)
1000
- overpart = np.split(overpart, [-1], ax_split)[0]
1065
+ ret, overpart = _split(ret, [-overlap], ax_fft, xp=xp)
1066
+ overpart = _split(overpart, [-1], ax_split, xp=xp)[0]
1001
1067
 
1002
- ret_overpart = np.split(ret, [overlap], ax_fft)[0]
1003
- ret_overpart = np.split(ret_overpart, [1], ax_split)[1]
1068
+ ret_overpart = _split(ret, [overlap], ax_fft, xp=xp)[0]
1069
+ ret_overpart = _split(ret_overpart, [1], ax_split, xp)[1]
1004
1070
  ret_overpart += overpart
1005
1071
 
1006
1072
  # Reshape back to the correct dimensionality.
1007
1073
  shape_ret = [ret.shape[i] if i not in fft_axes else
1008
1074
  ret.shape[i]*ret.shape[i-1]
1009
1075
  for i in range(ret.ndim) if i not in split_axes]
1010
- ret = ret.reshape(*shape_ret)
1076
+ ret = xp.reshape(ret, shape_ret)
1011
1077
 
1012
1078
  # Slice to the correct size.
1013
1079
  slice_final = tuple([slice(islice) for islice in shape_final])
1014
1080
  ret = ret[slice_final]
1015
1081
 
1016
- return _apply_conv_mode(ret, s1, s2, mode, axes)
1082
+ return _apply_conv_mode(ret, s1, s2, mode, axes, xp)
1017
1083
 
1018
1084
 
1019
- def _numeric_arrays(arrays, kinds='buifc'):
1085
+ def _numeric_arrays(arrays, kinds='buifc', xp=None):
1020
1086
  """
1021
1087
  See if a list of arrays are all numeric.
1022
1088
 
@@ -1029,7 +1095,12 @@ def _numeric_arrays(arrays, kinds='buifc'):
1029
1095
  the ndarrays are not in this string the function returns False and
1030
1096
  otherwise returns True.
1031
1097
  """
1032
- if isinstance(arrays, np.ndarray):
1098
+ if xp is None:
1099
+ xp = array_namespace(*arrays)
1100
+ if not is_numpy(xp):
1101
+ return True
1102
+
1103
+ if type(arrays) is np.ndarray:
1033
1104
  return arrays.dtype.kind in kinds
1034
1105
  for array_ in arrays:
1035
1106
  if array_.dtype.kind not in kinds:
@@ -1122,15 +1193,26 @@ def _fftconv_faster(x, h, mode):
1122
1193
  return O_fft * fft_ops < O_direct * direct_ops + O_offset
1123
1194
 
1124
1195
 
1125
- def _reverse_and_conj(x):
1196
+ def _reverse_and_conj(x, xp):
1126
1197
  """
1127
1198
  Reverse array `x` in all dimensions and perform the complex conjugate
1128
1199
  """
1129
- reverse = (slice(None, None, -1),) * x.ndim
1130
- return x[reverse].conj()
1200
+ if not is_torch(xp):
1201
+ reverse = (slice(None, None, -1),) * x.ndim
1202
+ x_rev = x[reverse]
1203
+ else:
1204
+ # NB: is a copy, not a view as torch does not allow negative indices
1205
+ # in slices, x-ref https://github.com/pytorch/pytorch/issues/59786
1206
+ x_rev = xp.flip(x)
1207
+
1208
+ # cf https://github.com/data-apis/array-api/issues/824
1209
+ if xp.isdtype(x.dtype, 'complex floating'):
1210
+ return xp.conj(x_rev)
1211
+ else:
1212
+ return x_rev
1131
1213
 
1132
1214
 
1133
- def _np_conv_ok(volume, kernel, mode):
1215
+ def _np_conv_ok(volume, kernel, mode, xp):
1134
1216
  """
1135
1217
  See if numpy supports convolution of `volume` and `kernel` (i.e. both are
1136
1218
  1D ndarrays and of the appropriate shape). NumPy's 'same' mode uses the
@@ -1142,7 +1224,7 @@ def _np_conv_ok(volume, kernel, mode):
1142
1224
  if mode in ('full', 'valid'):
1143
1225
  return True
1144
1226
  elif mode == 'same':
1145
- return volume.size >= kernel.size
1227
+ return xp_size(volume) >= xp_size(kernel)
1146
1228
  else:
1147
1229
  return False
1148
1230
 
@@ -1290,11 +1372,18 @@ def choose_conv_method(in1, in2, mode='full', measure=False):
1290
1372
  `convolve`.
1291
1373
 
1292
1374
  """
1293
- volume = np.asarray(in1)
1294
- kernel = np.asarray(in2)
1375
+ try:
1376
+ xp = array_namespace(in1, in2)
1377
+ except TypeError:
1378
+ # either in1 or in2 are object arrays
1379
+ xp = np_compat
1295
1380
 
1296
- _reject_objects(volume, 'choose_conv_method')
1297
- _reject_objects(kernel, 'choose_conv_method')
1381
+ if is_numpy(xp):
1382
+ _reject_objects(in1, 'choose_conv_method')
1383
+ _reject_objects(in2, 'choose_conv_method')
1384
+
1385
+ volume = xp.asarray(in1)
1386
+ kernel = xp.asarray(in2)
1298
1387
 
1299
1388
  if measure:
1300
1389
  times = {}
@@ -1308,16 +1397,16 @@ def choose_conv_method(in1, in2, mode='full', measure=False):
1308
1397
  # for integer input,
1309
1398
  # catch when more precision required than float provides (representing an
1310
1399
  # integer as float can lose precision in fftconvolve if larger than 2**52)
1311
- if any([_numeric_arrays([x], kinds='ui') for x in [volume, kernel]]):
1312
- max_value = int(np.abs(volume).max()) * int(np.abs(kernel).max())
1313
- max_value *= int(min(volume.size, kernel.size))
1400
+ if any([_numeric_arrays([x], kinds='ui', xp=xp) for x in [volume, kernel]]):
1401
+ max_value = int(xp.max(xp.abs(volume))) * int(xp.max(xp.abs(kernel)))
1402
+ max_value *= int(min(xp_size(volume), xp_size(kernel)))
1314
1403
  if max_value > 2**np.finfo('float').nmant - 1:
1315
1404
  return 'direct'
1316
1405
 
1317
- if _numeric_arrays([volume, kernel], kinds='b'):
1406
+ if _numeric_arrays([volume, kernel], kinds='b', xp=xp):
1318
1407
  return 'direct'
1319
1408
 
1320
- if _numeric_arrays([volume, kernel]):
1409
+ if _numeric_arrays([volume, kernel], xp=xp):
1321
1410
  if _fftconv_faster(volume, kernel, mode):
1322
1411
  return 'fft'
1323
1412
 
@@ -1422,11 +1511,18 @@ def convolve(in1, in2, mode='full', method='auto'):
1422
1511
  >>> fig.show()
1423
1512
 
1424
1513
  """
1425
- volume = np.asarray(in1)
1426
- kernel = np.asarray(in2)
1514
+ try:
1515
+ xp = array_namespace(in1, in2)
1516
+ except TypeError:
1517
+ # either in1 or in2 are object arrays
1518
+ xp = np_compat
1519
+
1520
+ if is_numpy(xp):
1521
+ _reject_objects(in1, 'correlate')
1522
+ _reject_objects(in2, 'correlate')
1427
1523
 
1428
- _reject_objects(volume, 'correlate')
1429
- _reject_objects(kernel, 'correlate')
1524
+ volume = xp.asarray(in1)
1525
+ kernel = xp.asarray(in2)
1430
1526
 
1431
1527
  if volume.ndim == kernel.ndim == 0:
1432
1528
  return volume * kernel
@@ -1443,23 +1539,27 @@ def convolve(in1, in2, mode='full', method='auto'):
1443
1539
 
1444
1540
  if method == 'fft':
1445
1541
  out = fftconvolve(volume, kernel, mode=mode)
1446
- result_type = np.result_type(volume, kernel)
1447
- if result_type.kind in {'u', 'i'}:
1448
- out = np.around(out)
1542
+ result_type = xp.result_type(volume, kernel)
1543
+ if xp.isdtype(result_type, 'integral'):
1544
+ out = xp.round(out)
1449
1545
 
1450
- if np.isnan(out.flat[0]) or np.isinf(out.flat[0]):
1546
+ if xp.isnan(xp.reshape(out, (-1,))[0]) or xp.isinf(xp.reshape(out, (-1,))[0]):
1451
1547
  warnings.warn("Use of fft convolution on input with NAN or inf"
1452
1548
  " results in NAN or inf output. Consider using"
1453
1549
  " method='direct' instead.",
1454
1550
  category=RuntimeWarning, stacklevel=2)
1455
1551
 
1456
- return out.astype(result_type)
1552
+ return xp.astype(out, result_type)
1457
1553
  elif method == 'direct':
1458
1554
  # fastpath to faster numpy.convolve for 1d inputs when possible
1459
- if _np_conv_ok(volume, kernel, mode):
1460
- return np.convolve(volume, kernel, mode)
1461
-
1462
- return correlate(volume, _reverse_and_conj(kernel), mode, 'direct')
1555
+ if _np_conv_ok(volume, kernel, mode, xp):
1556
+ # convert to numpy and back
1557
+ a_volume = np.asarray(volume)
1558
+ a_kernel = np.asarray(kernel)
1559
+ out = np.convolve(a_volume, a_kernel, mode)
1560
+ return xp.asarray(out)
1561
+
1562
+ return correlate(volume, _reverse_and_conj(kernel, xp), mode, 'direct')
1463
1563
  else:
1464
1564
  raise ValueError("Acceptable method flags are 'auto',"
1465
1565
  " 'direct', or 'fft'.")
@@ -1519,17 +1619,19 @@ def order_filter(a, domain, rank):
1519
1619
  [ 20, 21, 22, 23, 24]])
1520
1620
 
1521
1621
  """
1522
- domain = np.asarray(domain)
1622
+ xp = array_namespace(a, domain)
1623
+
1624
+ domain = xp.asarray(domain)
1523
1625
  for dimsize in domain.shape:
1524
1626
  if (dimsize % 2) != 1:
1525
1627
  raise ValueError("Each dimension of domain argument "
1526
1628
  "should have an odd number of elements.")
1527
1629
 
1528
- a = np.asarray(a)
1529
- if not (np.issubdtype(a.dtype, np.integer)
1530
- or a.dtype in [np.float32, np.float64]):
1630
+ a = xp.asarray(a)
1631
+ if not (
1632
+ xp.isdtype(a.dtype, "integral") or a.dtype in (xp.float32, xp.float64)
1633
+ ):
1531
1634
  raise ValueError(f"dtype={a.dtype} is not supported by order_filter")
1532
-
1533
1635
  result = ndimage.rank_filter(a, rank, footprint=domain, mode='constant')
1534
1636
  return result
1535
1637
 
@@ -1576,16 +1678,20 @@ def medfilt(volume, kernel_size=None):
1576
1678
  the specialised function `scipy.signal.medfilt2d` may be faster.
1577
1679
 
1578
1680
  """
1579
- volume = np.atleast_1d(volume)
1580
- if not (np.issubdtype(volume.dtype, np.integer)
1581
- or volume.dtype in [np.float32, np.float64]):
1681
+ xp = array_namespace(volume)
1682
+ volume = xp.asarray(volume)
1683
+ if volume.ndim == 0:
1684
+ volume = xpx.atleast_nd(volume, ndim=1, xp=xp)
1685
+
1686
+ if not (xp.isdtype(volume.dtype, "integral") or
1687
+ volume.dtype in [xp.float32, xp.float64]):
1582
1688
  raise ValueError(f"dtype={volume.dtype} is not supported by medfilt")
1583
1689
 
1584
1690
  if kernel_size is None:
1585
1691
  kernel_size = [3] * volume.ndim
1586
- kernel_size = np.asarray(kernel_size)
1692
+ kernel_size = xp.asarray(kernel_size)
1587
1693
  if kernel_size.shape == ():
1588
- kernel_size = np.repeat(kernel_size.item(), volume.ndim)
1694
+ kernel_size = xp.repeat(kernel_size, volume.ndim)
1589
1695
 
1590
1696
  for k in range(volume.ndim):
1591
1697
  if (kernel_size[k] % 2) != 1:
@@ -1651,28 +1757,32 @@ def wiener(im, mysize=None, noise=None):
1651
1757
  >>> plt.show()
1652
1758
 
1653
1759
  """
1654
- im = np.asarray(im)
1760
+ xp = array_namespace(im)
1761
+
1762
+ im = xp.asarray(im)
1655
1763
  if mysize is None:
1656
1764
  mysize = [3] * im.ndim
1657
- mysize = np.asarray(mysize)
1658
- if mysize.shape == ():
1659
- mysize = np.repeat(mysize.item(), im.ndim)
1765
+ mysize_arr = xp.asarray(mysize)
1766
+ if mysize_arr.shape == ():
1767
+ mysize = [mysize] * im.ndim
1660
1768
 
1661
1769
  # Estimate the local mean
1662
1770
  size = math.prod(mysize)
1663
- lMean = correlate(im, np.ones(mysize), 'same') / size
1771
+ lMean = correlate(im, xp.ones(mysize), 'same')
1772
+ lsize = float(size)
1773
+ lMean = lMean / lsize
1664
1774
 
1665
1775
  # Estimate the local variance
1666
- lVar = (correlate(im ** 2, np.ones(mysize), 'same') / size - lMean ** 2)
1776
+ lVar = (correlate(im ** 2, xp.ones(mysize), 'same') / lsize - lMean ** 2)
1667
1777
 
1668
1778
  # Estimate the noise power if needed.
1669
1779
  if noise is None:
1670
- noise = np.mean(np.ravel(lVar), axis=0)
1780
+ noise = xp.mean(xp.reshape(lVar, (-1,)), axis=0)
1671
1781
 
1672
1782
  res = (im - lMean)
1673
1783
  res *= (1 - noise / lVar)
1674
1784
  res += lMean
1675
- out = np.where(lVar < noise, lMean, res)
1785
+ out = xp.where(lVar < noise, lMean, res)
1676
1786
 
1677
1787
  return out
1678
1788
 
@@ -1752,6 +1862,10 @@ def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
1752
1862
  >>> fig.show()
1753
1863
 
1754
1864
  """
1865
+ xp = array_namespace(in1, in2)
1866
+
1867
+ # NB: do work in NumPy, only convert the output
1868
+
1755
1869
  in1 = np.asarray(in1)
1756
1870
  in2 = np.asarray(in2)
1757
1871
 
@@ -1764,7 +1878,7 @@ def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
1764
1878
  val = _valfrommode(mode)
1765
1879
  bval = _bvalfromboundary(boundary)
1766
1880
  out = _sigtools._convolve2d(in1, in2, 1, val, bval, fillvalue)
1767
- return out
1881
+ return xp.asarray(out)
1768
1882
 
1769
1883
 
1770
1884
  def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
@@ -1849,6 +1963,7 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
1849
1963
  >>> fig.show()
1850
1964
 
1851
1965
  """
1966
+ xp = array_namespace(in1, in2)
1852
1967
  in1 = np.asarray(in1)
1853
1968
  in2 = np.asarray(in2)
1854
1969
 
@@ -1866,7 +1981,7 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
1866
1981
  if swapped_inputs:
1867
1982
  out = out[::-1, ::-1]
1868
1983
 
1869
- return out
1984
+ return xp.asarray(out)
1870
1985
 
1871
1986
 
1872
1987
  def medfilt2d(input, kernel_size=3):
@@ -1957,12 +2072,14 @@ def medfilt2d(input, kernel_size=3):
1957
2072
  # kernel numbers must be odd and not exceed original array dim
1958
2073
 
1959
2074
  """
2075
+ xp = array_namespace(input)
2076
+
1960
2077
  image = np.asarray(input)
1961
2078
 
1962
2079
  # checking dtype.type, rather than just dtype, is necessary for
1963
2080
  # excluding np.longdouble with MS Visual C.
1964
2081
  if image.dtype.type not in (np.ubyte, np.float32, np.float64):
1965
- return medfilt(image, kernel_size)
2082
+ return xp.asarray(medfilt(image, kernel_size))
1966
2083
 
1967
2084
  if kernel_size is None:
1968
2085
  kernel_size = [3] * 2
@@ -1974,7 +2091,8 @@ def medfilt2d(input, kernel_size=3):
1974
2091
  if (size % 2) != 1:
1975
2092
  raise ValueError("Each element of kernel_size should be odd.")
1976
2093
 
1977
- return _sigtools._medfilt2d(image, kernel_size)
2094
+ result_np = _sigtools._medfilt2d(image, kernel_size)
2095
+ return xp.asarray(result_np)
1978
2096
 
1979
2097
 
1980
2098
  def lfilter(b, a, x, axis=-1, zi=None):
@@ -2101,12 +2219,22 @@ def lfilter(b, a, x, axis=-1, zi=None):
2101
2219
  >>> plt.show()
2102
2220
 
2103
2221
  """
2222
+ try:
2223
+ xp = array_namespace(b, a, x, zi)
2224
+ except TypeError:
2225
+ # either in1 or in2 are object arrays
2226
+ xp = np_compat
2227
+
2228
+ if is_numpy(xp):
2229
+ _reject_objects(x, 'lfilter')
2230
+ _reject_objects(a, 'lfilter')
2231
+ _reject_objects(b, 'lfilter')
2232
+
2104
2233
  b = np.atleast_1d(b)
2105
2234
  a = np.atleast_1d(a)
2106
-
2107
- _reject_objects(x, 'lfilter')
2108
- _reject_objects(a, 'lfilter')
2109
- _reject_objects(b, 'lfilter')
2235
+ x = np.asarray(x)
2236
+ if zi is not None:
2237
+ zi = np.asarray(zi)
2110
2238
 
2111
2239
  if len(a) == 1:
2112
2240
  # This path only supports types fdgFDGO to mirror _linear_filter below.
@@ -2165,16 +2293,18 @@ def lfilter(b, a, x, axis=-1, zi=None):
2165
2293
  out = out_full[tuple(ind)]
2166
2294
 
2167
2295
  if zi is None:
2168
- return out
2296
+ return xp.asarray(out)
2169
2297
  else:
2170
2298
  ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None)
2171
2299
  zf = out_full[tuple(ind)]
2172
- return out, zf
2300
+ return xp.asarray(out), xp.asarray(zf)
2173
2301
  else:
2174
2302
  if zi is None:
2175
- return _sigtools._linear_filter(b, a, x, axis)
2303
+ result =_sigtools._linear_filter(b, a, x, axis)
2304
+ return xp.asarray(result)
2176
2305
  else:
2177
- return _sigtools._linear_filter(b, a, x, axis, zi)
2306
+ out, zf = _sigtools._linear_filter(b, a, x, axis, zi)
2307
+ return xp.asarray(out), xp.asarray(zf)
2178
2308
 
2179
2309
 
2180
2310
  def lfiltic(b, a, y, x=None):
@@ -2217,40 +2347,66 @@ def lfiltic(b, a, y, x=None):
2217
2347
  lfilter, lfilter_zi
2218
2348
 
2219
2349
  """
2220
- N = np.size(a) - 1
2221
- M = np.size(b) - 1
2350
+ try:
2351
+ xp = array_namespace(a, b, y, x)
2352
+ except TypeError:
2353
+ xp = np_compat
2354
+
2355
+ if is_numpy(xp):
2356
+ _reject_objects(a, 'lfiltic')
2357
+ _reject_objects(b, 'lfiltic')
2358
+ _reject_objects(y, 'lfiltic')
2359
+ if x is not None:
2360
+ _reject_objects(x, 'lfiltic')
2361
+
2362
+ a = xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp)
2363
+ b = xpx.atleast_nd(xp.asarray(b), ndim=1, xp=xp)
2364
+ if a.ndim > 1:
2365
+ raise ValueError('Filter coefficients `a` must be 1-D.')
2366
+ if b.ndim > 1:
2367
+ raise ValueError('Filter coefficients `b` must be 1-D.')
2368
+ N = a.shape[0] - 1
2369
+ M = b.shape[0] - 1
2222
2370
  K = max(M, N)
2223
- y = np.asarray(y)
2371
+ y = xp.asarray(y)
2372
+
2373
+ if N < 0:
2374
+ raise ValueError("There must be at least one `a` coefficient.")
2224
2375
 
2225
2376
  if x is None:
2226
- result_type = np.result_type(np.asarray(b), np.asarray(a), y)
2227
- if result_type.kind in 'bui':
2228
- result_type = np.float64
2229
- x = np.zeros(M, dtype=result_type)
2377
+ result_type = xp.result_type(b, a, y)
2378
+ if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
2379
+ result_type = xp.float64
2380
+ x = xp.zeros(M, dtype=result_type)
2230
2381
  else:
2231
- x = np.asarray(x)
2382
+ x = xp.asarray(x)
2232
2383
 
2233
- result_type = np.result_type(np.asarray(b), np.asarray(a), y, x)
2234
- if result_type.kind in 'bui':
2235
- result_type = np.float64
2236
- x = x.astype(result_type)
2384
+ result_type = xp.result_type(b, a, y, x)
2385
+ if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
2386
+ result_type = xp.float64
2387
+ x = xp.astype(x, result_type)
2237
2388
 
2238
- L = np.size(x)
2389
+ L = xp_size(x)
2239
2390
  if L < M:
2240
- x = np.r_[x, np.zeros(M - L)]
2391
+ x = xp.concat((x, xp.zeros(M - L)))
2241
2392
 
2242
- y = y.astype(result_type)
2243
- zi = np.zeros(K, result_type)
2393
+ y = xp.astype(y, result_type)
2394
+ zi = xp.zeros(K, dtype=result_type)
2244
2395
 
2245
- L = np.size(y)
2396
+ L = xp_size(y)
2246
2397
  if L < N:
2247
- y = np.r_[y, np.zeros(N - L)]
2398
+ y = xp.concat((y, xp.zeros(N - L)))
2248
2399
 
2249
2400
  for m in range(M):
2250
- zi[m] = np.sum(b[m + 1:] * x[:M - m], axis=0)
2401
+ zi[m] = xp.sum(b[m + 1:] * x[:M - m], axis=0)
2251
2402
 
2252
2403
  for m in range(N):
2253
- zi[m] -= np.sum(a[m + 1:] * y[:N - m], axis=0)
2404
+ zi[m] -= xp.sum(a[m + 1:] * y[:N - m], axis=0)
2405
+
2406
+ if a[0] != 1.:
2407
+ if a[0] == 0.:
2408
+ raise ValueError("First `a` filter coefficient must be non-zero.")
2409
+ zi /= a[0]
2254
2410
 
2255
2411
  return zi
2256
2412
 
@@ -2296,19 +2452,21 @@ def deconvolve(signal, divisor):
2296
2452
  array([ 0., 1., 0., 0., 1., 1., 0., 0.])
2297
2453
 
2298
2454
  """
2299
- num = np.atleast_1d(signal)
2300
- den = np.atleast_1d(divisor)
2455
+ xp = array_namespace(signal, divisor)
2456
+
2457
+ num = xpx.atleast_nd(xp.asarray(signal), ndim=1, xp=xp)
2458
+ den = xpx.atleast_nd(xp.asarray(divisor), ndim=1, xp=xp)
2301
2459
  if num.ndim > 1:
2302
2460
  raise ValueError("signal must be 1-D.")
2303
2461
  if den.ndim > 1:
2304
2462
  raise ValueError("divisor must be 1-D.")
2305
- N = len(num)
2306
- D = len(den)
2463
+ N = num.shape[0]
2464
+ D = den.shape[0]
2307
2465
  if D > N:
2308
2466
  quot = []
2309
2467
  rem = num
2310
2468
  else:
2311
- input = np.zeros(N - D + 1, float)
2469
+ input = xp.zeros(N - D + 1, dtype=xp.float64)
2312
2470
  input[0] = 1
2313
2471
  quot = lfilter(num, den, input)
2314
2472
  rem = num - convolve(den, quot, mode='full')
@@ -2408,16 +2566,19 @@ def hilbert(x, N=None, axis=-1):
2408
2566
  >>> plt.show()
2409
2567
 
2410
2568
  """
2411
- x = np.asarray(x)
2412
- if np.iscomplexobj(x):
2569
+ xp = array_namespace(x)
2570
+
2571
+ x = xp.asarray(x)
2572
+ if xp.isdtype(x.dtype, 'complex floating'):
2413
2573
  raise ValueError("x must be real.")
2574
+
2414
2575
  if N is None:
2415
2576
  N = x.shape[axis]
2416
2577
  if N <= 0:
2417
2578
  raise ValueError("N must be positive.")
2418
2579
 
2419
2580
  Xf = sp_fft.fft(x, N, axis=axis)
2420
- h = np.zeros(N, dtype=Xf.dtype)
2581
+ h = xp.zeros(N, dtype=Xf.dtype)
2421
2582
  if N % 2 == 0:
2422
2583
  h[0] = h[N // 2] = 1
2423
2584
  h[1:N // 2] = 2
@@ -2426,7 +2587,7 @@ def hilbert(x, N=None, axis=-1):
2426
2587
  h[1:(N + 1) // 2] = 2
2427
2588
 
2428
2589
  if x.ndim > 1:
2429
- ind = [np.newaxis] * x.ndim
2590
+ ind = [xp.newaxis] * x.ndim
2430
2591
  ind[axis] = slice(None)
2431
2592
  h = h[tuple(ind)]
2432
2593
  x = sp_fft.ifft(Xf * h, axis=axis)
@@ -2455,24 +2616,26 @@ def hilbert2(x, N=None):
2455
2616
  https://en.wikipedia.org/wiki/Analytic_signal
2456
2617
 
2457
2618
  """
2458
- x = np.atleast_2d(x)
2619
+ xp = array_namespace(x)
2620
+ x = xpx.atleast_nd(xp.asarray(x), ndim=2, xp=xp)
2459
2621
  if x.ndim > 2:
2460
2622
  raise ValueError("x must be 2-D.")
2461
- if np.iscomplexobj(x):
2623
+ if xp.isdtype(x.dtype, 'complex floating'):
2462
2624
  raise ValueError("x must be real.")
2625
+
2463
2626
  if N is None:
2464
2627
  N = x.shape
2465
2628
  elif isinstance(N, int):
2466
2629
  if N <= 0:
2467
2630
  raise ValueError("N must be positive.")
2468
2631
  N = (N, N)
2469
- elif len(N) != 2 or np.any(np.asarray(N) <= 0):
2632
+ elif len(N) != 2 or xp.any(xp.asarray(N) <= 0):
2470
2633
  raise ValueError("When given as a tuple, N must hold exactly "
2471
2634
  "two positive integers")
2472
2635
 
2473
2636
  Xf = sp_fft.fft2(x, N, axes=(0, 1))
2474
- h1 = np.zeros(N[0], dtype=Xf.dtype)
2475
- h2 = np.zeros(N[1], dtype=Xf.dtype)
2637
+ h1 = xp.zeros(N[0], dtype=Xf.dtype)
2638
+ h2 = xp.zeros(N[1], dtype=Xf.dtype)
2476
2639
  for h in (h1, h2):
2477
2640
  N1 = h.shape[0]
2478
2641
  if N1 % 2 == 0:
@@ -2482,19 +2645,19 @@ def hilbert2(x, N=None):
2482
2645
  h[0] = 1
2483
2646
  h[1:(N1 + 1) // 2] = 2
2484
2647
 
2485
- h = h1[:, np.newaxis] * h2[np.newaxis, :]
2648
+ h = h1[:, xp.newaxis] * h2[xp.newaxis, :]
2486
2649
  k = x.ndim
2487
2650
  while k > 2:
2488
- h = h[:, np.newaxis]
2651
+ h = h[:, xp.newaxis]
2489
2652
  k -= 1
2490
2653
  x = sp_fft.ifft2(Xf * h, axes=(0, 1))
2491
2654
  return x
2492
2655
 
2493
2656
 
2494
- def envelope(z: np.ndarray, bp_in: tuple[int | None, int | None] = (1, None), *,
2657
+ def envelope(z, bp_in: tuple[int | None, int | None] = (1, None), *,
2495
2658
  n_out: int | None = None, squared: bool = False,
2496
2659
  residual: Literal['lowpass', 'all', None] = 'lowpass',
2497
- axis: int = -1) -> np.ndarray:
2660
+ axis: int = -1):
2498
2661
  r"""Compute the envelope of a real- or complex-valued signal.
2499
2662
 
2500
2663
  Parameters
@@ -2709,6 +2872,7 @@ def envelope(z: np.ndarray, bp_in: tuple[int | None, int | None] = (1, None), *,
2709
2872
  >>> fg0.subplots_adjust(left=0.08, right=0.97, wspace=0.15)
2710
2873
  >>> plt.show()
2711
2874
  """
2875
+ xp = array_namespace(z)
2712
2876
  if not (-z.ndim <= axis < z.ndim):
2713
2877
  raise ValueError(f"Invalid parameter {axis=} for {z.shape=}!")
2714
2878
  if not (z.shape[axis] > 0):
@@ -2731,12 +2895,13 @@ def envelope(z: np.ndarray, bp_in: tuple[int | None, int | None] = (1, None), *,
2731
2895
  f"for n={z.shape[axis]=} and {bp_in=}!")
2732
2896
 
2733
2897
  # moving active axis to end allows to use `...` for indexing:
2734
- z = np.moveaxis(z, axis, -1)
2898
+ z = xp.moveaxis(z, axis, -1)
2735
2899
 
2736
- if np.iscomplexobj(z):
2900
+ if xp.isdtype(z.dtype, 'complex floating'):
2737
2901
  Z = sp_fft.fft(z)
2738
2902
  else: # avoid calculating negative frequency bins for real signals:
2739
- Z = np.zeros_like(z, dtype=sp_fft.rfft(z.flat[:1]).dtype)
2903
+ dt = sp_fft.rfft(z[..., :1]).dtype
2904
+ Z = xp.zeros_like(z, dtype=dt)
2740
2905
  Z[..., :n//2 + 1] = sp_fft.rfft(z)
2741
2906
  if bp.start > 0: # make signal analytic within bp_in band:
2742
2907
  Z[..., bp] *= 2
@@ -2748,8 +2913,8 @@ def envelope(z: np.ndarray, bp_in: tuple[int | None, int | None] = (1, None), *,
2748
2913
  bp_shift = slice(bp.start + n//2, bp.stop + n//2)
2749
2914
  z_bb = sp_fft.ifft(sp_fft.fftshift(Z, axes=-1)[..., bp_shift], n=n_out) * fak
2750
2915
 
2751
- z_env = np.abs(z_bb) if not squared else z_bb.real ** 2 + z_bb.imag ** 2
2752
- z_env = np.moveaxis(z_env, -1, axis)
2916
+ z_env = xp.abs(z_bb) if not squared else xp.real(z_bb) ** 2 + xp.imag(z_bb) ** 2
2917
+ z_env = xp.moveaxis(z_env, -1, axis)
2753
2918
 
2754
2919
  # Calculate the residual from the input bandpass filter:
2755
2920
  if residual is None:
@@ -2764,9 +2929,14 @@ def envelope(z: np.ndarray, bp_in: tuple[int | None, int | None] = (1, None), *,
2764
2929
  else:
2765
2930
  Z[..., bp.start:], Z[..., 0:(n + 1) // 2] = 0, 0
2766
2931
 
2767
- z_res = fak * (sp_fft.ifft(Z, n=n_out) if np.iscomplexobj(z) else
2768
- sp_fft.irfft(Z, n=n_out))
2769
- return np.stack((z_env, np.moveaxis(z_res, -1, axis)), axis=0)
2932
+ if xp.isdtype(z.dtype, 'complex floating'): # resample accounts for unpaired bins:
2933
+ z_res = resample(Z, n_out, axis=-1, domain='freq') # ifft() with corrections
2934
+ else: # account for unpaired bin at m//2 before doing irfft():
2935
+ if n_out != n and (m := min(n, n_out)) % 2 == 0:
2936
+ Z[..., m//2] *= 2 if n_out < n else 0.5
2937
+ z_res = fak * sp_fft.irfft(Z, n=n_out)
2938
+ return xp.stack((z_env, xp.moveaxis(z_res, -1, axis)), axis=0)
2939
+
2770
2940
 
2771
2941
  def _cmplx_sort(p):
2772
2942
  """Sort roots based on magnitude.
@@ -3338,210 +3508,301 @@ def invresz(r, p, k, tol=1e-3, rtype='avg'):
3338
3508
 
3339
3509
 
3340
3510
  def resample(x, num, t=None, axis=0, window=None, domain='time'):
3341
- """
3342
- Resample `x` to `num` samples using Fourier method along the given axis.
3511
+ r"""Resample `x` to `num` samples using the Fourier method along the given `axis`.
3343
3512
 
3344
- The resampled signal starts at the same value as `x` but is sampled
3345
- with a spacing of ``len(x) / num * (spacing of x)``. Because a
3346
- Fourier method is used, the signal is assumed to be periodic.
3513
+ The resampling is performed by shortening or zero-padding the FFT of `x`. This has
3514
+ the advantages of providing an ideal antialiasing filter and allowing arbitrary
3515
+ up- or down-sampling ratios. The main drawback is the requirement of assuming `x`
3516
+ to be a periodic signal.
3347
3517
 
3348
3518
  Parameters
3349
3519
  ----------
3350
3520
  x : array_like
3351
- The data to be resampled.
3521
+ The input signal made up of equidistant samples. If `x` is a multidimensional
3522
+ array, the parameter `axis` specifies the time/frequency axis. It is assumed
3523
+ here that ``n_x = x.shape[axis]`` specifies the number of samples and ``T`` the
3524
+ sampling interval.
3352
3525
  num : int
3353
- The number of samples in the resampled signal.
3526
+ The number of samples of the resampled output signal. It may be larger or
3527
+ smaller than ``n_x``.
3354
3528
  t : array_like, optional
3355
- If `t` is given, it is assumed to be the equally spaced sample
3356
- positions associated with the signal data in `x`.
3529
+ If `t` is not ``None``, then the timestamps of the resampled signal are also
3530
+ returned. `t` must contain at least the first two timestamps of the input
3531
+ signal `x` (all others are ignored). The timestamps of the output signal are
3532
+ determined by ``t[0] + T * n_x / num * np.arange(num)`` with
3533
+ ``T = t[1] - t[0]``. Default is ``None``.
3357
3534
  axis : int, optional
3358
- The axis of `x` that is resampled. Default is 0.
3535
+ The time/frequency axis of `x` along which the resampling take place.
3536
+ The Default is 0.
3359
3537
  window : array_like, callable, string, float, or tuple, optional
3360
- Specifies the window applied to the signal in the Fourier
3361
- domain. See below for details.
3362
- domain : string, optional
3363
- A string indicating the domain of the input `x`:
3364
- ``time`` Consider the input `x` as time-domain (Default),
3365
- ``freq`` Consider the input `x` as frequency-domain.
3538
+ If not ``None``, it specifies a filter in the Fourier domain, which is applied
3539
+ before resampling. I.e., the FFT ``X`` of `x` is calculated by
3540
+ ``X = W * fft(x, axis=axis)``. ``W`` may be interpreted as a spectral windowing
3541
+ function ``W(f_X)`` which consumes the frequencies ``f_X = fftfreq(n_x, T)``.
3542
+
3543
+ If `window` is a 1d array of length ``n_x`` then ``W=window``.
3544
+ If `window` is a callable then ``W = window(f_X)``.
3545
+ Otherwise, `window` is passed to `~scipy.signal.get_window`, i.e.,
3546
+ ``W = fftshift(signal.get_window(window, n_x))``. Default is ``None``.
3547
+
3548
+ domain : 'time' | 'freq', optional
3549
+ If set to ``'time'`` (default) then an FFT is applied to `x`, otherwise
3550
+ (``'freq'``) it is asssmued that an FFT was already applied, i.e.,
3551
+ ``x = fft(x_t, axis=axis)`` with ``x_t`` being the input signal in the time
3552
+ domain.
3366
3553
 
3367
3554
  Returns
3368
3555
  -------
3369
- resampled_x or (resampled_x, resampled_t)
3370
- Either the resampled array, or, if `t` was given, a tuple
3371
- containing the resampled array and the corresponding resampled
3372
- positions.
3556
+ x_r : ndarray
3557
+ The resampled signal made up of `num` samples and sampling interval
3558
+ ``T * n_x / num``.
3559
+ t_r : ndarray, optional
3560
+ The `num` equidistant timestamps of `x_r`.
3561
+ This is only returned if paramater `t` is not ``None``.
3373
3562
 
3374
3563
  See Also
3375
3564
  --------
3376
- decimate : Downsample the signal after applying an FIR or IIR filter.
3377
- resample_poly : Resample using polyphase filtering and an FIR filter.
3565
+ decimate : Downsample a (periodic/non-periodic) signal after applying an FIR
3566
+ or IIR filter.
3567
+ resample_poly : Resample a (periodic/non-periodic) signal using polyphase filtering
3568
+ and an FIR filter.
3378
3569
 
3379
3570
  Notes
3380
3571
  -----
3381
- The argument `window` controls a Fourier-domain window that tapers
3382
- the Fourier spectrum before zero-padding to alleviate ringing in
3383
- the resampled values for sampled signals you didn't intend to be
3384
- interpreted as band-limited.
3385
-
3386
- If `window` is a function, then it is called with a vector of inputs
3387
- indicating the frequency bins (i.e. fftfreq(x.shape[axis]) ).
3388
-
3389
- If `window` is an array of the same length as `x.shape[axis]` it is
3390
- assumed to be the window to be applied directly in the Fourier
3391
- domain (with dc and low-frequency first).
3392
-
3393
- For any other type of `window`, the function `scipy.signal.get_window`
3394
- is called to generate the window.
3395
-
3396
- The first sample of the returned vector is the same as the first
3397
- sample of the input vector. The spacing between samples is changed
3398
- from ``dx`` to ``dx * len(x) / num``.
3399
-
3400
- If `t` is not None, then it is used solely to calculate the resampled
3401
- positions `resampled_t`
3402
-
3403
- As noted, `resample` uses FFT transformations, which can be very
3404
- slow if the number of input or output samples is large and prime;
3405
- see :func:`~scipy.fft.fft`. In such cases, it can be faster to first downsample
3406
- a signal of length ``n`` with :func:`~scipy.signal.resample_poly` by a factor of
3407
- ``n//num`` before using `resample`. Note that this approach changes the
3408
- characteristics of the antialiasing filter.
3572
+ This function uses the more efficient one-sided FFT, i.e. `~scipy.fft.rfft` /
3573
+ `~scipy.fft.irfft`, if `x` is real-valued and in the time domain.
3574
+ Else, the two-sided FFT, i.e., `~scipy.fft.fft` / `~scipy.fft.ifft`, is used
3575
+ (all FFT functions are taken from the `scipy.fft` module).
3576
+
3577
+ If a `window` is applied to a real-valued `x`, the one-sided spectral windowing
3578
+ function is determined by taking the average of the negative and the positive
3579
+ frequency component. This ensures that real-valued signals and complex signals with
3580
+ zero imaginary part are treated identically. I.e., passing `x` or passing
3581
+ ``x.astype(np.complex128)`` produce the same numeric result.
3582
+
3583
+ If the number of input or output samples are prime or have few prime factors, this
3584
+ function may be slow due to utilizing FFTs. Consult `~scipy.fft.prev_fast_len` and
3585
+ `~scipy.fft.next_fast_len` for determining efficient signals lengths.
3586
+ Alternatively, utilizing `resample_poly` to calculate an intermediate signal (as
3587
+ illustrated in the example below) can result in significant speed increases.
3588
+
3589
+ `resample` is intended to be used for periodic signals with equidistant sampling
3590
+ intervals. For non-periodic signals, `resample_poly` may be a better choice.
3591
+ Consult the `scipy.interpolate` module for methods of resampling signals with
3592
+ non-constant sampling intervals.
3409
3593
 
3410
3594
  Examples
3411
3595
  --------
3412
- Note that the end of the resampled data rises to meet the first
3413
- sample of the next cycle:
3596
+ The following example depicts a signal being up-sampled from 20 samples to 100
3597
+ samples. The ringing at the beginning of the up-sampled signal is due to
3598
+ interpreting the signal being periodic. The red square in the plot illustrates that
3599
+ periodictiy by showing the first sample of the next cycle of the signal.
3414
3600
 
3415
3601
  >>> import numpy as np
3416
- >>> from scipy import signal
3417
-
3418
- >>> x = np.linspace(0, 10, 20, endpoint=False)
3419
- >>> y = np.cos(-x**2/6.0)
3420
- >>> f = signal.resample(y, 100)
3421
- >>> xnew = np.linspace(0, 10, 100, endpoint=False)
3422
-
3423
3602
  >>> import matplotlib.pyplot as plt
3424
- >>> plt.plot(x, y, 'go-', xnew, f, '.-', 10, y[0], 'ro')
3425
- >>> plt.legend(['data', 'resampled'], loc='best')
3603
+ >>> from scipy.signal import resample
3604
+ ...
3605
+ >>> n0, n1 = 20, 100 # number of samples
3606
+ >>> t0 = np.linspace(0, 10, n0, endpoint=False) # input time stamps
3607
+ >>> x0 = np.cos(-t0**2/6) # input signal
3608
+ ...
3609
+ >>> x1 = resample(x0, n1) # resampled signal
3610
+ >>> t1 = np.linspace(0, 10, n1, endpoint=False) # timestamps of x1
3611
+ ...
3612
+ >>> fig0, ax0 = plt.subplots(1, 1, tight_layout=True)
3613
+ >>> ax0.set_title(f"Resampling $x(t)$ from {n0} samples to {n1} samples")
3614
+ >>> ax0.set(xlabel="Time $t$", ylabel="Amplitude $x(t)$")
3615
+ >>> ax0.plot(t1, x1, '.-', alpha=.5, label=f"Resampled")
3616
+ >>> ax0.plot(t0, x0, 'o-', alpha=.5, label="Original")
3617
+ >>> ax0.plot(10, x0[0], 'rs', alpha=.5, label="Next Cycle")
3618
+ >>> ax0.legend(loc='best')
3619
+ >>> ax0.grid(True)
3426
3620
  >>> plt.show()
3427
3621
 
3428
- Consider the following signal ``y`` where ``len(y)`` is a large prime number:
3429
-
3430
- >>> N = 55949
3431
- >>> freq = 100
3432
- >>> x = np.linspace(0, 1, N)
3433
- >>> y = np.cos(2 * np.pi * freq * x)
3434
-
3435
- Due to ``N`` being prime,
3622
+ The following example compares this function with a naive `~scipy.fft.rfft` /
3623
+ `~scipy.fft.irfft` combination: An input signal with a sampling interval of one
3624
+ second is upsampled by a factor of eight. The first figure depicts an odd number of
3625
+ input samples whereas the second figure an even number. The upper subplots show the
3626
+ signals over time: The input samples are marked by large green dots, the upsampled
3627
+ signals by a continuous and a dashed line. The lower subplots show the magnitude
3628
+ spectrum: The FFT values of the input are depicted by large green dots, which lie
3629
+ in the frequency interval [-0.5, 0.5] Hz, whereas the frequency interval of the
3630
+ upsampled signal is [-4, 4] Hz. The continuous green line depicts the upsampled
3631
+ spectrum without antialiasing filter, which is a periodic continuation of the input
3632
+ spectrum. The blue x's and orange dots depict the FFT values of the signal created
3633
+ by the naive approach as well as this function's result.
3436
3634
 
3437
- >>> num = 5000
3438
- >>> f = signal.resample(signal.resample_poly(y, 1, N // num), num)
3635
+ >>> import matplotlib.pyplot as plt
3636
+ >>> import numpy as np
3637
+ >>> from scipy.fft import fftshift, fftfreq, fft, rfft, irfft
3638
+ >>> from scipy.signal import resample, resample_poly
3639
+ ...
3640
+ >>> fac, T0, T1 = 8, 1, 1/8 # upsampling factor and sampling intervals
3641
+ >>> for n0 in (15, 16): # number of samples of input signal
3642
+ ... n1 = fac * n0 # number of samples of upsampled signal
3643
+ ... t0, t1 = T0 * np.arange(n0), T1 * np.arange(n1) # time stamps
3644
+ ... x0 = np.zeros(n0) # input signal has two non-zero sample values
3645
+ ... x0[n0//2], x0[n0//2+1] = n0 // 2, -(n0 // 2)
3646
+ ...
3647
+ ... x1n = irfft(rfft(x0), n=n1) * n1 / n0 # naive resampling
3648
+ ... x1r = resample(x0, n1) # resample signal
3649
+ ...
3650
+ ... # Determine magnitude spectrum:
3651
+ ... x0_up = np.zeros_like(x1r) # upsampling without antialiasing filter
3652
+ ... x0_up[::n1 // n0] = x0
3653
+ ... X0, X0_up = (fftshift(fft(x_)) / n0 for x_ in (x0, x0_up))
3654
+ ... XX1 = (fftshift(fft(x_)) / n1 for x_ in (x1n, x1r))
3655
+ ... f0, f1 = fftshift(fftfreq(n0, T0)), fftshift(fftfreq(n1, T1)) # frequencies
3656
+ ... df = f0[1] - f0[0] # frequency resolution
3657
+ ...
3658
+ ... fig, (ax0, ax1) = plt.subplots(2, 1, layout='constrained', figsize=(5, 4))
3659
+ ... ax0.set_title(rf"Upsampling ${fac}\times$ from {n0} to {n1} samples")
3660
+ ... ax0.set(xlabel="Time $t$ in seconds", ylabel="Amplitude $x(t)$",
3661
+ ... xlim=(0, n1*T1))
3662
+ ... ax0.step(t0, x0, 'C2o-', where='post', alpha=.3, linewidth=2,
3663
+ ... label="$x_0(t)$ / $X_0(f)$")
3664
+ ... for x_, l_ in zip((x1n, x1r), ('C0--', 'C1-')):
3665
+ ... ax0.plot(t1, x_, l_, alpha=.5, label=None)
3666
+ ... ax0.grid()
3667
+ ... ax1.set(xlabel=rf"Frequency $f$ in hertz ($\Delta f = {df*1e3:.1f}\,$mHz)",
3668
+ ... ylabel="Magnitude $|X(f)|$", xlim=(-0.7, 0.7))
3669
+ ... ax1.axvspan(0.5/T0, f1[-1], color='gray', alpha=.2)
3670
+ ... ax1.axvspan(f1[0], -0.5/T0, color='gray', alpha=.2)
3671
+ ... ax1.plot(f1, abs(X0_up), 'C2-', f0, abs(X0), 'C2o', alpha=.3, linewidth=2)
3672
+ ... for X_, n_, l_ in zip(XX1, ("naive", "resample"), ('C0x--', 'C1.-')):
3673
+ ... ax1.plot(f1, abs(X_), l_, alpha=.5, label=n_)
3674
+ ... ax1.grid()
3675
+ ... fig.legend(loc='outside lower center', ncols=4)
3676
+ >>> plt.show()
3439
3677
 
3440
- runs significantly faster than
3678
+ The first figure shows that upsampling an odd number of samples produces identical
3679
+ results. The second figure illustrates that the signal produced with the naive
3680
+ approach (dashed blue line) from an even number of samples does not touch all
3681
+ original samples. This deviation is due to `resample` correctly treating unpaired
3682
+ frequency bins. I.e., the input `x1` has a bin pair ±0.5 Hz, whereas the output has
3683
+ only one unpaired bin at -0.5 Hz, which demands rescaling of that bin pair.
3684
+ Generally, special treatment is required if ``n_x != num`` and ``min(n_x, num)`` is
3685
+ even. If the bin values at `±m` are zero, obviously, no special treatment is
3686
+ needed. Consult the source code of `resample` for details.
3687
+
3688
+ The final example shows how to utilize `resample_poly` to speed up the
3689
+ down-sampling: The input signal a non-zero value at :math:`t=0` and is downsampled
3690
+ from 19937 to 128 samples. Since 19937 is prime, the FFT is expected to be slow. To
3691
+ speed matters up, `resample_poly` is used to downsample first by a factor of ``n0
3692
+ // n1 = 155`` and then pass the result to `resample`. Two parameterization of
3693
+ `resample_poly` are used: Passing ``padtype='wrap'`` treats the input as being
3694
+ periodic wheras the default parametrization performs zero-padding. The upper
3695
+ subplot shows the resulting signals over time whereas the lower subplot depicts the
3696
+ resulting one-sided magnitude spectra.
3441
3697
 
3442
- >>> f = signal.resample(y, num)
3698
+ >>> import matplotlib.pyplot as plt
3699
+ >>> import numpy as np
3700
+ >>> from scipy.fft import rfftfreq, rfft
3701
+ >>> from scipy.signal import resample, resample_poly
3702
+ ...
3703
+ >>> n0 = 19937 # number of input samples - prime
3704
+ >>> n1 = 128 # number of output samples - fast FFT length
3705
+ >>> T0, T1 = 1/n0, 1/n1 # sampling intervals
3706
+ >>> t0, t1 = np.arange(n0)*T0, np.arange(n1)*T1 # time stamps
3707
+ ...
3708
+ >>> x0 = np.zeros(n0) # Input has one non-zero sample
3709
+ >>> x0[0] = n0
3710
+ >>>
3711
+ >>> x1r = resample(x0, n1) # slow due to n0 being prime
3712
+ >>> # This is faster:
3713
+ >>> x1p = resample(resample_poly(x0, 1, n0 // n1, padtype='wrap'), n1) # periodic
3714
+ >>> x2p = resample(resample_poly(x0, 1, n0 // n1), n1) # with zero-padding
3715
+ ...
3716
+ >>> X0 = rfft(x0) / n0
3717
+ >>> X1r, X1p, X2p = rfft(x1r) / n1, rfft(x1p) / n1, rfft(x2p) / n1
3718
+ >>> f0, f1 = rfftfreq(n0, T0), rfftfreq(n1, T1)
3719
+ ...
3720
+ >>> fig, (ax0, ax1) = plt.subplots(2, 1, layout='constrained', figsize=(5, 4))
3721
+ >>> ax0.set_title(f"Dowsampled Impulse response (from {n0} to {n1} samples)")
3722
+ >>> ax0.set(xlabel="Time $t$ in seconds", ylabel="Amplitude $x(t)$", xlim=(-T1, 1))
3723
+ >>> for x_ in (x1r, x1p, x2p):
3724
+ ... ax0.plot(t1, x_, alpha=.5)
3725
+ >>> ax0.grid()
3726
+ >>> ax1.set(xlabel=rf"Frequency $f$ in hertz ($\Delta f = {f1[1]}\,$Hz)",
3727
+ ... ylabel="Magnitude $|X(f)|$", xlim=(0, 0.55/T1))
3728
+ >>> ax1.axvspan(0.5/T1, f0[-1], color='gray', alpha=.2)
3729
+ >>> ax1.plot(f1, abs(X1r), 'C0.-', alpha=.5, label="resample")
3730
+ >>> ax1.plot(f1, abs(X1p), 'C1.-', alpha=.5, label="resample_poly(padtype='wrap')")
3731
+ >>> ax1.plot(f1, abs(X2p), 'C2x-', alpha=.5, label="resample_poly")
3732
+ >>> ax1.grid()
3733
+ >>> fig.legend(loc='outside lower center', ncols=2)
3734
+ >>> plt.show()
3735
+
3736
+ The plots show that the results of the "pure" `resample` and the usage of the
3737
+ default parameters of `resample_poly` agree well. The periodic padding of
3738
+ `resample_poly` (``padtype='wrap'``) on the other hand produces significant
3739
+ deviations. This is caused by the disconiuity at the beginning of the signal, for
3740
+ which the default filter of `resample_poly` is not suited well. This example
3741
+ illustrates that for some use cases, adpating the `resample_poly` parameters may
3742
+ be beneficial. `resample` has a big advantage in this regard: It uses the ideal
3743
+ antialiasing filter with the maximum bandwidth by default.
3744
+
3745
+ Note that the doubled spectral magnitude at the Nyqist frequency of 64 Hz is due the
3746
+ even number of ``n1=128`` output samples, which requires a special treatment as
3747
+ discussed in the previous example.
3443
3748
  """
3444
-
3445
3749
  if domain not in ('time', 'freq'):
3446
- raise ValueError("Acceptable domain flags are 'time' or"
3447
- f" 'freq', not domain={domain}")
3448
-
3449
- x = np.asarray(x)
3450
- Nx = x.shape[axis]
3451
-
3452
- # Check if we can use faster real FFT
3453
- real_input = np.isrealobj(x)
3454
-
3455
- if domain == 'time':
3456
- # Forward transform
3457
- if real_input:
3458
- X = sp_fft.rfft(x, axis=axis)
3459
- else: # Full complex FFT
3460
- X = sp_fft.fft(x, axis=axis)
3461
- else: # domain == 'freq'
3462
- X = x
3463
-
3464
- # Apply window to spectrum
3465
- if window is not None:
3466
- if callable(window):
3467
- W = window(sp_fft.fftfreq(Nx))
3468
- elif isinstance(window, np.ndarray):
3469
- if window.shape != (Nx,):
3470
- raise ValueError('window must have the same length as data')
3471
- W = window
3472
- else:
3473
- W = sp_fft.ifftshift(get_window(window, Nx))
3474
-
3475
- newshape_W = [1] * x.ndim
3476
- newshape_W[axis] = X.shape[axis]
3477
- if real_input:
3478
- # Fold the window back on itself to mimic complex behavior
3479
- W_real = W.copy()
3480
- W_real[1:] += W_real[-1:0:-1]
3481
- W_real[1:] *= 0.5
3482
- X *= W_real[:newshape_W[axis]].reshape(newshape_W)
3483
- else:
3484
- X *= W.reshape(newshape_W)
3485
-
3486
- # Copy each half of the original spectrum to the output spectrum, either
3487
- # truncating high frequencies (downsampling) or zero-padding them
3488
- # (upsampling)
3489
-
3490
- # Placeholder array for output spectrum
3491
- newshape = list(x.shape)
3492
- if real_input:
3493
- newshape[axis] = num // 2 + 1
3750
+ raise ValueError(f"Parameter {domain=} not in ('time', 'freq')!")
3751
+
3752
+ xp = array_namespace(x, t)
3753
+ x = xp.asarray(x)
3754
+ if x.ndim > 1: # moving active axis to end allows to use `...` in indexing:
3755
+ x = xp.moveaxis(x, axis, -1)
3756
+ n_x = x.shape[-1] # number of samples along the time/frequency axis
3757
+ s_fac = n_x / num # scaling factor represents sample interval dilatation
3758
+ m = min(num, n_x) # number of relevant frequency bins
3759
+ m2 = m // 2 + 1 # number of relevant frequency bins of a one-sided FFT
3760
+
3761
+ if window is None: # Determine spectral windowing function:
3762
+ W = None
3763
+ elif callable(window):
3764
+ W = window(sp_fft.fftfreq(n_x))
3765
+ elif hasattr(window, 'shape'): # must be an array object
3766
+ if window.shape != (n_x,):
3767
+ raise ValueError(f"{window.shape=} != ({n_x},), i.e., window length " +
3768
+ "is not equal to number of frequency bins!")
3769
+ W = xp.asarray(window, copy=True) # prevent modifying the function parameters
3494
3770
  else:
3495
- newshape[axis] = num
3496
- Y = np.zeros(newshape, X.dtype)
3497
-
3498
- # Copy positive frequency components (and Nyquist, if present)
3499
- N = min(num, Nx)
3500
- nyq = N // 2 + 1 # Slice index that includes Nyquist if present
3501
- sl = [slice(None)] * x.ndim
3502
- sl[axis] = slice(0, nyq)
3503
- Y[tuple(sl)] = X[tuple(sl)]
3504
- if not real_input:
3505
- # Copy negative frequency components
3506
- if N > 2: # (slice expression doesn't collapse to empty array)
3507
- sl[axis] = slice(nyq - N, None)
3508
- Y[tuple(sl)] = X[tuple(sl)]
3509
-
3510
- # Split/join Nyquist component(s) if present
3511
- # So far we have set Y[+N/2]=X[+N/2]
3512
- if N % 2 == 0:
3513
- if num < Nx: # downsampling
3514
- if real_input:
3515
- sl[axis] = slice(N//2, N//2 + 1)
3516
- Y[tuple(sl)] *= 2.
3517
- else:
3518
- # select the component of Y at frequency +N/2,
3519
- # add the component of X at -N/2
3520
- sl[axis] = slice(-N//2, -N//2 + 1)
3521
- Y[tuple(sl)] += X[tuple(sl)]
3522
- elif Nx < num: # upsampling
3523
- # select the component at frequency +N/2 and halve it
3524
- sl[axis] = slice(N//2, N//2 + 1)
3525
- Y[tuple(sl)] *= 0.5
3526
- if not real_input:
3527
- temp = Y[tuple(sl)]
3528
- # set the component at -N/2 equal to the component at +N/2
3529
- sl[axis] = slice(num-N//2, num-N//2 + 1)
3530
- Y[tuple(sl)] = temp
3531
-
3532
- # Inverse transform
3533
- if real_input:
3534
- y = sp_fft.irfft(Y, num, axis=axis)
3535
- else:
3536
- y = sp_fft.ifft(Y, axis=axis, overwrite_x=True)
3537
-
3538
- y *= (float(num) / float(Nx))
3539
-
3540
- if t is None:
3541
- return y
3542
- else:
3543
- new_t = np.arange(0, num) * (t[1] - t[0]) * Nx / float(num) + t[0]
3544
- return y, new_t
3771
+ W = sp_fft.fftshift(get_window(window, n_x, xp=xp))
3772
+ W = xp.astype(W, xp_default_dtype(xp)) # get_window always returns float64
3773
+
3774
+ if domain == 'time' and not xp.isdtype(x.dtype, 'complex floating'): # use rfft():
3775
+ X = sp_fft.rfft(x)
3776
+ if W is not None: # fold window, i.e., W1[l] = (W[l] + W[-l]) / 2 for l > 0
3777
+ n_X = X.shape[-1]
3778
+ W[1:n_X] += xp.flip(W[-n_X+1:]) #W[:-n_X:-1]
3779
+ W[1:n_X] /= 2
3780
+ X *= W[:n_X] # apply window
3781
+ X = X[..., :m2] # extract relevant data
3782
+ if m % 2 == 0 and num != n_x: # Account for unpaired bin at m//2:
3783
+ X[..., m//2] *= 2 if num < n_x else 0.5
3784
+ x_r = sp_fft.irfft(X / s_fac, n=num, overwrite_x=True)
3785
+ else: # use standard two-sided FFT:
3786
+ X = sp_fft.fft(x) if domain == 'time' else x
3787
+ if W is not None:
3788
+ X = X * W # writing X *= W could modify parameter x
3789
+ Y = xp.zeros(X.shape[:-1] + (num,), dtype=X.dtype)
3790
+ Y[..., :m2] = X[..., :m2] # copy part up to Nyquist frequency
3791
+ if m2 < m: # == m > 2
3792
+ Y[..., m2-m:] = X[..., m2-m:] # copy negative frequency part
3793
+ if m % 2 == 0: # Account for unpaired bin at m//2:
3794
+ if num < n_x: # down-sampling: unite bin pair into one unpaired bin
3795
+ Y[..., -m//2] += X[..., -m//2]
3796
+ elif n_x < num: # up-sampling: split unpaired bin into bin pair
3797
+ Y[..., m//2] /= 2
3798
+ Y[..., num-m//2] = Y[..., m//2]
3799
+ x_r = sp_fft.ifft(Y / s_fac, n=num, overwrite_x=True)
3800
+
3801
+ if x_r.ndim > 1: # moving active axis back to original position:
3802
+ x_r = xp.moveaxis(x_r, -1, axis)
3803
+ if t is not None:
3804
+ return x_r, t[0] + (t[1] - t[0]) * s_fac * xp.arange(num)
3805
+ return x_r
3545
3806
 
3546
3807
 
3547
3808
  def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
@@ -3668,7 +3929,9 @@ def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
3668
3929
  >>> plt.show()
3669
3930
 
3670
3931
  """
3671
- x = np.asarray(x)
3932
+ xp = array_namespace(x)
3933
+
3934
+ x = xp.asarray(x)
3672
3935
  if up != int(up):
3673
3936
  raise ValueError("up must be an integer")
3674
3937
  if down != int(down):
@@ -3687,31 +3950,29 @@ def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
3687
3950
  up //= g_
3688
3951
  down //= g_
3689
3952
  if up == down == 1:
3690
- return x.copy()
3953
+ return xp.asarray(x, copy=True)
3691
3954
  n_in = x.shape[axis]
3692
3955
  n_out = n_in * up
3693
3956
  n_out = n_out // down + bool(n_out % down)
3694
3957
 
3695
- if isinstance(window, (list | np.ndarray)):
3696
- window = np.array(window) # use array to force a copy (we modify it)
3958
+ if isinstance(window, list) or is_array_api_obj(window):
3959
+ window = xp.asarray(window, copy=True) # force a copy (we modify `window`)
3697
3960
  if window.ndim > 1:
3698
3961
  raise ValueError('window must be 1-D')
3699
- half_len = (window.size - 1) // 2
3962
+ half_len = (xp_size(window) - 1) // 2
3700
3963
  h = window
3701
3964
  else:
3702
3965
  # Design a linear-phase low-pass FIR filter
3703
3966
  max_rate = max(up, down)
3704
3967
  f_c = 1. / max_rate # cutoff of FIR filter (rel. to Nyquist)
3705
3968
  half_len = 10 * max_rate # reasonable cutoff for sinc-like function
3706
- if np.issubdtype(x.dtype, np.complexfloating):
3707
- h = firwin(2 * half_len + 1, f_c,
3708
- window=window).astype(x.dtype) # match dtype of x
3709
- elif np.issubdtype(x.dtype, np.floating):
3710
- h = firwin(2 * half_len + 1, f_c,
3711
- window=window).astype(x.dtype) # match dtype of x
3969
+ if xp.isdtype(x.dtype, ("real floating", "complex floating")):
3970
+ h = firwin(2 * half_len + 1, f_c, window=window)
3971
+ h = xp.asarray(h, dtype=x.dtype) # match dtype of x
3712
3972
  else:
3713
- h = firwin(2 * half_len + 1, f_c,
3714
- window=window)
3973
+ h = firwin(2 * half_len + 1, f_c, window=window)
3974
+ h = xp.asarray(h)
3975
+
3715
3976
  h *= up
3716
3977
 
3717
3978
  # Zero-pad our filter to put the output samples at the center
@@ -3719,16 +3980,20 @@ def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
3719
3980
  n_post_pad = 0
3720
3981
  n_pre_remove = (half_len + n_pre_pad) // down
3721
3982
  # We should rarely need to do this given our filter lengths...
3722
- while _output_len(len(h) + n_pre_pad + n_post_pad, n_in,
3983
+ while _output_len(h.shape[0] + n_pre_pad + n_post_pad, n_in,
3723
3984
  up, down) < n_out + n_pre_remove:
3724
3985
  n_post_pad += 1
3725
- h = np.concatenate((np.zeros(n_pre_pad, dtype=h.dtype), h,
3726
- np.zeros(n_post_pad, dtype=h.dtype)))
3986
+ h = xp.concat((xp.zeros(n_pre_pad, dtype=h.dtype), h,
3987
+ xp.zeros(n_post_pad, dtype=h.dtype)))
3727
3988
  n_pre_remove_end = n_pre_remove + n_out
3728
3989
 
3990
+ # XXX consider using stats.quantile, which is natively Array API compatible
3991
+ def _median(x, *args, **kwds):
3992
+ return xp.asarray(np.median(np.asarray(x), *args, **kwds))
3993
+
3729
3994
  # Remove background depending on the padtype option
3730
- funcs = {'mean': np.mean, 'median': np.median,
3731
- 'minimum': np.amin, 'maximum': np.amax}
3995
+ funcs = {'mean': xp.mean, 'median': _median,
3996
+ 'minimum': xp.min, 'maximum': xp.max}
3732
3997
  upfirdn_kwargs = {'mode': 'constant', 'cval': 0}
3733
3998
  if padtype in funcs:
3734
3999
  background_values = funcs[padtype](x, axis=axis, keepdims=True)
@@ -3759,6 +4024,15 @@ def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
3759
4024
  return y_keep
3760
4025
 
3761
4026
 
4027
+ def _angle(z, xp):
4028
+ """np.angle replacement
4029
+ """
4030
+ # XXX: https://github.com/data-apis/array-api/issues/595
4031
+ zimag = xp.imag(z) if xp.isdtype(z.dtype, 'complex floating') else 0.
4032
+ a = xp.atan2(zimag, xp.real(z))
4033
+ return a
4034
+
4035
+
3762
4036
  def vectorstrength(events, period):
3763
4037
  '''
3764
4038
  Determine the vector strength of the events corresponding to the given
@@ -3806,8 +4080,13 @@ def vectorstrength(events, period):
3806
4080
  fixed. Biol Cybern. 2013 Aug;107(4):491-94.
3807
4081
  :doi:`10.1007/s00422-013-0560-8`.
3808
4082
  '''
3809
- events = np.asarray(events)
3810
- period = np.asarray(period)
4083
+ xp = array_namespace(events, period)
4084
+
4085
+ events = xp.asarray(events)
4086
+ period = xp.asarray(period)
4087
+ if xp.isdtype(period.dtype, 'integral'):
4088
+ period = xp.astype(period, xp.float64)
4089
+
3811
4090
  if events.ndim > 1:
3812
4091
  raise ValueError('events cannot have dimensions more than 1')
3813
4092
  if period.ndim > 1:
@@ -3816,19 +4095,20 @@ def vectorstrength(events, period):
3816
4095
  # we need to know later if period was originally a scalar
3817
4096
  scalarperiod = not period.ndim
3818
4097
 
3819
- events = np.atleast_2d(events)
3820
- period = np.atleast_2d(period)
3821
- if (period <= 0).any():
4098
+ events = xpx.atleast_nd(events, ndim=2, xp=xp)
4099
+ period = xpx.atleast_nd(period, ndim=2, xp=xp)
4100
+ if xp.any(period <= 0):
3822
4101
  raise ValueError('periods must be positive')
3823
4102
 
3824
4103
  # this converts the times to vectors
3825
- vectors = np.exp(np.dot(2j*np.pi/period.T, events))
4104
+ events_ = xp.astype(events, period.dtype)
4105
+ vectors = xp.exp(2j * (xp.pi / period.T @ events_))
3826
4106
 
3827
4107
  # the vector strength is just the magnitude of the mean of the vectors
3828
4108
  # the vector phase is the angle of the mean of the vectors
3829
- vectormean = np.mean(vectors, axis=1)
3830
- strength = abs(vectormean)
3831
- phase = np.angle(vectormean)
4109
+ vectormean = xp.mean(vectors, axis=1)
4110
+ strength = xp.abs(vectormean)
4111
+ phase = _angle(vectormean, xp)
3832
4112
 
3833
4113
  # if the original period was a scalar, return scalars
3834
4114
  if scalarperiod:
@@ -3919,16 +4199,24 @@ def detrend(data: np.ndarray, axis: int = -1,
3919
4199
  """
3920
4200
  if type not in ['linear', 'l', 'constant', 'c']:
3921
4201
  raise ValueError("Trend type must be 'linear' or 'constant'.")
4202
+
4203
+ # XXX simplify when data-apis/array-api-compat#147 is available
4204
+ if isinstance(bp, int):
4205
+ xp = array_namespace(data)
4206
+ else:
4207
+ xp = array_namespace(data, bp)
4208
+
3922
4209
  data = np.asarray(data)
3923
4210
  dtype = data.dtype.char
3924
4211
  if dtype not in 'dfDF':
3925
4212
  dtype = 'd'
3926
4213
  if type in ['constant', 'c']:
3927
4214
  ret = data - np.mean(data, axis, keepdims=True)
3928
- return ret
4215
+ return xp.asarray(ret)
3929
4216
  else:
3930
4217
  dshape = data.shape
3931
4218
  N = dshape[axis]
4219
+ bp = np.asarray(bp)
3932
4220
  bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
3933
4221
  if np.any(bp > N):
3934
4222
  raise ValueError("Breakpoints must be less than length "
@@ -3961,7 +4249,7 @@ def detrend(data: np.ndarray, axis: int = -1,
3961
4249
  # Put data back in original shape.
3962
4250
  newdata = newdata.reshape(newdata_shape)
3963
4251
  ret = np.moveaxis(newdata, 0, axis)
3964
- return ret
4252
+ return xp.asarray(ret)
3965
4253
 
3966
4254
 
3967
4255
  def lfilter_zi(b, a):
@@ -4046,6 +4334,7 @@ def lfilter_zi(b, a):
4046
4334
  transient until the input drops from 0.5 to 0.0.
4047
4335
 
4048
4336
  """
4337
+ xp = array_namespace(b, a)
4049
4338
 
4050
4339
  # FIXME: Can this function be replaced with an appropriate
4051
4340
  # use of lfiltic? For example, when b,a = butter(N,Wn),
@@ -4055,16 +4344,16 @@ def lfilter_zi(b, a):
4055
4344
  # We could use scipy.signal.normalize, but it uses warnings in
4056
4345
  # cases where a ValueError is more appropriate, and it allows
4057
4346
  # b to be 2D.
4058
- b = np.atleast_1d(b)
4347
+ b = xpx.atleast_nd(xp.asarray(b), ndim=1, xp=xp)
4059
4348
  if b.ndim != 1:
4060
4349
  raise ValueError("Numerator b must be 1-D.")
4061
- a = np.atleast_1d(a)
4350
+ a = xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp)
4062
4351
  if a.ndim != 1:
4063
4352
  raise ValueError("Denominator a must be 1-D.")
4064
4353
 
4065
- while len(a) > 1 and a[0] == 0.0:
4354
+ while a.shape[0] > 1 and a[0] == 0.0:
4066
4355
  a = a[1:]
4067
- if a.size < 1:
4356
+ if xp_size(a) < 1:
4068
4357
  raise ValueError("There must be at least one nonzero `a` coefficient.")
4069
4358
 
4070
4359
  if a[0] != 1.0:
@@ -4072,18 +4361,20 @@ def lfilter_zi(b, a):
4072
4361
  b = b / a[0]
4073
4362
  a = a / a[0]
4074
4363
 
4075
- n = max(len(a), len(b))
4364
+ n = max(a.shape[0], b.shape[0])
4076
4365
 
4077
4366
  # Pad a or b with zeros so they are the same length.
4078
- if len(a) < n:
4079
- a = np.r_[a, np.zeros(n - len(a), dtype=a.dtype)]
4080
- elif len(b) < n:
4081
- b = np.r_[b, np.zeros(n - len(b), dtype=b.dtype)]
4082
-
4083
- IminusA = np.eye(n - 1, dtype=np.result_type(a, b)) - linalg.companion(a).T
4367
+ if a.shape[0] < n:
4368
+ a = xp.concat((a, xp.zeros(n - a.shape[0], dtype=a.dtype)))
4369
+ elif b.shape[0] < n:
4370
+ b = xp.concat((b, xp.zeros(n - b.shape[0], dtype=b.dtype)))
4371
+
4372
+ dt = xp.result_type(a, b)
4373
+ IminusA = np.eye(n - 1) - linalg.companion(a).T
4374
+ IminusA = xp.asarray(IminusA, dtype=dt)
4084
4375
  B = b[1:] - a[1:] * b[0]
4085
4376
  # Solve zi = A*zi + B
4086
- zi = np.linalg.solve(IminusA, B)
4377
+ zi = xp.linalg.solve(IminusA, B)
4087
4378
 
4088
4379
  # For future reference: we could also use the following
4089
4380
  # explicit formulas to solve the linear system:
@@ -4154,24 +4445,26 @@ def sosfilt_zi(sos):
4154
4445
  >>> plt.show()
4155
4446
 
4156
4447
  """
4157
- sos = np.asarray(sos)
4448
+ xp = array_namespace(sos)
4449
+
4450
+ sos = xp.asarray(sos)
4158
4451
  if sos.ndim != 2 or sos.shape[1] != 6:
4159
4452
  raise ValueError('sos must be shape (n_sections, 6)')
4160
4453
 
4161
- if sos.dtype.kind in 'bui':
4162
- sos = sos.astype(np.float64)
4454
+ if xp.isdtype(sos.dtype, ("integral", "bool")):
4455
+ sos = xp.astype(sos, xp.float64)
4163
4456
 
4164
4457
  n_sections = sos.shape[0]
4165
- zi = np.empty((n_sections, 2), dtype=sos.dtype)
4458
+ zi = xp.empty((n_sections, 2), dtype=sos.dtype)
4166
4459
  scale = 1.0
4167
4460
  for section in range(n_sections):
4168
4461
  b = sos[section, :3]
4169
4462
  a = sos[section, 3:]
4170
- zi[section] = scale * lfilter_zi(b, a)
4463
+ zi[section, ...] = scale * lfilter_zi(b, a)
4171
4464
  # If H(z) = B(z)/A(z) is this section's transfer function, then
4172
4465
  # b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
4173
4466
  # state value of this section's step response.
4174
- scale *= b.sum() / a.sum()
4467
+ scale *= xp.sum(b) / xp.sum(a)
4175
4468
 
4176
4469
  return zi
4177
4470
 
@@ -4513,6 +4806,8 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4513
4806
  2.875334415008979e-10
4514
4807
 
4515
4808
  """
4809
+ xp = array_namespace(b, a, x)
4810
+
4516
4811
  b = np.atleast_1d(b)
4517
4812
  a = np.atleast_1d(a)
4518
4813
  x = np.asarray(x)
@@ -4522,7 +4817,7 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4522
4817
 
4523
4818
  if method == "gust":
4524
4819
  y, z1, z2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
4525
- return y
4820
+ return xp.asarray(y)
4526
4821
 
4527
4822
  # method == "pad"
4528
4823
  edge, ext = _validate_pad(padtype, padlen, x, axis,
@@ -4553,8 +4848,10 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
4553
4848
  if edge > 0:
4554
4849
  # Slice the actual signal from the extended signal.
4555
4850
  y = axis_slice(y, start=edge, stop=-edge, axis=axis)
4851
+ if is_torch(xp):
4852
+ y = y.copy() # pytorch/pytorch#59786 : no negative strides in pytorch
4556
4853
 
4557
- return y
4854
+ return xp.asarray(y)
4558
4855
 
4559
4856
 
4560
4857
  def _validate_pad(padtype, padlen, x, axis, ntaps):
@@ -4574,8 +4871,10 @@ def _validate_pad(padtype, padlen, x, axis, ntaps):
4574
4871
 
4575
4872
  # x's 'axis' dimension must be bigger than edge.
4576
4873
  if x.shape[axis] <= edge:
4577
- raise ValueError("The length of the input vector x must be greater "
4578
- "than padlen, which is %d." % edge)
4874
+ raise ValueError(
4875
+ f"The length of the input vector x must be greater than padlen, "
4876
+ f"which is {edge}."
4877
+ )
4579
4878
 
4580
4879
  if padtype is not None and edge > 0:
4581
4880
  # Make an extension of length `edge` at each
@@ -4668,10 +4967,17 @@ def sosfilt(sos, x, axis=-1, zi=None):
4668
4967
  >>> plt.show()
4669
4968
 
4670
4969
  """
4671
- _reject_objects(sos, 'sosfilt')
4672
- _reject_objects(x, 'sosfilt')
4673
- if zi is not None:
4674
- _reject_objects(zi, 'sosfilt')
4970
+ try:
4971
+ xp = array_namespace(sos, x, zi)
4972
+ except TypeError:
4973
+ # either in1 or in2 are object arrays
4974
+ xp = np_compat
4975
+
4976
+ if is_numpy(xp):
4977
+ _reject_objects(sos, 'sosfilt')
4978
+ _reject_objects(x, 'sosfilt')
4979
+ if zi is not None:
4980
+ _reject_objects(zi, 'sosfilt')
4675
4981
 
4676
4982
  x = _validate_x(x)
4677
4983
  sos, n_sections = _validate_sos(sos)
@@ -4685,19 +4991,26 @@ def sosfilt(sos, x, axis=-1, zi=None):
4685
4991
  if dtype.char not in 'fdgFDGO':
4686
4992
  raise NotImplementedError(f"input type '{dtype}' not supported")
4687
4993
  if zi is not None:
4688
- zi = np.array(zi, dtype) # make a copy so that we can operate in place
4994
+ zi = np.asarray(zi, dtype=dtype)
4995
+
4996
+ # make a copy so that we can operate in place
4997
+ # NB: 1. use xp_copy to paper over numpy 1/2 copy= keyword
4998
+ # 2. make sure the copied zi remains a numpy array
4999
+ zi = xp_copy(zi, xp=array_namespace(zi))
4689
5000
  if zi.shape != x_zi_shape:
4690
- raise ValueError('Invalid zi shape. With axis=%r, an input with '
4691
- 'shape %r, and an sos array with %d sections, zi '
4692
- 'must have shape %r, got %r.' %
4693
- (axis, x.shape, n_sections, x_zi_shape, zi.shape))
5001
+ raise ValueError(
5002
+ f"Invalid zi shape. With axis={axis!r}, "
5003
+ f"an input with shape {x.shape!r}, "
5004
+ f"and an sos array with {n_sections} sections, zi must have "
5005
+ f"shape {x_zi_shape!r}, got {zi.shape!r}."
5006
+ )
4694
5007
  return_zi = True
4695
5008
  else:
4696
5009
  zi = np.zeros(x_zi_shape, dtype=dtype)
4697
5010
  return_zi = False
4698
5011
  axis = axis % x.ndim # make positive
4699
5012
  x = np.moveaxis(x, axis, -1)
4700
- zi = np.moveaxis(zi, [0, axis + 1], [-2, -1])
5013
+ zi = np.moveaxis(zi, (0, axis + 1), (-2, -1))
4701
5014
  x_shape, zi_shape = x.shape, zi.shape
4702
5015
  x = np.reshape(x, (-1, x.shape[-1]))
4703
5016
  x = np.array(x, dtype, order='C') # make a copy, can modify in place
@@ -4708,10 +5021,10 @@ def sosfilt(sos, x, axis=-1, zi=None):
4708
5021
  x = np.moveaxis(x, -1, axis)
4709
5022
  if return_zi:
4710
5023
  zi.shape = zi_shape
4711
- zi = np.moveaxis(zi, [-2, -1], [0, axis + 1])
4712
- out = (x, zi)
5024
+ zi = np.moveaxis(zi, (-2, -1), (0, axis + 1))
5025
+ out = (xp.asarray(x), xp.asarray(zi))
4713
5026
  else:
4714
- out = x
5027
+ out = xp.asarray(x)
4715
5028
  return out
4716
5029
 
4717
5030
 
@@ -4804,6 +5117,8 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
4804
5117
  >>> plt.show()
4805
5118
 
4806
5119
  """
5120
+ xp = array_namespace(sos, x)
5121
+
4807
5122
  sos, n_sections = _validate_sos(sos)
4808
5123
  x = _validate_x(x)
4809
5124
 
@@ -4825,7 +5140,7 @@ def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
4825
5140
  y = axis_reverse(y, axis=axis)
4826
5141
  if edge > 0:
4827
5142
  y = axis_slice(y, start=edge, stop=-edge, axis=axis)
4828
- return y
5143
+ return xp.asarray(y)
4829
5144
 
4830
5145
 
4831
5146
  def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True):
@@ -4838,11 +5153,12 @@ def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True):
4838
5153
  Parameters
4839
5154
  ----------
4840
5155
  x : array_like
4841
- The signal to be downsampled, as an N-dimensional array.
5156
+ The input signal made up of equidistant samples. If `x` is a multidimensional
5157
+ array, the parameter `axis` specifies the time axis.
4842
5158
  q : int
4843
- The downsampling factor. When using IIR downsampling, it is recommended
4844
- to call `decimate` multiple times for downsampling factors higher than
4845
- 13.
5159
+ The downsampling factor, which is a postive integer. When using IIR
5160
+ downsampling, it is recommended to call `decimate` multiple times for
5161
+ downsampling factors higher than 13.
4846
5162
  n : int, optional
4847
5163
  The order of the filter (1 less than the length for 'fir'). Defaults to
4848
5164
  8 for 'iir' and 20 times the downsampling factor for 'fir'.
@@ -4871,6 +5187,10 @@ def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True):
4871
5187
 
4872
5188
  Notes
4873
5189
  -----
5190
+ For non-integer downsampling factors, `~scipy.signal.resample` can be used. Consult
5191
+ the `scipy.interpolate` module for methods of resampling signals with non-constant
5192
+ sampling intervals.
5193
+
4874
5194
  The ``zero_phase`` keyword was added in 0.18.0.
4875
5195
  The possibility to use instances of ``dlti`` as ``ftype`` was added in
4876
5196
  0.18.0.