scipy 1.15.2__cp312-cp312-musllinux_1_2_aarch64.whl → 1.16.0rc1__cp312-cp312-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (649) hide show
  1. scipy/__config__.py +11 -11
  2. scipy/__init__.py +3 -6
  3. scipy/_cyutility.cpython-312-aarch64-linux-musl.so +0 -0
  4. scipy/_lib/_array_api.py +497 -161
  5. scipy/_lib/_array_api_compat_vendor.py +9 -0
  6. scipy/_lib/_bunch.py +4 -0
  7. scipy/_lib/_ccallback_c.cpython-312-aarch64-linux-musl.so +0 -0
  8. scipy/_lib/_docscrape.py +1 -1
  9. scipy/_lib/_elementwise_iterative_method.py +15 -26
  10. scipy/_lib/_fpumode.cpython-312-aarch64-linux-musl.so +0 -0
  11. scipy/_lib/_sparse.py +41 -0
  12. scipy/_lib/_test_ccallback.cpython-312-aarch64-linux-musl.so +0 -0
  13. scipy/_lib/_test_deprecation_call.cpython-312-aarch64-linux-musl.so +0 -0
  14. scipy/_lib/_test_deprecation_def.cpython-312-aarch64-linux-musl.so +0 -0
  15. scipy/_lib/_testutils.py +6 -2
  16. scipy/_lib/_uarray/_uarray.cpython-312-aarch64-linux-musl.so +0 -0
  17. scipy/_lib/_util.py +222 -125
  18. scipy/_lib/array_api_compat/__init__.py +4 -4
  19. scipy/_lib/array_api_compat/_internal.py +19 -6
  20. scipy/_lib/array_api_compat/common/__init__.py +1 -1
  21. scipy/_lib/array_api_compat/common/_aliases.py +365 -193
  22. scipy/_lib/array_api_compat/common/_fft.py +94 -64
  23. scipy/_lib/array_api_compat/common/_helpers.py +413 -180
  24. scipy/_lib/array_api_compat/common/_linalg.py +116 -40
  25. scipy/_lib/array_api_compat/common/_typing.py +179 -10
  26. scipy/_lib/array_api_compat/cupy/__init__.py +1 -4
  27. scipy/_lib/array_api_compat/cupy/_aliases.py +61 -41
  28. scipy/_lib/array_api_compat/cupy/_info.py +16 -6
  29. scipy/_lib/array_api_compat/cupy/_typing.py +24 -39
  30. scipy/_lib/array_api_compat/dask/array/__init__.py +6 -3
  31. scipy/_lib/array_api_compat/dask/array/_aliases.py +267 -108
  32. scipy/_lib/array_api_compat/dask/array/_info.py +105 -34
  33. scipy/_lib/array_api_compat/dask/array/fft.py +5 -8
  34. scipy/_lib/array_api_compat/dask/array/linalg.py +21 -22
  35. scipy/_lib/array_api_compat/numpy/__init__.py +13 -15
  36. scipy/_lib/array_api_compat/numpy/_aliases.py +98 -49
  37. scipy/_lib/array_api_compat/numpy/_info.py +36 -16
  38. scipy/_lib/array_api_compat/numpy/_typing.py +27 -43
  39. scipy/_lib/array_api_compat/numpy/fft.py +11 -5
  40. scipy/_lib/array_api_compat/numpy/linalg.py +75 -22
  41. scipy/_lib/array_api_compat/torch/__init__.py +3 -5
  42. scipy/_lib/array_api_compat/torch/_aliases.py +262 -159
  43. scipy/_lib/array_api_compat/torch/_info.py +27 -16
  44. scipy/_lib/array_api_compat/torch/_typing.py +3 -0
  45. scipy/_lib/array_api_compat/torch/fft.py +17 -18
  46. scipy/_lib/array_api_compat/torch/linalg.py +16 -16
  47. scipy/_lib/array_api_extra/__init__.py +26 -3
  48. scipy/_lib/array_api_extra/_delegation.py +171 -0
  49. scipy/_lib/array_api_extra/_lib/__init__.py +1 -0
  50. scipy/_lib/array_api_extra/_lib/_at.py +463 -0
  51. scipy/_lib/array_api_extra/_lib/_backends.py +46 -0
  52. scipy/_lib/array_api_extra/_lib/_funcs.py +937 -0
  53. scipy/_lib/array_api_extra/_lib/_lazy.py +357 -0
  54. scipy/_lib/array_api_extra/_lib/_testing.py +278 -0
  55. scipy/_lib/array_api_extra/_lib/_utils/__init__.py +1 -0
  56. scipy/_lib/array_api_extra/_lib/_utils/_compat.py +74 -0
  57. scipy/_lib/array_api_extra/_lib/_utils/_compat.pyi +45 -0
  58. scipy/_lib/array_api_extra/_lib/_utils/_helpers.py +559 -0
  59. scipy/_lib/array_api_extra/_lib/_utils/_typing.py +10 -0
  60. scipy/_lib/array_api_extra/_lib/_utils/_typing.pyi +105 -0
  61. scipy/_lib/array_api_extra/testing.py +359 -0
  62. scipy/_lib/decorator.py +2 -2
  63. scipy/_lib/doccer.py +1 -7
  64. scipy/_lib/messagestream.cpython-312-aarch64-linux-musl.so +0 -0
  65. scipy/_lib/pyprima/__init__.py +212 -0
  66. scipy/_lib/pyprima/cobyla/__init__.py +0 -0
  67. scipy/_lib/pyprima/cobyla/cobyla.py +559 -0
  68. scipy/_lib/pyprima/cobyla/cobylb.py +714 -0
  69. scipy/_lib/pyprima/cobyla/geometry.py +226 -0
  70. scipy/_lib/pyprima/cobyla/initialize.py +215 -0
  71. scipy/_lib/pyprima/cobyla/trustregion.py +492 -0
  72. scipy/_lib/pyprima/cobyla/update.py +289 -0
  73. scipy/_lib/pyprima/common/__init__.py +0 -0
  74. scipy/_lib/pyprima/common/_bounds.py +34 -0
  75. scipy/_lib/pyprima/common/_linear_constraints.py +46 -0
  76. scipy/_lib/pyprima/common/_nonlinear_constraints.py +54 -0
  77. scipy/_lib/pyprima/common/_project.py +173 -0
  78. scipy/_lib/pyprima/common/checkbreak.py +93 -0
  79. scipy/_lib/pyprima/common/consts.py +47 -0
  80. scipy/_lib/pyprima/common/evaluate.py +99 -0
  81. scipy/_lib/pyprima/common/history.py +38 -0
  82. scipy/_lib/pyprima/common/infos.py +30 -0
  83. scipy/_lib/pyprima/common/linalg.py +435 -0
  84. scipy/_lib/pyprima/common/message.py +290 -0
  85. scipy/_lib/pyprima/common/powalg.py +131 -0
  86. scipy/_lib/pyprima/common/preproc.py +277 -0
  87. scipy/_lib/pyprima/common/present.py +5 -0
  88. scipy/_lib/pyprima/common/ratio.py +54 -0
  89. scipy/_lib/pyprima/common/redrho.py +47 -0
  90. scipy/_lib/pyprima/common/selectx.py +296 -0
  91. scipy/_lib/tests/test__util.py +105 -121
  92. scipy/_lib/tests/test_array_api.py +169 -34
  93. scipy/_lib/tests/test_bunch.py +7 -0
  94. scipy/_lib/tests/test_ccallback.py +2 -10
  95. scipy/_lib/tests/test_public_api.py +13 -0
  96. scipy/cluster/_hierarchy.cpython-312-aarch64-linux-musl.so +0 -0
  97. scipy/cluster/_optimal_leaf_ordering.cpython-312-aarch64-linux-musl.so +0 -0
  98. scipy/cluster/_vq.cpython-312-aarch64-linux-musl.so +0 -0
  99. scipy/cluster/hierarchy.py +393 -223
  100. scipy/cluster/tests/test_hierarchy.py +273 -335
  101. scipy/cluster/tests/test_vq.py +45 -61
  102. scipy/cluster/vq.py +39 -35
  103. scipy/conftest.py +263 -157
  104. scipy/constants/_constants.py +4 -1
  105. scipy/constants/tests/test_codata.py +2 -2
  106. scipy/constants/tests/test_constants.py +11 -18
  107. scipy/datasets/_download_all.py +15 -1
  108. scipy/datasets/_fetchers.py +7 -1
  109. scipy/datasets/_utils.py +1 -1
  110. scipy/differentiate/_differentiate.py +25 -25
  111. scipy/differentiate/tests/test_differentiate.py +24 -25
  112. scipy/fft/_basic.py +20 -0
  113. scipy/fft/_helper.py +3 -34
  114. scipy/fft/_pocketfft/helper.py +29 -1
  115. scipy/fft/_pocketfft/pypocketfft.cpython-312-aarch64-linux-musl.so +0 -0
  116. scipy/fft/_pocketfft/tests/test_basic.py +2 -4
  117. scipy/fft/_pocketfft/tests/test_real_transforms.py +4 -4
  118. scipy/fft/_realtransforms.py +13 -0
  119. scipy/fft/tests/test_basic.py +27 -25
  120. scipy/fft/tests/test_fftlog.py +16 -7
  121. scipy/fft/tests/test_helper.py +18 -34
  122. scipy/fft/tests/test_real_transforms.py +8 -10
  123. scipy/fftpack/convolve.cpython-312-aarch64-linux-musl.so +0 -0
  124. scipy/fftpack/tests/test_basic.py +2 -4
  125. scipy/fftpack/tests/test_real_transforms.py +8 -9
  126. scipy/integrate/_bvp.py +9 -3
  127. scipy/integrate/_cubature.py +3 -2
  128. scipy/integrate/_dop.cpython-312-aarch64-linux-musl.so +0 -0
  129. scipy/integrate/_ivp/common.py +3 -3
  130. scipy/integrate/_ivp/ivp.py +9 -2
  131. scipy/integrate/_ivp/tests/test_ivp.py +19 -0
  132. scipy/integrate/_lsoda.cpython-312-aarch64-linux-musl.so +0 -0
  133. scipy/integrate/_ode.py +9 -2
  134. scipy/integrate/_odepack.cpython-312-aarch64-linux-musl.so +0 -0
  135. scipy/integrate/_quad_vec.py +21 -29
  136. scipy/integrate/_quadpack.cpython-312-aarch64-linux-musl.so +0 -0
  137. scipy/integrate/_quadpack_py.py +11 -7
  138. scipy/integrate/_quadrature.py +3 -3
  139. scipy/integrate/_rules/_base.py +2 -2
  140. scipy/integrate/_tanhsinh.py +57 -54
  141. scipy/integrate/_test_multivariate.cpython-312-aarch64-linux-musl.so +0 -0
  142. scipy/integrate/_test_odeint_banded.cpython-312-aarch64-linux-musl.so +0 -0
  143. scipy/integrate/_vode.cpython-312-aarch64-linux-musl.so +0 -0
  144. scipy/integrate/tests/test__quad_vec.py +0 -6
  145. scipy/integrate/tests/test_banded_ode_solvers.py +85 -0
  146. scipy/integrate/tests/test_cubature.py +21 -35
  147. scipy/integrate/tests/test_quadrature.py +6 -8
  148. scipy/integrate/tests/test_tanhsinh.py +61 -43
  149. scipy/interpolate/__init__.py +70 -58
  150. scipy/interpolate/_bary_rational.py +22 -22
  151. scipy/interpolate/_bsplines.py +119 -66
  152. scipy/interpolate/_cubic.py +65 -50
  153. scipy/interpolate/_dfitpack.cpython-312-aarch64-linux-musl.so +0 -0
  154. scipy/interpolate/_dierckx.cpython-312-aarch64-linux-musl.so +0 -0
  155. scipy/interpolate/_fitpack.cpython-312-aarch64-linux-musl.so +0 -0
  156. scipy/interpolate/_fitpack2.py +9 -6
  157. scipy/interpolate/_fitpack_impl.py +32 -26
  158. scipy/interpolate/_fitpack_repro.py +23 -19
  159. scipy/interpolate/_interpnd.cpython-312-aarch64-linux-musl.so +0 -0
  160. scipy/interpolate/_interpolate.py +30 -12
  161. scipy/interpolate/_ndbspline.py +13 -18
  162. scipy/interpolate/_ndgriddata.py +5 -8
  163. scipy/interpolate/_polyint.py +95 -31
  164. scipy/interpolate/_ppoly.cpython-312-aarch64-linux-musl.so +0 -0
  165. scipy/interpolate/_rbf.py +2 -2
  166. scipy/interpolate/_rbfinterp.py +1 -1
  167. scipy/interpolate/_rbfinterp_pythran.cpython-312-aarch64-linux-musl.so +0 -0
  168. scipy/interpolate/_rgi.py +31 -26
  169. scipy/interpolate/_rgi_cython.cpython-312-aarch64-linux-musl.so +0 -0
  170. scipy/interpolate/dfitpack.py +0 -20
  171. scipy/interpolate/interpnd.py +1 -2
  172. scipy/interpolate/tests/test_bary_rational.py +2 -2
  173. scipy/interpolate/tests/test_bsplines.py +97 -1
  174. scipy/interpolate/tests/test_fitpack2.py +39 -1
  175. scipy/interpolate/tests/test_interpnd.py +32 -20
  176. scipy/interpolate/tests/test_interpolate.py +48 -4
  177. scipy/interpolate/tests/test_rgi.py +2 -1
  178. scipy/io/_fast_matrix_market/__init__.py +2 -0
  179. scipy/io/_fast_matrix_market/_fmm_core.cpython-312-aarch64-linux-musl.so +0 -0
  180. scipy/io/_harwell_boeing/_fortran_format_parser.py +19 -16
  181. scipy/io/_harwell_boeing/hb.py +7 -11
  182. scipy/io/_idl.py +5 -7
  183. scipy/io/_netcdf.py +15 -5
  184. scipy/io/_test_fortran.cpython-312-aarch64-linux-musl.so +0 -0
  185. scipy/io/arff/tests/test_arffread.py +3 -3
  186. scipy/io/matlab/__init__.py +5 -3
  187. scipy/io/matlab/_mio.py +4 -1
  188. scipy/io/matlab/_mio5.py +19 -13
  189. scipy/io/matlab/_mio5_utils.cpython-312-aarch64-linux-musl.so +0 -0
  190. scipy/io/matlab/_mio_utils.cpython-312-aarch64-linux-musl.so +0 -0
  191. scipy/io/matlab/_miobase.py +4 -1
  192. scipy/io/matlab/_streams.cpython-312-aarch64-linux-musl.so +0 -0
  193. scipy/io/matlab/tests/test_mio.py +46 -18
  194. scipy/io/matlab/tests/test_mio_funcs.py +1 -1
  195. scipy/io/tests/test_mmio.py +7 -1
  196. scipy/io/tests/test_wavfile.py +41 -0
  197. scipy/io/wavfile.py +57 -10
  198. scipy/linalg/_basic.py +113 -86
  199. scipy/linalg/_cythonized_array_utils.cpython-312-aarch64-linux-musl.so +0 -0
  200. scipy/linalg/_decomp.py +22 -9
  201. scipy/linalg/_decomp_cholesky.py +28 -13
  202. scipy/linalg/_decomp_cossin.py +45 -30
  203. scipy/linalg/_decomp_interpolative.cpython-312-aarch64-linux-musl.so +0 -0
  204. scipy/linalg/_decomp_ldl.py +4 -1
  205. scipy/linalg/_decomp_lu.py +18 -6
  206. scipy/linalg/_decomp_lu_cython.cpython-312-aarch64-linux-musl.so +0 -0
  207. scipy/linalg/_decomp_polar.py +2 -0
  208. scipy/linalg/_decomp_qr.py +6 -2
  209. scipy/linalg/_decomp_qz.py +3 -0
  210. scipy/linalg/_decomp_schur.py +3 -1
  211. scipy/linalg/_decomp_svd.py +13 -2
  212. scipy/linalg/_decomp_update.cpython-312-aarch64-linux-musl.so +0 -0
  213. scipy/linalg/_expm_frechet.py +4 -0
  214. scipy/linalg/_fblas.cpython-312-aarch64-linux-musl.so +0 -0
  215. scipy/linalg/_flapack.cpython-312-aarch64-linux-musl.so +0 -0
  216. scipy/linalg/_linalg_pythran.cpython-312-aarch64-linux-musl.so +0 -0
  217. scipy/linalg/_matfuncs.py +187 -4
  218. scipy/linalg/_matfuncs_expm.cpython-312-aarch64-linux-musl.so +0 -0
  219. scipy/linalg/_matfuncs_schur_sqrtm.cpython-312-aarch64-linux-musl.so +0 -0
  220. scipy/linalg/_matfuncs_sqrtm.py +1 -99
  221. scipy/linalg/_matfuncs_sqrtm_triu.cpython-312-aarch64-linux-musl.so +0 -0
  222. scipy/linalg/_procrustes.py +2 -0
  223. scipy/linalg/_sketches.py +17 -6
  224. scipy/linalg/_solve_toeplitz.cpython-312-aarch64-linux-musl.so +0 -0
  225. scipy/linalg/_solvers.py +7 -2
  226. scipy/linalg/_special_matrices.py +26 -36
  227. scipy/linalg/cython_blas.cpython-312-aarch64-linux-musl.so +0 -0
  228. scipy/linalg/cython_lapack.cpython-312-aarch64-linux-musl.so +0 -0
  229. scipy/linalg/lapack.py +22 -2
  230. scipy/linalg/tests/_cython_examples/meson.build +7 -0
  231. scipy/linalg/tests/test_basic.py +31 -16
  232. scipy/linalg/tests/test_batch.py +588 -0
  233. scipy/linalg/tests/test_cythonized_array_utils.py +0 -2
  234. scipy/linalg/tests/test_decomp.py +40 -3
  235. scipy/linalg/tests/test_decomp_cossin.py +14 -0
  236. scipy/linalg/tests/test_decomp_ldl.py +1 -1
  237. scipy/linalg/tests/test_interpolative.py +17 -0
  238. scipy/linalg/tests/test_lapack.py +115 -7
  239. scipy/linalg/tests/test_matfuncs.py +157 -102
  240. scipy/linalg/tests/test_procrustes.py +0 -7
  241. scipy/linalg/tests/test_solve_toeplitz.py +1 -1
  242. scipy/linalg/tests/test_special_matrices.py +1 -5
  243. scipy/ndimage/__init__.py +1 -0
  244. scipy/ndimage/_ctest.cpython-312-aarch64-linux-musl.so +0 -0
  245. scipy/ndimage/_cytest.cpython-312-aarch64-linux-musl.so +0 -0
  246. scipy/ndimage/_delegators.py +8 -2
  247. scipy/ndimage/_filters.py +433 -5
  248. scipy/ndimage/_interpolation.py +36 -6
  249. scipy/ndimage/_measurements.py +4 -2
  250. scipy/ndimage/_morphology.py +5 -0
  251. scipy/ndimage/_nd_image.cpython-312-aarch64-linux-musl.so +0 -0
  252. scipy/ndimage/_ndimage_api.py +2 -1
  253. scipy/ndimage/_ni_docstrings.py +5 -1
  254. scipy/ndimage/_ni_label.cpython-312-aarch64-linux-musl.so +0 -0
  255. scipy/ndimage/_ni_support.py +1 -5
  256. scipy/ndimage/_rank_filter_1d.cpython-312-aarch64-linux-musl.so +0 -0
  257. scipy/ndimage/_support_alternative_backends.py +18 -6
  258. scipy/ndimage/tests/test_filters.py +351 -259
  259. scipy/ndimage/tests/test_fourier.py +7 -9
  260. scipy/ndimage/tests/test_interpolation.py +68 -61
  261. scipy/ndimage/tests/test_measurements.py +18 -35
  262. scipy/ndimage/tests/test_morphology.py +143 -131
  263. scipy/ndimage/tests/test_splines.py +1 -3
  264. scipy/odr/__odrpack.cpython-312-aarch64-linux-musl.so +0 -0
  265. scipy/optimize/_basinhopping.py +13 -7
  266. scipy/optimize/_bglu_dense.cpython-312-aarch64-linux-musl.so +0 -0
  267. scipy/optimize/_bracket.py +46 -26
  268. scipy/optimize/_chandrupatla.py +9 -10
  269. scipy/optimize/_cobyla_py.py +104 -123
  270. scipy/optimize/_constraints.py +14 -10
  271. scipy/optimize/_differentiable_functions.py +371 -230
  272. scipy/optimize/_differentialevolution.py +4 -3
  273. scipy/optimize/_direct.cpython-312-aarch64-linux-musl.so +0 -0
  274. scipy/optimize/_dual_annealing.py +1 -1
  275. scipy/optimize/_elementwise.py +1 -4
  276. scipy/optimize/_group_columns.cpython-312-aarch64-linux-musl.so +0 -0
  277. scipy/optimize/_highspy/_core.cpython-312-aarch64-linux-musl.so +0 -0
  278. scipy/optimize/_highspy/_highs_options.cpython-312-aarch64-linux-musl.so +0 -0
  279. scipy/optimize/_highspy/_highs_wrapper.py +6 -4
  280. scipy/optimize/_lbfgsb.cpython-312-aarch64-linux-musl.so +0 -0
  281. scipy/optimize/_lbfgsb_py.py +57 -16
  282. scipy/optimize/_linprog_doc.py +2 -2
  283. scipy/optimize/_linprog_highs.py +11 -11
  284. scipy/optimize/_linprog_ip.py +25 -10
  285. scipy/optimize/_linprog_util.py +18 -19
  286. scipy/optimize/_lsap.cpython-312-aarch64-linux-musl.so +0 -0
  287. scipy/optimize/_lsq/common.py +3 -3
  288. scipy/optimize/_lsq/dogbox.py +16 -2
  289. scipy/optimize/_lsq/givens_elimination.cpython-312-aarch64-linux-musl.so +0 -0
  290. scipy/optimize/_lsq/least_squares.py +198 -126
  291. scipy/optimize/_lsq/lsq_linear.py +6 -6
  292. scipy/optimize/_lsq/trf.py +35 -8
  293. scipy/optimize/_milp.py +3 -1
  294. scipy/optimize/_minimize.py +105 -36
  295. scipy/optimize/_minpack.cpython-312-aarch64-linux-musl.so +0 -0
  296. scipy/optimize/_minpack_py.py +21 -14
  297. scipy/optimize/_moduleTNC.cpython-312-aarch64-linux-musl.so +0 -0
  298. scipy/optimize/_nnls.py +20 -21
  299. scipy/optimize/_nonlin.py +34 -3
  300. scipy/optimize/_numdiff.py +288 -110
  301. scipy/optimize/_optimize.py +86 -48
  302. scipy/optimize/_pava_pybind.cpython-312-aarch64-linux-musl.so +0 -0
  303. scipy/optimize/_remove_redundancy.py +5 -5
  304. scipy/optimize/_root_scalar.py +1 -1
  305. scipy/optimize/_shgo.py +6 -0
  306. scipy/optimize/_shgo_lib/_complex.py +1 -1
  307. scipy/optimize/_slsqp_py.py +216 -124
  308. scipy/optimize/_slsqplib.cpython-312-aarch64-linux-musl.so +0 -0
  309. scipy/optimize/_spectral.py +1 -1
  310. scipy/optimize/_tnc.py +8 -1
  311. scipy/optimize/_trlib/_trlib.cpython-312-aarch64-linux-musl.so +0 -0
  312. scipy/optimize/_trustregion.py +20 -6
  313. scipy/optimize/_trustregion_constr/canonical_constraint.py +7 -7
  314. scipy/optimize/_trustregion_constr/equality_constrained_sqp.py +1 -1
  315. scipy/optimize/_trustregion_constr/minimize_trustregion_constr.py +11 -3
  316. scipy/optimize/_trustregion_constr/projections.py +12 -8
  317. scipy/optimize/_trustregion_constr/qp_subproblem.py +9 -9
  318. scipy/optimize/_trustregion_constr/tests/test_projections.py +7 -7
  319. scipy/optimize/_trustregion_constr/tests/test_qp_subproblem.py +77 -77
  320. scipy/optimize/_trustregion_constr/tr_interior_point.py +5 -5
  321. scipy/optimize/_trustregion_exact.py +0 -1
  322. scipy/optimize/_zeros.cpython-312-aarch64-linux-musl.so +0 -0
  323. scipy/optimize/_zeros_py.py +97 -17
  324. scipy/optimize/cython_optimize/_zeros.cpython-312-aarch64-linux-musl.so +0 -0
  325. scipy/optimize/slsqp.py +0 -1
  326. scipy/optimize/tests/test__basinhopping.py +1 -1
  327. scipy/optimize/tests/test__differential_evolution.py +4 -4
  328. scipy/optimize/tests/test__linprog_clean_inputs.py +5 -3
  329. scipy/optimize/tests/test__numdiff.py +66 -22
  330. scipy/optimize/tests/test__remove_redundancy.py +2 -2
  331. scipy/optimize/tests/test__shgo.py +9 -1
  332. scipy/optimize/tests/test_bracket.py +71 -46
  333. scipy/optimize/tests/test_chandrupatla.py +133 -135
  334. scipy/optimize/tests/test_cobyla.py +74 -45
  335. scipy/optimize/tests/test_constraints.py +1 -1
  336. scipy/optimize/tests/test_differentiable_functions.py +226 -6
  337. scipy/optimize/tests/test_lbfgsb_hessinv.py +22 -0
  338. scipy/optimize/tests/test_least_squares.py +125 -13
  339. scipy/optimize/tests/test_linear_assignment.py +3 -3
  340. scipy/optimize/tests/test_linprog.py +3 -3
  341. scipy/optimize/tests/test_lsq_linear.py +5 -5
  342. scipy/optimize/tests/test_minimize_constrained.py +2 -2
  343. scipy/optimize/tests/test_minpack.py +4 -4
  344. scipy/optimize/tests/test_nnls.py +43 -3
  345. scipy/optimize/tests/test_nonlin.py +36 -0
  346. scipy/optimize/tests/test_optimize.py +95 -17
  347. scipy/optimize/tests/test_slsqp.py +36 -4
  348. scipy/optimize/tests/test_zeros.py +34 -1
  349. scipy/signal/__init__.py +12 -23
  350. scipy/signal/_delegators.py +568 -0
  351. scipy/signal/_filter_design.py +459 -241
  352. scipy/signal/_fir_filter_design.py +262 -90
  353. scipy/signal/_lti_conversion.py +3 -2
  354. scipy/signal/_ltisys.py +118 -91
  355. scipy/signal/_max_len_seq_inner.cpython-312-aarch64-linux-musl.so +0 -0
  356. scipy/signal/_peak_finding_utils.cpython-312-aarch64-linux-musl.so +0 -0
  357. scipy/signal/_polyutils.py +172 -0
  358. scipy/signal/_short_time_fft.py +553 -76
  359. scipy/signal/_signal_api.py +30 -0
  360. scipy/signal/_signaltools.py +719 -396
  361. scipy/signal/_sigtools.cpython-312-aarch64-linux-musl.so +0 -0
  362. scipy/signal/_sosfilt.cpython-312-aarch64-linux-musl.so +0 -0
  363. scipy/signal/_spectral_py.py +221 -50
  364. scipy/signal/_spline.cpython-312-aarch64-linux-musl.so +0 -0
  365. scipy/signal/_spline_filters.py +108 -68
  366. scipy/signal/_support_alternative_backends.py +73 -0
  367. scipy/signal/_upfirdn.py +4 -1
  368. scipy/signal/_upfirdn_apply.cpython-312-aarch64-linux-musl.so +0 -0
  369. scipy/signal/_waveforms.py +2 -11
  370. scipy/signal/_wavelets.py +1 -1
  371. scipy/signal/fir_filter_design.py +1 -0
  372. scipy/signal/spline.py +4 -11
  373. scipy/signal/tests/_scipy_spectral_test_shim.py +5 -182
  374. scipy/signal/tests/test_bsplines.py +114 -79
  375. scipy/signal/tests/test_cont2discrete.py +9 -2
  376. scipy/signal/tests/test_filter_design.py +721 -481
  377. scipy/signal/tests/test_fir_filter_design.py +332 -140
  378. scipy/signal/tests/test_savitzky_golay.py +4 -3
  379. scipy/signal/tests/test_short_time_fft.py +231 -5
  380. scipy/signal/tests/test_signaltools.py +2149 -1348
  381. scipy/signal/tests/test_spectral.py +19 -6
  382. scipy/signal/tests/test_splines.py +161 -96
  383. scipy/signal/tests/test_upfirdn.py +84 -50
  384. scipy/signal/tests/test_waveforms.py +20 -0
  385. scipy/signal/tests/test_windows.py +607 -466
  386. scipy/signal/windows/_windows.py +287 -148
  387. scipy/sparse/__init__.py +23 -4
  388. scipy/sparse/_base.py +269 -120
  389. scipy/sparse/_bsr.py +7 -4
  390. scipy/sparse/_compressed.py +59 -234
  391. scipy/sparse/_construct.py +90 -38
  392. scipy/sparse/_coo.py +115 -181
  393. scipy/sparse/_csc.py +4 -4
  394. scipy/sparse/_csparsetools.cpython-312-aarch64-linux-musl.so +0 -0
  395. scipy/sparse/_csr.py +2 -2
  396. scipy/sparse/_data.py +48 -48
  397. scipy/sparse/_dia.py +105 -21
  398. scipy/sparse/_dok.py +0 -23
  399. scipy/sparse/_index.py +4 -4
  400. scipy/sparse/_matrix.py +23 -0
  401. scipy/sparse/_sparsetools.cpython-312-aarch64-linux-musl.so +0 -0
  402. scipy/sparse/_sputils.py +37 -22
  403. scipy/sparse/base.py +0 -9
  404. scipy/sparse/bsr.py +0 -14
  405. scipy/sparse/compressed.py +0 -23
  406. scipy/sparse/construct.py +0 -6
  407. scipy/sparse/coo.py +0 -14
  408. scipy/sparse/csc.py +0 -3
  409. scipy/sparse/csgraph/_flow.cpython-312-aarch64-linux-musl.so +0 -0
  410. scipy/sparse/csgraph/_matching.cpython-312-aarch64-linux-musl.so +0 -0
  411. scipy/sparse/csgraph/_min_spanning_tree.cpython-312-aarch64-linux-musl.so +0 -0
  412. scipy/sparse/csgraph/_reordering.cpython-312-aarch64-linux-musl.so +0 -0
  413. scipy/sparse/csgraph/_shortest_path.cpython-312-aarch64-linux-musl.so +0 -0
  414. scipy/sparse/csgraph/_tools.cpython-312-aarch64-linux-musl.so +0 -0
  415. scipy/sparse/csgraph/_traversal.cpython-312-aarch64-linux-musl.so +0 -0
  416. scipy/sparse/csgraph/tests/test_matching.py +14 -2
  417. scipy/sparse/csgraph/tests/test_pydata_sparse.py +4 -1
  418. scipy/sparse/csgraph/tests/test_shortest_path.py +83 -27
  419. scipy/sparse/csr.py +0 -5
  420. scipy/sparse/data.py +1 -6
  421. scipy/sparse/dia.py +0 -7
  422. scipy/sparse/dok.py +0 -10
  423. scipy/sparse/linalg/_dsolve/_superlu.cpython-312-aarch64-linux-musl.so +0 -0
  424. scipy/sparse/linalg/_dsolve/linsolve.py +9 -0
  425. scipy/sparse/linalg/_dsolve/tests/test_linsolve.py +35 -28
  426. scipy/sparse/linalg/_eigen/arpack/_arpack.cpython-312-aarch64-linux-musl.so +0 -0
  427. scipy/sparse/linalg/_eigen/arpack/arpack.py +28 -20
  428. scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py +6 -6
  429. scipy/sparse/linalg/_expm_multiply.py +8 -3
  430. scipy/sparse/linalg/_interface.py +29 -26
  431. scipy/sparse/linalg/_isolve/_gcrotmk.py +6 -5
  432. scipy/sparse/linalg/_isolve/iterative.py +51 -45
  433. scipy/sparse/linalg/_isolve/lgmres.py +6 -6
  434. scipy/sparse/linalg/_isolve/minres.py +5 -5
  435. scipy/sparse/linalg/_isolve/tfqmr.py +7 -7
  436. scipy/sparse/linalg/_isolve/utils.py +2 -8
  437. scipy/sparse/linalg/_matfuncs.py +1 -1
  438. scipy/sparse/linalg/_norm.py +1 -1
  439. scipy/sparse/linalg/_propack/_cpropack.cpython-312-aarch64-linux-musl.so +0 -0
  440. scipy/sparse/linalg/_propack/_dpropack.cpython-312-aarch64-linux-musl.so +0 -0
  441. scipy/sparse/linalg/_propack/_spropack.cpython-312-aarch64-linux-musl.so +0 -0
  442. scipy/sparse/linalg/_propack/_zpropack.cpython-312-aarch64-linux-musl.so +0 -0
  443. scipy/sparse/linalg/_special_sparse_arrays.py +39 -38
  444. scipy/sparse/linalg/tests/test_expm_multiply.py +10 -0
  445. scipy/sparse/linalg/tests/test_interface.py +35 -0
  446. scipy/sparse/linalg/tests/test_pydata_sparse.py +18 -0
  447. scipy/sparse/tests/test_arithmetic1d.py +5 -2
  448. scipy/sparse/tests/test_base.py +217 -40
  449. scipy/sparse/tests/test_common1d.py +17 -12
  450. scipy/sparse/tests/test_construct.py +1 -1
  451. scipy/sparse/tests/test_coo.py +272 -4
  452. scipy/sparse/tests/test_sparsetools.py +5 -0
  453. scipy/sparse/tests/test_sputils.py +36 -7
  454. scipy/spatial/_ckdtree.cpython-312-aarch64-linux-musl.so +0 -0
  455. scipy/spatial/_distance_pybind.cpython-312-aarch64-linux-musl.so +0 -0
  456. scipy/spatial/_distance_wrap.cpython-312-aarch64-linux-musl.so +0 -0
  457. scipy/spatial/_hausdorff.cpython-312-aarch64-linux-musl.so +0 -0
  458. scipy/spatial/_qhull.cpython-312-aarch64-linux-musl.so +0 -0
  459. scipy/spatial/_voronoi.cpython-312-aarch64-linux-musl.so +0 -0
  460. scipy/spatial/distance.py +49 -42
  461. scipy/spatial/tests/test_distance.py +3 -1
  462. scipy/spatial/tests/test_kdtree.py +1 -0
  463. scipy/spatial/tests/test_qhull.py +106 -2
  464. scipy/spatial/transform/__init__.py +5 -3
  465. scipy/spatial/transform/_rigid_transform.cpython-312-aarch64-linux-musl.so +0 -0
  466. scipy/spatial/transform/_rotation.cpython-312-aarch64-linux-musl.so +0 -0
  467. scipy/spatial/transform/tests/test_rigid_transform.py +1221 -0
  468. scipy/spatial/transform/tests/test_rotation.py +1342 -790
  469. scipy/spatial/transform/tests/test_rotation_groups.py +3 -3
  470. scipy/spatial/transform/tests/test_rotation_spline.py +29 -8
  471. scipy/special/__init__.py +1 -47
  472. scipy/special/_add_newdocs.py +34 -772
  473. scipy/special/_basic.py +22 -25
  474. scipy/special/_comb.cpython-312-aarch64-linux-musl.so +0 -0
  475. scipy/special/_ellip_harm_2.cpython-312-aarch64-linux-musl.so +0 -0
  476. scipy/special/_gufuncs.cpython-312-aarch64-linux-musl.so +0 -0
  477. scipy/special/_logsumexp.py +83 -69
  478. scipy/special/_orthogonal.pyi +1 -1
  479. scipy/special/_specfun.cpython-312-aarch64-linux-musl.so +0 -0
  480. scipy/special/_special_ufuncs.cpython-312-aarch64-linux-musl.so +0 -0
  481. scipy/special/_spherical_bessel.py +4 -4
  482. scipy/special/_support_alternative_backends.py +212 -119
  483. scipy/special/_test_internal.cpython-312-aarch64-linux-musl.so +0 -0
  484. scipy/special/_testutils.py +4 -4
  485. scipy/special/_ufuncs.cpython-312-aarch64-linux-musl.so +0 -0
  486. scipy/special/_ufuncs.pyi +1 -0
  487. scipy/special/_ufuncs.pyx +215 -1400
  488. scipy/special/_ufuncs_cxx.cpython-312-aarch64-linux-musl.so +0 -0
  489. scipy/special/_ufuncs_cxx.pxd +2 -15
  490. scipy/special/_ufuncs_cxx.pyx +5 -44
  491. scipy/special/_ufuncs_cxx_defs.h +2 -16
  492. scipy/special/_ufuncs_defs.h +0 -8
  493. scipy/special/cython_special.cpython-312-aarch64-linux-musl.so +0 -0
  494. scipy/special/cython_special.pxd +1 -1
  495. scipy/special/tests/_cython_examples/meson.build +10 -1
  496. scipy/special/tests/test_basic.py +153 -20
  497. scipy/special/tests/test_boost_ufuncs.py +3 -0
  498. scipy/special/tests/test_cdflib.py +35 -11
  499. scipy/special/tests/test_gammainc.py +16 -0
  500. scipy/special/tests/test_hyp2f1.py +23 -2
  501. scipy/special/tests/test_log1mexp.py +85 -0
  502. scipy/special/tests/test_logsumexp.py +220 -64
  503. scipy/special/tests/test_mpmath.py +1 -0
  504. scipy/special/tests/test_nan_inputs.py +1 -1
  505. scipy/special/tests/test_orthogonal.py +17 -18
  506. scipy/special/tests/test_sf_error.py +3 -2
  507. scipy/special/tests/test_sph_harm.py +6 -7
  508. scipy/special/tests/test_support_alternative_backends.py +211 -76
  509. scipy/stats/__init__.py +4 -1
  510. scipy/stats/_ansari_swilk_statistics.cpython-312-aarch64-linux-musl.so +0 -0
  511. scipy/stats/_axis_nan_policy.py +4 -3
  512. scipy/stats/_biasedurn.cpython-312-aarch64-linux-musl.so +0 -0
  513. scipy/stats/_continued_fraction.py +387 -0
  514. scipy/stats/_continuous_distns.py +296 -319
  515. scipy/stats/_covariance.py +6 -3
  516. scipy/stats/_discrete_distns.py +39 -32
  517. scipy/stats/_distn_infrastructure.py +39 -12
  518. scipy/stats/_distribution_infrastructure.py +900 -238
  519. scipy/stats/_entropy.py +7 -8
  520. scipy/{_lib → stats}/_finite_differences.py +1 -1
  521. scipy/stats/_hypotests.py +82 -49
  522. scipy/stats/_kde.py +53 -49
  523. scipy/stats/_ksstats.py +1 -1
  524. scipy/stats/_levy_stable/__init__.py +7 -15
  525. scipy/stats/_levy_stable/levyst.cpython-312-aarch64-linux-musl.so +0 -0
  526. scipy/stats/_morestats.py +112 -67
  527. scipy/stats/_mstats_basic.py +13 -17
  528. scipy/stats/_mstats_extras.py +8 -8
  529. scipy/stats/_multivariate.py +89 -113
  530. scipy/stats/_new_distributions.py +97 -20
  531. scipy/stats/_page_trend_test.py +12 -5
  532. scipy/stats/_probability_distribution.py +265 -43
  533. scipy/stats/_qmc.py +14 -9
  534. scipy/stats/_qmc_cy.cpython-312-aarch64-linux-musl.so +0 -0
  535. scipy/stats/_qmvnt.py +16 -95
  536. scipy/stats/_qmvnt_cy.cpython-312-aarch64-linux-musl.so +0 -0
  537. scipy/stats/_quantile.py +335 -0
  538. scipy/stats/_rcont/rcont.cpython-312-aarch64-linux-musl.so +0 -0
  539. scipy/stats/_resampling.py +4 -29
  540. scipy/stats/_sampling.py +1 -1
  541. scipy/stats/_sobol.cpython-312-aarch64-linux-musl.so +0 -0
  542. scipy/stats/_stats.cpython-312-aarch64-linux-musl.so +0 -0
  543. scipy/stats/_stats_mstats_common.py +19 -2
  544. scipy/stats/_stats_py.py +534 -460
  545. scipy/stats/_stats_pythran.cpython-312-aarch64-linux-musl.so +0 -0
  546. scipy/stats/_unuran/unuran_wrapper.cpython-312-aarch64-linux-musl.so +0 -0
  547. scipy/stats/_unuran/unuran_wrapper.pyi +2 -1
  548. scipy/stats/_variation.py +5 -7
  549. scipy/stats/_wilcoxon.py +13 -7
  550. scipy/stats/tests/common_tests.py +6 -4
  551. scipy/stats/tests/test_axis_nan_policy.py +62 -24
  552. scipy/stats/tests/test_continued_fraction.py +173 -0
  553. scipy/stats/tests/test_continuous.py +379 -60
  554. scipy/stats/tests/test_continuous_basic.py +18 -12
  555. scipy/stats/tests/test_discrete_basic.py +14 -8
  556. scipy/stats/tests/test_discrete_distns.py +16 -16
  557. scipy/stats/tests/test_distributions.py +117 -75
  558. scipy/stats/tests/test_entropy.py +40 -48
  559. scipy/stats/tests/test_fit.py +4 -3
  560. scipy/stats/tests/test_hypotests.py +153 -24
  561. scipy/stats/tests/test_kdeoth.py +109 -41
  562. scipy/stats/tests/test_marray.py +289 -0
  563. scipy/stats/tests/test_morestats.py +79 -47
  564. scipy/stats/tests/test_mstats_basic.py +3 -3
  565. scipy/stats/tests/test_multivariate.py +434 -83
  566. scipy/stats/tests/test_qmc.py +13 -10
  567. scipy/stats/tests/test_quantile.py +199 -0
  568. scipy/stats/tests/test_rank.py +119 -112
  569. scipy/stats/tests/test_resampling.py +47 -56
  570. scipy/stats/tests/test_sampling.py +9 -4
  571. scipy/stats/tests/test_stats.py +799 -939
  572. scipy/stats/tests/test_variation.py +8 -6
  573. scipy/version.py +2 -2
  574. {scipy-1.15.2.dist-info → scipy-1.16.0rc1.dist-info}/LICENSE.txt +1 -1
  575. {scipy-1.15.2.dist-info → scipy-1.16.0rc1.dist-info}/METADATA +9 -9
  576. {scipy-1.15.2.dist-info → scipy-1.16.0rc1.dist-info}/RECORD +1316 -1323
  577. scipy.libs/libgcc_s-69c45f16.so.1 +0 -0
  578. scipy.libs/libgfortran-db0b6589.so.5.0.0 +0 -0
  579. scipy.libs/{libstdc++-1b614e01.so.6.0.32 → libstdc++-1f1a71be.so.6.0.33} +0 -0
  580. scipy/_lib/array_api_extra/_funcs.py +0 -484
  581. scipy/_lib/array_api_extra/_typing.py +0 -8
  582. scipy/interpolate/_bspl.cpython-312-aarch64-linux-musl.so +0 -0
  583. scipy/optimize/_cobyla.cpython-312-aarch64-linux-musl.so +0 -0
  584. scipy/optimize/_cython_nnls.cpython-312-aarch64-linux-musl.so +0 -0
  585. scipy/optimize/_slsqp.cpython-312-aarch64-linux-musl.so +0 -0
  586. scipy/spatial/qhull_src/COPYING.txt +0 -38
  587. scipy/special/libsf_error_state.so +0 -0
  588. scipy/special/tests/test_log_softmax.py +0 -109
  589. scipy/special/tests/test_xsf_cuda.py +0 -114
  590. scipy/special/xsf/binom.h +0 -89
  591. scipy/special/xsf/cdflib.h +0 -100
  592. scipy/special/xsf/cephes/airy.h +0 -307
  593. scipy/special/xsf/cephes/besselpoly.h +0 -51
  594. scipy/special/xsf/cephes/beta.h +0 -257
  595. scipy/special/xsf/cephes/cbrt.h +0 -131
  596. scipy/special/xsf/cephes/chbevl.h +0 -85
  597. scipy/special/xsf/cephes/chdtr.h +0 -193
  598. scipy/special/xsf/cephes/const.h +0 -87
  599. scipy/special/xsf/cephes/ellie.h +0 -293
  600. scipy/special/xsf/cephes/ellik.h +0 -251
  601. scipy/special/xsf/cephes/ellpe.h +0 -107
  602. scipy/special/xsf/cephes/ellpk.h +0 -117
  603. scipy/special/xsf/cephes/expn.h +0 -260
  604. scipy/special/xsf/cephes/gamma.h +0 -398
  605. scipy/special/xsf/cephes/hyp2f1.h +0 -596
  606. scipy/special/xsf/cephes/hyperg.h +0 -361
  607. scipy/special/xsf/cephes/i0.h +0 -149
  608. scipy/special/xsf/cephes/i1.h +0 -158
  609. scipy/special/xsf/cephes/igam.h +0 -421
  610. scipy/special/xsf/cephes/igam_asymp_coeff.h +0 -195
  611. scipy/special/xsf/cephes/igami.h +0 -313
  612. scipy/special/xsf/cephes/j0.h +0 -225
  613. scipy/special/xsf/cephes/j1.h +0 -198
  614. scipy/special/xsf/cephes/jv.h +0 -715
  615. scipy/special/xsf/cephes/k0.h +0 -164
  616. scipy/special/xsf/cephes/k1.h +0 -163
  617. scipy/special/xsf/cephes/kn.h +0 -243
  618. scipy/special/xsf/cephes/lanczos.h +0 -112
  619. scipy/special/xsf/cephes/ndtr.h +0 -275
  620. scipy/special/xsf/cephes/poch.h +0 -85
  621. scipy/special/xsf/cephes/polevl.h +0 -167
  622. scipy/special/xsf/cephes/psi.h +0 -194
  623. scipy/special/xsf/cephes/rgamma.h +0 -111
  624. scipy/special/xsf/cephes/scipy_iv.h +0 -811
  625. scipy/special/xsf/cephes/shichi.h +0 -248
  626. scipy/special/xsf/cephes/sici.h +0 -224
  627. scipy/special/xsf/cephes/sindg.h +0 -221
  628. scipy/special/xsf/cephes/tandg.h +0 -139
  629. scipy/special/xsf/cephes/trig.h +0 -58
  630. scipy/special/xsf/cephes/unity.h +0 -186
  631. scipy/special/xsf/cephes/zeta.h +0 -172
  632. scipy/special/xsf/config.h +0 -304
  633. scipy/special/xsf/digamma.h +0 -205
  634. scipy/special/xsf/error.h +0 -57
  635. scipy/special/xsf/evalpoly.h +0 -47
  636. scipy/special/xsf/expint.h +0 -266
  637. scipy/special/xsf/hyp2f1.h +0 -694
  638. scipy/special/xsf/iv_ratio.h +0 -173
  639. scipy/special/xsf/lambertw.h +0 -150
  640. scipy/special/xsf/loggamma.h +0 -163
  641. scipy/special/xsf/sici.h +0 -200
  642. scipy/special/xsf/tools.h +0 -427
  643. scipy/special/xsf/trig.h +0 -164
  644. scipy/special/xsf/wright_bessel.h +0 -843
  645. scipy/special/xsf/zlog1.h +0 -35
  646. scipy/stats/_mvn.cpython-312-aarch64-linux-musl.so +0 -0
  647. scipy.libs/libgcc_s-7393e603.so.1 +0 -0
  648. scipy.libs/libgfortran-eb933d8e.so.5.0.0 +0 -0
  649. {scipy-1.15.2.dist-info → scipy-1.16.0rc1.dist-info}/WHEEL +0 -0
@@ -1,15 +1,30 @@
1
+ import math
2
+
1
3
  import pytest
2
4
 
3
5
  import numpy as np
4
- from numpy.testing import assert_equal, assert_array_almost_equal
5
- from numpy.testing import assert_allclose
6
+ from numpy.testing import assert_equal
6
7
  from scipy.spatial.transform import Rotation, Slerp
7
8
  from scipy.stats import special_ortho_group
8
- from itertools import permutations
9
+ from itertools import permutations, product
10
+ from scipy._lib._array_api import (
11
+ xp_assert_equal,
12
+ is_numpy,
13
+ is_lazy_array,
14
+ xp_vector_norm,
15
+ xp_assert_close,
16
+ eager_warns,
17
+ xp_default_dtype
18
+ )
19
+ import scipy._lib.array_api_extra as xpx
9
20
 
10
21
  import pickle
11
22
  import copy
12
23
 
24
+
25
+ pytestmark = pytest.mark.skip_xp_backends(np_only=True)
26
+
27
+
13
28
  def basis_vec(axis):
14
29
  if axis == 'x':
15
30
  return [1, 0, 0]
@@ -18,81 +33,114 @@ def basis_vec(axis):
18
33
  elif axis == 'z':
19
34
  return [0, 0, 1]
20
35
 
21
- def test_generic_quat_matrix():
22
- x = np.array([[3, 4, 0, 0], [5, 12, 0, 0]])
36
+
37
+ def rotation_to_xp(r: Rotation, xp):
38
+ return Rotation.from_quat(xp.asarray(r.as_quat()))
39
+
40
+
41
+ def test_init_non_array():
42
+ Rotation((0, 0, 0, 1))
43
+ Rotation([0, 0, 0, 1])
44
+
45
+
46
+ def test_generic_quat_matrix(xp):
47
+ x = xp.asarray([[3.0, 4, 0, 0], [5, 12, 0, 0]])
23
48
  r = Rotation.from_quat(x)
24
- expected_quat = x / np.array([[5], [13]])
25
- assert_array_almost_equal(r.as_quat(), expected_quat)
49
+ expected_quat = x / xp.asarray([[5.0], [13.0]])
50
+ xp_assert_close(r.as_quat(), expected_quat)
26
51
 
27
52
 
28
- def test_from_single_1d_quaternion():
29
- x = np.array([3, 4, 0, 0])
53
+ def test_from_single_1d_quaternion(xp):
54
+ x = xp.asarray([3.0, 4, 0, 0])
30
55
  r = Rotation.from_quat(x)
31
56
  expected_quat = x / 5
32
- assert_array_almost_equal(r.as_quat(), expected_quat)
57
+ xp_assert_close(r.as_quat(), expected_quat)
33
58
 
34
59
 
35
- def test_from_single_2d_quaternion():
36
- x = np.array([[3, 4, 0, 0]])
60
+ def test_from_single_2d_quaternion(xp):
61
+ x = xp.asarray([[3.0, 4, 0, 0]])
37
62
  r = Rotation.from_quat(x)
38
63
  expected_quat = x / 5
39
- assert_array_almost_equal(r.as_quat(), expected_quat)
64
+ xp_assert_close(r.as_quat(), expected_quat)
40
65
 
41
66
 
42
- def test_from_quat_scalar_first():
67
+ def test_from_quat_scalar_first(xp):
43
68
  rng = np.random.RandomState(0)
44
69
 
45
- r = Rotation.from_quat([1, 0, 0, 0], scalar_first=True)
46
- assert_allclose(r.as_matrix(), np.eye(3), rtol=1e-15, atol=1e-16)
70
+ r = Rotation.from_quat(xp.asarray([1, 0, 0, 0]), scalar_first=True)
71
+ xp_assert_close(r.as_matrix(), xp.eye(3), rtol=1e-15, atol=1e-16)
47
72
 
48
- r = Rotation.from_quat(np.tile([1, 0, 0, 0], (10, 1)), scalar_first=True)
49
- assert_allclose(r.as_matrix(), np.tile(np.eye(3), (10, 1, 1)),
50
- rtol=1e-15, atol=1e-16)
73
+ q = xp.tile(xp.asarray([1, 0, 0, 0]), (10, 1))
74
+ r = Rotation.from_quat(q, scalar_first=True)
75
+ xp_assert_close(
76
+ r.as_matrix(), xp.tile(xp.eye(3), (10, 1, 1)), rtol=1e-15, atol=1e-16
77
+ )
51
78
 
52
- q = rng.randn(100, 4)
53
- q /= np.linalg.norm(q, axis=1)[:, None]
54
- for qi in q:
79
+ q = xp.asarray(rng.randn(100, 4))
80
+ q /= xp_vector_norm(q, axis=1)[:, None]
81
+ for i in range(q.shape[0]): # Array API conforming loop
82
+ qi = q[i, ...]
55
83
  r = Rotation.from_quat(qi, scalar_first=True)
56
- assert_allclose(np.roll(r.as_quat(), 1), qi, rtol=1e-15)
84
+ xp_assert_close(xp.roll(r.as_quat(), 1), qi, rtol=1e-15)
57
85
 
58
86
  r = Rotation.from_quat(q, scalar_first=True)
59
- assert_allclose(np.roll(r.as_quat(), 1, axis=1), q, rtol=1e-15)
87
+ xp_assert_close(xp.roll(r.as_quat(), 1, axis=1), q, rtol=1e-15)
60
88
 
61
89
 
62
- def test_as_quat_scalar_first():
90
+ def test_from_quat_array_like():
91
+ rng = np.random.default_rng(123)
92
+ # Single rotation
93
+ r_expected = Rotation.random(rng=rng)
94
+ r = Rotation.from_quat(r_expected.as_quat().tolist())
95
+ assert r_expected.approx_equal(r, atol=1e-12)
96
+
97
+ # Multiple rotations
98
+ r_expected = Rotation.random(3, rng=rng)
99
+ r = Rotation.from_quat(r_expected.as_quat().tolist())
100
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
101
+
102
+
103
+ def test_from_quat_int_dtype(xp):
104
+ r = Rotation.from_quat(xp.asarray([1, 0, 0, 0]))
105
+ assert r.as_quat().dtype == xp_default_dtype(xp)
106
+
107
+
108
+ def test_as_quat_scalar_first(xp):
63
109
  rng = np.random.RandomState(0)
64
110
 
65
- r = Rotation.from_euler('xyz', np.zeros(3))
66
- assert_allclose(r.as_quat(scalar_first=True), [1, 0, 0, 0],
111
+ r = Rotation.from_euler('xyz', xp.zeros(3))
112
+ xp_assert_close(r.as_quat(scalar_first=True), xp.asarray([1.0, 0, 0, 0]),
67
113
  rtol=1e-15, atol=1e-16)
68
114
 
69
- r = Rotation.from_euler('xyz', np.zeros((10, 3)))
70
- assert_allclose(r.as_quat(scalar_first=True),
71
- np.tile([1, 0, 0, 0], (10, 1)), rtol=1e-15, atol=1e-16)
115
+ r = Rotation.from_euler('xyz', xp.zeros((10, 3)))
116
+ xp_assert_close(r.as_quat(scalar_first=True),
117
+ xp.tile(xp.asarray([1.0, 0, 0, 0]), (10, 1)),
118
+ rtol=1e-15, atol=1e-16)
72
119
 
73
- q = rng.randn(100, 4)
74
- q /= np.linalg.norm(q, axis=1)[:, None]
75
- for qi in q:
120
+ q = xp.asarray(rng.randn(100, 4))
121
+ q /= xp_vector_norm(q, axis=1)[:, None]
122
+ for i in range(q.shape[0]): # Array API conforming loop
123
+ qi = q[i, ...]
76
124
  r = Rotation.from_quat(qi)
77
- assert_allclose(r.as_quat(scalar_first=True), np.roll(qi, 1),
125
+ xp_assert_close(r.as_quat(scalar_first=True), xp.roll(qi, 1),
78
126
  rtol=1e-15)
79
127
 
80
- assert_allclose(r.as_quat(canonical=True, scalar_first=True),
81
- np.roll(r.as_quat(canonical=True), 1),
128
+ xp_assert_close(r.as_quat(canonical=True, scalar_first=True),
129
+ xp.roll(r.as_quat(canonical=True), 1),
82
130
  rtol=1e-15)
83
131
 
84
132
  r = Rotation.from_quat(q)
85
- assert_allclose(r.as_quat(scalar_first=True), np.roll(q, 1, axis=1),
133
+ xp_assert_close(r.as_quat(scalar_first=True), xp.roll(q, 1, axis=1),
86
134
  rtol=1e-15)
87
135
 
88
- assert_allclose(r.as_quat(canonical=True, scalar_first=True),
89
- np.roll(r.as_quat(canonical=True), 1, axis=1), rtol=1e-15)
136
+ xp_assert_close(r.as_quat(canonical=True, scalar_first=True),
137
+ xp.roll(r.as_quat(canonical=True), 1, axis=1), rtol=1e-15)
90
138
 
91
139
 
92
- def test_from_square_quat_matrix():
140
+ def test_from_square_quat_matrix(xp):
93
141
  # Ensure proper norm array broadcasting
94
- x = np.array([
95
- [3, 0, 0, 4],
142
+ x = xp.asarray([
143
+ [3.0, 0, 0, 4],
96
144
  [5, 0, 12, 0],
97
145
  [0, 0, 0, 1],
98
146
  [-1, -1, -1, 1],
@@ -100,647 +148,724 @@ def test_from_square_quat_matrix():
100
148
  [-1, -1, -1, -1] # Check double cover
101
149
  ])
102
150
  r = Rotation.from_quat(x)
103
- expected_quat = x / np.array([[5], [13], [1], [2], [1], [2]])
104
- assert_array_almost_equal(r.as_quat(), expected_quat)
151
+ expected_quat = x / xp.asarray([[5.0], [13], [1], [2], [1], [2]])
152
+ xp_assert_close(r.as_quat(), expected_quat)
105
153
 
106
154
 
107
- def test_quat_double_to_canonical_single_cover():
108
- x = np.array([
109
- [-1, 0, 0, 0],
155
+ def test_quat_double_to_canonical_single_cover(xp):
156
+ x = xp.asarray([
157
+ [-1.0, 0, 0, 0],
110
158
  [0, -1, 0, 0],
111
159
  [0, 0, -1, 0],
112
160
  [0, 0, 0, -1],
113
161
  [-1, -1, -1, -1]
114
162
  ])
115
163
  r = Rotation.from_quat(x)
116
- expected_quat = np.abs(x) / np.linalg.norm(x, axis=1)[:, None]
117
- assert_allclose(r.as_quat(canonical=True), expected_quat)
164
+ expected_quat = xp.abs(x) / xp_vector_norm(x, axis=1)[:, None]
165
+ xp_assert_close(r.as_quat(canonical=True), expected_quat)
118
166
 
119
167
 
120
- def test_quat_double_cover():
168
+ def test_quat_double_cover(xp):
121
169
  # See the Rotation.from_quat() docstring for scope of the quaternion
122
170
  # double cover property.
123
171
  # Check from_quat and as_quat(canonical=False)
124
- q = np.array([0, 0, 0, -1])
172
+ q = xp.asarray([0.0, 0, 0, -1])
125
173
  r = Rotation.from_quat(q)
126
- assert_equal(q, r.as_quat(canonical=False))
127
-
174
+ xp_assert_equal(q, r.as_quat(canonical=False))
128
175
  # Check composition and inverse
129
- q = np.array([1, 0, 0, 1])/np.sqrt(2) # 90 deg rotation about x
176
+ q = xp.asarray([1.0, 0, 0, 1])/math.sqrt(2) # 90 deg rotation about x
130
177
  r = Rotation.from_quat(q)
131
178
  r3 = r*r*r
132
- assert_allclose(r.as_quat(canonical=False)*np.sqrt(2),
133
- [1, 0, 0, 1])
134
- assert_allclose(r.inv().as_quat(canonical=False)*np.sqrt(2),
135
- [-1, 0, 0, 1])
136
- assert_allclose(r3.as_quat(canonical=False)*np.sqrt(2),
137
- [1, 0, 0, -1])
138
- assert_allclose(r3.inv().as_quat(canonical=False)*np.sqrt(2),
139
- [-1, 0, 0, -1])
179
+ xp_assert_close(r.as_quat(canonical=False)*math.sqrt(2),
180
+ xp.asarray([1.0, 0, 0, 1]))
181
+ xp_assert_close(r.inv().as_quat(canonical=False)*math.sqrt(2),
182
+ xp.asarray([-1.0, 0, 0, 1]))
183
+ xp_assert_close(r3.as_quat(canonical=False)*math.sqrt(2),
184
+ xp.asarray([1.0, 0, 0, -1]))
185
+ xp_assert_close(r3.inv().as_quat(canonical=False)*math.sqrt(2),
186
+ xp.asarray([-1.0, 0, 0, -1]))
140
187
 
141
188
  # More sanity checks
142
- assert_allclose((r*r.inv()).as_quat(canonical=False),
143
- [0, 0, 0, 1], atol=2e-16)
144
- assert_allclose((r3*r3.inv()).as_quat(canonical=False),
145
- [0, 0, 0, 1], atol=2e-16)
146
- assert_allclose((r*r3).as_quat(canonical=False),
147
- [0, 0, 0, -1], atol=2e-16)
148
- assert_allclose((r.inv()*r3.inv()).as_quat(canonical=False),
149
- [0, 0, 0, -1], atol=2e-16)
189
+ xp_assert_close((r*r.inv()).as_quat(canonical=False),
190
+ xp.asarray([0.0, 0, 0, 1]), atol=2e-16)
191
+ xp_assert_close((r3*r3.inv()).as_quat(canonical=False),
192
+ xp.asarray([0.0, 0, 0, 1]), atol=2e-16)
193
+ xp_assert_close((r*r3).as_quat(canonical=False),
194
+ xp.asarray([0.0, 0, 0, -1]), atol=2e-16)
195
+ xp_assert_close((r.inv() * r3.inv()).as_quat(canonical=False),
196
+ xp.asarray([0.0, 0, 0, -1]), atol=2e-16)
150
197
 
151
198
 
152
- def test_from_quat_wrong_shape():
199
+ def test_from_quat_wrong_shape(xp):
153
200
  # Wrong shape 1d array
154
201
  with pytest.raises(ValueError, match='Expected `quat` to have shape'):
155
- Rotation.from_quat(np.array([1, 2, 3]))
202
+ Rotation.from_quat(xp.asarray([1, 2, 3]))
156
203
 
157
204
  # Wrong shape 2d array
158
205
  with pytest.raises(ValueError, match='Expected `quat` to have shape'):
159
- Rotation.from_quat(np.array([
206
+ Rotation.from_quat(xp.asarray([
160
207
  [1, 2, 3, 4, 5],
161
208
  [4, 5, 6, 7, 8]
162
209
  ]))
163
210
 
164
211
  # 3d array
165
212
  with pytest.raises(ValueError, match='Expected `quat` to have shape'):
166
- Rotation.from_quat(np.array([
213
+ Rotation.from_quat(xp.asarray([
167
214
  [[1, 2, 3, 4]],
168
215
  [[4, 5, 6, 7]]
169
216
  ]))
170
217
 
171
- # 0-length 2d array
172
- with pytest.raises(ValueError, match='Expected `quat` to have shape'):
173
- Rotation.from_quat(np.array([]).reshape((0, 4)))
174
-
175
218
 
176
- def test_zero_norms_from_quat():
177
- x = np.array([
219
+ def test_zero_norms_from_quat(xp):
220
+ x = xp.asarray([
178
221
  [3, 4, 0, 0],
179
222
  [0, 0, 0, 0],
180
223
  [5, 0, 12, 0]
181
224
  ])
182
- with pytest.raises(ValueError):
183
- Rotation.from_quat(x)
225
+ if is_lazy_array(x):
226
+ assert xp.all(xp.isnan(Rotation.from_quat(x).as_quat()[1, ...]))
227
+ else:
228
+ with pytest.raises(ValueError):
229
+ Rotation.from_quat(x)
184
230
 
185
231
 
186
- def test_as_matrix_single_1d_quaternion():
187
- quat = [0, 0, 0, 1]
232
+ def test_as_matrix_single_1d_quaternion(xp):
233
+ quat = xp.asarray([0, 0, 0, 1])
188
234
  mat = Rotation.from_quat(quat).as_matrix()
189
235
  # mat.shape == (3,3) due to 1d input
190
- assert_array_almost_equal(mat, np.eye(3))
236
+ xp_assert_close(mat, xp.eye(3))
191
237
 
192
238
 
193
- def test_as_matrix_single_2d_quaternion():
194
- quat = [[0, 0, 1, 1]]
239
+ def test_as_matrix_single_2d_quaternion(xp):
240
+ quat = xp.asarray([[0, 0, 1, 1]])
195
241
  mat = Rotation.from_quat(quat).as_matrix()
196
242
  assert_equal(mat.shape, (1, 3, 3))
197
- expected_mat = np.array([
198
- [0, -1, 0],
243
+ expected_mat = xp.asarray([
244
+ [0.0, -1, 0],
199
245
  [1, 0, 0],
200
246
  [0, 0, 1]
201
247
  ])
202
- assert_array_almost_equal(mat[0], expected_mat)
248
+ xp_assert_close(mat[0, ...], expected_mat)
203
249
 
204
250
 
205
- def test_as_matrix_from_square_input():
206
- quats = [
251
+ def test_as_matrix_from_square_input(xp):
252
+ quats = xp.asarray([
207
253
  [0, 0, 1, 1],
208
254
  [0, 1, 0, 1],
209
255
  [0, 0, 0, 1],
210
256
  [0, 0, 0, -1]
211
- ]
257
+ ])
212
258
  mat = Rotation.from_quat(quats).as_matrix()
213
259
  assert_equal(mat.shape, (4, 3, 3))
214
260
 
215
- expected0 = np.array([
216
- [0, -1, 0],
261
+ expected0 = xp.asarray([
262
+ [0.0, -1, 0],
217
263
  [1, 0, 0],
218
264
  [0, 0, 1]
219
265
  ])
220
- assert_array_almost_equal(mat[0], expected0)
266
+ xp_assert_close(mat[0, ...], expected0)
221
267
 
222
- expected1 = np.array([
223
- [0, 0, 1],
268
+ expected1 = xp.asarray([
269
+ [0.0, 0, 1],
224
270
  [0, 1, 0],
225
271
  [-1, 0, 0]
226
272
  ])
227
- assert_array_almost_equal(mat[1], expected1)
273
+ xp_assert_close(mat[1, ...], expected1)
228
274
 
229
- assert_array_almost_equal(mat[2], np.eye(3))
230
- assert_array_almost_equal(mat[3], np.eye(3))
275
+ xp_assert_close(mat[2, ...], xp.eye(3))
276
+ xp_assert_close(mat[3, ...], xp.eye(3))
231
277
 
232
278
 
233
- def test_as_matrix_from_generic_input():
234
- quats = [
279
+ def test_as_matrix_from_generic_input(xp):
280
+ quats = xp.asarray([
235
281
  [0, 0, 1, 1],
236
282
  [0, 1, 0, 1],
237
283
  [1, 2, 3, 4]
238
- ]
284
+ ])
239
285
  mat = Rotation.from_quat(quats).as_matrix()
240
286
  assert_equal(mat.shape, (3, 3, 3))
241
287
 
242
- expected0 = np.array([
243
- [0, -1, 0],
288
+ expected0 = xp.asarray([
289
+ [0.0, -1, 0],
244
290
  [1, 0, 0],
245
291
  [0, 0, 1]
246
292
  ])
247
- assert_array_almost_equal(mat[0], expected0)
293
+ xp_assert_close(mat[0, ...], expected0)
248
294
 
249
- expected1 = np.array([
250
- [0, 0, 1],
295
+ expected1 = xp.asarray([
296
+ [0.0, 0, 1],
251
297
  [0, 1, 0],
252
298
  [-1, 0, 0]
253
299
  ])
254
- assert_array_almost_equal(mat[1], expected1)
300
+ xp_assert_close(mat[1, ...], expected1)
255
301
 
256
- expected2 = np.array([
302
+ expected2 = xp.asarray([
257
303
  [0.4, -2, 2.2],
258
304
  [2.8, 1, 0.4],
259
305
  [-1, 2, 2]
260
306
  ]) / 3
261
- assert_array_almost_equal(mat[2], expected2)
307
+ xp_assert_close(mat[2, ...], expected2)
262
308
 
263
309
 
264
- def test_from_single_2d_matrix():
265
- mat = [
310
+ def test_from_single_2d_matrix(xp):
311
+ mat = xp.asarray([
266
312
  [0, 0, 1],
267
313
  [1, 0, 0],
268
314
  [0, 1, 0]
269
- ]
270
- expected_quat = [0.5, 0.5, 0.5, 0.5]
271
- assert_array_almost_equal(
272
- Rotation.from_matrix(mat).as_quat(),
273
- expected_quat)
315
+ ])
316
+ expected_quat = xp.asarray([0.5, 0.5, 0.5, 0.5])
317
+ xp_assert_close(Rotation.from_matrix(mat).as_quat(), expected_quat)
274
318
 
275
319
 
276
- def test_from_single_3d_matrix():
277
- mat = np.array([
320
+ def test_from_single_3d_matrix(xp):
321
+ mat = xp.asarray([[
278
322
  [0, 0, 1],
279
323
  [1, 0, 0],
280
- [0, 1, 0]
281
- ]).reshape((1, 3, 3))
282
- expected_quat = np.array([0.5, 0.5, 0.5, 0.5]).reshape((1, 4))
283
- assert_array_almost_equal(
284
- Rotation.from_matrix(mat).as_quat(),
285
- expected_quat)
324
+ [0, 1, 0],
325
+ ]])
326
+ expected_quat = xp.asarray([[0.5, 0.5, 0.5, 0.5]])
327
+ xp_assert_close(Rotation.from_matrix(mat).as_quat(), expected_quat)
286
328
 
287
329
 
288
- def test_from_matrix_calculation():
289
- expected_quat = np.array([1, 1, 6, 1]) / np.sqrt(39)
290
- mat = np.array([
330
+ def test_from_matrix_calculation(xp):
331
+ atol = 1e-8
332
+ expected_quat = xp.asarray([1.0, 1, 6, 1]) / math.sqrt(39)
333
+ mat = xp.asarray([
291
334
  [-0.8974359, -0.2564103, 0.3589744],
292
335
  [0.3589744, -0.8974359, 0.2564103],
293
336
  [0.2564103, 0.3589744, 0.8974359]
294
337
  ])
295
- assert_array_almost_equal(
296
- Rotation.from_matrix(mat).as_quat(),
297
- expected_quat)
298
- assert_array_almost_equal(
299
- Rotation.from_matrix(mat.reshape((1, 3, 3))).as_quat(),
300
- expected_quat.reshape((1, 4)))
338
+ xp_assert_close(Rotation.from_matrix(mat).as_quat(), expected_quat, atol=atol)
339
+ xp_assert_close(Rotation.from_matrix(xp.reshape(mat, (1, 3, 3))).as_quat(),
340
+ xp.reshape(expected_quat, (1, 4)),
341
+ atol=atol)
301
342
 
302
343
 
303
- def test_matrix_calculation_pipeline():
304
- mat = special_ortho_group.rvs(3, size=10, random_state=0)
305
- assert_array_almost_equal(Rotation.from_matrix(mat).as_matrix(), mat)
344
+ def test_matrix_calculation_pipeline(xp):
345
+ mat = xp.asarray(special_ortho_group.rvs(3, size=10, random_state=0))
346
+ xp_assert_close(Rotation.from_matrix(mat).as_matrix(), mat)
306
347
 
307
348
 
308
- def test_from_matrix_ortho_output():
349
+ def test_from_matrix_ortho_output(xp):
350
+ atol = 1e-12
309
351
  rnd = np.random.RandomState(0)
310
- mat = rnd.random_sample((100, 3, 3))
311
- dets = np.linalg.det(mat)
312
- for i in range(len(dets)):
352
+ mat = xp.asarray(rnd.random_sample((100, 3, 3)))
353
+ dets = xp.linalg.det(mat)
354
+ for i in range(dets.shape[0]):
313
355
  # Make sure we have a right-handed rotation matrix
314
356
  if dets[i] < 0:
315
- mat[i] = -mat[i]
357
+ mat = xpx.at(mat)[i, ...].set(-mat[i, ...])
316
358
  ortho_mat = Rotation.from_matrix(mat).as_matrix()
317
359
 
318
- mult_result = np.einsum('...ij,...jk->...ik', ortho_mat,
319
- ortho_mat.transpose((0, 2, 1)))
360
+ mult_result = xp.matmul(ortho_mat, xp.matrix_transpose(ortho_mat))
320
361
 
321
- eye3d = np.zeros((100, 3, 3))
322
- for i in range(3):
323
- eye3d[:, i, i] = 1.0
362
+ eye3d = xp.zeros((100, 3, 3)) + xp.eye(3)
363
+ xp_assert_close(mult_result, eye3d, atol=atol)
324
364
 
325
- assert_array_almost_equal(mult_result, eye3d)
326
365
 
327
-
328
- def test_from_matrix_normalize():
329
- mat = np.array([
366
+ def test_from_matrix_normalize(xp):
367
+ mat = xp.asarray([
330
368
  [1, 1, 0],
331
369
  [0, 1, 0],
332
370
  [0, 0, 1]])
333
- expected = np.array([[ 0.894427, 0.447214, 0.0],
334
- [-0.447214, 0.894427, 0.0],
335
- [ 0.0, 0.0, 1.0]])
336
- assert_allclose(Rotation.from_matrix(mat).as_matrix(), expected, atol=1e-6)
371
+ expected = xp.asarray([[ 0.894427, 0.447214, 0.0],
372
+ [-0.447214, 0.894427, 0.0],
373
+ [ 0.0, 0.0, 1.0]])
374
+ xp_assert_close(Rotation.from_matrix(mat).as_matrix(), expected, atol=1e-6)
337
375
 
338
- mat = np.array([
376
+ mat = xp.asarray([
339
377
  [0, -0.5, 0 ],
340
378
  [0.5, 0 , 0 ],
341
379
  [0, 0 , 0.5]])
342
- expected = np.array([[ 0, -1, 0],
343
- [ 1, 0, 0],
344
- [ 0, 0, 1]])
345
- assert_allclose(Rotation.from_matrix(mat).as_matrix(), expected, atol=1e-6)
346
-
347
-
348
- def test_from_matrix_non_positive_determinant():
349
- mat = np.eye(3)
350
- mat[0, 0] = 0
351
- with pytest.raises(ValueError, match='Non-positive determinant'):
352
- Rotation.from_matrix(mat)
353
-
354
- mat[0, 0] = -1
355
- with pytest.raises(ValueError, match='Non-positive determinant'):
356
- Rotation.from_matrix(mat)
380
+ expected = xp.asarray([[0.0, -1, 0],
381
+ [ 1, 0, 0],
382
+ [ 0, 0, 1]])
383
+ xp_assert_close(Rotation.from_matrix(mat).as_matrix(), expected, atol=1e-6)
384
+
385
+
386
+ def test_from_matrix_non_positive_determinant(xp):
387
+ mat = xp.eye(3)
388
+ mat = xpx.at(mat)[0, 0].set(0)
389
+ if is_lazy_array(mat):
390
+ assert xp.all(xp.isnan(Rotation.from_matrix(mat).as_matrix()))
391
+ else:
392
+ with pytest.raises(ValueError, match="Non-positive determinant"):
393
+ Rotation.from_matrix(mat)
394
+
395
+ mat = xpx.at(mat)[0, 0].set(-1)
396
+ if is_lazy_array(mat):
397
+ assert xp.all(xp.isnan(Rotation.from_matrix(mat).as_matrix()))
398
+ else:
399
+ with pytest.raises(ValueError, match="Non-positive determinant"):
400
+ Rotation.from_matrix(mat)
401
+
402
+
403
+ def test_from_matrix_array_like():
404
+ rng = np.random.default_rng(123)
405
+ # Single rotation
406
+ r_expected = Rotation.random(rng=rng)
407
+ r = Rotation.from_matrix(r_expected.as_matrix().tolist())
408
+ assert r_expected.approx_equal(r, atol=1e-12)
409
+
410
+ # Multiple rotations
411
+ r_expected = Rotation.random(3, rng=rng)
412
+ r = Rotation.from_matrix(r_expected.as_matrix().tolist())
413
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
414
+
415
+
416
+ def test_from_matrix_int_dtype(xp):
417
+ mat = xp.asarray([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
418
+ r = Rotation.from_matrix(mat)
419
+ assert r.as_quat().dtype == xp_default_dtype(xp)
357
420
 
358
421
 
359
- def test_from_1d_single_rotvec():
360
- rotvec = [1, 0, 0]
361
- expected_quat = np.array([0.4794255, 0, 0, 0.8775826])
422
+ def test_from_1d_single_rotvec(xp):
423
+ atol = 1e-7
424
+ rotvec = xp.asarray([1, 0, 0])
425
+ expected_quat = xp.asarray([0.4794255, 0, 0, 0.8775826])
362
426
  result = Rotation.from_rotvec(rotvec)
363
- assert_array_almost_equal(result.as_quat(), expected_quat)
427
+ xp_assert_close(result.as_quat(), expected_quat, atol=atol)
364
428
 
365
429
 
366
- def test_from_2d_single_rotvec():
367
- rotvec = [[1, 0, 0]]
368
- expected_quat = np.array([[0.4794255, 0, 0, 0.8775826]])
430
+ def test_from_2d_single_rotvec(xp):
431
+ atol = 1e-7
432
+ rotvec = xp.asarray([[1, 0, 0]])
433
+ expected_quat = xp.asarray([[0.4794255, 0, 0, 0.8775826]])
369
434
  result = Rotation.from_rotvec(rotvec)
370
- assert_array_almost_equal(result.as_quat(), expected_quat)
435
+ xp_assert_close(result.as_quat(), expected_quat, atol=atol)
371
436
 
372
437
 
373
- def test_from_generic_rotvec():
374
- rotvec = [
438
+ def test_from_generic_rotvec(xp):
439
+ atol = 1e-7
440
+ rotvec = xp.asarray([
375
441
  [1, 2, 2],
376
442
  [1, -1, 0.5],
377
- [0, 0, 0]
378
- ]
379
- expected_quat = np.array([
443
+ [0, 0, 0]])
444
+ expected_quat = xp.asarray([
380
445
  [0.3324983, 0.6649967, 0.6649967, 0.0707372],
381
446
  [0.4544258, -0.4544258, 0.2272129, 0.7316889],
382
447
  [0, 0, 0, 1]
383
448
  ])
384
- assert_array_almost_equal(
385
- Rotation.from_rotvec(rotvec).as_quat(),
386
- expected_quat)
449
+ xp_assert_close(Rotation.from_rotvec(rotvec).as_quat(), expected_quat, atol=atol)
387
450
 
388
451
 
389
- def test_from_rotvec_small_angle():
390
- rotvec = np.array([
391
- [5e-4 / np.sqrt(3), -5e-4 / np.sqrt(3), 5e-4 / np.sqrt(3)],
452
+ def test_from_rotvec_small_angle(xp):
453
+ rotvec = xp.asarray([
454
+ [5e-4 / math.sqrt(3), -5e-4 / math.sqrt(3), 5e-4 / math.sqrt(3)],
392
455
  [0.2, 0.3, 0.4],
393
456
  [0, 0, 0]
394
457
  ])
395
458
 
396
459
  quat = Rotation.from_rotvec(rotvec).as_quat()
397
460
  # cos(theta/2) ~~ 1 for small theta
398
- assert_allclose(quat[0, 3], 1)
461
+ xp_assert_close(quat[0, 3], xp.asarray(1.0)[()])
399
462
  # sin(theta/2) / theta ~~ 0.5 for small theta
400
- assert_allclose(quat[0, :3], rotvec[0] * 0.5)
463
+ xp_assert_close(quat[0, :3], rotvec[0, ...] * 0.5)
401
464
 
402
- assert_allclose(quat[1, 3], 0.9639685)
403
- assert_allclose(
404
- quat[1, :3],
405
- np.array([
465
+ xp_assert_close(quat[1, 3], xp.asarray(0.9639685)[()])
466
+ xp_assert_close(quat[1, :3],
467
+ xp.asarray([
406
468
  0.09879603932153465,
407
469
  0.14819405898230198,
408
- 0.19759207864306931
409
- ]))
470
+ 0.19759207864306931]))
471
+
472
+ xp_assert_equal(quat[2, ...], xp.asarray([0.0, 0, 0, 1]))
473
+
474
+
475
+ def test_from_rotvec_array_like():
476
+ rng = np.random.default_rng(123)
477
+ # Single rotation
478
+ r_expected = Rotation.random(rng=rng)
479
+ r = Rotation.from_rotvec(r_expected.as_rotvec().tolist())
480
+ assert r_expected.approx_equal(r, atol=1e-12)
481
+
482
+ # Multiple rotations
483
+ r_expected = Rotation.random(3, rng=rng)
484
+ r = Rotation.from_rotvec(r_expected.as_rotvec().tolist())
485
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
410
486
 
411
- assert_equal(quat[2], np.array([0, 0, 0, 1]))
412
487
 
488
+ def test_from_rotvec_int_dtype(xp):
489
+ rotvec = xp.asarray([1, 0, 0])
490
+ r = Rotation.from_rotvec(rotvec)
491
+ assert r.as_quat().dtype == xp_default_dtype(xp)
413
492
 
414
- def test_degrees_from_rotvec():
415
- rotvec1 = [1.0 / np.cbrt(3), 1.0 / np.cbrt(3), 1.0 / np.cbrt(3)]
493
+
494
+ def test_degrees_from_rotvec(xp):
495
+ rotvec1 = xp.asarray([1 / 3 ** (1/3)] * 3)
416
496
  rot1 = Rotation.from_rotvec(rotvec1, degrees=True)
417
497
  quat1 = rot1.as_quat()
418
498
 
419
- rotvec2 = np.deg2rad(rotvec1)
499
+ # deg2rad is not implemented in Array API -> / 180 * xp.pi
500
+ rotvec2 = xp.asarray(rotvec1 / 180 * xp.pi)
420
501
  rot2 = Rotation.from_rotvec(rotvec2)
421
502
  quat2 = rot2.as_quat()
422
503
 
423
- assert_allclose(quat1, quat2)
504
+ xp_assert_close(quat1, quat2)
424
505
 
425
506
 
426
- def test_malformed_1d_from_rotvec():
507
+ def test_malformed_1d_from_rotvec(xp):
427
508
  with pytest.raises(ValueError, match='Expected `rot_vec` to have shape'):
428
- Rotation.from_rotvec([1, 2])
509
+ Rotation.from_rotvec(xp.asarray([1, 2]))
429
510
 
430
511
 
431
- def test_malformed_2d_from_rotvec():
512
+ def test_malformed_2d_from_rotvec(xp):
432
513
  with pytest.raises(ValueError, match='Expected `rot_vec` to have shape'):
433
- Rotation.from_rotvec([
514
+ Rotation.from_rotvec(xp.asarray([
434
515
  [1, 2, 3, 4],
435
516
  [5, 6, 7, 8]
436
- ])
517
+ ]))
437
518
 
438
519
 
439
- def test_as_generic_rotvec():
440
- quat = np.array([
520
+ def test_as_generic_rotvec(xp):
521
+ quat = xp.asarray([
441
522
  [1, 2, -1, 0.5],
442
523
  [1, -1, 1, 0.0003],
443
524
  [0, 0, 0, 1]
444
525
  ])
445
- quat /= np.linalg.norm(quat, axis=1)[:, None]
526
+ quat /= xp_vector_norm(quat, axis=-1, keepdims=True)
446
527
 
447
528
  rotvec = Rotation.from_quat(quat).as_rotvec()
448
- angle = np.linalg.norm(rotvec, axis=1)
529
+ angle = xp_vector_norm(rotvec, axis=-1)
449
530
 
450
- assert_allclose(quat[:, 3], np.cos(angle/2))
451
- assert_allclose(np.cross(rotvec, quat[:, :3]), np.zeros((3, 3)))
531
+ xp_assert_close(quat[:, 3], xp.cos(angle / 2))
532
+ xp_assert_close(xp.linalg.cross(rotvec, quat[:, :3]), xp.zeros((3, 3)), atol=1e-15)
452
533
 
453
534
 
454
- def test_as_rotvec_single_1d_input():
455
- quat = np.array([1, 2, -3, 2])
456
- expected_rotvec = np.array([0.5772381, 1.1544763, -1.7317144])
535
+ def test_as_rotvec_single_1d_input(xp):
536
+ quat = xp.asarray([1, 2, -3, 2])
537
+ expected_rotvec = xp.asarray([0.5772381, 1.1544763, -1.7317144])
457
538
 
458
539
  actual_rotvec = Rotation.from_quat(quat).as_rotvec()
459
540
 
460
541
  assert_equal(actual_rotvec.shape, (3,))
461
- assert_allclose(actual_rotvec, expected_rotvec)
542
+ xp_assert_close(actual_rotvec, expected_rotvec)
462
543
 
463
544
 
464
- def test_as_rotvec_single_2d_input():
465
- quat = np.array([[1, 2, -3, 2]])
466
- expected_rotvec = np.array([[0.5772381, 1.1544763, -1.7317144]])
545
+ def test_as_rotvec_single_2d_input(xp):
546
+ quat = xp.asarray([[1, 2, -3, 2]])
547
+ expected_rotvec = xp.asarray([[0.5772381, 1.1544763, -1.7317144]])
467
548
 
468
549
  actual_rotvec = Rotation.from_quat(quat).as_rotvec()
469
550
 
470
551
  assert_equal(actual_rotvec.shape, (1, 3))
471
- assert_allclose(actual_rotvec, expected_rotvec)
552
+ xp_assert_close(actual_rotvec, expected_rotvec)
472
553
 
473
554
 
474
- def test_as_rotvec_degrees():
555
+ def test_as_rotvec_degrees(xp):
475
556
  # x->y, y->z, z->x
476
- mat = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
557
+ mat = xp.asarray([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
477
558
  rot = Rotation.from_matrix(mat)
478
559
  rotvec = rot.as_rotvec(degrees=True)
479
- angle = np.linalg.norm(rotvec)
480
- assert_allclose(angle, 120.0)
481
- assert_allclose(rotvec[0], rotvec[1])
482
- assert_allclose(rotvec[1], rotvec[2])
560
+ angle = xp_vector_norm(rotvec, axis=-1)
561
+ xp_assert_close(angle, xp.asarray(120.0)[()])
562
+ xp_assert_close(rotvec[0], rotvec[1])
563
+ xp_assert_close(rotvec[1], rotvec[2])
483
564
 
484
565
 
485
- def test_rotvec_calc_pipeline():
566
+ def test_rotvec_calc_pipeline(xp):
486
567
  # Include small angles
487
- rotvec = np.array([
568
+ rotvec = xp.asarray([
488
569
  [0, 0, 0],
489
570
  [1, -1, 2],
490
571
  [-3e-4, 3.5e-4, 7.5e-5]
491
572
  ])
492
- assert_allclose(Rotation.from_rotvec(rotvec).as_rotvec(), rotvec)
493
- assert_allclose(Rotation.from_rotvec(rotvec, degrees=True).as_rotvec(degrees=True),
573
+ xp_assert_close(Rotation.from_rotvec(rotvec).as_rotvec(), rotvec)
574
+ xp_assert_close(Rotation.from_rotvec(rotvec, degrees=True).as_rotvec(degrees=True),
494
575
  rotvec)
495
576
 
496
577
 
497
- def test_from_1d_single_mrp():
498
- mrp = [0, 0, 1.0]
499
- expected_quat = np.array([0, 0, 1, 0])
578
+ def test_from_1d_single_mrp(xp):
579
+ mrp = xp.asarray([0, 0, 1.0])
580
+ expected_quat = xp.asarray([0.0, 0, 1, 0])
500
581
  result = Rotation.from_mrp(mrp)
501
- assert_array_almost_equal(result.as_quat(), expected_quat)
582
+ xp_assert_close(result.as_quat(), expected_quat, atol=1e-12)
502
583
 
503
584
 
504
- def test_from_2d_single_mrp():
505
- mrp = [[0, 0, 1.0]]
506
- expected_quat = np.array([[0, 0, 1, 0]])
585
+ def test_from_2d_single_mrp(xp):
586
+ mrp = xp.asarray([[0, 0, 1.0]])
587
+ expected_quat = xp.asarray([[0.0, 0, 1, 0]])
507
588
  result = Rotation.from_mrp(mrp)
508
- assert_array_almost_equal(result.as_quat(), expected_quat)
589
+ xp_assert_close(result.as_quat(), expected_quat)
590
+
591
+
592
+ def test_from_mrp_array_like():
593
+ rng = np.random.default_rng(123)
594
+ # Single rotation
595
+ r_expected = Rotation.random(rng=rng)
596
+ r = Rotation.from_mrp(r_expected.as_mrp().tolist())
597
+ assert r_expected.approx_equal(r, atol=1e-12)
509
598
 
599
+ # Multiple rotations
600
+ r_expected = Rotation.random(3, rng=rng)
601
+ r = Rotation.from_mrp(r_expected.as_mrp().tolist())
602
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
510
603
 
511
- def test_from_generic_mrp():
512
- mrp = np.array([
604
+
605
+ def test_from_mrp_int_dtype(xp):
606
+ mrp = xp.asarray([0, 0, 1])
607
+ r = Rotation.from_mrp(mrp)
608
+ assert r.as_quat().dtype == xp_default_dtype(xp)
609
+
610
+
611
+ def test_from_generic_mrp(xp):
612
+ mrp = xp.asarray([
513
613
  [1, 2, 2],
514
614
  [1, -1, 0.5],
515
615
  [0, 0, 0]])
516
- expected_quat = np.array([
616
+ expected_quat = xp.asarray([
517
617
  [0.2, 0.4, 0.4, -0.8],
518
618
  [0.61538462, -0.61538462, 0.30769231, -0.38461538],
519
619
  [0, 0, 0, 1]])
520
- assert_array_almost_equal(Rotation.from_mrp(mrp).as_quat(), expected_quat)
620
+ xp_assert_close(Rotation.from_mrp(mrp).as_quat(), expected_quat)
521
621
 
522
622
 
523
- def test_malformed_1d_from_mrp():
623
+ def test_malformed_1d_from_mrp(xp):
524
624
  with pytest.raises(ValueError, match='Expected `mrp` to have shape'):
525
- Rotation.from_mrp([1, 2])
625
+ Rotation.from_mrp(xp.asarray([1, 2]))
526
626
 
527
627
 
528
- def test_malformed_2d_from_mrp():
628
+ def test_malformed_2d_from_mrp(xp):
529
629
  with pytest.raises(ValueError, match='Expected `mrp` to have shape'):
530
- Rotation.from_mrp([
630
+ Rotation.from_mrp(xp.asarray([
531
631
  [1, 2, 3, 4],
532
632
  [5, 6, 7, 8]
533
- ])
633
+ ]))
534
634
 
535
635
 
536
- def test_as_generic_mrp():
537
- quat = np.array([
636
+ def test_as_generic_mrp(xp):
637
+ quat = xp.asarray([
538
638
  [1, 2, -1, 0.5],
539
639
  [1, -1, 1, 0.0003],
540
640
  [0, 0, 0, 1]])
541
- quat /= np.linalg.norm(quat, axis=1)[:, None]
641
+ quat /= xp_vector_norm(quat, axis=1)[:, None]
542
642
 
543
- expected_mrp = np.array([
643
+ expected_mrp = xp.asarray([
544
644
  [0.33333333, 0.66666667, -0.33333333],
545
645
  [0.57725028, -0.57725028, 0.57725028],
546
646
  [0, 0, 0]])
547
- assert_array_almost_equal(Rotation.from_quat(quat).as_mrp(), expected_mrp)
647
+ xp_assert_close(Rotation.from_quat(quat).as_mrp(), expected_mrp)
648
+
548
649
 
549
- def test_past_180_degree_rotation():
650
+ def test_past_180_degree_rotation(xp):
550
651
  # ensure that a > 180 degree rotation is returned as a <180 rotation in MRPs
551
652
  # in this case 270 should be returned as -90
552
- expected_mrp = np.array([-np.tan(np.pi/2/4), 0.0, 0])
553
- assert_array_almost_equal(
554
- Rotation.from_euler('xyz', [270, 0, 0], degrees=True).as_mrp(),
555
- expected_mrp
653
+ expected_mrp = xp.asarray([-math.tan(xp.pi / 2 / 4), 0.0, 0])
654
+ xp_assert_close(
655
+ Rotation.from_euler('xyz', xp.asarray([270, 0, 0]), degrees=True).as_mrp(),
656
+ expected_mrp,
556
657
  )
557
658
 
558
659
 
559
- def test_as_mrp_single_1d_input():
560
- quat = np.array([1, 2, -3, 2])
561
- expected_mrp = np.array([0.16018862, 0.32037724, -0.48056586])
660
+ def test_as_mrp_single_1d_input(xp):
661
+ quat = xp.asarray([1, 2, -3, 2])
662
+ expected_mrp = xp.asarray([0.16018862, 0.32037724, -0.48056586])
562
663
 
563
664
  actual_mrp = Rotation.from_quat(quat).as_mrp()
564
665
 
565
666
  assert_equal(actual_mrp.shape, (3,))
566
- assert_allclose(actual_mrp, expected_mrp)
667
+ xp_assert_close(actual_mrp, expected_mrp)
567
668
 
568
669
 
569
- def test_as_mrp_single_2d_input():
570
- quat = np.array([[1, 2, -3, 2]])
571
- expected_mrp = np.array([[0.16018862, 0.32037724, -0.48056586]])
670
+ def test_as_mrp_single_2d_input(xp):
671
+ quat = xp.asarray([[1, 2, -3, 2]])
672
+ expected_mrp = xp.asarray([[0.16018862, 0.32037724, -0.48056586]])
572
673
 
573
674
  actual_mrp = Rotation.from_quat(quat).as_mrp()
574
675
 
575
676
  assert_equal(actual_mrp.shape, (1, 3))
576
- assert_allclose(actual_mrp, expected_mrp)
677
+ xp_assert_close(actual_mrp, expected_mrp)
577
678
 
578
679
 
579
- def test_mrp_calc_pipeline():
580
- actual_mrp = np.array([
680
+ def test_mrp_calc_pipeline(xp):
681
+ actual_mrp = xp.asarray([
581
682
  [0, 0, 0],
582
683
  [1, -1, 2],
583
684
  [0.41421356, 0, 0],
584
685
  [0.1, 0.2, 0.1]])
585
- expected_mrp = np.array([
686
+ expected_mrp = xp.asarray([
586
687
  [0, 0, 0],
587
688
  [-0.16666667, 0.16666667, -0.33333333],
588
689
  [0.41421356, 0, 0],
589
690
  [0.1, 0.2, 0.1]])
590
- assert_allclose(Rotation.from_mrp(actual_mrp).as_mrp(), expected_mrp)
691
+ xp_assert_close(Rotation.from_mrp(actual_mrp).as_mrp(), expected_mrp)
591
692
 
592
693
 
593
- def test_from_euler_single_rotation():
594
- quat = Rotation.from_euler('z', 90, degrees=True).as_quat()
595
- expected_quat = np.array([0, 0, 1, 1]) / np.sqrt(2)
596
- assert_allclose(quat, expected_quat)
694
+ def test_from_euler_single_rotation(xp):
695
+ quat = Rotation.from_euler("z", xp.asarray(90), degrees=True).as_quat()
696
+ expected_quat = xp.asarray([0.0, 0, 1, 1]) / math.sqrt(2)
697
+ xp_assert_close(quat, expected_quat)
597
698
 
598
699
 
599
- def test_single_intrinsic_extrinsic_rotation():
600
- extrinsic = Rotation.from_euler('z', 90, degrees=True).as_matrix()
601
- intrinsic = Rotation.from_euler('Z', 90, degrees=True).as_matrix()
602
- assert_allclose(extrinsic, intrinsic)
700
+ def test_single_intrinsic_extrinsic_rotation(xp):
701
+ extrinsic = Rotation.from_euler('z', xp.asarray(90), degrees=True).as_matrix()
702
+ intrinsic = Rotation.from_euler('Z', xp.asarray(90), degrees=True).as_matrix()
703
+ xp_assert_close(extrinsic, intrinsic)
603
704
 
604
705
 
605
- def test_from_euler_rotation_order():
706
+ def test_from_euler_rotation_order(xp):
606
707
  # Intrinsic rotation is same as extrinsic with order reversed
607
708
  rnd = np.random.RandomState(0)
608
- a = rnd.randint(low=0, high=180, size=(6, 3))
609
- b = a[:, ::-1]
709
+ a = xp.asarray(rnd.randint(low=0, high=180, size=(6, 3)))
710
+ b = xp.flip(a, axis=-1)
610
711
  x = Rotation.from_euler('xyz', a, degrees=True).as_quat()
611
712
  y = Rotation.from_euler('ZYX', b, degrees=True).as_quat()
612
- assert_allclose(x, y)
713
+ xp_assert_close(x, y)
613
714
 
614
715
 
615
- def test_from_euler_elementary_extrinsic_rotation():
716
+ def test_from_euler_elementary_extrinsic_rotation(xp):
717
+ atol = 1e-12
616
718
  # Simple test to check if extrinsic rotations are implemented correctly
617
- mat = Rotation.from_euler('zx', [90, 90], degrees=True).as_matrix()
618
- expected_mat = np.array([
619
- [0, -1, 0],
719
+ mat = Rotation.from_euler('zx', xp.asarray([90, 90]), degrees=True).as_matrix()
720
+ expected_mat = xp.asarray([
721
+ [0.0, -1, 0],
620
722
  [0, 0, -1],
621
723
  [1, 0, 0]
622
724
  ])
623
- assert_array_almost_equal(mat, expected_mat)
725
+ xp_assert_close(mat, expected_mat, atol=atol)
624
726
 
625
727
 
626
- def test_from_euler_intrinsic_rotation_312():
627
- angles = [
728
+ def test_from_euler_intrinsic_rotation_312(xp):
729
+ atol = 1e-7
730
+ angles = xp.asarray([
628
731
  [30, 60, 45],
629
732
  [30, 60, 30],
630
733
  [45, 30, 60]
631
- ]
734
+ ])
632
735
  mat = Rotation.from_euler('ZXY', angles, degrees=True).as_matrix()
633
736
 
634
- assert_array_almost_equal(mat[0], np.array([
737
+ xp_assert_close(mat[0, ...], xp.asarray([
635
738
  [0.3061862, -0.2500000, 0.9185587],
636
739
  [0.8838835, 0.4330127, -0.1767767],
637
740
  [-0.3535534, 0.8660254, 0.3535534]
638
- ]))
741
+ ]), atol=atol)
639
742
 
640
- assert_array_almost_equal(mat[1], np.array([
743
+ xp_assert_close(mat[1, ...], xp.asarray([
641
744
  [0.5334936, -0.2500000, 0.8080127],
642
745
  [0.8080127, 0.4330127, -0.3995191],
643
746
  [-0.2500000, 0.8660254, 0.4330127]
644
- ]))
747
+ ]), atol=atol)
645
748
 
646
- assert_array_almost_equal(mat[2], np.array([
749
+ xp_assert_close(mat[2, ...], xp.asarray([
647
750
  [0.0473672, -0.6123725, 0.7891491],
648
751
  [0.6597396, 0.6123725, 0.4355958],
649
752
  [-0.7500000, 0.5000000, 0.4330127]
650
- ]))
753
+ ]), atol=atol)
651
754
 
652
755
 
653
- def test_from_euler_intrinsic_rotation_313():
654
- angles = [
756
+ def test_from_euler_intrinsic_rotation_313(xp):
757
+ angles = xp.asarray([
655
758
  [30, 60, 45],
656
759
  [30, 60, 30],
657
760
  [45, 30, 60]
658
- ]
761
+ ])
659
762
  mat = Rotation.from_euler('ZXZ', angles, degrees=True).as_matrix()
660
763
 
661
- assert_array_almost_equal(mat[0], np.array([
764
+ xp_assert_close(mat[0, ...], xp.asarray([
662
765
  [0.43559574, -0.78914913, 0.4330127],
663
766
  [0.65973961, -0.04736717, -0.750000],
664
767
  [0.61237244, 0.61237244, 0.500000]
665
768
  ]))
666
769
 
667
- assert_array_almost_equal(mat[1], np.array([
770
+ xp_assert_close(mat[1, ...], xp.asarray([
668
771
  [0.6250000, -0.64951905, 0.4330127],
669
772
  [0.64951905, 0.1250000, -0.750000],
670
773
  [0.4330127, 0.750000, 0.500000]
671
774
  ]))
672
775
 
673
- assert_array_almost_equal(mat[2], np.array([
776
+ xp_assert_close(mat[2, ...], xp.asarray([
674
777
  [-0.1767767, -0.91855865, 0.35355339],
675
778
  [0.88388348, -0.30618622, -0.35355339],
676
779
  [0.4330127, 0.25000000, 0.8660254]
677
780
  ]))
678
781
 
679
782
 
680
- def test_from_euler_extrinsic_rotation_312():
681
- angles = [
783
+ def test_from_euler_extrinsic_rotation_312(xp):
784
+ angles = xp.asarray([
682
785
  [30, 60, 45],
683
786
  [30, 60, 30],
684
787
  [45, 30, 60]
685
- ]
788
+ ])
686
789
  mat = Rotation.from_euler('zxy', angles, degrees=True).as_matrix()
687
790
 
688
- assert_array_almost_equal(mat[0], np.array([
791
+ xp_assert_close(mat[0, ...], xp.asarray([
689
792
  [0.91855865, 0.1767767, 0.35355339],
690
793
  [0.25000000, 0.4330127, -0.8660254],
691
794
  [-0.30618622, 0.88388348, 0.35355339]
692
795
  ]))
693
796
 
694
- assert_array_almost_equal(mat[1], np.array([
797
+ xp_assert_close(mat[1, ...], xp.asarray([
695
798
  [0.96650635, -0.0580127, 0.2500000],
696
799
  [0.25000000, 0.4330127, -0.8660254],
697
800
  [-0.0580127, 0.89951905, 0.4330127]
698
801
  ]))
699
802
 
700
- assert_array_almost_equal(mat[2], np.array([
803
+ xp_assert_close(mat[2, ...], xp.asarray([
701
804
  [0.65973961, -0.04736717, 0.7500000],
702
805
  [0.61237244, 0.61237244, -0.5000000],
703
806
  [-0.43559574, 0.78914913, 0.4330127]
704
807
  ]))
705
808
 
706
809
 
707
- def test_from_euler_extrinsic_rotation_313():
708
- angles = [
810
+ def test_from_euler_extrinsic_rotation_313(xp):
811
+ angles = xp.asarray([
709
812
  [30, 60, 45],
710
813
  [30, 60, 30],
711
814
  [45, 30, 60]
712
- ]
815
+ ])
713
816
  mat = Rotation.from_euler('zxz', angles, degrees=True).as_matrix()
714
817
 
715
- assert_array_almost_equal(mat[0], np.array([
818
+ xp_assert_close(mat[0, ...], xp.asarray([
716
819
  [0.43559574, -0.65973961, 0.61237244],
717
820
  [0.78914913, -0.04736717, -0.61237244],
718
821
  [0.4330127, 0.75000000, 0.500000]
719
822
  ]))
720
823
 
721
- assert_array_almost_equal(mat[1], np.array([
824
+ xp_assert_close(mat[1, ...], xp.asarray([
722
825
  [0.62500000, -0.64951905, 0.4330127],
723
826
  [0.64951905, 0.12500000, -0.750000],
724
827
  [0.4330127, 0.75000000, 0.500000]
725
828
  ]))
726
829
 
727
- assert_array_almost_equal(mat[2], np.array([
830
+ xp_assert_close(mat[2, ...], xp.asarray([
728
831
  [-0.1767767, -0.88388348, 0.4330127],
729
832
  [0.91855865, -0.30618622, -0.250000],
730
833
  [0.35355339, 0.35355339, 0.8660254]
731
834
  ]))
732
835
 
733
836
 
837
+ def test_from_euler_array_like():
838
+ rng = np.random.default_rng(123)
839
+ order = "xyz"
840
+ # Single rotation
841
+ r_expected = Rotation.random(rng=rng)
842
+ r = Rotation.from_euler(order, r_expected.as_euler(order).tolist())
843
+ assert r_expected.approx_equal(r, atol=1e-12)
844
+
845
+ # Multiple rotations
846
+ r_expected = Rotation.random(3, rng=rng)
847
+ r = Rotation.from_euler(order, r_expected.as_euler(order).tolist())
848
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
849
+
850
+
851
+ def test_from_euler_scalar():
852
+ rng = np.random.default_rng(123)
853
+ deg = rng.uniform(low=-180, high=180)
854
+ r_expected = Rotation.from_euler("x", deg, degrees=True)
855
+ r = Rotation.from_euler("x", float(deg), degrees=True)
856
+ assert r_expected.approx_equal(r, atol=1e-12)
857
+
858
+
734
859
  @pytest.mark.parametrize("seq_tuple", permutations("xyz"))
735
860
  @pytest.mark.parametrize("intrinsic", (False, True))
736
- def test_as_euler_asymmetric_axes(seq_tuple, intrinsic):
861
+ def test_as_euler_asymmetric_axes(xp, seq_tuple, intrinsic):
737
862
  # helper function for mean error tests
738
863
  def test_stats(error, mean_max, rms_max):
739
- mean = np.mean(error, axis=0)
740
- std = np.std(error, axis=0)
741
- rms = np.hypot(mean, std)
742
- assert np.all(np.abs(mean) < mean_max)
743
- assert np.all(rms < rms_max)
864
+ mean = xp.mean(error, axis=0)
865
+ std = xp.std(error, axis=0)
866
+ rms = xp.hypot(mean, std)
867
+ assert xp.all(xp.abs(mean) < mean_max)
868
+ assert xp.all(rms < rms_max)
744
869
 
745
870
  rnd = np.random.RandomState(0)
746
871
  n = 1000
@@ -748,6 +873,7 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic):
748
873
  angles[:, 0] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
749
874
  angles[:, 1] = rnd.uniform(low=-np.pi / 2, high=np.pi / 2, size=(n,))
750
875
  angles[:, 2] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
876
+ angles = xp.asarray(angles)
751
877
 
752
878
  seq = "".join(seq_tuple)
753
879
  if intrinsic:
@@ -756,9 +882,11 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic):
756
882
  seq = seq.upper()
757
883
  rotation = Rotation.from_euler(seq, angles)
758
884
  angles_quat = rotation.as_euler(seq)
885
+ # TODO: Why are we using _as_euler_from_matrix here? As a sanity check? It is not
886
+ # part of the public API and should not be used anywhere else
759
887
  angles_mat = rotation._as_euler_from_matrix(seq)
760
- assert_allclose(angles, angles_quat, atol=0, rtol=1e-12)
761
- assert_allclose(angles, angles_mat, atol=0, rtol=1e-12)
888
+ xp_assert_close(angles, angles_quat, atol=0, rtol=1e-12)
889
+ xp_assert_close(angles, angles_mat, atol=0, rtol=1e-12)
762
890
  test_stats(angles_quat - angles, 1e-15, 1e-14)
763
891
  test_stats(angles_mat - angles, 1e-15, 1e-14)
764
892
 
@@ -766,14 +894,14 @@ def test_as_euler_asymmetric_axes(seq_tuple, intrinsic):
766
894
 
767
895
  @pytest.mark.parametrize("seq_tuple", permutations("xyz"))
768
896
  @pytest.mark.parametrize("intrinsic", (False, True))
769
- def test_as_euler_symmetric_axes(seq_tuple, intrinsic):
897
+ def test_as_euler_symmetric_axes(xp, seq_tuple, intrinsic):
770
898
  # helper function for mean error tests
771
899
  def test_stats(error, mean_max, rms_max):
772
- mean = np.mean(error, axis=0)
773
- std = np.std(error, axis=0)
774
- rms = np.hypot(mean, std)
775
- assert np.all(np.abs(mean) < mean_max)
776
- assert np.all(rms < rms_max)
900
+ mean = xp.mean(error, axis=0)
901
+ std = xp.std(error, axis=0)
902
+ rms = xp.hypot(mean, std)
903
+ assert xp.all(xp.abs(mean) < mean_max)
904
+ assert xp.all(rms < rms_max)
777
905
 
778
906
  rnd = np.random.RandomState(0)
779
907
  n = 1000
@@ -781,6 +909,7 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic):
781
909
  angles[:, 0] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
782
910
  angles[:, 1] = rnd.uniform(low=0, high=np.pi, size=(n,))
783
911
  angles[:, 2] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
912
+ angles = xp.asarray(angles)
784
913
 
785
914
  # Rotation of the form A/B/A are rotation around symmetric axes
786
915
  seq = "".join([seq_tuple[0], seq_tuple[1], seq_tuple[0]])
@@ -788,9 +917,10 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic):
788
917
  seq = seq.upper()
789
918
  rotation = Rotation.from_euler(seq, angles)
790
919
  angles_quat = rotation.as_euler(seq)
920
+ # TODO: Same as before: Remove _as_euler_from_matrix?
791
921
  angles_mat = rotation._as_euler_from_matrix(seq)
792
- assert_allclose(angles, angles_quat, atol=0, rtol=1e-13)
793
- assert_allclose(angles, angles_mat, atol=0, rtol=1e-9)
922
+ xp_assert_close(angles, angles_quat, atol=0, rtol=1e-13)
923
+ xp_assert_close(angles, angles_mat, atol=0, rtol=1e-9)
794
924
  test_stats(angles_quat - angles, 1e-16, 1e-14)
795
925
  test_stats(angles_mat - angles, 1e-15, 1e-13)
796
926
 
@@ -798,10 +928,11 @@ def test_as_euler_symmetric_axes(seq_tuple, intrinsic):
798
928
  @pytest.mark.thread_unsafe
799
929
  @pytest.mark.parametrize("seq_tuple", permutations("xyz"))
800
930
  @pytest.mark.parametrize("intrinsic", (False, True))
801
- def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic):
931
+ def test_as_euler_degenerate_asymmetric_axes(xp, seq_tuple, intrinsic):
932
+ atol = 1e-12
802
933
  # Since we cannot check for angle equality, we check for rotation matrix
803
934
  # equality
804
- angles = np.array([
935
+ angles = xp.asarray([
805
936
  [45, 90, 35],
806
937
  [35, -90, 20],
807
938
  [35, 90, 25],
@@ -815,20 +946,23 @@ def test_as_euler_degenerate_asymmetric_axes(seq_tuple, intrinsic):
815
946
  rotation = Rotation.from_euler(seq, angles, degrees=True)
816
947
  mat_expected = rotation.as_matrix()
817
948
 
818
- with pytest.warns(UserWarning, match="Gimbal lock"):
949
+ # We can only warn on non-lazy backends because we'd need to condition on traced
950
+ # booleans
951
+ with eager_warns(mat_expected, UserWarning, match="Gimbal lock"):
819
952
  angle_estimates = rotation.as_euler(seq, degrees=True)
820
953
  mat_estimated = Rotation.from_euler(seq, angle_estimates, degrees=True).as_matrix()
821
954
 
822
- assert_array_almost_equal(mat_expected, mat_estimated)
955
+ xp_assert_close(mat_expected, mat_estimated, atol=atol)
823
956
 
824
957
 
825
958
  @pytest.mark.thread_unsafe
826
959
  @pytest.mark.parametrize("seq_tuple", permutations("xyz"))
827
960
  @pytest.mark.parametrize("intrinsic", (False, True))
828
- def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic):
961
+ def test_as_euler_degenerate_symmetric_axes(xp, seq_tuple, intrinsic):
962
+ atol = 1e-12
829
963
  # Since we cannot check for angle equality, we check for rotation matrix
830
964
  # equality
831
- angles = np.array([
965
+ angles = xp.asarray([
832
966
  [15, 0, 60],
833
967
  [35, 0, 75],
834
968
  [60, 180, 35],
@@ -843,22 +977,23 @@ def test_as_euler_degenerate_symmetric_axes(seq_tuple, intrinsic):
843
977
  rotation = Rotation.from_euler(seq, angles, degrees=True)
844
978
  mat_expected = rotation.as_matrix()
845
979
 
846
- with pytest.warns(UserWarning, match="Gimbal lock"):
980
+ # We can only warn on non-lazy backends
981
+ with eager_warns(mat_expected, UserWarning, match="Gimbal lock"):
847
982
  angle_estimates = rotation.as_euler(seq, degrees=True)
848
983
  mat_estimated = Rotation.from_euler(seq, angle_estimates, degrees=True).as_matrix()
849
984
 
850
- assert_array_almost_equal(mat_expected, mat_estimated)
985
+ xp_assert_close(mat_expected, mat_estimated, atol=atol)
851
986
 
852
987
 
853
988
  @pytest.mark.thread_unsafe
854
989
  @pytest.mark.parametrize("seq_tuple", permutations("xyz"))
855
990
  @pytest.mark.parametrize("intrinsic", (False, True))
856
- def test_as_euler_degenerate_compare_algorithms(seq_tuple, intrinsic):
991
+ def test_as_euler_degenerate_compare_algorithms(xp, seq_tuple, intrinsic):
857
992
  # this test makes sure that both algorithms are doing the same choices
858
993
  # in degenerate cases
859
994
 
860
995
  # asymmetric axes
861
- angles = np.array([
996
+ angles = xp.asarray([
862
997
  [45, 90, 35],
863
998
  [35, -90, 20],
864
999
  [35, 90, 25],
@@ -871,21 +1006,20 @@ def test_as_euler_degenerate_compare_algorithms(seq_tuple, intrinsic):
871
1006
  seq = seq.upper()
872
1007
 
873
1008
  rot = Rotation.from_euler(seq, angles, degrees=True)
874
- with pytest.warns(UserWarning, match="Gimbal lock"):
1009
+ with eager_warns(rot, UserWarning, match="Gimbal lock"):
875
1010
  estimates_matrix = rot._as_euler_from_matrix(seq, degrees=True)
876
- with pytest.warns(UserWarning, match="Gimbal lock"):
877
1011
  estimates_quat = rot.as_euler(seq, degrees=True)
878
- assert_allclose(
1012
+ xp_assert_close(
879
1013
  estimates_matrix[:, [0, 2]], estimates_quat[:, [0, 2]], atol=0, rtol=1e-12
880
1014
  )
881
- assert_allclose(estimates_matrix[:, 1], estimates_quat[:, 1], atol=0, rtol=1e-7)
1015
+ xp_assert_close(estimates_matrix[:, 1], estimates_quat[:, 1], atol=0, rtol=1e-7)
882
1016
 
883
1017
  # symmetric axes
884
1018
  # Absolute error tolerance must be looser to directly compare the results
885
1019
  # from both algorithms, because of numerical loss of precision for the
886
1020
  # method _as_euler_from_matrix near a zero angle value
887
1021
 
888
- angles = np.array([
1022
+ angles = xp.asarray([
889
1023
  [15, 0, 60],
890
1024
  [35, 0, 75],
891
1025
  [60, 180, 35],
@@ -901,45 +1035,49 @@ def test_as_euler_degenerate_compare_algorithms(seq_tuple, intrinsic):
901
1035
  seq = seq.upper()
902
1036
 
903
1037
  rot = Rotation.from_euler(seq, angles, degrees=True)
904
- with pytest.warns(UserWarning, match="Gimbal lock"):
1038
+ with eager_warns(rot, UserWarning, match="Gimbal lock"):
905
1039
  estimates_matrix = rot._as_euler_from_matrix(seq, degrees=True)
906
- with pytest.warns(UserWarning, match="Gimbal lock"):
1040
+ with eager_warns(rot, UserWarning, match="Gimbal lock"):
907
1041
  estimates_quat = rot.as_euler(seq, degrees=True)
908
- assert_allclose(
1042
+ xp_assert_close(
909
1043
  estimates_matrix[:, [0, 2]], estimates_quat[:, [0, 2]], atol=0, rtol=1e-12
910
1044
  )
911
1045
 
912
- assert_allclose(
1046
+ xp_assert_close(
913
1047
  estimates_matrix[~idx, 1], estimates_quat[~idx, 1], atol=0, rtol=1e-7
914
1048
  )
915
1049
 
916
- assert_allclose(
1050
+ xp_assert_close(
917
1051
  estimates_matrix[idx, 1], estimates_quat[idx, 1], atol=1e-6
918
1052
  ) # problematic, angles[1] = 0
919
1053
 
920
1054
 
921
- def test_inv():
1055
+ def test_inv(xp):
1056
+ atol = 1e-12
922
1057
  rnd = np.random.RandomState(0)
923
1058
  n = 10
924
1059
  # preserve use of old random_state during SPEC 7 transition
925
1060
  p = Rotation.random(num=n, random_state=rnd)
1061
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
926
1062
  q = p.inv()
927
1063
 
928
1064
  p_mat = p.as_matrix()
929
1065
  q_mat = q.as_matrix()
930
- result1 = np.einsum('...ij,...jk->...ik', p_mat, q_mat)
931
- result2 = np.einsum('...ij,...jk->...ik', q_mat, p_mat)
1066
+ result1 = xp.asarray(np.einsum("...ij,...jk->...ik", p_mat, q_mat))
1067
+ result2 = xp.asarray(np.einsum("...ij,...jk->...ik", q_mat, p_mat))
932
1068
 
933
- eye3d = np.empty((n, 3, 3))
934
- eye3d[:] = np.eye(3)
1069
+ eye3d = xp.empty((n, 3, 3))
1070
+ eye3d = xpx.at(eye3d)[..., :3, :3].set(xp.eye(3))
935
1071
 
936
- assert_array_almost_equal(result1, eye3d)
937
- assert_array_almost_equal(result2, eye3d)
1072
+ xp_assert_close(result1, eye3d, atol=atol)
1073
+ xp_assert_close(result2, eye3d, atol=atol)
938
1074
 
939
1075
 
940
- def test_inv_single_rotation():
1076
+ def test_inv_single_rotation(xp):
1077
+ atol = 1e-12
941
1078
  rng = np.random.default_rng(146972845698875399755764481408308808739)
942
1079
  p = Rotation.random(rng=rng)
1080
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
943
1081
  q = p.inv()
944
1082
 
945
1083
  p_mat = p.as_matrix()
@@ -947,93 +1085,105 @@ def test_inv_single_rotation():
947
1085
  res1 = np.dot(p_mat, q_mat)
948
1086
  res2 = np.dot(q_mat, p_mat)
949
1087
 
950
- eye = np.eye(3)
1088
+ eye = xp.eye(3)
951
1089
 
952
- assert_array_almost_equal(res1, eye)
953
- assert_array_almost_equal(res2, eye)
1090
+ xp_assert_close(res1, eye, atol=atol)
1091
+ xp_assert_close(res2, eye, atol=atol)
954
1092
 
955
1093
  x = Rotation.random(num=1, rng=rng)
1094
+ x = Rotation.from_quat(xp.asarray(x.as_quat()))
956
1095
  y = x.inv()
957
1096
 
958
1097
  x_matrix = x.as_matrix()
959
1098
  y_matrix = y.as_matrix()
960
- result1 = np.einsum('...ij,...jk->...ik', x_matrix, y_matrix)
961
- result2 = np.einsum('...ij,...jk->...ik', y_matrix, x_matrix)
1099
+ result1 = xp.linalg.matmul(x_matrix, y_matrix)
1100
+ result2 = xp.linalg.matmul(y_matrix, x_matrix)
962
1101
 
963
- eye3d = np.empty((1, 3, 3))
964
- eye3d[:] = np.eye(3)
1102
+ eye3d = xp.empty((1, 3, 3))
1103
+ eye3d = xpx.at(eye3d)[..., :3, :3].set(xp.eye(3))
965
1104
 
966
- assert_array_almost_equal(result1, eye3d)
967
- assert_array_almost_equal(result2, eye3d)
1105
+ xp_assert_close(result1, eye3d, atol=atol)
1106
+ xp_assert_close(result2, eye3d, atol=atol)
968
1107
 
969
1108
 
970
- def test_identity_magnitude():
1109
+ def test_identity_magnitude(xp):
971
1110
  n = 10
972
- assert_allclose(Rotation.identity(n).magnitude(), 0)
973
- assert_allclose(Rotation.identity(n).inv().magnitude(), 0)
1111
+ r = Rotation.identity(n)
1112
+ r = Rotation.from_quat(xp.asarray(r.as_quat()))
1113
+ expected = xp.zeros(n)
1114
+ xp_assert_close(r.magnitude(), expected)
1115
+ xp_assert_close(r.inv().magnitude(), expected)
974
1116
 
975
1117
 
976
- def test_single_identity_magnitude():
977
- assert Rotation.identity().magnitude() == 0
978
- assert Rotation.identity().inv().magnitude() == 0
1118
+ def test_single_identity_magnitude(xp):
1119
+ r = Rotation.from_quat(xp.asarray(Rotation.identity().as_quat()))
1120
+ assert r.magnitude() == 0
1121
+ assert r.inv().magnitude() == 0
979
1122
 
980
1123
 
981
- def test_identity_invariance():
1124
+ def test_identity_invariance(xp):
1125
+ atol = 1e-12
982
1126
  n = 10
983
1127
  p = Rotation.random(n, rng=0)
984
-
985
- result = p * Rotation.identity(n)
986
- assert_array_almost_equal(p.as_quat(), result.as_quat())
1128
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
1129
+ q = Rotation.from_quat(xp.asarray(Rotation.identity(n).as_quat()))
1130
+ result = p * q
1131
+ xp_assert_close(p.as_quat(), result.as_quat())
987
1132
 
988
1133
  result = result * p.inv()
989
- assert_array_almost_equal(result.magnitude(), np.zeros(n))
1134
+ xp_assert_close(result.magnitude(), xp.zeros(n), atol=atol)
990
1135
 
991
1136
 
992
- def test_single_identity_invariance():
1137
+ def test_single_identity_invariance(xp):
1138
+ atol = 1e-12
993
1139
  n = 10
994
1140
  p = Rotation.random(n, rng=0)
1141
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
995
1142
 
996
- result = p * Rotation.identity()
997
- assert_array_almost_equal(p.as_quat(), result.as_quat())
1143
+ q = Rotation.from_quat(xp.asarray(Rotation.identity().as_quat()))
1144
+ result = p * q
1145
+ xp_assert_close(p.as_quat(), result.as_quat())
998
1146
 
999
1147
  result = result * p.inv()
1000
- assert_array_almost_equal(result.magnitude(), np.zeros(n))
1148
+ xp_assert_close(result.magnitude(), xp.zeros(n), atol=atol)
1001
1149
 
1002
1150
 
1003
- def test_magnitude():
1004
- r = Rotation.from_quat(np.eye(4))
1151
+ def test_magnitude(xp):
1152
+ r = Rotation.from_quat(xp.eye(4))
1005
1153
  result = r.magnitude()
1006
- assert_array_almost_equal(result, [np.pi, np.pi, np.pi, 0])
1154
+ xp_assert_close(result, xp.asarray([xp.pi, xp.pi, xp.pi, 0]))
1007
1155
 
1008
- r = Rotation.from_quat(-np.eye(4))
1156
+ r = Rotation.from_quat(-xp.eye(4))
1009
1157
  result = r.magnitude()
1010
- assert_array_almost_equal(result, [np.pi, np.pi, np.pi, 0])
1158
+ xp_assert_close(result, xp.asarray([xp.pi, xp.pi, xp.pi, 0]))
1011
1159
 
1012
1160
 
1013
- def test_magnitude_single_rotation():
1014
- r = Rotation.from_quat(np.eye(4))
1161
+ def test_magnitude_single_rotation(xp):
1162
+ r = Rotation.from_quat(xp.eye(4))
1015
1163
  result1 = r[0].magnitude()
1016
- assert_allclose(result1, np.pi)
1164
+ xp_assert_close(result1, xp.pi)
1017
1165
 
1018
1166
  result2 = r[3].magnitude()
1019
- assert_allclose(result2, 0)
1167
+ xp_assert_close(result2, 0.0)
1020
1168
 
1021
1169
 
1022
- def test_approx_equal():
1170
+ def test_approx_equal(xp):
1023
1171
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1024
1172
  p = Rotation.random(10, rng=rng)
1025
1173
  q = Rotation.random(10, rng=rng)
1174
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
1175
+ q = Rotation.from_quat(xp.asarray(q.as_quat()))
1026
1176
  r = p * q.inv()
1027
1177
  r_mag = r.magnitude()
1028
- atol = np.median(r_mag) # ensure we get mix of Trues and Falses
1029
- assert_equal(p.approx_equal(q, atol), (r_mag < atol))
1178
+ atol = xp.asarray(np.median(r_mag)) # ensure we get mix of Trues and Falses
1179
+ xp_assert_equal(p.approx_equal(q, atol), (r_mag < atol))
1030
1180
 
1031
1181
 
1032
1182
  @pytest.mark.thread_unsafe
1033
- def test_approx_equal_single_rotation():
1183
+ def test_approx_equal_single_rotation(xp):
1034
1184
  # also tests passing single argument to approx_equal
1035
- p = Rotation.from_rotvec([0, 0, 1e-9]) # less than default atol of 1e-8
1036
- q = Rotation.from_quat(np.eye(4))
1185
+ p = Rotation.from_rotvec(xp.asarray([0, 0, 1e-9])) # less than default atol of 1e-8
1186
+ q = Rotation.from_quat(xp.eye(4))
1037
1187
  assert p.approx_equal(q[3])
1038
1188
  assert not p.approx_equal(q[0])
1039
1189
 
@@ -1044,40 +1194,47 @@ def test_approx_equal_single_rotation():
1044
1194
  assert p.approx_equal(q[3], degrees=True)
1045
1195
 
1046
1196
 
1047
- def test_mean():
1197
+ def test_mean(xp):
1198
+ axes = xp.concat((-xp.eye(3), xp.eye(3)))
1048
1199
  axes = np.concatenate((-np.eye(3), np.eye(3)))
1049
- thetas = np.linspace(0, np.pi / 2, 100)
1200
+ thetas = xp.linspace(0, xp.pi / 2, 100)
1050
1201
  for t in thetas:
1051
1202
  r = Rotation.from_rotvec(t * axes)
1052
- assert_allclose(r.mean().magnitude(), 0, atol=1E-10)
1203
+ xp_assert_close(r.mean().magnitude(), 0.0, atol=1e-10)
1053
1204
 
1054
1205
 
1055
- def test_weighted_mean():
1206
+ def test_weighted_mean(xp):
1056
1207
  # test that doubling a weight is equivalent to including a rotation twice.
1057
- axes = np.array([[0, 0, 0], [1, 0, 0], [1, 0, 0]])
1058
- thetas = np.linspace(0, np.pi / 2, 100)
1208
+ axes = xp.asarray([[0.0, 0, 0], [1, 0, 0], [1, 0, 0]])
1209
+ thetas = xp.linspace(0, xp.pi / 2, 100)
1059
1210
  for t in thetas:
1060
- rw = Rotation.from_rotvec(t * axes[:2])
1211
+ rw = Rotation.from_rotvec(t * axes[:2, ...])
1061
1212
  mw = rw.mean(weights=[1, 2])
1062
1213
 
1063
1214
  r = Rotation.from_rotvec(t * axes)
1064
1215
  m = r.mean()
1065
- assert_allclose((m * mw.inv()).magnitude(), 0, atol=1E-10)
1216
+ xp_assert_close((m * mw.inv()).magnitude(), 0.0, atol=1e-10)
1066
1217
 
1067
1218
 
1068
- def test_mean_invalid_weights():
1069
- with pytest.raises(ValueError, match="non-negative"):
1070
- r = Rotation.from_quat(np.eye(4))
1071
- r.mean(weights=-np.ones(4))
1219
+ def test_mean_invalid_weights(xp):
1220
+ r = Rotation.from_quat(xp.eye(4))
1221
+ if is_lazy_array(r.as_quat()):
1222
+ m = r.mean(weights=-xp.ones(4))
1223
+ assert all(xp.isnan(m._quat))
1224
+ else:
1225
+ with pytest.raises(ValueError, match="non-negative"):
1226
+ r.mean(weights=-xp.ones(4))
1072
1227
 
1073
1228
 
1074
- def test_reduction_no_indices():
1075
- result = Rotation.identity().reduce(return_indices=False)
1229
+ def test_reduction_no_indices(xp):
1230
+ r = Rotation.from_quat(xp.asarray([0.0, 0.0, 0.0, 1.0]))
1231
+ result = r.reduce(return_indices=False)
1076
1232
  assert isinstance(result, Rotation)
1077
1233
 
1078
1234
 
1079
- def test_reduction_none_indices():
1080
- result = Rotation.identity().reduce(return_indices=True)
1235
+ def test_reduction_none_indices(xp):
1236
+ r = Rotation.from_quat(xp.asarray([0.0, 0.0, 0.0, 1.0]))
1237
+ result = r.reduce(return_indices=True)
1081
1238
  assert type(result) is tuple
1082
1239
  assert len(result) == 3
1083
1240
 
@@ -1086,11 +1243,12 @@ def test_reduction_none_indices():
1086
1243
  assert right_best is None
1087
1244
 
1088
1245
 
1089
- def test_reduction_scalar_calculation():
1246
+ def test_reduction_scalar_calculation(xp):
1247
+ atol = 1e-12
1090
1248
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1091
- l = Rotation.random(5, rng=rng)
1092
- r = Rotation.random(10, rng=rng)
1093
- p = Rotation.random(7, rng=rng)
1249
+ l = Rotation.from_quat(xp.asarray(Rotation.random(5, rng=rng).as_quat()))
1250
+ r = Rotation.from_quat(xp.asarray(Rotation.random(10, rng=rng).as_quat()))
1251
+ p = Rotation.from_quat(xp.asarray(Rotation.random(7, rng=rng).as_quat()))
1094
1252
  reduced, left_best, right_best = p.reduce(l, r, return_indices=True)
1095
1253
 
1096
1254
  # Loop implementation of the vectorized calculation in Rotation.reduce
@@ -1102,66 +1260,66 @@ def test_reduction_scalar_calculation():
1102
1260
  scalars = np.reshape(np.moveaxis(scalars, 1, 0), (scalars.shape[1], -1))
1103
1261
 
1104
1262
  max_ind = np.argmax(np.reshape(scalars, (len(p), -1)), axis=1)
1105
- left_best_check = max_ind // len(r)
1106
- right_best_check = max_ind % len(r)
1107
- assert (left_best == left_best_check).all()
1108
- assert (right_best == right_best_check).all()
1263
+ left_best_check = xp.asarray(max_ind // len(r))
1264
+ right_best_check = xp.asarray(max_ind % len(r))
1265
+ assert xp.all(left_best == left_best_check)
1266
+ assert xp.all(right_best == right_best_check)
1109
1267
 
1110
1268
  reduced_check = l[left_best_check] * p * r[right_best_check]
1111
1269
  mag = (reduced.inv() * reduced_check).magnitude()
1112
- assert_array_almost_equal(mag, np.zeros(len(p)))
1270
+ xp_assert_close(mag, xp.zeros(len(p)), atol=atol)
1113
1271
 
1114
1272
 
1115
- def test_apply_single_rotation_single_point():
1116
- mat = np.array([
1273
+ def test_apply_single_rotation_single_point(xp):
1274
+ mat = xp.asarray([
1117
1275
  [0, -1, 0],
1118
1276
  [1, 0, 0],
1119
1277
  [0, 0, 1]
1120
1278
  ])
1121
1279
  r_1d = Rotation.from_matrix(mat)
1122
- r_2d = Rotation.from_matrix(np.expand_dims(mat, axis=0))
1280
+ r_2d = Rotation.from_matrix(xp.expand_dims(mat, axis=0))
1123
1281
 
1124
- v_1d = np.array([1, 2, 3])
1125
- v_2d = np.expand_dims(v_1d, axis=0)
1126
- v1d_rotated = np.array([-2, 1, 3])
1127
- v2d_rotated = np.expand_dims(v1d_rotated, axis=0)
1282
+ v_1d = xp.asarray([1.0, 2, 3])
1283
+ v_2d = xp.expand_dims(v_1d, axis=0)
1284
+ v1d_rotated = xp.asarray([-2.0, 1, 3])
1285
+ v2d_rotated = xp.expand_dims(v1d_rotated, axis=0)
1128
1286
 
1129
- assert_allclose(r_1d.apply(v_1d), v1d_rotated)
1130
- assert_allclose(r_1d.apply(v_2d), v2d_rotated)
1131
- assert_allclose(r_2d.apply(v_1d), v2d_rotated)
1132
- assert_allclose(r_2d.apply(v_2d), v2d_rotated)
1287
+ xp_assert_close(r_1d.apply(v_1d), v1d_rotated)
1288
+ xp_assert_close(r_1d.apply(v_2d), v2d_rotated)
1289
+ xp_assert_close(r_2d.apply(v_1d), v2d_rotated)
1290
+ xp_assert_close(r_2d.apply(v_2d), v2d_rotated)
1133
1291
 
1134
- v1d_inverse = np.array([2, -1, 3])
1135
- v2d_inverse = np.expand_dims(v1d_inverse, axis=0)
1292
+ v1d_inverse = xp.asarray([2.0, -1, 3])
1293
+ v2d_inverse = xp.expand_dims(v1d_inverse, axis=0)
1136
1294
 
1137
- assert_allclose(r_1d.apply(v_1d, inverse=True), v1d_inverse)
1138
- assert_allclose(r_1d.apply(v_2d, inverse=True), v2d_inverse)
1139
- assert_allclose(r_2d.apply(v_1d, inverse=True), v2d_inverse)
1140
- assert_allclose(r_2d.apply(v_2d, inverse=True), v2d_inverse)
1295
+ xp_assert_close(r_1d.apply(v_1d, inverse=True), v1d_inverse)
1296
+ xp_assert_close(r_1d.apply(v_2d, inverse=True), v2d_inverse)
1297
+ xp_assert_close(r_2d.apply(v_1d, inverse=True), v2d_inverse)
1298
+ xp_assert_close(r_2d.apply(v_2d, inverse=True), v2d_inverse)
1141
1299
 
1142
1300
 
1143
- def test_apply_single_rotation_multiple_points():
1144
- mat = np.array([
1301
+ def test_apply_single_rotation_multiple_points(xp):
1302
+ mat = xp.asarray([
1145
1303
  [0, -1, 0],
1146
1304
  [1, 0, 0],
1147
1305
  [0, 0, 1]
1148
1306
  ])
1149
1307
  r1 = Rotation.from_matrix(mat)
1150
- r2 = Rotation.from_matrix(np.expand_dims(mat, axis=0))
1308
+ r2 = Rotation.from_matrix(xp.expand_dims(mat, axis=0))
1151
1309
 
1152
- v = np.array([[1, 2, 3], [4, 5, 6]])
1153
- v_rotated = np.array([[-2, 1, 3], [-5, 4, 6]])
1310
+ v = xp.asarray([[1, 2, 3], [4, 5, 6]])
1311
+ v_rotated = xp.asarray([[-2.0, 1, 3], [-5, 4, 6]])
1154
1312
 
1155
- assert_allclose(r1.apply(v), v_rotated)
1156
- assert_allclose(r2.apply(v), v_rotated)
1313
+ xp_assert_close(r1.apply(v), v_rotated)
1314
+ xp_assert_close(r2.apply(v), v_rotated)
1157
1315
 
1158
- v_inverse = np.array([[2, -1, 3], [5, -4, 6]])
1316
+ v_inverse = xp.asarray([[2.0, -1, 3], [5, -4, 6]])
1159
1317
 
1160
- assert_allclose(r1.apply(v, inverse=True), v_inverse)
1161
- assert_allclose(r2.apply(v, inverse=True), v_inverse)
1318
+ xp_assert_close(r1.apply(v, inverse=True), v_inverse)
1319
+ xp_assert_close(r2.apply(v, inverse=True), v_inverse)
1162
1320
 
1163
1321
 
1164
- def test_apply_multiple_rotations_single_point():
1322
+ def test_apply_multiple_rotations_single_point(xp):
1165
1323
  mat = np.empty((2, 3, 3))
1166
1324
  mat[0] = np.array([
1167
1325
  [0, -1, 0],
@@ -1173,23 +1331,24 @@ def test_apply_multiple_rotations_single_point():
1173
1331
  [0, 0, -1],
1174
1332
  [0, 1, 0]
1175
1333
  ])
1334
+ mat = xp.asarray(mat)
1176
1335
  r = Rotation.from_matrix(mat)
1177
1336
 
1178
- v1 = np.array([1, 2, 3])
1179
- v2 = np.expand_dims(v1, axis=0)
1337
+ v1 = xp.asarray([1, 2, 3])
1338
+ v2 = xp.expand_dims(v1, axis=0)
1180
1339
 
1181
- v_rotated = np.array([[-2, 1, 3], [1, -3, 2]])
1340
+ v_rotated = xp.asarray([[-2.0, 1, 3], [1, -3, 2]])
1182
1341
 
1183
- assert_allclose(r.apply(v1), v_rotated)
1184
- assert_allclose(r.apply(v2), v_rotated)
1342
+ xp_assert_close(r.apply(v1), v_rotated)
1343
+ xp_assert_close(r.apply(v2), v_rotated)
1185
1344
 
1186
- v_inverse = np.array([[2, -1, 3], [1, 3, -2]])
1345
+ v_inverse = xp.asarray([[2.0, -1, 3], [1, 3, -2]])
1187
1346
 
1188
- assert_allclose(r.apply(v1, inverse=True), v_inverse)
1189
- assert_allclose(r.apply(v2, inverse=True), v_inverse)
1347
+ xp_assert_close(r.apply(v1, inverse=True), v_inverse)
1348
+ xp_assert_close(r.apply(v2, inverse=True), v_inverse)
1190
1349
 
1191
1350
 
1192
- def test_apply_multiple_rotations_multiple_points():
1351
+ def test_apply_multiple_rotations_multiple_points(xp):
1193
1352
  mat = np.empty((2, 3, 3))
1194
1353
  mat[0] = np.array([
1195
1354
  [0, -1, 0],
@@ -1201,17 +1360,52 @@ def test_apply_multiple_rotations_multiple_points():
1201
1360
  [0, 0, -1],
1202
1361
  [0, 1, 0]
1203
1362
  ])
1363
+ mat = xp.asarray(mat)
1204
1364
  r = Rotation.from_matrix(mat)
1205
1365
 
1206
- v = np.array([[1, 2, 3], [4, 5, 6]])
1207
- v_rotated = np.array([[-2, 1, 3], [4, -6, 5]])
1208
- assert_allclose(r.apply(v), v_rotated)
1209
-
1210
- v_inverse = np.array([[2, -1, 3], [4, 6, -5]])
1211
- assert_allclose(r.apply(v, inverse=True), v_inverse)
1212
-
1213
-
1214
- def test_getitem():
1366
+ v = xp.asarray([[1, 2, 3], [4, 5, 6]])
1367
+ v_rotated = xp.asarray([[-2.0, 1, 3], [4, -6, 5]])
1368
+ xp_assert_close(r.apply(v), v_rotated)
1369
+
1370
+ v_inverse = xp.asarray([[2.0, -1, 3], [4, 6, -5]])
1371
+ xp_assert_close(r.apply(v, inverse=True), v_inverse)
1372
+
1373
+
1374
+ def test_apply_shapes(xp):
1375
+ vector0 = xp.asarray([1.0, 2.0, 3.0])
1376
+ vector1 = xp.asarray([vector0])
1377
+ vector2 = xp.asarray([vector0, vector0])
1378
+ matrix0 = xp.eye(3)
1379
+ matrix1 = xp.asarray([matrix0])
1380
+ matrix2 = xp.asarray([matrix0, matrix0])
1381
+
1382
+ for m, v in product([matrix0, matrix1, matrix2], [vector0, vector1, vector2]):
1383
+ r = Rotation.from_matrix(m)
1384
+ shape = v.shape
1385
+ if not r.single and (v.shape == (3,) or v.shape == (1, 3)):
1386
+ shape = (len(r), 3)
1387
+ x = r.apply(v)
1388
+ assert x.shape == shape
1389
+ x = r.apply(v, inverse=True)
1390
+ assert x.shape == shape
1391
+
1392
+
1393
+ def test_apply_array_like():
1394
+ rng = np.random.default_rng(123)
1395
+ # Single vector
1396
+ r = Rotation.random(rng=rng)
1397
+ t = rng.uniform(-100, 100, size=(3,))
1398
+ v = r.apply(t.tolist())
1399
+ v_expected = r.apply(t)
1400
+ xp_assert_close(v, v_expected, atol=1e-12)
1401
+ # Multiple vectors
1402
+ t = rng.uniform(-100, 100, size=(2, 3))
1403
+ v = r.apply(t.tolist())
1404
+ v_expected = r.apply(t)
1405
+ xp_assert_close(v, v_expected, atol=1e-12)
1406
+
1407
+
1408
+ def test_getitem(xp):
1215
1409
  mat = np.empty((2, 3, 3))
1216
1410
  mat[0] = np.array([
1217
1411
  [0, -1, 0],
@@ -1223,47 +1417,60 @@ def test_getitem():
1223
1417
  [0, 0, -1],
1224
1418
  [0, 1, 0]
1225
1419
  ])
1420
+ mat = xp.asarray(mat)
1226
1421
  r = Rotation.from_matrix(mat)
1227
1422
 
1228
- assert_allclose(r[0].as_matrix(), mat[0], atol=1e-15)
1229
- assert_allclose(r[1].as_matrix(), mat[1], atol=1e-15)
1230
- assert_allclose(r[:-1].as_matrix(), np.expand_dims(mat[0], axis=0), atol=1e-15)
1423
+ xp_assert_close(r[0].as_matrix(), mat[0], atol=1e-15)
1424
+ xp_assert_close(r[1].as_matrix(), mat[1, ...], atol=1e-15)
1425
+ xp_assert_close(r[:-1].as_matrix(), xp.expand_dims(mat[0, ...], axis=0), atol=1e-15)
1231
1426
 
1232
1427
 
1233
- def test_getitem_single():
1428
+ def test_getitem_single(xp):
1234
1429
  with pytest.raises(TypeError, match='not subscriptable'):
1235
- Rotation.identity()[0]
1430
+ Rotation.from_quat(xp.asarray([0, 0, 0, 1]))[0]
1431
+
1432
+
1433
+ def test_getitem_array_like():
1434
+ mat = np.array([[[0.0, -1, 0],
1435
+ [1, 0, 0],
1436
+ [0, 0, 1]],
1437
+ [[1, 0, 0],
1438
+ [0, 0, -1],
1439
+ [0, 1, 0]]])
1440
+ r = Rotation.from_matrix(mat)
1441
+ xp_assert_close(r[[0]].as_matrix(), mat[[0]], atol=1e-15)
1442
+ xp_assert_close(r[[0, 1]].as_matrix(), mat[[0, 1]], atol=1e-15)
1236
1443
 
1237
1444
 
1238
- def test_setitem_single():
1239
- r = Rotation.identity()
1445
+ def test_setitem_single(xp):
1446
+ r = Rotation.from_quat(xp.asarray([0, 0, 0, 1]))
1240
1447
  with pytest.raises(TypeError, match='not subscriptable'):
1241
- r[0] = Rotation.identity()
1448
+ r[0] = Rotation.from_quat(xp.asarray([0, 0, 0, 1]))
1242
1449
 
1243
1450
 
1244
- def test_setitem_slice():
1451
+ def test_setitem_slice(xp):
1245
1452
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1246
- r1 = Rotation.random(10, rng=rng)
1247
- r2 = Rotation.random(5, rng=rng)
1453
+ r1 = Rotation.from_quat(xp.asarray(Rotation.random(10, rng=rng).as_quat()))
1454
+ r2 = Rotation.from_quat(xp.asarray(Rotation.random(5, rng=rng).as_quat()))
1248
1455
  r1[1:6] = r2
1249
- assert_equal(r1[1:6].as_quat(), r2.as_quat())
1456
+ xp_assert_equal(r1[1:6].as_quat(), r2.as_quat())
1250
1457
 
1251
1458
 
1252
- def test_setitem_integer():
1459
+ def test_setitem_integer(xp):
1253
1460
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1254
- r1 = Rotation.random(10, rng=rng)
1255
- r2 = Rotation.random(rng=rng)
1461
+ r1 = Rotation.from_quat(xp.asarray(Rotation.random(10, rng=rng).as_quat()))
1462
+ r2 = Rotation.from_quat(xp.asarray(Rotation.random(rng=rng).as_quat()))
1256
1463
  r1[1] = r2
1257
- assert_equal(r1[1].as_quat(), r2.as_quat())
1464
+ xp_assert_equal(r1[1].as_quat(), r2.as_quat())
1258
1465
 
1259
1466
 
1260
- def test_setitem_wrong_type():
1261
- r = Rotation.random(10, rng=0)
1467
+ def test_setitem_wrong_type(xp):
1468
+ r = Rotation.from_quat(xp.asarray(Rotation.random(10, rng=0).as_quat()))
1262
1469
  with pytest.raises(TypeError, match='Rotation object'):
1263
1470
  r[0] = 1
1264
1471
 
1265
1472
 
1266
- def test_n_rotations():
1473
+ def test_n_rotations(xp):
1267
1474
  mat = np.empty((2, 3, 3))
1268
1475
  mat[0] = np.array([
1269
1476
  [0, -1, 0],
@@ -1275,6 +1482,7 @@ def test_n_rotations():
1275
1482
  [0, 0, -1],
1276
1483
  [0, 1, 0]
1277
1484
  ])
1485
+ mat = xp.asarray(mat)
1278
1486
  r = Rotation.from_matrix(mat)
1279
1487
 
1280
1488
  assert_equal(len(r), 2)
@@ -1282,6 +1490,7 @@ def test_n_rotations():
1282
1490
 
1283
1491
 
1284
1492
  def test_random_rotation_shape():
1493
+ # No xp testing since random rotations are always using NumPy
1285
1494
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1286
1495
  assert_equal(Rotation.random(rng=rng).as_quat().shape, (4,))
1287
1496
  assert_equal(Rotation.random(None, rng=rng).as_quat().shape, (4,))
@@ -1290,80 +1499,80 @@ def test_random_rotation_shape():
1290
1499
  assert_equal(Rotation.random(5, rng=rng).as_quat().shape, (5, 4))
1291
1500
 
1292
1501
 
1293
- def test_align_vectors_no_rotation():
1294
- x = np.array([[1, 2, 3], [4, 5, 6]])
1295
- y = x.copy()
1502
+ def test_align_vectors_no_rotation(xp):
1503
+ x = xp.asarray([[1, 2, 3], [4, 5, 6]])
1504
+ y = xp.asarray(x, copy=True)
1296
1505
 
1297
1506
  r, rssd = Rotation.align_vectors(x, y)
1298
- assert_array_almost_equal(r.as_matrix(), np.eye(3))
1299
- assert_allclose(rssd, 0, atol=1e-6)
1507
+ xp_assert_close(r.as_matrix(), xp.eye(3), atol=1e-12)
1508
+ xp_assert_close(rssd, xp.asarray(0.0)[()], check_shape=False, atol=1e-6)
1300
1509
 
1301
1510
 
1302
- def test_align_vectors_no_noise():
1511
+ def test_align_vectors_no_noise(xp):
1303
1512
  rng = np.random.default_rng(14697284569885399755764481408308808739)
1304
- c = Rotation.random(rng=rng)
1305
- b = rng.normal(size=(5, 3))
1513
+ c = Rotation.from_quat(xp.asarray(Rotation.random(rng=rng).as_quat()))
1514
+ b = xp.asarray(rng.normal(size=(5, 3)))
1306
1515
  a = c.apply(b)
1307
1516
 
1308
1517
  est, rssd = Rotation.align_vectors(a, b)
1309
- assert_allclose(c.as_quat(), est.as_quat())
1310
- assert_allclose(rssd, 0, atol=1e-7)
1518
+ xp_assert_close(c.as_quat(), est.as_quat())
1519
+ xp_assert_close(rssd, xp.asarray(0.0)[()], check_shape=False, atol=1e-7)
1311
1520
 
1312
1521
 
1313
- def test_align_vectors_improper_rotation():
1522
+ def test_align_vectors_improper_rotation(xp):
1314
1523
  # Tests correct logic for issue #10444
1315
- x = np.array([[0.89299824, -0.44372674, 0.0752378],
1316
- [0.60221789, -0.47564102, -0.6411702]])
1317
- y = np.array([[0.02386536, -0.82176463, 0.5693271],
1318
- [-0.27654929, -0.95191427, -0.1318321]])
1524
+ x = xp.asarray([[0.89299824, -0.44372674, 0.0752378],
1525
+ [0.60221789, -0.47564102, -0.6411702]])
1526
+ y = xp.asarray([[0.02386536, -0.82176463, 0.5693271],
1527
+ [-0.27654929, -0.95191427, -0.1318321]])
1319
1528
 
1320
1529
  est, rssd = Rotation.align_vectors(x, y)
1321
- assert_allclose(x, est.apply(y), atol=1e-6)
1322
- assert_allclose(rssd, 0, atol=1e-7)
1530
+ xp_assert_close(x, est.apply(y), atol=1e-6)
1531
+ xp_assert_close(rssd, xp.asarray(0.0)[()], check_shape=False, atol=1e-7)
1323
1532
 
1324
1533
 
1325
- def test_align_vectors_rssd_sensitivity():
1326
- rssd_expected = 0.141421356237308
1327
- sens_expected = np.array([[0.2, 0. , 0.],
1328
- [0. , 1.5, 1.],
1329
- [0. , 1. , 1.]])
1534
+ def test_align_vectors_rssd_sensitivity(xp):
1535
+ rssd_expected = xp.asarray(0.141421356237308)[()]
1536
+ sens_expected = xp.asarray([[0.2, 0. , 0.],
1537
+ [0. , 1.5, 1.],
1538
+ [0. , 1. , 1.]])
1330
1539
  atol = 1e-6
1331
- a = [[0, 1, 0], [0, 1, 1], [0, 1, 1]]
1332
- b = [[1, 0, 0], [1, 1.1, 0], [1, 0.9, 0]]
1540
+ a = xp.asarray([[0, 1, 0], [0, 1, 1], [0, 1, 1]])
1541
+ b = xp.asarray([[1, 0, 0], [1, 1.1, 0], [1, 0.9, 0]])
1333
1542
  rot, rssd, sens = Rotation.align_vectors(a, b, return_sensitivity=True)
1334
- assert np.isclose(rssd, rssd_expected, atol=atol)
1335
- assert np.allclose(sens, sens_expected, atol=atol)
1543
+ xp_assert_close(rssd, rssd_expected, atol=atol)
1544
+ xp_assert_close(sens, sens_expected, atol=atol)
1336
1545
 
1337
1546
 
1338
- def test_align_vectors_scaled_weights():
1547
+ def test_align_vectors_scaled_weights(xp):
1339
1548
  n = 10
1340
- a = Rotation.random(n, rng=0).apply([1, 0, 0])
1341
- b = Rotation.random(n, rng=1).apply([1, 0, 0])
1549
+ a = xp.asarray(Rotation.random(n, rng=0).apply([1, 0, 0]))
1550
+ b = xp.asarray(Rotation.random(n, rng=1).apply([1, 0, 0]))
1342
1551
  scale = 2
1343
1552
 
1344
- est1, rssd1, cov1 = Rotation.align_vectors(a, b, np.ones(n), True)
1345
- est2, rssd2, cov2 = Rotation.align_vectors(a, b, scale * np.ones(n), True)
1553
+ est1, rssd1, cov1 = Rotation.align_vectors(a, b, xp.ones(n), True)
1554
+ est2, rssd2, cov2 = Rotation.align_vectors(a, b, scale * xp.ones(n), True)
1346
1555
 
1347
- assert_allclose(est1.as_matrix(), est2.as_matrix())
1348
- assert_allclose(np.sqrt(scale) * rssd1, rssd2, atol=1e-6)
1349
- assert_allclose(cov1, cov2)
1556
+ xp_assert_close(est1.as_matrix(), est2.as_matrix())
1557
+ xp_assert_close(math.sqrt(scale) * rssd1, rssd2, atol=1e-6)
1558
+ xp_assert_close(cov1, cov2)
1350
1559
 
1351
1560
 
1352
- def test_align_vectors_noise():
1561
+ def test_align_vectors_noise(xp):
1353
1562
  rng = np.random.default_rng(146972845698875399755764481408308808739)
1354
1563
  n_vectors = 100
1355
- rot = Rotation.random(rng=rng)
1356
- vectors = rng.normal(size=(n_vectors, 3))
1564
+ rot = rotation_to_xp(Rotation.random(rng=rng), xp)
1565
+ vectors = xp.asarray(rng.normal(size=(n_vectors, 3)))
1357
1566
  result = rot.apply(vectors)
1358
1567
 
1359
1568
  # The paper adds noise as independently distributed angular errors
1360
1569
  sigma = np.deg2rad(1)
1361
1570
  tolerance = 1.5 * sigma
1362
1571
  noise = Rotation.from_rotvec(
1363
- rng.normal(
1572
+ xp.asarray(rng.normal(
1364
1573
  size=(n_vectors, 3),
1365
1574
  scale=sigma
1366
- )
1575
+ ))
1367
1576
  )
1368
1577
 
1369
1578
  # Attitude errors must preserve norm. Hence apply individual random
@@ -1375,99 +1584,134 @@ def test_align_vectors_noise():
1375
1584
 
1376
1585
  # Use rotation compositions to find out closeness
1377
1586
  error_vector = (rot * est.inv()).as_rotvec()
1378
- assert_allclose(error_vector[0], 0, atol=tolerance)
1379
- assert_allclose(error_vector[1], 0, atol=tolerance)
1380
- assert_allclose(error_vector[2], 0, atol=tolerance)
1587
+ xp_assert_close(error_vector[0], xp.asarray(0.0)[()], atol=tolerance)
1588
+ xp_assert_close(error_vector[1], xp.asarray(0.0)[()], atol=tolerance)
1589
+ xp_assert_close(error_vector[2], xp.asarray(0.0)[()], atol=tolerance)
1381
1590
 
1382
1591
  # Check error bounds using covariance matrix
1383
1592
  cov *= sigma
1384
- assert_allclose(cov[0, 0], 0, atol=tolerance)
1385
- assert_allclose(cov[1, 1], 0, atol=tolerance)
1386
- assert_allclose(cov[2, 2], 0, atol=tolerance)
1593
+ xp_assert_close(cov[0, 0], xp.asarray(0.0)[()], atol=tolerance)
1594
+ xp_assert_close(cov[1, 1], xp.asarray(0.0)[()], atol=tolerance)
1595
+ xp_assert_close(cov[2, 2], xp.asarray(0.0)[()], atol=tolerance)
1387
1596
 
1388
- assert_allclose(rssd, np.sum((noisy_result - est.apply(vectors))**2)**0.5)
1597
+ rssd_check = xp.sum((noisy_result - est.apply(vectors)) ** 2) ** 0.5
1598
+ xp_assert_close(rssd, rssd_check, check_shape=False)
1389
1599
 
1390
1600
 
1391
- def test_align_vectors_invalid_input():
1601
+ def test_align_vectors_invalid_input(xp):
1392
1602
  with pytest.raises(ValueError, match="Expected input `a` to have shape"):
1393
- Rotation.align_vectors([1, 2, 3, 4], [1, 2, 3])
1603
+ a, b = xp.asarray([1, 2, 3, 4]), xp.asarray([1, 2, 3])
1604
+ Rotation.align_vectors(a, b)
1394
1605
 
1395
1606
  with pytest.raises(ValueError, match="Expected input `b` to have shape"):
1396
- Rotation.align_vectors([1, 2, 3], [1, 2, 3, 4])
1607
+ a, b = xp.asarray([1, 2, 3]), xp.asarray([1, 2, 3, 4])
1608
+ Rotation.align_vectors(a, b)
1397
1609
 
1398
1610
  with pytest.raises(ValueError, match="Expected inputs `a` and `b` "
1399
1611
  "to have same shapes"):
1400
- Rotation.align_vectors([[1, 2, 3],[4, 5, 6]], [[1, 2, 3]])
1612
+ a, b = xp.asarray([[1, 2, 3], [4, 5, 6]]), xp.asarray([[1, 2, 3]])
1613
+ Rotation.align_vectors(a, b)
1401
1614
 
1402
1615
  with pytest.raises(ValueError,
1403
1616
  match="Expected `weights` to be 1 dimensional"):
1404
- Rotation.align_vectors([[1, 2, 3]], [[1, 2, 3]], weights=[[1]])
1617
+ a, b = xp.asarray([[1, 2, 3]]), xp.asarray([[1, 2, 3]])
1618
+ weights = xp.asarray([[1]])
1619
+ Rotation.align_vectors(a, b, weights)
1405
1620
 
1406
1621
  with pytest.raises(ValueError,
1407
1622
  match="Expected `weights` to have number of values"):
1408
- Rotation.align_vectors([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1409
- weights=[1, 2, 3])
1410
-
1411
- with pytest.raises(ValueError,
1412
- match="`weights` may not contain negative values"):
1413
- Rotation.align_vectors([[1, 2, 3]], [[1, 2, 3]], weights=[-1])
1414
-
1415
- with pytest.raises(ValueError,
1416
- match="Only one infinite weight is allowed"):
1417
- Rotation.align_vectors([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1418
- weights=[np.inf, np.inf])
1419
-
1420
- with pytest.raises(ValueError,
1421
- match="Cannot align zero length primary vectors"):
1422
- Rotation.align_vectors([[0, 0, 0]], [[1, 2, 3]])
1423
-
1424
- with pytest.raises(ValueError,
1425
- match="Cannot return sensitivity matrix"):
1426
- Rotation.align_vectors([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1427
- return_sensitivity=True, weights=[np.inf, 1])
1428
-
1429
- with pytest.raises(ValueError,
1430
- match="Cannot return sensitivity matrix"):
1431
- Rotation.align_vectors([[1, 2, 3]], [[1, 2, 3]],
1432
- return_sensitivity=True)
1433
-
1434
-
1435
- def test_align_vectors_align_constrain():
1623
+ a, b = xp.asarray([[1, 2, 3], [4, 5, 6]]), xp.asarray([[1, 2, 3], [4, 5, 6]])
1624
+ weights = xp.asarray([1, 2, 3])
1625
+ Rotation.align_vectors(a, b, weights)
1626
+
1627
+ a, b = xp.asarray([[1, 2, 3]]), xp.asarray([[1, 2, 3]])
1628
+ weights = xp.asarray([-1])
1629
+ if is_lazy_array(weights):
1630
+ r, rssd = Rotation.align_vectors(a, b, weights)
1631
+ assert xp.all(xp.isnan(r.as_quat())), "Quaternion should be nan"
1632
+ assert xp.isnan(rssd), "RSSD should be nan"
1633
+ else:
1634
+ with pytest.raises(ValueError,
1635
+ match="`weights` may not contain negative values"):
1636
+ Rotation.align_vectors(a, b, weights)
1637
+
1638
+ a, b = xp.asarray([[1, 2, 3], [4, 5, 6]]), xp.asarray([[1, 2, 3], [4, 5, 6]])
1639
+ weights = xp.asarray([xp.inf, xp.inf])
1640
+ if is_lazy_array(weights):
1641
+ r, rssd = Rotation.align_vectors(a, b, weights)
1642
+ assert xp.all(xp.isnan(r.as_quat())), "Quaternion should be nan"
1643
+ assert xp.isnan(rssd), "RSSD should be nan"
1644
+ else:
1645
+ with pytest.raises(ValueError,
1646
+ match="Only one infinite weight is allowed"):
1647
+ Rotation.align_vectors(a, b, weights)
1648
+
1649
+ a, b = xp.asarray([[0, 0, 0]]), xp.asarray([[1, 2, 3]])
1650
+ if is_lazy_array(a):
1651
+ r, rssd = Rotation.align_vectors(a, b)
1652
+ assert xp.all(xp.isnan(r.as_quat())), "Quaternion should be nan"
1653
+ assert xp.isnan(rssd), "RSSD should be nan"
1654
+ else:
1655
+ with pytest.raises(ValueError,
1656
+ match="Cannot align zero length primary vectors"):
1657
+ Rotation.align_vectors(a, b)
1658
+
1659
+ a, b = xp.asarray([[1, 2, 3], [4, 5, 6]]), xp.asarray([[1, 2, 3], [4, 5, 6]])
1660
+ weights = xp.asarray([xp.inf, 1])
1661
+ if is_lazy_array(a):
1662
+ r, rssd, sens = Rotation.align_vectors(a, b, weights, return_sensitivity=True)
1663
+ assert xp.all(xp.isnan(sens)), "Sensitivity matrix should be nan"
1664
+ else:
1665
+ with pytest.raises(ValueError,
1666
+ match="Cannot return sensitivity matrix"):
1667
+ Rotation.align_vectors(a, b, weights, return_sensitivity=True)
1668
+
1669
+ a, b = xp.asarray([[1, 2, 3]]), xp.asarray([[1, 2, 3]])
1670
+ if is_lazy_array(a):
1671
+ r, rssd, sens = Rotation.align_vectors(a, b, return_sensitivity=True)
1672
+ assert xp.all(xp.isnan(sens)), "Sensitivity matrix should be nan"
1673
+ else:
1674
+ with pytest.raises(ValueError,
1675
+ match="Cannot return sensitivity matrix"):
1676
+ Rotation.align_vectors(a, b, return_sensitivity=True)
1677
+
1678
+
1679
+ def test_align_vectors_align_constrain(xp):
1436
1680
  # Align the primary +X B axis with the primary +Y A axis, and rotate about
1437
1681
  # it such that the +Y B axis (residual of the [1, 1, 0] secondary b vector)
1438
1682
  # is aligned with the +Z A axis (residual of the [0, 1, 1] secondary a
1439
1683
  # vector)
1440
1684
  atol = 1e-12
1441
- b = [[1, 0, 0], [1, 1, 0]]
1442
- a = [[0, 1, 0], [0, 1, 1]]
1443
- m_expected = np.array([[0, 0, 1],
1444
- [1, 0, 0],
1445
- [0, 1, 0]])
1446
- R, rssd = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1447
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1448
- assert_allclose(R.apply(b), a, atol=atol) # Pri and sec align exactly
1449
- assert np.isclose(rssd, 0, atol=atol)
1685
+ b = xp.asarray([[1, 0, 0], [1, 1, 0]])
1686
+ a = xp.asarray([[0.0, 1, 0], [0, 1, 1]])
1687
+ m_expected = xp.asarray([[0.0, 0, 1],
1688
+ [1, 0, 0],
1689
+ [0, 1, 0]])
1690
+ R, rssd = Rotation.align_vectors(a, b, weights=xp.asarray([xp.inf, 1]))
1691
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1692
+ xp_assert_close(R.apply(b), a, atol=atol) # Pri and sec align exactly
1693
+ assert xpx.isclose(rssd, 0.0, atol=atol, xp=xp)
1450
1694
 
1451
1695
  # Do the same but with an inexact secondary rotation
1452
- b = [[1, 0, 0], [1, 2, 0]]
1696
+ b = xp.asarray([[1, 0, 0], [1, 2, 0]])
1453
1697
  rssd_expected = 1.0
1454
- R, rssd = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1455
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1456
- assert_allclose(R.apply(b)[0], a[0], atol=atol) # Only pri aligns exactly
1457
- assert np.isclose(rssd, rssd_expected, atol=atol)
1458
- a_expected = [[0, 1, 0], [0, 1, 2]]
1459
- assert_allclose(R.apply(b), a_expected, atol=atol)
1698
+ R, rssd = Rotation.align_vectors(a, b, weights=xp.asarray([xp.inf, 1]))
1699
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1700
+ xp_assert_close(R.apply(b)[0, ...], a[0, ...], atol=atol) # Only pri aligns exactly
1701
+ assert xpx.isclose(rssd, rssd_expected, atol=atol, xp=xp)
1702
+ a_expected = xp.asarray([[0.0, 1, 0], [0, 1, 2]])
1703
+ xp_assert_close(R.apply(b), a_expected, atol=atol)
1460
1704
 
1461
1705
  # Check random vectors
1462
- b = [[1, 2, 3], [-2, 3, -1]]
1463
- a = [[-1, 3, 2], [1, -1, 2]]
1706
+ b = xp.asarray([[1, 2, 3], [-2, 3, -1]])
1707
+ a = xp.asarray([[-1.0, 3, 2], [1, -1, 2]])
1464
1708
  rssd_expected = 1.3101595297515016
1465
- R, rssd = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1466
- assert_allclose(R.apply(b)[0], a[0], atol=atol) # Only pri aligns exactly
1467
- assert np.isclose(rssd, rssd_expected, atol=atol)
1709
+ R, rssd = Rotation.align_vectors(a, b, weights=xp.asarray([xp.inf, 1]))
1710
+ xp_assert_close(R.apply(b)[0, ...], a[0, ...], atol=atol) # Only pri aligns exactly
1711
+ assert xpx.isclose(rssd, rssd_expected, atol=atol, xp=xp)
1468
1712
 
1469
1713
 
1470
- def test_align_vectors_near_inf():
1714
+ def test_align_vectors_near_inf(xp):
1471
1715
  # align_vectors should return near the same result for high weights as for
1472
1716
  # infinite weights. rssd will be different with floating point error on the
1473
1717
  # exactly aligned vector being multiplied by a large non-infinite weight
@@ -1478,58 +1722,60 @@ def test_align_vectors_near_inf():
1478
1722
 
1479
1723
  for i in range(n):
1480
1724
  # Get random pairs of 3-element vectors
1481
- a = [1*mats[0][i][0], 2*mats[1][i][0]]
1482
- b = [3*mats[2][i][0], 4*mats[3][i][0]]
1725
+ a = xp.asarray([1 * mats[0][i][0], 2 * mats[1][i][0]])
1726
+ b = xp.asarray([3 * mats[2][i][0], 4 * mats[3][i][0]])
1483
1727
 
1484
1728
  R, _ = Rotation.align_vectors(a, b, weights=[1e10, 1])
1485
- R2, _ = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1486
- assert_allclose(R.as_matrix(), R2.as_matrix(), atol=1e-4)
1729
+ R2, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 1])
1730
+ xp_assert_close(R.as_matrix(), R2.as_matrix(), atol=1e-4)
1487
1731
 
1488
1732
  for i in range(n):
1489
1733
  # Get random triplets of 3-element vectors
1490
- a = [1*mats[0][i][0], 2*mats[1][i][0], 3*mats[2][i][0]]
1491
- b = [4*mats[3][i][0], 5*mats[4][i][0], 6*mats[5][i][0]]
1734
+ a = xp.asarray([1*mats[0][i][0], 2*mats[1][i][0], 3*mats[2][i][0]])
1735
+ b = xp.asarray([4*mats[3][i][0], 5*mats[4][i][0], 6*mats[5][i][0]])
1492
1736
 
1493
1737
  R, _ = Rotation.align_vectors(a, b, weights=[1e10, 2, 1])
1494
- R2, _ = Rotation.align_vectors(a, b, weights=[np.inf, 2, 1])
1495
- assert_allclose(R.as_matrix(), R2.as_matrix(), atol=1e-4)
1738
+ R2, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 2, 1])
1739
+ xp_assert_close(R.as_matrix(), R2.as_matrix(), atol=1e-4)
1496
1740
 
1497
1741
 
1498
- def test_align_vectors_parallel():
1742
+ def test_align_vectors_parallel(xp):
1499
1743
  atol = 1e-12
1500
- a = [[1, 0, 0], [0, 1, 0]]
1501
- b = [[0, 1, 0], [0, 1, 0]]
1502
- m_expected = np.array([[0, 1, 0],
1503
- [-1, 0, 0],
1504
- [0, 0, 1]])
1505
- R, _ = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1506
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1507
- R, _ = Rotation.align_vectors(a[0], b[0])
1508
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1509
- assert_allclose(R.apply(b[0]), a[0], atol=atol)
1510
-
1511
- b = [[1, 0, 0], [1, 0, 0]]
1512
- m_expected = np.array([[1, 0, 0],
1513
- [0, 1, 0],
1514
- [0, 0, 1]])
1515
- R, _ = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1516
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1517
- R, _ = Rotation.align_vectors(a[0], b[0])
1518
- assert_allclose(R.as_matrix(), m_expected, atol=atol)
1519
- assert_allclose(R.apply(b[0]), a[0], atol=atol)
1520
-
1521
-
1522
- def test_align_vectors_antiparallel():
1744
+ a = xp.asarray([[1.0, 0, 0], [0, 1, 0]])
1745
+ b = xp.asarray([[0.0, 1, 0], [0, 1, 0]])
1746
+ m_expected = xp.asarray([[0.0, 1, 0],
1747
+ [-1, 0, 0],
1748
+ [0, 0, 1]])
1749
+ R, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 1])
1750
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1751
+ R, _ = Rotation.align_vectors(a[0, ...], b[0, ...])
1752
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1753
+ xp_assert_close(R.apply(b[0, ...]), a[0, ...], atol=atol)
1754
+
1755
+ b = xp.asarray([[1, 0, 0], [1, 0, 0]])
1756
+ m_expected = xp.asarray([[1.0, 0, 0],
1757
+ [0, 1, 0],
1758
+ [0, 0, 1]])
1759
+ R, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 1])
1760
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1761
+ R, _ = Rotation.align_vectors(a[0, ...], b[0, ...])
1762
+ xp_assert_close(R.as_matrix(), m_expected, atol=atol)
1763
+ xp_assert_close(R.apply(b[0, ...]), a[0, ...], atol=atol)
1764
+
1765
+
1766
+ def test_align_vectors_antiparallel(xp):
1523
1767
  # Test exact 180 deg rotation
1524
1768
  atol = 1e-12
1525
- as_to_test = np.array([[[1, 0, 0], [0, 1, 0]],
1769
+ as_to_test = np.array([[[1.0, 0, 0], [0, 1, 0]],
1526
1770
  [[0, 1, 0], [1, 0, 0]],
1527
1771
  [[0, 0, 1], [0, 1, 0]]])
1772
+
1528
1773
  bs_to_test = [[-a[0], a[1]] for a in as_to_test]
1529
1774
  for a, b in zip(as_to_test, bs_to_test):
1530
- R, _ = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1531
- assert_allclose(R.magnitude(), np.pi, atol=atol)
1532
- assert_allclose(R.apply(b[0]), a[0], atol=atol)
1775
+ a, b = xp.asarray(a), xp.asarray(b)
1776
+ R, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 1])
1777
+ xp_assert_close(R.magnitude(), xp.pi, atol=atol)
1778
+ xp_assert_close(R.apply(b[0, ...]), a[0, ...], atol=atol)
1533
1779
 
1534
1780
  # Test exact rotations near 180 deg
1535
1781
  Rs = Rotation.random(100, rng=0)
@@ -1538,218 +1784,303 @@ def test_align_vectors_antiparallel():
1538
1784
  b = [[-1, 0, 0], [0, 1, 0]]
1539
1785
  as_to_test = []
1540
1786
  for dR in dRs:
1541
- as_to_test.append([dR.apply(a[0]), a[1]])
1787
+ as_to_test.append(np.array([dR.apply(a[0]), a[1]]))
1542
1788
  for a in as_to_test:
1543
- R, _ = Rotation.align_vectors(a, b, weights=[np.inf, 1])
1789
+ a, b = xp.asarray(a), xp.asarray(b)
1790
+ R, _ = Rotation.align_vectors(a, b, weights=[xp.inf, 1])
1544
1791
  R2, _ = Rotation.align_vectors(a, b, weights=[1e10, 1])
1545
- assert_allclose(R.as_matrix(), R2.as_matrix(), atol=atol)
1792
+ xp_assert_close(R.as_matrix(), R2.as_matrix(), atol=atol)
1546
1793
 
1547
1794
 
1548
- def test_align_vectors_primary_only():
1795
+ def test_align_vectors_primary_only(xp):
1549
1796
  atol = 1e-12
1550
1797
  mats_a = Rotation.random(100, rng=0).as_matrix()
1551
1798
  mats_b = Rotation.random(100, rng=1).as_matrix()
1799
+
1552
1800
  for mat_a, mat_b in zip(mats_a, mats_b):
1553
1801
  # Get random 3-element unit vectors
1554
- a = mat_a[0]
1555
- b = mat_b[0]
1802
+ a = xp.asarray(mat_a[0])
1803
+ b = xp.asarray(mat_b[0])
1556
1804
 
1557
1805
  # Compare to align_vectors with primary only
1558
1806
  R, rssd = Rotation.align_vectors(a, b)
1559
- assert_allclose(R.apply(b), a, atol=atol)
1807
+ xp_assert_close(R.apply(b), a, atol=atol)
1560
1808
  assert np.isclose(rssd, 0, atol=atol)
1561
1809
 
1562
1810
 
1563
- def test_slerp():
1811
+ def test_align_vectors_array_like():
1812
+ rng = np.random.default_rng(123)
1813
+ c = Rotation.random(rng=rng)
1814
+ b = rng.normal(size=(5, 3))
1815
+ a = c.apply(b)
1816
+
1817
+ est_expected, rssd_expected = Rotation.align_vectors(a, b)
1818
+ est, rssd = Rotation.align_vectors(a.tolist(), b.tolist())
1819
+ xp_assert_close(est_expected.as_quat(), est.as_quat())
1820
+ xp_assert_close(rssd, rssd_expected)
1821
+
1822
+
1823
+ def test_repr_single_rotation(xp):
1824
+ q = xp.asarray([0, 0, 0, 1])
1825
+ actual = repr(Rotation.from_quat(q))
1826
+ if is_numpy(xp):
1827
+ expected = """\
1828
+ Rotation.from_matrix(array([[1., 0., 0.],
1829
+ [0., 1., 0.],
1830
+ [0., 0., 1.]]))"""
1831
+ assert actual == expected
1832
+ else:
1833
+ assert actual.startswith("Rotation.from_matrix(")
1834
+
1835
+
1836
+ def test_repr_rotation_sequence(xp):
1837
+ q = xp.asarray([[0.0, 1, 0, 1], [0, 0, 1, 1]]) / math.sqrt(2)
1838
+ actual = f"{Rotation.from_quat(q)!r}"
1839
+ if is_numpy(xp):
1840
+ expected = """\
1841
+ Rotation.from_matrix(array([[[ 0., 0., 1.],
1842
+ [ 0., 1., 0.],
1843
+ [-1., 0., 0.]],
1844
+
1845
+ [[ 0., -1., 0.],
1846
+ [ 1., 0., 0.],
1847
+ [ 0., 0., 1.]]]))"""
1848
+ assert actual == expected
1849
+ else:
1850
+ assert actual.startswith("Rotation.from_matrix(")
1851
+
1852
+
1853
+ def test_slerp(xp):
1564
1854
  rnd = np.random.RandomState(0)
1565
1855
 
1566
- key_rots = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1856
+ key_rots = Rotation.from_quat(xp.asarray(rnd.uniform(size=(5, 4))))
1567
1857
  key_quats = key_rots.as_quat()
1568
1858
 
1569
1859
  key_times = [0, 1, 2, 3, 4]
1570
1860
  interpolator = Slerp(key_times, key_rots)
1861
+ assert isinstance(interpolator.times, type(xp.asarray(0)))
1571
1862
 
1572
1863
  times = [0, 0.5, 0.25, 1, 1.5, 2, 2.75, 3, 3.25, 3.60, 4]
1573
1864
  interp_rots = interpolator(times)
1574
1865
  interp_quats = interp_rots.as_quat()
1575
1866
 
1576
1867
  # Dot products are affected by sign of quaternions
1577
- interp_quats[interp_quats[:, -1] < 0] *= -1
1868
+ mask = (interp_quats[:, -1] < 0)[:, None]
1869
+ interp_quats = xp.where(mask, -interp_quats, interp_quats)
1578
1870
  # Checking for quaternion equality, perform same operation
1579
- key_quats[key_quats[:, -1] < 0] *= -1
1871
+ mask = (key_quats[:, -1] < 0)[:, None]
1872
+ key_quats = xp.where(mask, -key_quats, key_quats)
1580
1873
 
1581
1874
  # Equality at keyframes, including both endpoints
1582
- assert_allclose(interp_quats[0], key_quats[0])
1583
- assert_allclose(interp_quats[3], key_quats[1])
1584
- assert_allclose(interp_quats[5], key_quats[2])
1585
- assert_allclose(interp_quats[7], key_quats[3])
1586
- assert_allclose(interp_quats[10], key_quats[4])
1875
+ xp_assert_close(interp_quats[0, ...], key_quats[0, ...])
1876
+ xp_assert_close(interp_quats[3, ...], key_quats[1, ...])
1877
+ xp_assert_close(interp_quats[5, ...], key_quats[2, ...])
1878
+ xp_assert_close(interp_quats[7, ...], key_quats[3, ...])
1879
+ xp_assert_close(interp_quats[10, ...], key_quats[4, ...])
1587
1880
 
1588
1881
  # Constant angular velocity between keyframes. Check by equating
1589
1882
  # cos(theta) between quaternion pairs with equal time difference.
1590
- cos_theta1 = np.sum(interp_quats[0] * interp_quats[2])
1591
- cos_theta2 = np.sum(interp_quats[2] * interp_quats[1])
1592
- assert_allclose(cos_theta1, cos_theta2)
1883
+ cos_theta1 = xp.sum(interp_quats[0, ...] * interp_quats[2, ...])
1884
+ cos_theta2 = xp.sum(interp_quats[2, ...] * interp_quats[1, ...])
1885
+ xp_assert_close(cos_theta1, cos_theta2)
1593
1886
 
1594
- cos_theta4 = np.sum(interp_quats[3] * interp_quats[4])
1595
- cos_theta5 = np.sum(interp_quats[4] * interp_quats[5])
1596
- assert_allclose(cos_theta4, cos_theta5)
1887
+ cos_theta4 = xp.sum(interp_quats[3, ...] * interp_quats[4, ...])
1888
+ cos_theta5 = xp.sum(interp_quats[4, ...] * interp_quats[5, ...])
1889
+ xp_assert_close(cos_theta4, cos_theta5)
1597
1890
 
1598
1891
  # theta1: 0 -> 0.25, theta3 : 0.5 -> 1
1599
1892
  # Use double angle formula for double the time difference
1600
- cos_theta3 = np.sum(interp_quats[1] * interp_quats[3])
1601
- assert_allclose(cos_theta3, 2 * (cos_theta1**2) - 1)
1893
+ cos_theta3 = xp.sum(interp_quats[1, ...] * interp_quats[3, ...])
1894
+ xp_assert_close(cos_theta3, 2 * (cos_theta1**2) - 1)
1602
1895
 
1603
1896
  # Miscellaneous checks
1604
1897
  assert_equal(len(interp_rots), len(times))
1605
1898
 
1606
1899
 
1607
- def test_slerp_rot_is_rotation():
1900
+ def test_slerp_rot_is_rotation(xp):
1608
1901
  with pytest.raises(TypeError, match="must be a `Rotation` instance"):
1609
- r = np.array([[1,2,3,4],
1610
- [0,0,0,1]])
1611
- t = np.array([0, 1])
1902
+ r = xp.asarray([[1,2,3,4],
1903
+ [0,0,0,1]])
1904
+ t = xp.asarray([0, 1])
1612
1905
  Slerp(t, r)
1613
1906
 
1614
1907
 
1615
- def test_slerp_single_rot():
1616
- msg = "must be a sequence of at least 2 rotations"
1617
- with pytest.raises(ValueError, match=msg):
1618
- r = Rotation.from_quat([1, 2, 3, 4])
1908
+ SLERP_EXCEPTION_MESSAGE = "must be a sequence of at least 2 rotations"
1909
+
1910
+
1911
+ def test_slerp_single_rot(xp):
1912
+ r = Rotation.from_quat(xp.asarray([[1.0, 2, 3, 4]]))
1913
+ with pytest.raises(ValueError, match=SLERP_EXCEPTION_MESSAGE):
1619
1914
  Slerp([1], r)
1620
1915
 
1621
1916
 
1622
- def test_slerp_rot_len1():
1623
- msg = "must be a sequence of at least 2 rotations"
1624
- with pytest.raises(ValueError, match=msg):
1625
- r = Rotation.from_quat([[1, 2, 3, 4]])
1917
+ def test_slerp_rot_len0(xp):
1918
+ r = Rotation.random()
1919
+ r = Rotation.from_quat(xp.asarray(r.as_quat()))
1920
+ with pytest.raises(ValueError, match=SLERP_EXCEPTION_MESSAGE):
1921
+ Slerp([], r)
1922
+
1923
+
1924
+ def test_slerp_rot_len1(xp):
1925
+ r = Rotation.random(1)
1926
+ r = Rotation.from_quat(xp.asarray(r.as_quat()))
1927
+ with pytest.raises(ValueError, match=SLERP_EXCEPTION_MESSAGE):
1626
1928
  Slerp([1], r)
1627
1929
 
1628
1930
 
1629
- def test_slerp_time_dim_mismatch():
1931
+ def test_slerp_time_dim_mismatch(xp):
1630
1932
  with pytest.raises(ValueError,
1631
1933
  match="times to be specified in a 1 dimensional array"):
1632
1934
  rnd = np.random.RandomState(0)
1633
- r = Rotation.from_quat(rnd.uniform(size=(2, 4)))
1634
- t = np.array([[1],
1635
- [2]])
1935
+ r = Rotation.from_quat(xp.asarray(rnd.uniform(size=(2, 4))))
1936
+ t = xp.asarray([[1],
1937
+ [2]])
1636
1938
  Slerp(t, r)
1637
1939
 
1638
1940
 
1639
- def test_slerp_num_rotations_mismatch():
1941
+ def test_slerp_num_rotations_mismatch(xp):
1640
1942
  with pytest.raises(ValueError, match="number of rotations to be equal to "
1641
1943
  "number of timestamps"):
1642
1944
  rnd = np.random.RandomState(0)
1643
- r = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1644
- t = np.arange(7)
1945
+ r = Rotation.from_quat(xp.asarray(rnd.uniform(size=(5, 4))))
1946
+ t = xp.arange(7)
1645
1947
  Slerp(t, r)
1646
1948
 
1647
1949
 
1648
- def test_slerp_equal_times():
1649
- with pytest.raises(ValueError, match="strictly increasing order"):
1650
- rnd = np.random.RandomState(0)
1651
- r = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1652
- t = [0, 1, 2, 2, 4]
1653
- Slerp(t, r)
1950
+ def test_slerp_equal_times(xp):
1951
+ rnd = np.random.RandomState(0)
1952
+ q = xp.asarray(rnd.uniform(size=(5, 4)))
1953
+ r = Rotation.from_quat(q)
1954
+ t = [0, 1, 2, 2, 4]
1955
+ if is_lazy_array(q):
1956
+ s = Slerp(t, r)
1957
+ assert xp.all(xp.isnan(s.times))
1958
+ else:
1959
+ with pytest.raises(ValueError, match="strictly increasing order"):
1960
+ Slerp(t, r)
1654
1961
 
1655
1962
 
1656
- def test_slerp_decreasing_times():
1657
- with pytest.raises(ValueError, match="strictly increasing order"):
1658
- rnd = np.random.RandomState(0)
1659
- r = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1660
- t = [0, 1, 3, 2, 4]
1661
- Slerp(t, r)
1963
+ def test_slerp_decreasing_times(xp):
1964
+ rnd = np.random.RandomState(0)
1965
+ q = xp.asarray(rnd.uniform(size=(5, 4)))
1966
+ r = Rotation.from_quat(q)
1967
+ t = [0, 1, 3, 2, 4]
1968
+ if is_lazy_array(q):
1969
+ s = Slerp(t, r)
1970
+ assert xp.all(xp.isnan(s.times))
1971
+ else:
1972
+ with pytest.raises(ValueError, match="strictly increasing order"):
1973
+ Slerp(t, r)
1662
1974
 
1663
1975
 
1664
- def test_slerp_call_time_dim_mismatch():
1976
+ def test_slerp_call_time_dim_mismatch(xp):
1665
1977
  rnd = np.random.RandomState(0)
1666
- r = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1667
- t = np.arange(5)
1978
+ r = Rotation.from_quat(xp.asarray(rnd.uniform(size=(5, 4))))
1979
+ t = xp.arange(5)
1668
1980
  s = Slerp(t, r)
1669
1981
 
1670
1982
  with pytest.raises(ValueError,
1671
1983
  match="`times` must be at most 1-dimensional."):
1672
- interp_times = np.array([[3.5],
1673
- [4.2]])
1984
+ interp_times = xp.asarray([[3.5],
1985
+ [4.2]])
1674
1986
  s(interp_times)
1675
1987
 
1676
1988
 
1677
- def test_slerp_call_time_out_of_range():
1989
+ def test_slerp_call_time_out_of_range(xp):
1678
1990
  rnd = np.random.RandomState(0)
1679
- r = Rotation.from_quat(rnd.uniform(size=(5, 4)))
1680
- t = np.arange(5) + 1
1991
+ r = Rotation.from_quat(xp.asarray(rnd.uniform(size=(5, 4))))
1992
+ t = xp.arange(5) + 1
1681
1993
  s = Slerp(t, r)
1682
1994
 
1683
- with pytest.raises(ValueError, match="times must be within the range"):
1684
- s([0, 1, 2])
1685
- with pytest.raises(ValueError, match="times must be within the range"):
1686
- s([1, 2, 6])
1687
-
1688
-
1689
- def test_slerp_call_scalar_time():
1690
- r = Rotation.from_euler('X', [0, 80], degrees=True)
1995
+ times_low = xp.asarray([0, 1, 2])
1996
+ times_high = xp.asarray([1, 2, 6])
1997
+ if is_lazy_array(times_low):
1998
+ q = s(times_low).as_quat()
1999
+ in_range = xp.logical_and(times_low >= xp.min(t), times_low <= xp.max(t))
2000
+ assert xp.all(xp.isnan(q[~in_range, ...]))
2001
+ assert xp.all(~xp.isnan(q[in_range, ...]))
2002
+ q = s(times_high).as_quat()
2003
+ in_range = xp.logical_and(times_high >= xp.min(t), times_high <= xp.max(t))
2004
+ assert xp.all(xp.isnan(q[~in_range, ...]))
2005
+ assert xp.all(~xp.isnan(q[in_range, ...]))
2006
+ else:
2007
+ with pytest.raises(ValueError, match="times must be within the range"):
2008
+ s(times_low)
2009
+ with pytest.raises(ValueError, match="times must be within the range"):
2010
+ s(times_high)
2011
+
2012
+
2013
+ def test_slerp_call_scalar_time(xp):
2014
+ r = Rotation.from_euler('X', xp.asarray([0, 80]), degrees=True)
1691
2015
  s = Slerp([0, 1], r)
1692
2016
 
1693
2017
  r_interpolated = s(0.25)
1694
- r_interpolated_expected = Rotation.from_euler('X', 20, degrees=True)
2018
+ r_interpolated_expected = Rotation.from_euler('X', xp.asarray(20), degrees=True)
1695
2019
 
1696
2020
  delta = r_interpolated * r_interpolated_expected.inv()
1697
2021
 
1698
- assert_allclose(delta.magnitude(), 0, atol=1e-16)
2022
+ assert xp.allclose(delta.magnitude(), 0, atol=1e-16)
1699
2023
 
1700
2024
 
1701
- def test_multiplication_stability():
2025
+ def test_multiplication_stability(xp):
1702
2026
  qs = Rotation.random(50, rng=0)
2027
+ qs = Rotation.from_quat(xp.asarray(qs.as_quat()))
1703
2028
  rs = Rotation.random(1000, rng=1)
2029
+ rs = Rotation.from_quat(xp.asarray(rs.as_quat()))
2030
+ expected = xp.ones(len(rs))
1704
2031
  for q in qs:
1705
2032
  rs *= q * rs
1706
- assert_allclose(np.linalg.norm(rs.as_quat(), axis=1), 1)
2033
+ xp_assert_close(xp_vector_norm(rs.as_quat(), axis=1), expected)
1707
2034
 
1708
2035
 
1709
- def test_pow():
2036
+ def test_pow(xp):
1710
2037
  atol = 1e-14
1711
2038
  p = Rotation.random(10, rng=0)
2039
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
1712
2040
  p_inv = p.inv()
1713
2041
  # Test the short-cuts and other integers
1714
2042
  for n in [-5, -2, -1, 0, 1, 2, 5]:
1715
2043
  # Test accuracy
1716
2044
  q = p ** n
1717
2045
  r = Rotation.identity(10)
2046
+ r = Rotation.from_quat(xp.asarray(r.as_quat()))
1718
2047
  for _ in range(abs(n)):
1719
2048
  if n > 0:
1720
2049
  r = r * p
1721
2050
  else:
1722
2051
  r = r * p_inv
1723
2052
  ang = (q * r.inv()).magnitude()
1724
- assert np.all(ang < atol)
2053
+ assert xp.all(ang < atol)
1725
2054
 
1726
2055
  # Test shape preservation
1727
- r = Rotation.from_quat([0, 0, 0, 1])
2056
+ r = Rotation.from_quat(xp.asarray([0, 0, 0, 1]))
1728
2057
  assert (r**n).as_quat().shape == (4,)
1729
- r = Rotation.from_quat([[0, 0, 0, 1]])
2058
+ r = Rotation.from_quat(xp.asarray([[0, 0, 0, 1]]))
1730
2059
  assert (r**n).as_quat().shape == (1, 4)
1731
2060
 
1732
2061
  # Large angle fractional
1733
2062
  for n in [-1.5, -0.5, -0.0, 0.0, 0.5, 1.5]:
1734
2063
  q = p ** n
1735
2064
  r = Rotation.from_rotvec(n * p.as_rotvec())
1736
- assert_allclose(q.as_quat(), r.as_quat(), atol=atol)
2065
+ xp_assert_close(q.as_quat(), r.as_quat(), atol=atol)
1737
2066
 
1738
2067
  # Small angle
1739
- p = Rotation.from_rotvec([1e-12, 0, 0])
2068
+ p = Rotation.from_rotvec(xp.asarray([1e-12, 0, 0]))
1740
2069
  n = 3
1741
2070
  q = p ** n
1742
2071
  r = Rotation.from_rotvec(n * p.as_rotvec())
1743
- assert_allclose(q.as_quat(), r.as_quat(), atol=atol)
2072
+ xp_assert_close(q.as_quat(), r.as_quat(), atol=atol)
1744
2073
 
1745
2074
 
1746
- def test_pow_errors():
2075
+ def test_pow_errors(xp):
1747
2076
  p = Rotation.random(rng=0)
2077
+ p = Rotation.from_quat(xp.asarray(p.as_quat()))
1748
2078
  with pytest.raises(NotImplementedError, match='modulus not supported'):
1749
2079
  pow(p, 1, 1)
1750
2080
 
1751
2081
 
1752
2082
  def test_rotation_within_numpy_array():
2083
+ # TODO: Do we want to support this for all Array API frameworks?
1753
2084
  single = Rotation.random(rng=0)
1754
2085
  multiple = Rotation.random(2, rng=1)
1755
2086
 
@@ -1758,8 +2089,8 @@ def test_rotation_within_numpy_array():
1758
2089
 
1759
2090
  array = np.array(multiple)
1760
2091
  assert_equal(array.shape, (2,))
1761
- assert_allclose(array[0].as_matrix(), multiple[0].as_matrix())
1762
- assert_allclose(array[1].as_matrix(), multiple[1].as_matrix())
2092
+ xp_assert_close(array[0].as_matrix(), multiple[0].as_matrix())
2093
+ xp_assert_close(array[1].as_matrix(), multiple[1].as_matrix())
1763
2094
 
1764
2095
  array = np.array([single])
1765
2096
  assert_equal(array.shape, (1,))
@@ -1767,8 +2098,8 @@ def test_rotation_within_numpy_array():
1767
2098
 
1768
2099
  array = np.array([multiple])
1769
2100
  assert_equal(array.shape, (1, 2))
1770
- assert_allclose(array[0, 0].as_matrix(), multiple[0].as_matrix())
1771
- assert_allclose(array[0, 1].as_matrix(), multiple[1].as_matrix())
2101
+ xp_assert_close(array[0, 0].as_matrix(), multiple[0].as_matrix())
2102
+ xp_assert_close(array[0, 1].as_matrix(), multiple[1].as_matrix())
1772
2103
 
1773
2104
  array = np.array([single, multiple], dtype=object)
1774
2105
  assert_equal(array.shape, (2,))
@@ -1779,20 +2110,25 @@ def test_rotation_within_numpy_array():
1779
2110
  assert_equal(array.shape, (3, 2))
1780
2111
 
1781
2112
 
1782
- def test_pickling():
1783
- r = Rotation.from_quat([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)])
2113
+ def test_pickling(xp):
2114
+ # Note: Array API makes no provision for arrays to be pickleable, so
2115
+ # it's OK to skip this test for the backends that don't support it
2116
+ r = Rotation.from_quat(xp.asarray([0, 0, math.sin(np.pi/4), math.cos(np.pi/4)]))
1784
2117
  pkl = pickle.dumps(r)
1785
2118
  unpickled = pickle.loads(pkl)
1786
- assert_allclose(r.as_matrix(), unpickled.as_matrix(), atol=1e-15)
2119
+ xp_assert_close(r.as_matrix(), unpickled.as_matrix(), atol=1e-15)
1787
2120
 
1788
2121
 
1789
- def test_deepcopy():
1790
- r = Rotation.from_quat([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)])
2122
+ def test_deepcopy(xp):
2123
+ # Note: Array API makes no provision for arrays to support the `__copy__`
2124
+ # protocol, so it's OK to skip this test for the backends that don't
2125
+ r = Rotation.from_quat(xp.asarray([0, 0, math.sin(np.pi/4), math.cos(np.pi/4)]))
1791
2126
  r1 = copy.deepcopy(r)
1792
- assert_allclose(r.as_matrix(), r1.as_matrix(), atol=1e-15)
2127
+ xp_assert_close(r.as_matrix(), r1.as_matrix(), atol=1e-15)
1793
2128
 
1794
2129
 
1795
2130
  def test_as_euler_contiguous():
2131
+ # The Array API does not specify contiguous arrays, so we can only check for NumPy
1796
2132
  r = Rotation.from_quat([0, 0, 0, 1])
1797
2133
  e1 = r.as_euler('xyz') # extrinsic euler rotation
1798
2134
  e2 = r.as_euler('XYZ') # intrinsic
@@ -1802,36 +2138,39 @@ def test_as_euler_contiguous():
1802
2138
  assert all(i >= 0 for i in e2.strides)
1803
2139
 
1804
2140
 
1805
- def test_concatenate():
2141
+ def test_concatenate(xp):
1806
2142
  rotation = Rotation.random(10, rng=0)
2143
+ rotation = Rotation.from_quat(xp.asarray(rotation.as_quat()))
1807
2144
  sizes = [1, 2, 3, 1, 3]
1808
2145
  starts = [0] + list(np.cumsum(sizes))
1809
2146
  split = [rotation[i:i + n] for i, n in zip(starts, sizes)]
1810
2147
  result = Rotation.concatenate(split)
1811
- assert_equal(rotation.as_quat(), result.as_quat())
2148
+ xp_assert_equal(rotation.as_quat(), result.as_quat())
1812
2149
 
1813
2150
  # Test Rotation input for multiple rotations
1814
2151
  result = Rotation.concatenate(rotation)
1815
- assert_equal(rotation.as_quat(), result.as_quat())
2152
+ xp_assert_equal(rotation.as_quat(), result.as_quat())
1816
2153
 
1817
2154
  # Test that a copy is returned
1818
2155
  assert rotation is not result
1819
2156
 
1820
2157
  # Test Rotation input for single rotations
1821
- result = Rotation.concatenate(Rotation.identity())
1822
- assert_equal(Rotation.identity().as_quat(), result.as_quat())
2158
+ rot = Rotation.from_quat(xp.asarray(Rotation.identity().as_quat()))
2159
+ result = Rotation.concatenate(rot)
2160
+ xp_assert_equal(rot.as_quat(), result.as_quat())
1823
2161
 
1824
2162
 
1825
- def test_concatenate_wrong_type():
2163
+ def test_concatenate_wrong_type(xp):
1826
2164
  with pytest.raises(TypeError, match='Rotation objects only'):
1827
- Rotation.concatenate([Rotation.identity(), 1, None])
2165
+ rot = Rotation(xp.asarray(Rotation.identity().as_quat()))
2166
+ Rotation.concatenate([rot, 1, None])
1828
2167
 
1829
2168
 
1830
2169
  # Regression test for gh-16663
1831
- def test_len_and_bool():
1832
- rotation_multi_one = Rotation([[0, 0, 0, 1]])
1833
- rotation_multi = Rotation([[0, 0, 0, 1], [0, 0, 0, 1]])
1834
- rotation_single = Rotation([0, 0, 0, 1])
2170
+ def test_len_and_bool(xp):
2171
+ rotation_multi_one = Rotation(xp.asarray([[0, 0, 0, 1]]))
2172
+ rotation_multi = Rotation(xp.asarray([[0, 0, 0, 1], [0, 0, 0, 1]]))
2173
+ rotation_single = Rotation(xp.asarray([0, 0, 0, 1]))
1835
2174
 
1836
2175
  assert len(rotation_multi_one) == 1
1837
2176
  assert len(rotation_multi) == 2
@@ -1844,61 +2183,93 @@ def test_len_and_bool():
1844
2183
  assert rotation_single
1845
2184
 
1846
2185
 
1847
- def test_from_davenport_single_rotation():
1848
- axis = [0, 0, 1]
2186
+ def test_from_davenport_single_rotation(xp):
2187
+ axis = xp.asarray([0, 0, 1])
1849
2188
  quat = Rotation.from_davenport(axis, 'extrinsic', 90,
1850
2189
  degrees=True).as_quat()
1851
- expected_quat = np.array([0, 0, 1, 1]) / np.sqrt(2)
1852
- assert_allclose(quat, expected_quat)
2190
+ expected_quat = xp.asarray([0.0, 0, 1, 1]) / math.sqrt(2)
2191
+ xp_assert_close(quat, expected_quat)
1853
2192
 
1854
2193
 
1855
- def test_from_davenport_one_or_two_axes():
1856
- ez = [0, 0, 1]
1857
- ey = [0, 1, 0]
2194
+ def test_from_davenport_one_or_two_axes(xp):
2195
+ ez = xp.asarray([0.0, 0, 1])
2196
+ ey = xp.asarray([0.0, 1, 0])
1858
2197
 
1859
2198
  # Single rotation, single axis, axes.shape == (3, )
1860
- rot = Rotation.from_rotvec(np.array(ez) * np.pi/4)
1861
- rot_dav = Rotation.from_davenport(ez, 'e', np.pi/4)
1862
- assert_allclose(rot.as_quat(canonical=True),
2199
+ rot = Rotation.from_rotvec(ez * xp.pi/4)
2200
+ rot_dav = Rotation.from_davenport(ez, 'e', xp.pi/4)
2201
+ xp_assert_close(rot.as_quat(canonical=True),
1863
2202
  rot_dav.as_quat(canonical=True))
1864
2203
 
1865
2204
  # Single rotation, single axis, axes.shape == (1, 3)
1866
- rot = Rotation.from_rotvec([np.array(ez) * np.pi/4])
1867
- rot_dav = Rotation.from_davenport([ez], 'e', [np.pi/4])
1868
- assert_allclose(rot.as_quat(canonical=True),
2205
+ axes = xp.reshape(ez, (1, 3)) # Torch can't create tensors from xp.asarray([ez])
2206
+ rot = Rotation.from_rotvec(axes * xp.pi/4)
2207
+ rot_dav = Rotation.from_davenport(axes, 'e', [xp.pi/4])
2208
+ xp_assert_close(rot.as_quat(canonical=True),
1869
2209
  rot_dav.as_quat(canonical=True))
1870
2210
 
1871
2211
  # Single rotation, two axes, axes.shape == (2, 3)
1872
- rot = Rotation.from_rotvec([np.array(ez) * np.pi/4,
1873
- np.array(ey) * np.pi/6])
2212
+ axes = xp.stack([ez, ey], axis=0)
2213
+ rot = Rotation.from_rotvec(axes * xp.asarray([[xp.pi/4], [xp.pi/6]]))
1874
2214
  rot = rot[0] * rot[1]
1875
- rot_dav = Rotation.from_davenport([ey, ez], 'e', [np.pi/6, np.pi/4])
1876
- assert_allclose(rot.as_quat(canonical=True),
2215
+ axes_dav = xp.stack([ey, ez], axis=0)
2216
+ rot_dav = Rotation.from_davenport(axes_dav, 'e', [xp.pi/6, xp.pi/4])
2217
+ xp_assert_close(rot.as_quat(canonical=True),
1877
2218
  rot_dav.as_quat(canonical=True))
1878
2219
 
1879
2220
  # Two rotations, single axis, axes.shape == (3, )
1880
- rot = Rotation.from_rotvec([np.array(ez) * np.pi/6,
1881
- np.array(ez) * np.pi/4])
1882
- rot_dav = Rotation.from_davenport([ez], 'e', [np.pi/6, np.pi/4])
1883
- assert_allclose(rot.as_quat(canonical=True),
2221
+ axes = xp.stack([ez, ez], axis=0)
2222
+ rot = Rotation.from_rotvec(axes * xp.asarray([[xp.pi/6], [xp.pi/4]]))
2223
+ axes_dav = xp.reshape(ez, (1, 3))
2224
+ rot_dav = Rotation.from_davenport(axes_dav, 'e', [xp.pi/6, xp.pi/4])
2225
+ xp_assert_close(rot.as_quat(canonical=True),
1884
2226
  rot_dav.as_quat(canonical=True))
1885
2227
 
1886
2228
 
1887
- def test_from_davenport_invalid_input():
2229
+ def test_from_davenport_invalid_input(xp):
1888
2230
  ez = [0, 0, 1]
1889
2231
  ey = [0, 1, 0]
1890
2232
  ezy = [0, 1, 1]
1891
- with pytest.raises(ValueError, match="must be orthogonal"):
1892
- Rotation.from_davenport([ez, ezy], 'e', [0, 0])
1893
- with pytest.raises(ValueError, match="must be orthogonal"):
1894
- Rotation.from_davenport([ez, ey, ezy], 'e', [0, 0, 0])
2233
+ # We can only raise in non-lazy frameworks.
2234
+ axes = xp.asarray([ez, ezy])
2235
+ if is_lazy_array(axes):
2236
+ q = Rotation.from_davenport(axes, 'e', [0, 0]).as_quat()
2237
+ assert xp.all(xp.isnan(q))
2238
+ else:
2239
+ with pytest.raises(ValueError, match="must be orthogonal"):
2240
+ Rotation.from_davenport(axes, 'e', [0, 0])
2241
+ axes = xp.asarray([ez, ey, ezy])
2242
+ if is_lazy_array(axes):
2243
+ q = Rotation.from_davenport(axes, 'e', [0, 0, 0]).as_quat()
2244
+ assert xp.all(xp.isnan(q))
2245
+ else:
2246
+ with pytest.raises(ValueError, match="must be orthogonal"):
2247
+ Rotation.from_davenport(axes, 'e', [0, 0, 0])
1895
2248
  with pytest.raises(ValueError, match="order should be"):
1896
- Rotation.from_davenport([ez], 'xyz', [0])
2249
+ Rotation.from_davenport(xp.asarray([ez]), 'xyz', [0])
1897
2250
  with pytest.raises(ValueError, match="Expected `angles`"):
1898
- Rotation.from_davenport([ez, ey, ez], 'e', [0, 1, 2, 3])
2251
+ Rotation.from_davenport(xp.asarray([ez, ey, ez]), 'e', [0, 1, 2, 3])
2252
+
2253
+
2254
+ def test_from_davenport_array_like():
2255
+ rng = np.random.default_rng(123)
2256
+ # Single rotation
2257
+ e1 = np.array([1, 0, 0])
2258
+ e2 = np.array([0, 1, 0])
2259
+ e3 = np.array([0, 0, 1])
2260
+ r_expected = Rotation.random(rng=rng)
2261
+ angles = r_expected.as_davenport([e1, e2, e3], 'e')
2262
+ r = Rotation.from_davenport([e1, e2, e3], 'e', angles.tolist())
2263
+ assert r_expected.approx_equal(r, atol=1e-12)
1899
2264
 
2265
+ # Multiple rotations
2266
+ r_expected = Rotation.random(2, rng=rng)
2267
+ angles = r_expected.as_davenport([e1, e2, e3], 'e')
2268
+ r = Rotation.from_davenport([e1, e2, e3], 'e', angles.tolist())
2269
+ assert np.all(r_expected.approx_equal(r, atol=1e-12))
1900
2270
 
1901
- def test_as_davenport():
2271
+
2272
+ def test_as_davenport(xp):
1902
2273
  rnd = np.random.RandomState(0)
1903
2274
  n = 100
1904
2275
  angles = np.empty((n, 3))
@@ -1907,21 +2278,22 @@ def test_as_davenport():
1907
2278
  angles[:, 2] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
1908
2279
  lambdas = rnd.uniform(low=0, high=np.pi, size=(20,))
1909
2280
 
1910
- e1 = np.array([1, 0, 0])
1911
- e2 = np.array([0, 1, 0])
2281
+ e1 = xp.asarray([1.0, 0, 0])
2282
+ e2 = xp.asarray([0.0, 1, 0])
1912
2283
 
1913
2284
  for lamb in lambdas:
1914
- ax_lamb = [e1, e2, Rotation.from_rotvec(lamb*e2).apply(e1)]
2285
+ e3 = xp.asarray(Rotation.from_rotvec(lamb*e2).apply(e1))
2286
+ ax_lamb = xp.stack([e1, e2, e3], axis=0)
1915
2287
  angles[:, 1] = angles_middle - lamb
1916
2288
  for order in ['extrinsic', 'intrinsic']:
1917
- ax = ax_lamb if order == 'intrinsic' else ax_lamb[::-1]
1918
- rot = Rotation.from_davenport(ax, order, angles)
1919
- angles_dav = rot.as_davenport(ax, order)
1920
- assert_allclose(angles_dav, angles)
2289
+ ax = ax_lamb if order == "intrinsic" else xp.flip(ax_lamb, axis=0)
2290
+ rot = Rotation.from_davenport(xp.asarray(ax), order, angles)
2291
+ angles_dav = rot.as_davenport(xp.asarray(ax), order)
2292
+ xp_assert_close(angles_dav, xp.asarray(angles))
1921
2293
 
1922
2294
 
1923
2295
  @pytest.mark.thread_unsafe
1924
- def test_as_davenport_degenerate():
2296
+ def test_as_davenport_degenerate(xp):
1925
2297
  # Since we cannot check for angle equality, we check for rotation matrix
1926
2298
  # equality
1927
2299
  rnd = np.random.RandomState(0)
@@ -1934,23 +2306,25 @@ def test_as_davenport_degenerate():
1934
2306
  angles[:, 2] = rnd.uniform(low=-np.pi, high=np.pi, size=(n,))
1935
2307
  lambdas = rnd.uniform(low=0, high=np.pi, size=(5,))
1936
2308
 
1937
- e1 = np.array([1, 0, 0])
1938
- e2 = np.array([0, 1, 0])
2309
+ e1 = xp.asarray([1.0, 0, 0])
2310
+ e2 = xp.asarray([0.0, 1, 0])
1939
2311
 
1940
2312
  for lamb in lambdas:
1941
- ax_lamb = [e1, e2, Rotation.from_rotvec(lamb*e2).apply(e1)]
2313
+ e3 = xp.asarray(Rotation.from_rotvec(lamb*e2).apply(e1))
2314
+ ax_lamb = xp.stack([e1, e2, e3], axis=0)
1942
2315
  angles[:, 1] = angles_middle - lamb
1943
2316
  for order in ['extrinsic', 'intrinsic']:
1944
2317
  ax = ax_lamb if order == 'intrinsic' else ax_lamb[::-1]
1945
- rot = Rotation.from_davenport(ax, order, angles)
1946
- with pytest.warns(UserWarning, match="Gimbal lock"):
1947
- angles_dav = rot.as_davenport(ax, order)
2318
+ rot = Rotation.from_davenport(xp.asarray(ax), order, angles)
2319
+ with eager_warns(rot, UserWarning, match="Gimbal lock"):
2320
+ angles_dav = rot.as_davenport(xp.asarray(ax), order)
1948
2321
  mat_expected = rot.as_matrix()
1949
- mat_estimated = Rotation.from_davenport(ax, order, angles_dav).as_matrix()
1950
- assert_array_almost_equal(mat_expected, mat_estimated)
2322
+ rot_estimated = Rotation.from_davenport(xp.asarray(ax), order, angles_dav)
2323
+ mat_estimated = rot_estimated.as_matrix()
2324
+ xp_assert_close(mat_expected, mat_estimated, atol=1e-12)
1951
2325
 
1952
2326
 
1953
- def test_compare_from_davenport_from_euler():
2327
+ def test_compare_from_davenport_from_euler(xp):
1954
2328
  rnd = np.random.RandomState(0)
1955
2329
  n = 100
1956
2330
  angles = np.empty((n, 3))
@@ -1965,9 +2339,9 @@ def test_compare_from_davenport_from_euler():
1965
2339
  ax = [basis_vec(i) for i in seq]
1966
2340
  if order == 'intrinsic':
1967
2341
  seq = seq.upper()
1968
- eul = Rotation.from_euler(seq, angles)
1969
- dav = Rotation.from_davenport(ax, order, angles)
1970
- assert_allclose(eul.as_quat(canonical=True), dav.as_quat(canonical=True),
2342
+ eul = Rotation.from_euler(seq, xp.asarray(angles))
2343
+ dav = Rotation.from_davenport(xp.asarray(ax), order, xp.asarray(angles))
2344
+ xp_assert_close(eul.as_quat(canonical=True), dav.as_quat(canonical=True),
1971
2345
  rtol=1e-12)
1972
2346
 
1973
2347
  # asymmetric sequences
@@ -1978,12 +2352,12 @@ def test_compare_from_davenport_from_euler():
1978
2352
  ax = [basis_vec(i) for i in seq]
1979
2353
  if order == 'intrinsic':
1980
2354
  seq = seq.upper()
1981
- eul = Rotation.from_euler(seq, angles)
1982
- dav = Rotation.from_davenport(ax, order, angles)
1983
- assert_allclose(eul.as_quat(), dav.as_quat(), rtol=1e-12)
2355
+ eul = Rotation.from_euler(seq, xp.asarray(angles))
2356
+ dav = Rotation.from_davenport(xp.asarray(ax), order, xp.asarray(angles))
2357
+ xp_assert_close(eul.as_quat(), dav.as_quat(), rtol=1e-12)
1984
2358
 
1985
2359
 
1986
- def test_compare_as_davenport_as_euler():
2360
+ def test_compare_as_davenport_as_euler(xp):
1987
2361
  rnd = np.random.RandomState(0)
1988
2362
  n = 100
1989
2363
  angles = np.empty((n, 3))
@@ -1998,10 +2372,10 @@ def test_compare_as_davenport_as_euler():
1998
2372
  ax = [basis_vec(i) for i in seq]
1999
2373
  if order == 'intrinsic':
2000
2374
  seq = seq.upper()
2001
- rot = Rotation.from_euler(seq, angles)
2375
+ rot = Rotation.from_euler(seq, xp.asarray(angles))
2002
2376
  eul = rot.as_euler(seq)
2003
- dav = rot.as_davenport(ax, order)
2004
- assert_allclose(eul, dav, rtol=1e-12)
2377
+ dav = rot.as_davenport(xp.asarray(ax), order)
2378
+ xp_assert_close(eul, dav, rtol=1e-12)
2005
2379
 
2006
2380
  # asymmetric sequences
2007
2381
  angles[:, 1] -= np.pi / 2
@@ -2011,7 +2385,185 @@ def test_compare_as_davenport_as_euler():
2011
2385
  ax = [basis_vec(i) for i in seq]
2012
2386
  if order == 'intrinsic':
2013
2387
  seq = seq.upper()
2014
- rot = Rotation.from_euler(seq, angles)
2388
+ rot = Rotation.from_euler(seq, xp.asarray(angles))
2015
2389
  eul = rot.as_euler(seq)
2016
- dav = rot.as_davenport(ax, order)
2017
- assert_allclose(eul, dav, rtol=1e-12)
2390
+ dav = rot.as_davenport(xp.asarray(ax), order)
2391
+ xp_assert_close(eul, dav, rtol=1e-12)
2392
+
2393
+
2394
+ def test_zero_rotation_construction(xp):
2395
+ r = Rotation.random(num=0)
2396
+ assert len(r) == 0
2397
+
2398
+ r_ide = Rotation.identity(num=0)
2399
+ assert len(r_ide) == 0
2400
+
2401
+ r_get = Rotation.random(num=3)[[]]
2402
+ assert len(r_get) == 0
2403
+
2404
+ r_quat = Rotation.from_quat(xp.zeros((0, 4)))
2405
+ assert len(r_quat) == 0
2406
+
2407
+ r_matrix = Rotation.from_matrix(xp.zeros((0, 3, 3)))
2408
+ assert len(r_matrix) == 0
2409
+
2410
+ r_euler = Rotation.from_euler("xyz", xp.zeros((0, 3)))
2411
+ assert len(r_euler) == 0
2412
+
2413
+ r_vec = Rotation.from_rotvec(xp.zeros((0, 3)))
2414
+ assert len(r_vec) == 0
2415
+
2416
+ r_dav = Rotation.from_davenport(xp.eye(3), "extrinsic", xp.zeros((0, 3)))
2417
+ assert len(r_dav) == 0
2418
+
2419
+ r_mrp = Rotation.from_mrp(xp.zeros((0, 3)))
2420
+ assert len(r_mrp) == 0
2421
+
2422
+
2423
+ def test_zero_rotation_representation(xp):
2424
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2425
+ assert r.as_quat().shape == (0, 4)
2426
+ assert r.as_matrix().shape == (0, 3, 3)
2427
+ assert r.as_euler("xyz").shape == (0, 3)
2428
+ assert r.as_rotvec().shape == (0, 3)
2429
+ assert r.as_mrp().shape == (0, 3)
2430
+ assert r.as_davenport(xp.eye(3), "extrinsic").shape == (0, 3)
2431
+
2432
+
2433
+ def test_zero_rotation_array_rotation(xp):
2434
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2435
+
2436
+ v = xp.asarray([1, 2, 3])
2437
+ v_rotated = r.apply(v)
2438
+ assert v_rotated.shape == (0, 3)
2439
+
2440
+ v0 = xp.zeros((0, 3))
2441
+ v0_rot = r.apply(v0)
2442
+ assert v0_rot.shape == (0, 3)
2443
+
2444
+ v2 = xp.ones((2, 3))
2445
+ with pytest.raises(
2446
+ ValueError, match="Expected equal numbers of rotations and vectors"):
2447
+ r.apply(v2)
2448
+
2449
+
2450
+ def test_zero_rotation_multiplication(xp):
2451
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2452
+
2453
+ r_single = Rotation.from_quat(xp.asarray([0.0, 0, 0, 1]))
2454
+ r_mult_left = r * r_single
2455
+ assert len(r_mult_left) == 0
2456
+
2457
+ r_mult_right = r_single * r
2458
+ assert len(r_mult_right) == 0
2459
+
2460
+ r0 = Rotation.from_quat(xp.zeros((0, 4)))
2461
+ r_mult = r * r0
2462
+ assert len(r_mult) == 0
2463
+
2464
+ msg_rotation_error = "Expected equal number of rotations"
2465
+ r2 = Rotation.random(2)
2466
+ r2 = Rotation.from_quat(xp.asarray(r2.as_quat()))
2467
+ with pytest.raises(ValueError, match=msg_rotation_error):
2468
+ r0 * r2
2469
+
2470
+ with pytest.raises(ValueError, match=msg_rotation_error):
2471
+ r2 * r0
2472
+
2473
+
2474
+ def test_zero_rotation_concatentation(xp):
2475
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2476
+
2477
+ r0 = Rotation.concatenate([r, r])
2478
+ assert len(r0) == 0
2479
+
2480
+ r1 = Rotation.from_quat(xp.asarray([0.0, 0, 0, 1]))
2481
+ r1 = r.concatenate([r1, r])
2482
+ assert len(r1) == 1
2483
+
2484
+ r3 = Rotation.from_quat(xp.asarray(Rotation.random(3).as_quat()))
2485
+ r3 = r.concatenate([r3, r])
2486
+ assert len(r3) == 3
2487
+
2488
+ r4 = Rotation.from_quat(xp.asarray(Rotation.random(4).as_quat()))
2489
+ r4 = r.concatenate([r, r4])
2490
+ r4 = r.concatenate([r, r4])
2491
+ assert len(r4) == 4
2492
+
2493
+
2494
+ def test_zero_rotation_power(xp):
2495
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2496
+ for pp in [-1.5, -1, 0, 1, 1.5]:
2497
+ pow0 = r**pp
2498
+ assert len(pow0) == 0
2499
+
2500
+
2501
+ def test_zero_rotation_inverse(xp):
2502
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2503
+ r_inv = r.inv()
2504
+ assert len(r_inv) == 0
2505
+
2506
+
2507
+ def test_zero_rotation_magnitude(xp):
2508
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2509
+ magnitude = r.magnitude()
2510
+ assert magnitude.shape == (0,)
2511
+
2512
+
2513
+ def test_zero_rotation_mean(xp):
2514
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2515
+ with pytest.raises(ValueError, match="Mean of an empty rotation set is undefined."):
2516
+ r.mean()
2517
+
2518
+
2519
+ def test_zero_rotation_approx_equal(xp):
2520
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2521
+ r0 = Rotation.from_quat(xp.zeros((0, 4)))
2522
+ assert r.approx_equal(r0).shape == (0,)
2523
+ r1 = Rotation.from_quat(xp.asarray([0.0, 0, 0, 1]))
2524
+ assert r.approx_equal(r1).shape == (0,)
2525
+ r2 = Rotation.from_quat(xp.asarray(Rotation.random().as_quat()))
2526
+ assert r2.approx_equal(r).shape == (0,)
2527
+
2528
+ approx_msg = "Expected equal number of rotations"
2529
+ r3 = Rotation.from_quat(xp.asarray(Rotation.random(2).as_quat()))
2530
+ with pytest.raises(ValueError, match=approx_msg):
2531
+ r.approx_equal(r3)
2532
+
2533
+ with pytest.raises(ValueError, match=approx_msg):
2534
+ r3.approx_equal(r)
2535
+
2536
+
2537
+ def test_zero_rotation_get_set(xp):
2538
+ r = Rotation.from_quat(xp.zeros((0, 4)))
2539
+
2540
+ r_get = r[xp.asarray([], dtype=xp.bool)]
2541
+ assert len(r_get) == 0
2542
+
2543
+ r_slice = r[:0]
2544
+ assert len(r_slice) == 0
2545
+
2546
+ with pytest.raises(IndexError):
2547
+ r[xp.asarray([0])]
2548
+
2549
+ with pytest.raises(IndexError):
2550
+ r[xp.asarray([True])]
2551
+
2552
+ with pytest.raises(IndexError):
2553
+ r[0] = Rotation.from_quat(xp.asarray([0, 0, 0, 1]))
2554
+
2555
+
2556
+ def test_boolean_indexes(xp):
2557
+ r = Rotation.from_quat(xp.asarray(Rotation.random(3).as_quat()))
2558
+
2559
+ r0 = r[xp.asarray([False, False, False])]
2560
+ assert len(r0) == 0
2561
+
2562
+ r1 = r[xp.asarray([False, True, False])]
2563
+ assert len(r1) == 1
2564
+
2565
+ r3 = r[xp.asarray([True, True, True])]
2566
+ assert len(r3) == 3
2567
+
2568
+ with pytest.raises(IndexError):
2569
+ r[xp.asarray([True, True])]