scipy 1.15.3__cp312-cp312-macosx_14_0_arm64.whl → 1.16.0rc2__cp312-cp312-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (628) hide show
  1. scipy/__config__.py +4 -4
  2. scipy/__init__.py +3 -6
  3. scipy/_cyutility.cpython-312-darwin.so +0 -0
  4. scipy/_lib/_array_api.py +486 -161
  5. scipy/_lib/_array_api_compat_vendor.py +9 -0
  6. scipy/_lib/_bunch.py +4 -0
  7. scipy/_lib/_ccallback_c.cpython-312-darwin.so +0 -0
  8. scipy/_lib/_docscrape.py +1 -1
  9. scipy/_lib/_elementwise_iterative_method.py +15 -26
  10. scipy/_lib/_sparse.py +41 -0
  11. scipy/_lib/_test_deprecation_call.cpython-312-darwin.so +0 -0
  12. scipy/_lib/_test_deprecation_def.cpython-312-darwin.so +0 -0
  13. scipy/_lib/_testutils.py +6 -2
  14. scipy/_lib/_util.py +222 -125
  15. scipy/_lib/array_api_compat/__init__.py +4 -4
  16. scipy/_lib/array_api_compat/_internal.py +19 -6
  17. scipy/_lib/array_api_compat/common/__init__.py +1 -1
  18. scipy/_lib/array_api_compat/common/_aliases.py +365 -193
  19. scipy/_lib/array_api_compat/common/_fft.py +94 -64
  20. scipy/_lib/array_api_compat/common/_helpers.py +413 -180
  21. scipy/_lib/array_api_compat/common/_linalg.py +116 -40
  22. scipy/_lib/array_api_compat/common/_typing.py +179 -10
  23. scipy/_lib/array_api_compat/cupy/__init__.py +1 -4
  24. scipy/_lib/array_api_compat/cupy/_aliases.py +61 -41
  25. scipy/_lib/array_api_compat/cupy/_info.py +16 -6
  26. scipy/_lib/array_api_compat/cupy/_typing.py +24 -39
  27. scipy/_lib/array_api_compat/dask/array/__init__.py +6 -3
  28. scipy/_lib/array_api_compat/dask/array/_aliases.py +267 -108
  29. scipy/_lib/array_api_compat/dask/array/_info.py +105 -34
  30. scipy/_lib/array_api_compat/dask/array/fft.py +5 -8
  31. scipy/_lib/array_api_compat/dask/array/linalg.py +21 -22
  32. scipy/_lib/array_api_compat/numpy/__init__.py +13 -15
  33. scipy/_lib/array_api_compat/numpy/_aliases.py +98 -49
  34. scipy/_lib/array_api_compat/numpy/_info.py +36 -16
  35. scipy/_lib/array_api_compat/numpy/_typing.py +27 -43
  36. scipy/_lib/array_api_compat/numpy/fft.py +11 -5
  37. scipy/_lib/array_api_compat/numpy/linalg.py +75 -22
  38. scipy/_lib/array_api_compat/torch/__init__.py +3 -5
  39. scipy/_lib/array_api_compat/torch/_aliases.py +262 -159
  40. scipy/_lib/array_api_compat/torch/_info.py +27 -16
  41. scipy/_lib/array_api_compat/torch/_typing.py +3 -0
  42. scipy/_lib/array_api_compat/torch/fft.py +17 -18
  43. scipy/_lib/array_api_compat/torch/linalg.py +16 -16
  44. scipy/_lib/array_api_extra/__init__.py +26 -3
  45. scipy/_lib/array_api_extra/_delegation.py +171 -0
  46. scipy/_lib/array_api_extra/_lib/__init__.py +1 -0
  47. scipy/_lib/array_api_extra/_lib/_at.py +463 -0
  48. scipy/_lib/array_api_extra/_lib/_backends.py +46 -0
  49. scipy/_lib/array_api_extra/_lib/_funcs.py +937 -0
  50. scipy/_lib/array_api_extra/_lib/_lazy.py +357 -0
  51. scipy/_lib/array_api_extra/_lib/_testing.py +278 -0
  52. scipy/_lib/array_api_extra/_lib/_utils/__init__.py +1 -0
  53. scipy/_lib/array_api_extra/_lib/_utils/_compat.py +74 -0
  54. scipy/_lib/array_api_extra/_lib/_utils/_compat.pyi +45 -0
  55. scipy/_lib/array_api_extra/_lib/_utils/_helpers.py +559 -0
  56. scipy/_lib/array_api_extra/_lib/_utils/_typing.py +10 -0
  57. scipy/_lib/array_api_extra/_lib/_utils/_typing.pyi +105 -0
  58. scipy/_lib/array_api_extra/testing.py +359 -0
  59. scipy/_lib/decorator.py +2 -2
  60. scipy/_lib/doccer.py +1 -7
  61. scipy/_lib/messagestream.cpython-312-darwin.so +0 -0
  62. scipy/_lib/pyprima/__init__.py +212 -0
  63. scipy/_lib/pyprima/cobyla/__init__.py +0 -0
  64. scipy/_lib/pyprima/cobyla/cobyla.py +559 -0
  65. scipy/_lib/pyprima/cobyla/cobylb.py +714 -0
  66. scipy/_lib/pyprima/cobyla/geometry.py +226 -0
  67. scipy/_lib/pyprima/cobyla/initialize.py +215 -0
  68. scipy/_lib/pyprima/cobyla/trustregion.py +492 -0
  69. scipy/_lib/pyprima/cobyla/update.py +289 -0
  70. scipy/_lib/pyprima/common/__init__.py +0 -0
  71. scipy/_lib/pyprima/common/_bounds.py +34 -0
  72. scipy/_lib/pyprima/common/_linear_constraints.py +46 -0
  73. scipy/_lib/pyprima/common/_nonlinear_constraints.py +54 -0
  74. scipy/_lib/pyprima/common/_project.py +173 -0
  75. scipy/_lib/pyprima/common/checkbreak.py +93 -0
  76. scipy/_lib/pyprima/common/consts.py +47 -0
  77. scipy/_lib/pyprima/common/evaluate.py +99 -0
  78. scipy/_lib/pyprima/common/history.py +38 -0
  79. scipy/_lib/pyprima/common/infos.py +30 -0
  80. scipy/_lib/pyprima/common/linalg.py +435 -0
  81. scipy/_lib/pyprima/common/message.py +290 -0
  82. scipy/_lib/pyprima/common/powalg.py +131 -0
  83. scipy/_lib/pyprima/common/preproc.py +277 -0
  84. scipy/_lib/pyprima/common/present.py +5 -0
  85. scipy/_lib/pyprima/common/ratio.py +54 -0
  86. scipy/_lib/pyprima/common/redrho.py +47 -0
  87. scipy/_lib/pyprima/common/selectx.py +296 -0
  88. scipy/_lib/tests/test__util.py +105 -121
  89. scipy/_lib/tests/test_array_api.py +166 -35
  90. scipy/_lib/tests/test_bunch.py +7 -0
  91. scipy/_lib/tests/test_ccallback.py +2 -10
  92. scipy/_lib/tests/test_public_api.py +13 -0
  93. scipy/cluster/_hierarchy.cpython-312-darwin.so +0 -0
  94. scipy/cluster/_optimal_leaf_ordering.cpython-312-darwin.so +0 -0
  95. scipy/cluster/_vq.cpython-312-darwin.so +0 -0
  96. scipy/cluster/hierarchy.py +393 -223
  97. scipy/cluster/tests/test_hierarchy.py +273 -335
  98. scipy/cluster/tests/test_vq.py +45 -61
  99. scipy/cluster/vq.py +39 -35
  100. scipy/conftest.py +263 -157
  101. scipy/constants/_constants.py +4 -1
  102. scipy/constants/tests/test_codata.py +2 -2
  103. scipy/constants/tests/test_constants.py +11 -18
  104. scipy/datasets/_download_all.py +15 -1
  105. scipy/datasets/_fetchers.py +7 -1
  106. scipy/datasets/_utils.py +1 -1
  107. scipy/differentiate/_differentiate.py +25 -25
  108. scipy/differentiate/tests/test_differentiate.py +24 -25
  109. scipy/fft/_basic.py +20 -0
  110. scipy/fft/_helper.py +3 -34
  111. scipy/fft/_pocketfft/helper.py +29 -1
  112. scipy/fft/_pocketfft/tests/test_basic.py +2 -4
  113. scipy/fft/_pocketfft/tests/test_real_transforms.py +4 -4
  114. scipy/fft/_realtransforms.py +13 -0
  115. scipy/fft/tests/test_basic.py +27 -25
  116. scipy/fft/tests/test_fftlog.py +16 -7
  117. scipy/fft/tests/test_helper.py +18 -34
  118. scipy/fft/tests/test_real_transforms.py +8 -10
  119. scipy/fftpack/convolve.cpython-312-darwin.so +0 -0
  120. scipy/fftpack/tests/test_basic.py +2 -4
  121. scipy/fftpack/tests/test_real_transforms.py +8 -9
  122. scipy/integrate/_bvp.py +9 -3
  123. scipy/integrate/_cubature.py +3 -2
  124. scipy/integrate/_dop.cpython-312-darwin.so +0 -0
  125. scipy/integrate/_lsoda.cpython-312-darwin.so +0 -0
  126. scipy/integrate/_ode.py +9 -2
  127. scipy/integrate/_odepack.cpython-312-darwin.so +0 -0
  128. scipy/integrate/_quad_vec.py +21 -29
  129. scipy/integrate/_quadpack.cpython-312-darwin.so +0 -0
  130. scipy/integrate/_quadpack_py.py +11 -7
  131. scipy/integrate/_quadrature.py +3 -3
  132. scipy/integrate/_rules/_base.py +2 -2
  133. scipy/integrate/_tanhsinh.py +48 -47
  134. scipy/integrate/_test_odeint_banded.cpython-312-darwin.so +0 -0
  135. scipy/integrate/_vode.cpython-312-darwin.so +0 -0
  136. scipy/integrate/tests/test__quad_vec.py +0 -6
  137. scipy/integrate/tests/test_banded_ode_solvers.py +85 -0
  138. scipy/integrate/tests/test_cubature.py +21 -35
  139. scipy/integrate/tests/test_quadrature.py +6 -8
  140. scipy/integrate/tests/test_tanhsinh.py +56 -48
  141. scipy/interpolate/__init__.py +70 -58
  142. scipy/interpolate/_bary_rational.py +22 -22
  143. scipy/interpolate/_bsplines.py +119 -66
  144. scipy/interpolate/_cubic.py +65 -50
  145. scipy/interpolate/_dfitpack.cpython-312-darwin.so +0 -0
  146. scipy/interpolate/_dierckx.cpython-312-darwin.so +0 -0
  147. scipy/interpolate/_fitpack.cpython-312-darwin.so +0 -0
  148. scipy/interpolate/_fitpack2.py +9 -6
  149. scipy/interpolate/_fitpack_impl.py +32 -26
  150. scipy/interpolate/_fitpack_repro.py +23 -19
  151. scipy/interpolate/_interpnd.cpython-312-darwin.so +0 -0
  152. scipy/interpolate/_interpolate.py +30 -12
  153. scipy/interpolate/_ndbspline.py +13 -18
  154. scipy/interpolate/_ndgriddata.py +5 -8
  155. scipy/interpolate/_polyint.py +95 -31
  156. scipy/interpolate/_ppoly.cpython-312-darwin.so +0 -0
  157. scipy/interpolate/_rbf.py +2 -2
  158. scipy/interpolate/_rbfinterp.py +1 -1
  159. scipy/interpolate/_rbfinterp_pythran.cpython-312-darwin.so +0 -0
  160. scipy/interpolate/_rgi.py +31 -26
  161. scipy/interpolate/_rgi_cython.cpython-312-darwin.so +0 -0
  162. scipy/interpolate/dfitpack.py +0 -20
  163. scipy/interpolate/interpnd.py +1 -2
  164. scipy/interpolate/tests/test_bary_rational.py +2 -2
  165. scipy/interpolate/tests/test_bsplines.py +97 -1
  166. scipy/interpolate/tests/test_fitpack2.py +39 -1
  167. scipy/interpolate/tests/test_interpnd.py +32 -20
  168. scipy/interpolate/tests/test_interpolate.py +48 -4
  169. scipy/interpolate/tests/test_rgi.py +2 -1
  170. scipy/io/_fast_matrix_market/__init__.py +2 -0
  171. scipy/io/_harwell_boeing/_fortran_format_parser.py +19 -16
  172. scipy/io/_harwell_boeing/hb.py +7 -11
  173. scipy/io/_idl.py +5 -7
  174. scipy/io/_netcdf.py +15 -5
  175. scipy/io/_test_fortran.cpython-312-darwin.so +0 -0
  176. scipy/io/arff/tests/test_arffread.py +3 -3
  177. scipy/io/matlab/__init__.py +5 -3
  178. scipy/io/matlab/_mio.py +4 -1
  179. scipy/io/matlab/_mio5.py +19 -13
  180. scipy/io/matlab/_mio5_utils.cpython-312-darwin.so +0 -0
  181. scipy/io/matlab/_mio_utils.cpython-312-darwin.so +0 -0
  182. scipy/io/matlab/_miobase.py +4 -1
  183. scipy/io/matlab/_streams.cpython-312-darwin.so +0 -0
  184. scipy/io/matlab/tests/test_mio.py +46 -18
  185. scipy/io/matlab/tests/test_mio_funcs.py +1 -1
  186. scipy/io/tests/test_mmio.py +7 -1
  187. scipy/io/tests/test_wavfile.py +41 -0
  188. scipy/io/wavfile.py +57 -10
  189. scipy/linalg/_basic.py +113 -86
  190. scipy/linalg/_cythonized_array_utils.cpython-312-darwin.so +0 -0
  191. scipy/linalg/_decomp.py +22 -9
  192. scipy/linalg/_decomp_cholesky.py +28 -13
  193. scipy/linalg/_decomp_cossin.py +45 -30
  194. scipy/linalg/_decomp_interpolative.cpython-312-darwin.so +0 -0
  195. scipy/linalg/_decomp_ldl.py +4 -1
  196. scipy/linalg/_decomp_lu.py +18 -6
  197. scipy/linalg/_decomp_lu_cython.cpython-312-darwin.so +0 -0
  198. scipy/linalg/_decomp_polar.py +2 -0
  199. scipy/linalg/_decomp_qr.py +6 -2
  200. scipy/linalg/_decomp_qz.py +3 -0
  201. scipy/linalg/_decomp_schur.py +3 -1
  202. scipy/linalg/_decomp_svd.py +13 -2
  203. scipy/linalg/_decomp_update.cpython-312-darwin.so +0 -0
  204. scipy/linalg/_expm_frechet.py +4 -0
  205. scipy/linalg/_fblas.cpython-312-darwin.so +0 -0
  206. scipy/linalg/_flapack.cpython-312-darwin.so +0 -0
  207. scipy/linalg/_linalg_pythran.cpython-312-darwin.so +0 -0
  208. scipy/linalg/_matfuncs.py +187 -4
  209. scipy/linalg/_matfuncs_expm.cpython-312-darwin.so +0 -0
  210. scipy/linalg/_matfuncs_schur_sqrtm.cpython-312-darwin.so +0 -0
  211. scipy/linalg/_matfuncs_sqrtm.py +1 -99
  212. scipy/linalg/_matfuncs_sqrtm_triu.cpython-312-darwin.so +0 -0
  213. scipy/linalg/_procrustes.py +2 -0
  214. scipy/linalg/_sketches.py +17 -6
  215. scipy/linalg/_solve_toeplitz.cpython-312-darwin.so +0 -0
  216. scipy/linalg/_solvers.py +7 -2
  217. scipy/linalg/_special_matrices.py +26 -36
  218. scipy/linalg/cython_blas.cpython-312-darwin.so +0 -0
  219. scipy/linalg/cython_lapack.cpython-312-darwin.so +0 -0
  220. scipy/linalg/lapack.py +22 -2
  221. scipy/linalg/tests/_cython_examples/meson.build +7 -0
  222. scipy/linalg/tests/test_basic.py +31 -16
  223. scipy/linalg/tests/test_batch.py +588 -0
  224. scipy/linalg/tests/test_cythonized_array_utils.py +0 -2
  225. scipy/linalg/tests/test_decomp.py +40 -3
  226. scipy/linalg/tests/test_decomp_cossin.py +14 -0
  227. scipy/linalg/tests/test_decomp_ldl.py +1 -1
  228. scipy/linalg/tests/test_lapack.py +115 -7
  229. scipy/linalg/tests/test_matfuncs.py +157 -102
  230. scipy/linalg/tests/test_procrustes.py +0 -7
  231. scipy/linalg/tests/test_solve_toeplitz.py +1 -1
  232. scipy/linalg/tests/test_special_matrices.py +1 -5
  233. scipy/ndimage/__init__.py +1 -0
  234. scipy/ndimage/_cytest.cpython-312-darwin.so +0 -0
  235. scipy/ndimage/_delegators.py +8 -2
  236. scipy/ndimage/_filters.py +453 -5
  237. scipy/ndimage/_interpolation.py +36 -6
  238. scipy/ndimage/_measurements.py +4 -2
  239. scipy/ndimage/_morphology.py +5 -0
  240. scipy/ndimage/_nd_image.cpython-312-darwin.so +0 -0
  241. scipy/ndimage/_ni_docstrings.py +5 -1
  242. scipy/ndimage/_ni_label.cpython-312-darwin.so +0 -0
  243. scipy/ndimage/_ni_support.py +1 -5
  244. scipy/ndimage/_rank_filter_1d.cpython-312-darwin.so +0 -0
  245. scipy/ndimage/_support_alternative_backends.py +18 -6
  246. scipy/ndimage/tests/test_filters.py +370 -259
  247. scipy/ndimage/tests/test_fourier.py +7 -9
  248. scipy/ndimage/tests/test_interpolation.py +68 -61
  249. scipy/ndimage/tests/test_measurements.py +18 -35
  250. scipy/ndimage/tests/test_morphology.py +143 -131
  251. scipy/ndimage/tests/test_splines.py +1 -3
  252. scipy/odr/__odrpack.cpython-312-darwin.so +0 -0
  253. scipy/optimize/_basinhopping.py +13 -7
  254. scipy/optimize/_bglu_dense.cpython-312-darwin.so +0 -0
  255. scipy/optimize/_bracket.py +17 -24
  256. scipy/optimize/_chandrupatla.py +9 -10
  257. scipy/optimize/_cobyla_py.py +104 -123
  258. scipy/optimize/_constraints.py +14 -10
  259. scipy/optimize/_differentiable_functions.py +371 -230
  260. scipy/optimize/_differentialevolution.py +4 -3
  261. scipy/optimize/_direct.cpython-312-darwin.so +0 -0
  262. scipy/optimize/_dual_annealing.py +1 -1
  263. scipy/optimize/_elementwise.py +1 -4
  264. scipy/optimize/_group_columns.cpython-312-darwin.so +0 -0
  265. scipy/optimize/_lbfgsb.cpython-312-darwin.so +0 -0
  266. scipy/optimize/_lbfgsb_py.py +57 -16
  267. scipy/optimize/_linprog_doc.py +2 -2
  268. scipy/optimize/_linprog_highs.py +2 -2
  269. scipy/optimize/_linprog_ip.py +25 -10
  270. scipy/optimize/_linprog_util.py +14 -16
  271. scipy/optimize/_lsap.cpython-312-darwin.so +0 -0
  272. scipy/optimize/_lsq/common.py +3 -3
  273. scipy/optimize/_lsq/dogbox.py +16 -2
  274. scipy/optimize/_lsq/givens_elimination.cpython-312-darwin.so +0 -0
  275. scipy/optimize/_lsq/least_squares.py +198 -126
  276. scipy/optimize/_lsq/lsq_linear.py +6 -6
  277. scipy/optimize/_lsq/trf.py +35 -8
  278. scipy/optimize/_milp.py +3 -1
  279. scipy/optimize/_minimize.py +105 -36
  280. scipy/optimize/_minpack.cpython-312-darwin.so +0 -0
  281. scipy/optimize/_minpack_py.py +21 -14
  282. scipy/optimize/_moduleTNC.cpython-312-darwin.so +0 -0
  283. scipy/optimize/_nnls.py +20 -21
  284. scipy/optimize/_nonlin.py +34 -3
  285. scipy/optimize/_numdiff.py +288 -110
  286. scipy/optimize/_optimize.py +86 -48
  287. scipy/optimize/_pava_pybind.cpython-312-darwin.so +0 -0
  288. scipy/optimize/_remove_redundancy.py +5 -5
  289. scipy/optimize/_root_scalar.py +1 -1
  290. scipy/optimize/_shgo.py +6 -0
  291. scipy/optimize/_shgo_lib/_complex.py +1 -1
  292. scipy/optimize/_slsqp_py.py +216 -124
  293. scipy/optimize/_slsqplib.cpython-312-darwin.so +0 -0
  294. scipy/optimize/_spectral.py +1 -1
  295. scipy/optimize/_tnc.py +8 -1
  296. scipy/optimize/_trlib/_trlib.cpython-312-darwin.so +0 -0
  297. scipy/optimize/_trustregion.py +20 -6
  298. scipy/optimize/_trustregion_constr/canonical_constraint.py +7 -7
  299. scipy/optimize/_trustregion_constr/equality_constrained_sqp.py +1 -1
  300. scipy/optimize/_trustregion_constr/minimize_trustregion_constr.py +11 -3
  301. scipy/optimize/_trustregion_constr/projections.py +12 -8
  302. scipy/optimize/_trustregion_constr/qp_subproblem.py +9 -9
  303. scipy/optimize/_trustregion_constr/tests/test_projections.py +7 -7
  304. scipy/optimize/_trustregion_constr/tests/test_qp_subproblem.py +77 -77
  305. scipy/optimize/_trustregion_constr/tr_interior_point.py +5 -5
  306. scipy/optimize/_trustregion_exact.py +0 -1
  307. scipy/optimize/_zeros.cpython-312-darwin.so +0 -0
  308. scipy/optimize/_zeros_py.py +97 -17
  309. scipy/optimize/cython_optimize/_zeros.cpython-312-darwin.so +0 -0
  310. scipy/optimize/slsqp.py +0 -1
  311. scipy/optimize/tests/test__basinhopping.py +1 -1
  312. scipy/optimize/tests/test__differential_evolution.py +4 -4
  313. scipy/optimize/tests/test__linprog_clean_inputs.py +5 -3
  314. scipy/optimize/tests/test__numdiff.py +66 -22
  315. scipy/optimize/tests/test__remove_redundancy.py +2 -2
  316. scipy/optimize/tests/test__shgo.py +9 -1
  317. scipy/optimize/tests/test_bracket.py +36 -46
  318. scipy/optimize/tests/test_chandrupatla.py +133 -135
  319. scipy/optimize/tests/test_cobyla.py +74 -45
  320. scipy/optimize/tests/test_constraints.py +1 -1
  321. scipy/optimize/tests/test_differentiable_functions.py +226 -6
  322. scipy/optimize/tests/test_lbfgsb_hessinv.py +22 -0
  323. scipy/optimize/tests/test_least_squares.py +125 -13
  324. scipy/optimize/tests/test_linear_assignment.py +3 -3
  325. scipy/optimize/tests/test_linprog.py +3 -3
  326. scipy/optimize/tests/test_lsq_linear.py +6 -6
  327. scipy/optimize/tests/test_minimize_constrained.py +2 -2
  328. scipy/optimize/tests/test_minpack.py +4 -4
  329. scipy/optimize/tests/test_nnls.py +43 -3
  330. scipy/optimize/tests/test_nonlin.py +36 -0
  331. scipy/optimize/tests/test_optimize.py +95 -17
  332. scipy/optimize/tests/test_slsqp.py +36 -4
  333. scipy/optimize/tests/test_zeros.py +34 -1
  334. scipy/signal/__init__.py +12 -23
  335. scipy/signal/_delegators.py +568 -0
  336. scipy/signal/_filter_design.py +459 -241
  337. scipy/signal/_fir_filter_design.py +262 -90
  338. scipy/signal/_lti_conversion.py +3 -2
  339. scipy/signal/_ltisys.py +118 -91
  340. scipy/signal/_max_len_seq_inner.cpython-312-darwin.so +0 -0
  341. scipy/signal/_peak_finding_utils.cpython-312-darwin.so +0 -0
  342. scipy/signal/_polyutils.py +172 -0
  343. scipy/signal/_short_time_fft.py +519 -70
  344. scipy/signal/_signal_api.py +30 -0
  345. scipy/signal/_signaltools.py +719 -399
  346. scipy/signal/_sigtools.cpython-312-darwin.so +0 -0
  347. scipy/signal/_sosfilt.cpython-312-darwin.so +0 -0
  348. scipy/signal/_spectral_py.py +230 -50
  349. scipy/signal/_spline.cpython-312-darwin.so +0 -0
  350. scipy/signal/_spline_filters.py +108 -68
  351. scipy/signal/_support_alternative_backends.py +73 -0
  352. scipy/signal/_upfirdn.py +4 -1
  353. scipy/signal/_upfirdn_apply.cpython-312-darwin.so +0 -0
  354. scipy/signal/_waveforms.py +2 -11
  355. scipy/signal/_wavelets.py +1 -1
  356. scipy/signal/fir_filter_design.py +1 -0
  357. scipy/signal/spline.py +4 -11
  358. scipy/signal/tests/_scipy_spectral_test_shim.py +2 -171
  359. scipy/signal/tests/test_bsplines.py +114 -79
  360. scipy/signal/tests/test_cont2discrete.py +9 -2
  361. scipy/signal/tests/test_filter_design.py +721 -481
  362. scipy/signal/tests/test_fir_filter_design.py +332 -140
  363. scipy/signal/tests/test_savitzky_golay.py +4 -3
  364. scipy/signal/tests/test_short_time_fft.py +221 -3
  365. scipy/signal/tests/test_signaltools.py +2144 -1348
  366. scipy/signal/tests/test_spectral.py +50 -6
  367. scipy/signal/tests/test_splines.py +161 -96
  368. scipy/signal/tests/test_upfirdn.py +84 -50
  369. scipy/signal/tests/test_waveforms.py +20 -0
  370. scipy/signal/tests/test_windows.py +607 -466
  371. scipy/signal/windows/_windows.py +287 -148
  372. scipy/sparse/__init__.py +23 -4
  373. scipy/sparse/_base.py +270 -108
  374. scipy/sparse/_bsr.py +7 -4
  375. scipy/sparse/_compressed.py +59 -231
  376. scipy/sparse/_construct.py +90 -38
  377. scipy/sparse/_coo.py +115 -181
  378. scipy/sparse/_csc.py +4 -4
  379. scipy/sparse/_csparsetools.cpython-312-darwin.so +0 -0
  380. scipy/sparse/_csr.py +2 -2
  381. scipy/sparse/_data.py +48 -48
  382. scipy/sparse/_dia.py +105 -18
  383. scipy/sparse/_dok.py +0 -23
  384. scipy/sparse/_index.py +4 -4
  385. scipy/sparse/_matrix.py +23 -0
  386. scipy/sparse/_sparsetools.cpython-312-darwin.so +0 -0
  387. scipy/sparse/_sputils.py +37 -22
  388. scipy/sparse/base.py +0 -9
  389. scipy/sparse/bsr.py +0 -14
  390. scipy/sparse/compressed.py +0 -23
  391. scipy/sparse/construct.py +0 -6
  392. scipy/sparse/coo.py +0 -14
  393. scipy/sparse/csc.py +0 -3
  394. scipy/sparse/csgraph/_flow.cpython-312-darwin.so +0 -0
  395. scipy/sparse/csgraph/_matching.cpython-312-darwin.so +0 -0
  396. scipy/sparse/csgraph/_min_spanning_tree.cpython-312-darwin.so +0 -0
  397. scipy/sparse/csgraph/_reordering.cpython-312-darwin.so +0 -0
  398. scipy/sparse/csgraph/_shortest_path.cpython-312-darwin.so +0 -0
  399. scipy/sparse/csgraph/_tools.cpython-312-darwin.so +0 -0
  400. scipy/sparse/csgraph/_traversal.cpython-312-darwin.so +0 -0
  401. scipy/sparse/csgraph/tests/test_matching.py +14 -2
  402. scipy/sparse/csgraph/tests/test_pydata_sparse.py +4 -1
  403. scipy/sparse/csgraph/tests/test_shortest_path.py +83 -27
  404. scipy/sparse/csr.py +0 -5
  405. scipy/sparse/data.py +1 -6
  406. scipy/sparse/dia.py +0 -7
  407. scipy/sparse/dok.py +0 -10
  408. scipy/sparse/linalg/_dsolve/_superlu.cpython-312-darwin.so +0 -0
  409. scipy/sparse/linalg/_dsolve/linsolve.py +9 -0
  410. scipy/sparse/linalg/_dsolve/tests/test_linsolve.py +35 -28
  411. scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-312-darwin.so +0 -0
  412. scipy/sparse/linalg/_eigen/arpack/arpack.py +23 -17
  413. scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py +6 -6
  414. scipy/sparse/linalg/_interface.py +17 -18
  415. scipy/sparse/linalg/_isolve/_gcrotmk.py +4 -4
  416. scipy/sparse/linalg/_isolve/iterative.py +51 -45
  417. scipy/sparse/linalg/_isolve/lgmres.py +6 -6
  418. scipy/sparse/linalg/_isolve/minres.py +5 -5
  419. scipy/sparse/linalg/_isolve/tfqmr.py +7 -7
  420. scipy/sparse/linalg/_isolve/utils.py +2 -8
  421. scipy/sparse/linalg/_matfuncs.py +1 -1
  422. scipy/sparse/linalg/_norm.py +1 -1
  423. scipy/sparse/linalg/_propack/_cpropack.cpython-312-darwin.so +0 -0
  424. scipy/sparse/linalg/_propack/_dpropack.cpython-312-darwin.so +0 -0
  425. scipy/sparse/linalg/_propack/_spropack.cpython-312-darwin.so +0 -0
  426. scipy/sparse/linalg/_propack/_zpropack.cpython-312-darwin.so +0 -0
  427. scipy/sparse/linalg/_special_sparse_arrays.py +39 -38
  428. scipy/sparse/linalg/tests/test_pydata_sparse.py +14 -0
  429. scipy/sparse/tests/test_arithmetic1d.py +5 -2
  430. scipy/sparse/tests/test_base.py +214 -42
  431. scipy/sparse/tests/test_common1d.py +7 -7
  432. scipy/sparse/tests/test_construct.py +1 -1
  433. scipy/sparse/tests/test_coo.py +272 -4
  434. scipy/sparse/tests/test_sparsetools.py +5 -0
  435. scipy/sparse/tests/test_sputils.py +36 -7
  436. scipy/spatial/_ckdtree.cpython-312-darwin.so +0 -0
  437. scipy/spatial/_distance_pybind.cpython-312-darwin.so +0 -0
  438. scipy/spatial/_distance_wrap.cpython-312-darwin.so +0 -0
  439. scipy/spatial/_hausdorff.cpython-312-darwin.so +0 -0
  440. scipy/spatial/_qhull.cpython-312-darwin.so +0 -0
  441. scipy/spatial/_voronoi.cpython-312-darwin.so +0 -0
  442. scipy/spatial/distance.py +49 -42
  443. scipy/spatial/tests/test_distance.py +15 -1
  444. scipy/spatial/tests/test_kdtree.py +1 -0
  445. scipy/spatial/tests/test_qhull.py +7 -2
  446. scipy/spatial/transform/__init__.py +5 -3
  447. scipy/spatial/transform/_rigid_transform.cpython-312-darwin.so +0 -0
  448. scipy/spatial/transform/_rotation.cpython-312-darwin.so +0 -0
  449. scipy/spatial/transform/tests/test_rigid_transform.py +1221 -0
  450. scipy/spatial/transform/tests/test_rotation.py +1213 -832
  451. scipy/spatial/transform/tests/test_rotation_groups.py +3 -3
  452. scipy/spatial/transform/tests/test_rotation_spline.py +29 -8
  453. scipy/special/__init__.py +1 -47
  454. scipy/special/_add_newdocs.py +34 -772
  455. scipy/special/_basic.py +22 -25
  456. scipy/special/_comb.cpython-312-darwin.so +0 -0
  457. scipy/special/_ellip_harm_2.cpython-312-darwin.so +0 -0
  458. scipy/special/_gufuncs.cpython-312-darwin.so +0 -0
  459. scipy/special/_logsumexp.py +67 -58
  460. scipy/special/_orthogonal.pyi +1 -1
  461. scipy/special/_specfun.cpython-312-darwin.so +0 -0
  462. scipy/special/_special_ufuncs.cpython-312-darwin.so +0 -0
  463. scipy/special/_spherical_bessel.py +4 -4
  464. scipy/special/_support_alternative_backends.py +212 -119
  465. scipy/special/_test_internal.cpython-312-darwin.so +0 -0
  466. scipy/special/_testutils.py +4 -4
  467. scipy/special/_ufuncs.cpython-312-darwin.so +0 -0
  468. scipy/special/_ufuncs.pyi +1 -0
  469. scipy/special/_ufuncs.pyx +215 -1400
  470. scipy/special/_ufuncs_cxx.cpython-312-darwin.so +0 -0
  471. scipy/special/_ufuncs_cxx.pxd +2 -15
  472. scipy/special/_ufuncs_cxx.pyx +5 -44
  473. scipy/special/_ufuncs_cxx_defs.h +2 -16
  474. scipy/special/_ufuncs_defs.h +0 -8
  475. scipy/special/cython_special.cpython-312-darwin.so +0 -0
  476. scipy/special/cython_special.pxd +1 -1
  477. scipy/special/tests/_cython_examples/meson.build +10 -1
  478. scipy/special/tests/test_basic.py +153 -20
  479. scipy/special/tests/test_boost_ufuncs.py +3 -0
  480. scipy/special/tests/test_cdflib.py +35 -11
  481. scipy/special/tests/test_gammainc.py +16 -0
  482. scipy/special/tests/test_hyp2f1.py +2 -2
  483. scipy/special/tests/test_log1mexp.py +85 -0
  484. scipy/special/tests/test_logsumexp.py +206 -64
  485. scipy/special/tests/test_mpmath.py +1 -0
  486. scipy/special/tests/test_nan_inputs.py +1 -1
  487. scipy/special/tests/test_orthogonal.py +17 -18
  488. scipy/special/tests/test_sf_error.py +3 -2
  489. scipy/special/tests/test_sph_harm.py +6 -7
  490. scipy/special/tests/test_support_alternative_backends.py +211 -76
  491. scipy/stats/__init__.py +4 -1
  492. scipy/stats/_ansari_swilk_statistics.cpython-312-darwin.so +0 -0
  493. scipy/stats/_axis_nan_policy.py +5 -12
  494. scipy/stats/_biasedurn.cpython-312-darwin.so +0 -0
  495. scipy/stats/_continued_fraction.py +387 -0
  496. scipy/stats/_continuous_distns.py +277 -310
  497. scipy/stats/_correlation.py +1 -1
  498. scipy/stats/_covariance.py +6 -3
  499. scipy/stats/_discrete_distns.py +39 -32
  500. scipy/stats/_distn_infrastructure.py +39 -12
  501. scipy/stats/_distribution_infrastructure.py +900 -238
  502. scipy/stats/_entropy.py +9 -10
  503. scipy/{_lib → stats}/_finite_differences.py +1 -1
  504. scipy/stats/_hypotests.py +83 -50
  505. scipy/stats/_kde.py +53 -49
  506. scipy/stats/_ksstats.py +1 -1
  507. scipy/stats/_levy_stable/__init__.py +7 -15
  508. scipy/stats/_levy_stable/levyst.cpython-312-darwin.so +0 -0
  509. scipy/stats/_morestats.py +118 -73
  510. scipy/stats/_mstats_basic.py +13 -17
  511. scipy/stats/_mstats_extras.py +8 -8
  512. scipy/stats/_multivariate.py +89 -113
  513. scipy/stats/_new_distributions.py +97 -20
  514. scipy/stats/_page_trend_test.py +12 -5
  515. scipy/stats/_probability_distribution.py +265 -43
  516. scipy/stats/_qmc.py +14 -9
  517. scipy/stats/_qmc_cy.cpython-312-darwin.so +0 -0
  518. scipy/stats/_qmvnt.py +16 -95
  519. scipy/stats/_qmvnt_cy.cpython-312-darwin.so +0 -0
  520. scipy/stats/_quantile.py +335 -0
  521. scipy/stats/_rcont/rcont.cpython-312-darwin.so +0 -0
  522. scipy/stats/_resampling.py +4 -29
  523. scipy/stats/_sampling.py +1 -1
  524. scipy/stats/_sobol.cpython-312-darwin.so +0 -0
  525. scipy/stats/_stats.cpython-312-darwin.so +0 -0
  526. scipy/stats/_stats_mstats_common.py +21 -2
  527. scipy/stats/_stats_py.py +550 -476
  528. scipy/stats/_stats_pythran.cpython-312-darwin.so +0 -0
  529. scipy/stats/_unuran/unuran_wrapper.cpython-312-darwin.so +0 -0
  530. scipy/stats/_unuran/unuran_wrapper.pyi +2 -1
  531. scipy/stats/_variation.py +6 -8
  532. scipy/stats/_wilcoxon.py +13 -7
  533. scipy/stats/tests/common_tests.py +6 -4
  534. scipy/stats/tests/test_axis_nan_policy.py +62 -24
  535. scipy/stats/tests/test_continued_fraction.py +173 -0
  536. scipy/stats/tests/test_continuous.py +379 -60
  537. scipy/stats/tests/test_continuous_basic.py +18 -12
  538. scipy/stats/tests/test_discrete_basic.py +14 -8
  539. scipy/stats/tests/test_discrete_distns.py +16 -16
  540. scipy/stats/tests/test_distributions.py +95 -75
  541. scipy/stats/tests/test_entropy.py +40 -48
  542. scipy/stats/tests/test_fit.py +4 -3
  543. scipy/stats/tests/test_hypotests.py +153 -24
  544. scipy/stats/tests/test_kdeoth.py +109 -41
  545. scipy/stats/tests/test_marray.py +289 -0
  546. scipy/stats/tests/test_morestats.py +79 -47
  547. scipy/stats/tests/test_mstats_basic.py +3 -3
  548. scipy/stats/tests/test_multivariate.py +434 -83
  549. scipy/stats/tests/test_qmc.py +13 -10
  550. scipy/stats/tests/test_quantile.py +199 -0
  551. scipy/stats/tests/test_rank.py +119 -112
  552. scipy/stats/tests/test_resampling.py +47 -56
  553. scipy/stats/tests/test_sampling.py +9 -4
  554. scipy/stats/tests/test_stats.py +799 -939
  555. scipy/stats/tests/test_variation.py +8 -6
  556. scipy/version.py +2 -2
  557. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/LICENSE.txt +4 -4
  558. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/METADATA +11 -11
  559. {scipy-1.15.3.dist-info → scipy-1.16.0rc2.dist-info}/RECORD +560 -567
  560. scipy-1.16.0rc2.dist-info/WHEEL +6 -0
  561. scipy/_lib/array_api_extra/_funcs.py +0 -484
  562. scipy/_lib/array_api_extra/_typing.py +0 -8
  563. scipy/interpolate/_bspl.cpython-312-darwin.so +0 -0
  564. scipy/optimize/_cobyla.cpython-312-darwin.so +0 -0
  565. scipy/optimize/_cython_nnls.cpython-312-darwin.so +0 -0
  566. scipy/optimize/_slsqp.cpython-312-darwin.so +0 -0
  567. scipy/spatial/qhull_src/COPYING.txt +0 -38
  568. scipy/special/libsf_error_state.dylib +0 -0
  569. scipy/special/tests/test_log_softmax.py +0 -109
  570. scipy/special/tests/test_xsf_cuda.py +0 -114
  571. scipy/special/xsf/binom.h +0 -89
  572. scipy/special/xsf/cdflib.h +0 -100
  573. scipy/special/xsf/cephes/airy.h +0 -307
  574. scipy/special/xsf/cephes/besselpoly.h +0 -51
  575. scipy/special/xsf/cephes/beta.h +0 -257
  576. scipy/special/xsf/cephes/cbrt.h +0 -131
  577. scipy/special/xsf/cephes/chbevl.h +0 -85
  578. scipy/special/xsf/cephes/chdtr.h +0 -193
  579. scipy/special/xsf/cephes/const.h +0 -87
  580. scipy/special/xsf/cephes/ellie.h +0 -293
  581. scipy/special/xsf/cephes/ellik.h +0 -251
  582. scipy/special/xsf/cephes/ellpe.h +0 -107
  583. scipy/special/xsf/cephes/ellpk.h +0 -117
  584. scipy/special/xsf/cephes/expn.h +0 -260
  585. scipy/special/xsf/cephes/gamma.h +0 -398
  586. scipy/special/xsf/cephes/hyp2f1.h +0 -596
  587. scipy/special/xsf/cephes/hyperg.h +0 -361
  588. scipy/special/xsf/cephes/i0.h +0 -149
  589. scipy/special/xsf/cephes/i1.h +0 -158
  590. scipy/special/xsf/cephes/igam.h +0 -421
  591. scipy/special/xsf/cephes/igam_asymp_coeff.h +0 -195
  592. scipy/special/xsf/cephes/igami.h +0 -313
  593. scipy/special/xsf/cephes/j0.h +0 -225
  594. scipy/special/xsf/cephes/j1.h +0 -198
  595. scipy/special/xsf/cephes/jv.h +0 -715
  596. scipy/special/xsf/cephes/k0.h +0 -164
  597. scipy/special/xsf/cephes/k1.h +0 -163
  598. scipy/special/xsf/cephes/kn.h +0 -243
  599. scipy/special/xsf/cephes/lanczos.h +0 -112
  600. scipy/special/xsf/cephes/ndtr.h +0 -275
  601. scipy/special/xsf/cephes/poch.h +0 -85
  602. scipy/special/xsf/cephes/polevl.h +0 -167
  603. scipy/special/xsf/cephes/psi.h +0 -194
  604. scipy/special/xsf/cephes/rgamma.h +0 -111
  605. scipy/special/xsf/cephes/scipy_iv.h +0 -811
  606. scipy/special/xsf/cephes/shichi.h +0 -248
  607. scipy/special/xsf/cephes/sici.h +0 -224
  608. scipy/special/xsf/cephes/sindg.h +0 -221
  609. scipy/special/xsf/cephes/tandg.h +0 -139
  610. scipy/special/xsf/cephes/trig.h +0 -58
  611. scipy/special/xsf/cephes/unity.h +0 -186
  612. scipy/special/xsf/cephes/zeta.h +0 -172
  613. scipy/special/xsf/config.h +0 -304
  614. scipy/special/xsf/digamma.h +0 -205
  615. scipy/special/xsf/error.h +0 -57
  616. scipy/special/xsf/evalpoly.h +0 -47
  617. scipy/special/xsf/expint.h +0 -266
  618. scipy/special/xsf/hyp2f1.h +0 -694
  619. scipy/special/xsf/iv_ratio.h +0 -173
  620. scipy/special/xsf/lambertw.h +0 -150
  621. scipy/special/xsf/loggamma.h +0 -163
  622. scipy/special/xsf/sici.h +0 -200
  623. scipy/special/xsf/tools.h +0 -427
  624. scipy/special/xsf/trig.h +0 -164
  625. scipy/special/xsf/wright_bessel.h +0 -843
  626. scipy/special/xsf/zlog1.h +0 -35
  627. scipy/stats/_mvn.cpython-312-darwin.so +0 -0
  628. scipy-1.15.3.dist-info/WHEEL +0 -4
@@ -5,34 +5,97 @@ Functions which start with an underscore are for internal use only but helpers
5
5
  that are in __all__ are intended as additional helper functions for use by end
6
6
  users of the compat library.
7
7
  """
8
+
8
9
  from __future__ import annotations
9
10
 
10
- from typing import TYPE_CHECKING
11
+ import inspect
12
+ import math
13
+ import sys
14
+ import warnings
15
+ from collections.abc import Collection, Hashable
16
+ from functools import lru_cache
17
+ from typing import (
18
+ TYPE_CHECKING,
19
+ Any,
20
+ Final,
21
+ Literal,
22
+ SupportsIndex,
23
+ TypeAlias,
24
+ TypeGuard,
25
+ TypeVar,
26
+ cast,
27
+ overload,
28
+ )
29
+
30
+ from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace
11
31
 
12
32
  if TYPE_CHECKING:
13
- from typing import Optional, Union, Any
14
- from ._typing import Array, Device
15
33
 
16
- import sys
17
- import math
18
- import inspect
19
- import warnings
34
+ import dask.array as da
35
+ import jax
36
+ import ndonnx as ndx
37
+ import numpy as np
38
+ import numpy.typing as npt
39
+ import sparse # pyright: ignore[reportMissingTypeStubs]
40
+ import torch
41
+
42
+ # TODO: import from typing (requires Python >=3.13)
43
+ from typing_extensions import TypeIs, TypeVar
44
+
45
+ _SizeT = TypeVar("_SizeT", bound = int | None)
20
46
 
21
- def _is_jax_zero_gradient_array(x):
47
+ _ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
48
+ _CupyArray: TypeAlias = Any # cupy has no py.typed
49
+
50
+ _ArrayApiObj: TypeAlias = (
51
+ npt.NDArray[Any]
52
+ | da.Array
53
+ | jax.Array
54
+ | ndx.Array
55
+ | sparse.SparseArray
56
+ | torch.Tensor
57
+ | SupportsArrayNamespace[Any]
58
+ | _CupyArray
59
+ )
60
+
61
+ _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
62
+ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
63
+
64
+
65
+ @lru_cache(100)
66
+ def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
67
+ try:
68
+ mod = sys.modules[modname]
69
+ except KeyError:
70
+ return False
71
+ parent_cls = getattr(mod, clsname)
72
+ return issubclass(cls, parent_cls)
73
+
74
+
75
+ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
22
76
  """Return True if `x` is a zero-gradient array.
23
77
 
24
78
  These arrays are a design quirk of Jax that may one day be removed.
25
79
  See https://github.com/google/jax/issues/20620.
26
80
  """
27
- if 'numpy' not in sys.modules or 'jax' not in sys.modules:
81
+ # Fast exit
82
+ try:
83
+ dtype = x.dtype # type: ignore[attr-defined]
84
+ except AttributeError:
85
+ return False
86
+ cls = cast(Hashable, type(dtype))
87
+ if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
88
+ return False
89
+
90
+ if "jax" not in sys.modules:
28
91
  return False
29
92
 
30
- import numpy as np
31
93
  import jax
94
+ # jax.float0 is a np.dtype([('float0', 'V')])
95
+ return dtype == jax.float0
32
96
 
33
- return isinstance(x, np.ndarray) and x.dtype == jax.float0
34
97
 
35
- def is_numpy_array(x):
98
+ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
36
99
  """
37
100
  Return True if `x` is a NumPy array.
38
101
 
@@ -53,17 +116,15 @@ def is_numpy_array(x):
53
116
  is_jax_array
54
117
  is_pydata_sparse_array
55
118
  """
56
- # Avoid importing NumPy if it isn't already
57
- if 'numpy' not in sys.modules:
58
- return False
59
-
60
- import numpy as np
61
-
62
119
  # TODO: Should we reject ndarray subclasses?
63
- return (isinstance(x, (np.ndarray, np.generic))
64
- and not _is_jax_zero_gradient_array(x))
120
+ cls = cast(Hashable, type(x))
121
+ return (
122
+ _issubclass_fast(cls, "numpy", "ndarray")
123
+ or _issubclass_fast(cls, "numpy", "generic")
124
+ ) and not _is_jax_zero_gradient_array(x)
65
125
 
66
- def is_cupy_array(x):
126
+
127
+ def is_cupy_array(x: object) -> bool:
67
128
  """
68
129
  Return True if `x` is a CuPy array.
69
130
 
@@ -84,16 +145,11 @@ def is_cupy_array(x):
84
145
  is_jax_array
85
146
  is_pydata_sparse_array
86
147
  """
87
- # Avoid importing CuPy if it isn't already
88
- if 'cupy' not in sys.modules:
89
- return False
90
-
91
- import cupy as cp
148
+ cls = cast(Hashable, type(x))
149
+ return _issubclass_fast(cls, "cupy", "ndarray")
92
150
 
93
- # TODO: Should we reject ndarray subclasses?
94
- return isinstance(x, (cp.ndarray, cp.generic))
95
151
 
96
- def is_torch_array(x):
152
+ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
97
153
  """
98
154
  Return True if `x` is a PyTorch tensor.
99
155
 
@@ -111,16 +167,11 @@ def is_torch_array(x):
111
167
  is_jax_array
112
168
  is_pydata_sparse_array
113
169
  """
114
- # Avoid importing torch if it isn't already
115
- if 'torch' not in sys.modules:
116
- return False
170
+ cls = cast(Hashable, type(x))
171
+ return _issubclass_fast(cls, "torch", "Tensor")
117
172
 
118
- import torch
119
173
 
120
- # TODO: Should we reject ndarray subclasses?
121
- return isinstance(x, torch.Tensor)
122
-
123
- def is_ndonnx_array(x):
174
+ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
124
175
  """
125
176
  Return True if `x` is a ndonnx Array.
126
177
 
@@ -139,15 +190,11 @@ def is_ndonnx_array(x):
139
190
  is_jax_array
140
191
  is_pydata_sparse_array
141
192
  """
142
- # Avoid importing torch if it isn't already
143
- if 'ndonnx' not in sys.modules:
144
- return False
145
-
146
- import ndonnx as ndx
193
+ cls = cast(Hashable, type(x))
194
+ return _issubclass_fast(cls, "ndonnx", "Array")
147
195
 
148
- return isinstance(x, ndx.Array)
149
196
 
150
- def is_dask_array(x):
197
+ def is_dask_array(x: object) -> TypeIs[da.Array]:
151
198
  """
152
199
  Return True if `x` is a dask.array Array.
153
200
 
@@ -166,15 +213,11 @@ def is_dask_array(x):
166
213
  is_jax_array
167
214
  is_pydata_sparse_array
168
215
  """
169
- # Avoid importing dask if it isn't already
170
- if 'dask.array' not in sys.modules:
171
- return False
172
-
173
- import dask.array
216
+ cls = cast(Hashable, type(x))
217
+ return _issubclass_fast(cls, "dask.array", "Array")
174
218
 
175
- return isinstance(x, dask.array.Array)
176
219
 
177
- def is_jax_array(x):
220
+ def is_jax_array(x: object) -> TypeIs[jax.Array]:
178
221
  """
179
222
  Return True if `x` is a JAX array.
180
223
 
@@ -194,15 +237,11 @@ def is_jax_array(x):
194
237
  is_dask_array
195
238
  is_pydata_sparse_array
196
239
  """
197
- # Avoid importing jax if it isn't already
198
- if 'jax' not in sys.modules:
199
- return False
240
+ cls = cast(Hashable, type(x))
241
+ return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
200
242
 
201
- import jax
202
-
203
- return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204
243
 
205
- def is_pydata_sparse_array(x) -> bool:
244
+ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
206
245
  """
207
246
  Return True if `x` is an array from the `sparse` package.
208
247
 
@@ -222,16 +261,12 @@ def is_pydata_sparse_array(x) -> bool:
222
261
  is_dask_array
223
262
  is_jax_array
224
263
  """
225
- # Avoid importing jax if it isn't already
226
- if 'sparse' not in sys.modules:
227
- return False
228
-
229
- import sparse
230
-
231
264
  # TODO: Account for other backends.
232
- return isinstance(x, sparse.SparseArray)
265
+ cls = cast(Hashable, type(x))
266
+ return _issubclass_fast(cls, "sparse", "SparseArray")
267
+
233
268
 
234
- def is_array_api_obj(x):
269
+ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
235
270
  """
236
271
  Return True if `x` is an array API compatible array object.
237
272
 
@@ -246,19 +281,34 @@ def is_array_api_obj(x):
246
281
  is_dask_array
247
282
  is_jax_array
248
283
  """
249
- return is_numpy_array(x) \
250
- or is_cupy_array(x) \
251
- or is_torch_array(x) \
252
- or is_dask_array(x) \
253
- or is_jax_array(x) \
254
- or is_pydata_sparse_array(x) \
255
- or hasattr(x, '__array_namespace__')
256
-
257
- def _compat_module_name():
258
- assert __name__.endswith('.common._helpers')
259
- return __name__.removesuffix('.common._helpers')
260
-
261
- def is_numpy_namespace(xp) -> bool:
284
+ return (
285
+ hasattr(x, '__array_namespace__')
286
+ or _is_array_api_cls(cast(Hashable, type(x)))
287
+ )
288
+
289
+
290
+ @lru_cache(100)
291
+ def _is_array_api_cls(cls: type) -> bool:
292
+ return (
293
+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
294
+ _issubclass_fast(cls, "numpy", "ndarray")
295
+ or _issubclass_fast(cls, "numpy", "generic")
296
+ or _issubclass_fast(cls, "cupy", "ndarray")
297
+ or _issubclass_fast(cls, "torch", "Tensor")
298
+ or _issubclass_fast(cls, "dask.array", "Array")
299
+ or _issubclass_fast(cls, "sparse", "SparseArray")
300
+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
301
+ or _issubclass_fast(cls, "jax", "Array")
302
+ )
303
+
304
+
305
+ def _compat_module_name() -> str:
306
+ assert __name__.endswith(".common._helpers")
307
+ return __name__.removesuffix(".common._helpers")
308
+
309
+
310
+ @lru_cache(100)
311
+ def is_numpy_namespace(xp: Namespace) -> bool:
262
312
  """
263
313
  Returns True if `xp` is a NumPy namespace.
264
314
 
@@ -276,9 +326,11 @@ def is_numpy_namespace(xp) -> bool:
276
326
  is_pydata_sparse_namespace
277
327
  is_array_api_strict_namespace
278
328
  """
279
- return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
329
+ return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
280
330
 
281
- def is_cupy_namespace(xp) -> bool:
331
+
332
+ @lru_cache(100)
333
+ def is_cupy_namespace(xp: Namespace) -> bool:
282
334
  """
283
335
  Returns True if `xp` is a CuPy namespace.
284
336
 
@@ -296,9 +348,11 @@ def is_cupy_namespace(xp) -> bool:
296
348
  is_pydata_sparse_namespace
297
349
  is_array_api_strict_namespace
298
350
  """
299
- return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
351
+ return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
352
+
300
353
 
301
- def is_torch_namespace(xp) -> bool:
354
+ @lru_cache(100)
355
+ def is_torch_namespace(xp: Namespace) -> bool:
302
356
  """
303
357
  Returns True if `xp` is a PyTorch namespace.
304
358
 
@@ -316,10 +370,10 @@ def is_torch_namespace(xp) -> bool:
316
370
  is_pydata_sparse_namespace
317
371
  is_array_api_strict_namespace
318
372
  """
319
- return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
373
+ return xp.__name__ in {"torch", _compat_module_name() + ".torch"}
320
374
 
321
375
 
322
- def is_ndonnx_namespace(xp):
376
+ def is_ndonnx_namespace(xp: Namespace) -> bool:
323
377
  """
324
378
  Returns True if `xp` is an NDONNX namespace.
325
379
 
@@ -335,9 +389,11 @@ def is_ndonnx_namespace(xp):
335
389
  is_pydata_sparse_namespace
336
390
  is_array_api_strict_namespace
337
391
  """
338
- return xp.__name__ == 'ndonnx'
392
+ return xp.__name__ == "ndonnx"
339
393
 
340
- def is_dask_namespace(xp):
394
+
395
+ @lru_cache(100)
396
+ def is_dask_namespace(xp: Namespace) -> bool:
341
397
  """
342
398
  Returns True if `xp` is a Dask namespace.
343
399
 
@@ -355,9 +411,10 @@ def is_dask_namespace(xp):
355
411
  is_pydata_sparse_namespace
356
412
  is_array_api_strict_namespace
357
413
  """
358
- return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
414
+ return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
415
+
359
416
 
360
- def is_jax_namespace(xp):
417
+ def is_jax_namespace(xp: Namespace) -> bool:
361
418
  """
362
419
  Returns True if `xp` is a JAX namespace.
363
420
 
@@ -376,9 +433,10 @@ def is_jax_namespace(xp):
376
433
  is_pydata_sparse_namespace
377
434
  is_array_api_strict_namespace
378
435
  """
379
- return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
436
+ return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"}
380
437
 
381
- def is_pydata_sparse_namespace(xp):
438
+
439
+ def is_pydata_sparse_namespace(xp: Namespace) -> bool:
382
440
  """
383
441
  Returns True if `xp` is a pydata/sparse namespace.
384
442
 
@@ -394,9 +452,10 @@ def is_pydata_sparse_namespace(xp):
394
452
  is_jax_namespace
395
453
  is_array_api_strict_namespace
396
454
  """
397
- return xp.__name__ == 'sparse'
455
+ return xp.__name__ == "sparse"
456
+
398
457
 
399
- def is_array_api_strict_namespace(xp):
458
+ def is_array_api_strict_namespace(xp: Namespace) -> bool:
400
459
  """
401
460
  Returns True if `xp` is an array-api-strict namespace.
402
461
 
@@ -412,27 +471,37 @@ def is_array_api_strict_namespace(xp):
412
471
  is_jax_namespace
413
472
  is_pydata_sparse_namespace
414
473
  """
415
- return xp.__name__ == 'array_api_strict'
474
+ return xp.__name__ == "array_api_strict"
475
+
416
476
 
417
- def _check_api_version(api_version):
418
- if api_version in ['2021.12', '2022.12']:
419
- warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
420
- elif api_version is not None and api_version not in ['2021.12', '2022.12',
421
- '2023.12']:
422
- raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
477
+ def _check_api_version(api_version: str | None) -> None:
478
+ if api_version in _API_VERSIONS_OLD:
479
+ warnings.warn(
480
+ f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12"
481
+ )
482
+ elif api_version is not None and api_version not in _API_VERSIONS:
483
+ raise ValueError(
484
+ "Only the 2024.12 version of the array API specification is currently supported"
485
+ )
423
486
 
424
- def array_namespace(*xs, api_version=None, use_compat=None):
487
+
488
+ def array_namespace(
489
+ *xs: Array | complex | None,
490
+ api_version: str | None = None,
491
+ use_compat: bool | None = None,
492
+ ) -> Namespace:
425
493
  """
426
494
  Get the array API compatible namespace for the arrays `xs`.
427
495
 
428
496
  Parameters
429
497
  ----------
430
498
  xs: arrays
431
- one or more arrays.
499
+ one or more arrays. xs can also be Python scalars (bool, int, float,
500
+ complex, or None), which are ignored.
432
501
 
433
502
  api_version: str
434
503
  The newest version of the spec that you need support for (currently
435
- the compat library wrapped APIs support v2023.12).
504
+ the compat library wrapped APIs support v2024.12).
436
505
 
437
506
  use_compat: bool or None
438
507
  If None (the default), the native namespace will be returned if it is
@@ -489,11 +558,13 @@ def array_namespace(*xs, api_version=None, use_compat=None):
489
558
 
490
559
  _use_compat = use_compat in [None, True]
491
560
 
492
- namespaces = set()
561
+ namespaces: set[Namespace] = set()
493
562
  for x in xs:
494
563
  if is_numpy_array(x):
495
- from .. import numpy as numpy_namespace
496
564
  import numpy as np
565
+
566
+ from .. import numpy as numpy_namespace
567
+
497
568
  if use_compat is True:
498
569
  _check_api_version(api_version)
499
570
  namespaces.add(numpy_namespace)
@@ -507,25 +578,31 @@ def array_namespace(*xs, api_version=None, use_compat=None):
507
578
  if _use_compat:
508
579
  _check_api_version(api_version)
509
580
  from .. import cupy as cupy_namespace
581
+
510
582
  namespaces.add(cupy_namespace)
511
583
  else:
512
- import cupy as cp
584
+ import cupy as cp # pyright: ignore[reportMissingTypeStubs]
585
+
513
586
  namespaces.add(cp)
514
587
  elif is_torch_array(x):
515
588
  if _use_compat:
516
589
  _check_api_version(api_version)
517
590
  from .. import torch as torch_namespace
591
+
518
592
  namespaces.add(torch_namespace)
519
593
  else:
520
594
  import torch
595
+
521
596
  namespaces.add(torch)
522
597
  elif is_dask_array(x):
523
598
  if _use_compat:
524
599
  _check_api_version(api_version)
525
600
  from ..dask import array as dask_namespace
601
+
526
602
  namespaces.add(dask_namespace)
527
603
  else:
528
604
  import dask.array as da
605
+
529
606
  namespaces.add(da)
530
607
  elif is_jax_array(x):
531
608
  if use_compat is True:
@@ -537,24 +614,30 @@ def array_namespace(*xs, api_version=None, use_compat=None):
537
614
  # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
538
615
  # For older JAX versions, it is available via jax.experimental.array_api.
539
616
  import jax.numpy
617
+
540
618
  if hasattr(jax.numpy, "__array_api_version__"):
541
619
  jnp = jax.numpy
542
620
  else:
543
- import jax.experimental.array_api as jnp
621
+ import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
544
622
  namespaces.add(jnp)
545
623
  elif is_pydata_sparse_array(x):
546
624
  if use_compat is True:
547
625
  _check_api_version(api_version)
548
626
  raise ValueError("`sparse` does not have an array-api-compat wrapper")
549
627
  else:
550
- import sparse
628
+ import sparse # pyright: ignore[reportMissingTypeStubs]
551
629
  # `sparse` is already an array namespace. We do not have a wrapper
552
630
  # submodule for it.
553
631
  namespaces.add(sparse)
554
- elif hasattr(x, '__array_namespace__'):
632
+ elif hasattr(x, "__array_namespace__"):
555
633
  if use_compat is True:
556
- raise ValueError("The given array does not have an array-api-compat wrapper")
634
+ raise ValueError(
635
+ "The given array does not have an array-api-compat wrapper"
636
+ )
637
+ x = cast("SupportsArrayNamespace[Any]", x)
557
638
  namespaces.add(x.__array_namespace__(api_version=api_version))
639
+ elif isinstance(x, (bool, int, float, complex, type(None))):
640
+ continue
558
641
  else:
559
642
  # TODO: Support Python scalars?
560
643
  raise TypeError(f"{type(x).__name__} is not a supported array type")
@@ -565,34 +648,55 @@ def array_namespace(*xs, api_version=None, use_compat=None):
565
648
  if len(namespaces) != 1:
566
649
  raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
567
650
 
568
- xp, = namespaces
651
+ (xp,) = namespaces
569
652
 
570
653
  return xp
571
654
 
655
+
572
656
  # backwards compatibility alias
573
657
  get_namespace = array_namespace
574
658
 
575
- def _check_device(xp, device):
576
- if xp == sys.modules.get('numpy'):
577
- if device not in ["cpu", None]:
659
+
660
+ def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
661
+ """
662
+ Validate dummy device on device-less array backends.
663
+
664
+ Notes
665
+ -----
666
+ This function is also invoked by CuPy, which does have multiple devices
667
+ if there are multiple GPUs available.
668
+ However, CuPy multi-device support is currently impossible
669
+ without using the global device or a context manager:
670
+
671
+ https://github.com/data-apis/array-api-compat/pull/293
672
+ """
673
+ if bare_xp is sys.modules.get("numpy"):
674
+ if device not in ("cpu", None):
578
675
  raise ValueError(f"Unsupported device for NumPy: {device!r}")
579
676
 
677
+ elif bare_xp is sys.modules.get("dask.array"):
678
+ if device not in ("cpu", _DASK_DEVICE, None):
679
+ raise ValueError(f"Unsupported device for Dask: {device!r}")
680
+
681
+
580
682
  # Placeholder object to represent the dask device
581
683
  # when the array backend is not the CPU.
582
684
  # (since it is not easy to tell which device a dask array is on)
583
685
  class _dask_device:
584
- def __repr__(self):
686
+ def __repr__(self) -> Literal["DASK_DEVICE"]:
585
687
  return "DASK_DEVICE"
586
688
 
689
+
587
690
  _DASK_DEVICE = _dask_device()
588
691
 
692
+
589
693
  # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
590
694
  # or cupy.ndarray. They are not included in array objects of this library
591
695
  # because this library just reuses the respective ndarray classes without
592
696
  # wrapping or subclassing them. These helper functions can be used instead of
593
697
  # the wrapper functions for libraries that need to support both NumPy/CuPy and
594
698
  # other libraries that use devices.
595
- def device(x: Array, /) -> Device:
699
+ def device(x: _ArrayApiObj, /) -> Device:
596
700
  """
597
701
  Hardware device the array data resides on.
598
702
 
@@ -627,86 +731,86 @@ def device(x: Array, /) -> Device:
627
731
  if is_numpy_array(x):
628
732
  return "cpu"
629
733
  elif is_dask_array(x):
630
- # Peek at the metadata of the jax array to determine type
631
- try:
632
- import numpy as np
633
- if isinstance(x._meta, np.ndarray):
634
- # Must be on CPU since backed by numpy
635
- return "cpu"
636
- except ImportError:
637
- pass
734
+ # Peek at the metadata of the Dask array to determine type
735
+ if is_numpy_array(x._meta): # pyright: ignore
736
+ # Must be on CPU since backed by numpy
737
+ return "cpu"
638
738
  return _DASK_DEVICE
639
739
  elif is_jax_array(x):
640
- # JAX has .device() as a method, but it is being deprecated so that it
641
- # can become a property, in accordance with the standard. In order for
642
- # this function to not break when JAX makes the flip, we check for
643
- # both here.
644
- if inspect.ismethod(x.device):
645
- return x.device()
740
+ # FIXME Jitted JAX arrays do not have a device attribute
741
+ # https://github.com/jax-ml/jax/issues/26000
742
+ # Return None in this case. Note that this workaround breaks
743
+ # the standard and will result in new arrays being created on the
744
+ # default device instead of the same device as the input array(s).
745
+ x_device = getattr(x, "device", None)
746
+ # Older JAX releases had .device() as a method, which has been replaced
747
+ # with a property in accordance with the standard.
748
+ if inspect.ismethod(x_device):
749
+ return x_device()
646
750
  else:
647
- return x.device
751
+ return x_device
648
752
  elif is_pydata_sparse_array(x):
649
753
  # `sparse` will gain `.device`, so check for this first.
650
- x_device = getattr(x, 'device', None)
754
+ x_device = getattr(x, "device", None)
651
755
  if x_device is not None:
652
756
  return x_device
653
757
  # Everything but DOK has this attr.
654
758
  try:
655
- inner = x.data
759
+ inner = x.data # pyright: ignore
656
760
  except AttributeError:
657
761
  return "cpu"
658
762
  # Return the device of the constituent array
659
- return device(inner)
660
- return x.device
763
+ return device(inner) # pyright: ignore
764
+ return x.device # pyright: ignore
765
+
661
766
 
662
767
  # Prevent shadowing, used below
663
768
  _device = device
664
769
 
770
+
665
771
  # Based on cupy.array_api.Array.to_device
666
- def _cupy_to_device(x, device, /, stream=None):
772
+ def _cupy_to_device(
773
+ x: _CupyArray,
774
+ device: Device,
775
+ /,
776
+ stream: int | Any | None = None,
777
+ ) -> _CupyArray:
667
778
  import cupy as cp
668
- from cupy.cuda import Device as _Device
669
- from cupy.cuda import stream as stream_module
670
- from cupy_backends.cuda.api import runtime
671
779
 
672
- if device == x.device:
673
- return x
674
- elif device == "cpu":
780
+ if device == "cpu":
675
781
  # allowing us to use `to_device(x, "cpu")`
676
782
  # is useful for portable test swapping between
677
783
  # host and device backends
678
784
  return x.get()
679
- elif not isinstance(device, _Device):
680
- raise ValueError(f"Unsupported device {device!r}")
681
- else:
682
- # see cupy/cupy#5985 for the reason how we handle device/stream here
683
- prev_device = runtime.getDevice()
684
- prev_stream: stream_module.Stream = None
685
- if stream is not None:
686
- prev_stream = stream_module.get_current_stream()
687
- # stream can be an int as specified in __dlpack__, or a CuPy stream
688
- if isinstance(stream, int):
689
- stream = cp.cuda.ExternalStream(stream)
690
- elif isinstance(stream, cp.cuda.Stream):
691
- pass
692
- else:
693
- raise ValueError('the input stream is not recognized')
694
- stream.use()
695
- try:
696
- runtime.setDevice(device.id)
697
- arr = x.copy()
698
- finally:
699
- runtime.setDevice(prev_device)
700
- if stream is not None:
701
- prev_stream.use()
702
- return arr
703
-
704
- def _torch_to_device(x, device, /, stream=None):
785
+ if not isinstance(device, cp.cuda.Device):
786
+ raise TypeError(f"Unsupported device type {device!r}")
787
+
788
+ if stream is None:
789
+ with device:
790
+ return cp.asarray(x)
791
+
792
+ # stream can be an int as specified in __dlpack__, or a CuPy stream
793
+ if isinstance(stream, int):
794
+ stream = cp.cuda.ExternalStream(stream)
795
+ elif not isinstance(stream, cp.cuda.Stream):
796
+ raise TypeError(f"Unsupported stream type {stream!r}")
797
+
798
+ with device, stream:
799
+ return cp.asarray(x)
800
+
801
+
802
+ def _torch_to_device(
803
+ x: torch.Tensor,
804
+ device: torch.device | str | int,
805
+ /,
806
+ stream: None = None,
807
+ ) -> torch.Tensor:
705
808
  if stream is not None:
706
809
  raise NotImplementedError
707
810
  return x.to(device)
708
811
 
709
- def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
812
+
813
+ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array:
710
814
  """
711
815
  Copy the array from the device on which it currently resides to the specified ``device``.
712
816
 
@@ -726,7 +830,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
726
830
  a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
727
831
  section of the array API specification).
728
832
 
729
- stream: Optional[Union[int, Any]]
833
+ stream: int | Any | None
730
834
  stream object to use during copy. In addition to the types supported
731
835
  in ``array.__dlpack__``, implementations may choose to support any
732
836
  library-specific stream object with the caveat that any code using
@@ -758,45 +862,169 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
758
862
  if is_numpy_array(x):
759
863
  if stream is not None:
760
864
  raise ValueError("The stream argument to to_device() is not supported")
761
- if device == 'cpu':
865
+ if device == "cpu":
762
866
  return x
763
867
  raise ValueError(f"Unsupported device {device!r}")
764
868
  elif is_cupy_array(x):
765
869
  # cupy does not yet have to_device
766
870
  return _cupy_to_device(x, device, stream=stream)
767
871
  elif is_torch_array(x):
768
- return _torch_to_device(x, device, stream=stream)
872
+ return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
769
873
  elif is_dask_array(x):
770
874
  if stream is not None:
771
875
  raise ValueError("The stream argument to to_device() is not supported")
772
876
  # TODO: What if our array is on the GPU already?
773
- if device == 'cpu':
877
+ if device == "cpu":
774
878
  return x
775
879
  raise ValueError(f"Unsupported device {device!r}")
776
880
  elif is_jax_array(x):
777
881
  if not hasattr(x, "__array_namespace__"):
778
- # In JAX v0.4.31 and older, this import adds to_device method to x.
779
- import jax.experimental.array_api # noqa: F401
882
+ # In JAX v0.4.31 and older, this import adds to_device method to x...
883
+ import jax.experimental.array_api # noqa: F401 # pyright: ignore
884
+
885
+ # ... but only on eager JAX. It won't work inside jax.jit.
886
+ if not hasattr(x, "to_device"):
887
+ return x
780
888
  return x.to_device(device, stream=stream)
781
889
  elif is_pydata_sparse_array(x) and device == _device(x):
782
890
  # Perform trivial check to return the same array if
783
891
  # device is same instead of err-ing.
784
892
  return x
785
- return x.to_device(device, stream=stream)
893
+ return x.to_device(device, stream=stream) # pyright: ignore
786
894
 
787
- def size(x):
895
+
896
+ @overload
897
+ def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
898
+ @overload
899
+ def size(x: HasShape[Collection[None]]) -> None: ...
900
+ @overload
901
+ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
902
+ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
788
903
  """
789
904
  Return the total number of elements of x.
790
905
 
791
906
  This is equivalent to `x.size` according to the `standard
792
907
  <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
908
+
793
909
  This helper is included because PyTorch defines `size` in an
794
910
  :external+torch:meth:`incompatible way <torch.Tensor.size>`.
795
-
911
+ It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
912
+ the standard requires None.
796
913
  """
914
+ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
797
915
  if None in x.shape:
798
916
  return None
799
- return math.prod(x.shape)
917
+ out = math.prod(cast("Collection[SupportsIndex]", x.shape))
918
+ # dask.array.Array.shape can contain NaN
919
+ return None if math.isnan(out) else out
920
+
921
+
922
+ @lru_cache(100)
923
+ def _is_writeable_cls(cls: type) -> bool | None:
924
+ if (
925
+ _issubclass_fast(cls, "numpy", "generic")
926
+ or _issubclass_fast(cls, "jax", "Array")
927
+ or _issubclass_fast(cls, "sparse", "SparseArray")
928
+ ):
929
+ return False
930
+ if _is_array_api_cls(cls):
931
+ return True
932
+ return None
933
+
934
+
935
+ def is_writeable_array(x: object) -> bool:
936
+ """
937
+ Return False if ``x.__setitem__`` is expected to raise; True otherwise.
938
+ Return False if `x` is not an array API compatible object.
939
+
940
+ Warning
941
+ -------
942
+ As there is no standard way to check if an array is writeable without actually
943
+ writing to it, this function blindly returns True for all unknown array types.
944
+ """
945
+ cls = cast(Hashable, type(x))
946
+ if _issubclass_fast(cls, "numpy", "ndarray"):
947
+ return cast("npt.NDArray", x).flags.writeable
948
+ res = _is_writeable_cls(cls)
949
+ if res is not None:
950
+ return res
951
+ return hasattr(x, '__array_namespace__')
952
+
953
+
954
+ @lru_cache(100)
955
+ def _is_lazy_cls(cls: type) -> bool | None:
956
+ if (
957
+ _issubclass_fast(cls, "numpy", "ndarray")
958
+ or _issubclass_fast(cls, "numpy", "generic")
959
+ or _issubclass_fast(cls, "cupy", "ndarray")
960
+ or _issubclass_fast(cls, "torch", "Tensor")
961
+ or _issubclass_fast(cls, "sparse", "SparseArray")
962
+ ):
963
+ return False
964
+ if (
965
+ _issubclass_fast(cls, "jax", "Array")
966
+ or _issubclass_fast(cls, "dask.array", "Array")
967
+ or _issubclass_fast(cls, "ndonnx", "Array")
968
+ ):
969
+ return True
970
+ return None
971
+
972
+
973
+ def is_lazy_array(x: object) -> bool:
974
+ """Return True if x is potentially a future or it may be otherwise impossible or
975
+ expensive to eagerly read its contents, regardless of their size, e.g. by
976
+ calling ``bool(x)`` or ``float(x)``.
977
+
978
+ Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
979
+ cheap as long as the array has the right dtype and size.
980
+
981
+ Note
982
+ ----
983
+ This function errs on the side of caution for array types that may or may not be
984
+ lazy, e.g. JAX arrays, by always returning True for them.
985
+ """
986
+ # **JAX note:** while it is possible to determine if you're inside or outside
987
+ # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
988
+ # as we do below for unknown arrays, this is not recommended by JAX best practices.
989
+
990
+ # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
991
+ # This behaviour, while impossible to change without breaking backwards
992
+ # compatibility, is highly detrimental to performance as the whole graph will end
993
+ # up being computed multiple times.
994
+
995
+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
996
+ # exclusively get them once they leave a jax.grad JIT context.
997
+ cls = cast(Hashable, type(x))
998
+ res = _is_lazy_cls(cls)
999
+ if res is not None:
1000
+ return res
1001
+
1002
+ if not hasattr(x, "__array_namespace__"):
1003
+ return False
1004
+
1005
+ # Unknown Array API compatible object. Note that this test may have dire consequences
1006
+ # in terms of performance, e.g. for a lazy object that eagerly computes the graph
1007
+ # on __bool__ (dask is one such example, which however is special-cased above).
1008
+
1009
+ # Select a single point of the array
1010
+ s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
1011
+ if s is None:
1012
+ return True
1013
+ xp = array_namespace(x)
1014
+ if s > 1:
1015
+ x = xp.reshape(x, (-1,))[0]
1016
+ # Cast to dtype=bool and deal with size 0 arrays
1017
+ x = xp.any(x)
1018
+
1019
+ try:
1020
+ bool(x)
1021
+ return False
1022
+ # The Array API standard dictactes that __bool__ should raise TypeError if the
1023
+ # output cannot be defined.
1024
+ # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
1025
+ except Exception:
1026
+ return True
1027
+
800
1028
 
801
1029
  __all__ = [
802
1030
  "array_namespace",
@@ -818,8 +1046,13 @@ __all__ = [
818
1046
  "is_ndonnx_namespace",
819
1047
  "is_pydata_sparse_array",
820
1048
  "is_pydata_sparse_namespace",
1049
+ "is_writeable_array",
1050
+ "is_lazy_array",
821
1051
  "size",
822
1052
  "to_device",
823
1053
  ]
824
1054
 
825
- _all_ignore = ['sys', 'math', 'inspect', 'warnings']
1055
+ _all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
1056
+
1057
+ def __dir__() -> list[str]:
1058
+ return __all__