warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/tests/test_mat.py ADDED
@@ -0,0 +1,3515 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ np_signed_int_types = [np.int8, np.int16, np.int32, np.int64, np.byte]
25
+ np_float_types = [np.float16, np.float32, np.float64]
26
+
27
+
28
+ def randvals(rng, shape, dtype):
29
+ if dtype in np_float_types:
30
+ return rng.standard_normal(size=shape).astype(dtype)
31
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
32
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
33
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
34
+
35
+
36
+ kernel_cache = {}
37
+
38
+
39
+ def getkernel(func, suffix=""):
40
+ key = func.__name__ + "_" + suffix
41
+ if key not in kernel_cache:
42
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
43
+ return kernel_cache[key]
44
+
45
+
46
+ def get_select_kernel(dtype):
47
+ def output_select_kernel_fn(input: wp.array(dtype=dtype), index: int, out: wp.array(dtype=dtype)):
48
+ out[0] = input[index]
49
+
50
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
51
+
52
+
53
+ def test_shape_mismatch(test, device):
54
+ test.assertNotEqual(wp.mat33f(0.0), wp.mat22f(0.0))
55
+ test.assertNotEqual(wp.mat22f(0.0), wp.mat33f(0.0))
56
+
57
+ @wp.kernel(module="unique")
58
+ def kernel():
59
+ wp.expect_neq(wp.mat33f(0.0), wp.mat22f(0.0))
60
+ wp.expect_neq(wp.mat22f(0.0), wp.mat33f(0.0))
61
+
62
+ with test.assertRaisesRegex(
63
+ RuntimeError,
64
+ r"Can't test equality for objects with different types$",
65
+ ):
66
+ wp.launch(kernel, dim=1, inputs=[], device=device)
67
+
68
+
69
+ def test_py_arithmetic_ops(test, device, dtype):
70
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
71
+
72
+ def make_mat(*args):
73
+ if wptype in wp._src.types.int_types:
74
+ # Cast to the correct integer type to simulate wrapping.
75
+ return tuple(tuple(wptype._type_(x).value for x in row) for row in args)
76
+
77
+ return args
78
+
79
+ def make_vec(*args):
80
+ if wptype in wp._src.types.int_types:
81
+ # Cast to the correct integer type to simulate wrapping.
82
+ return tuple(wptype._type_(x).value for x in args)
83
+
84
+ return args
85
+
86
+ mat_cls = wp.mat((3, 3), wptype)
87
+ vec_cls = wp.vec(3, wptype)
88
+
89
+ m = mat_cls(((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
90
+ test.assertSequenceEqual(+m, make_mat((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
91
+ test.assertSequenceEqual(-m, make_mat((1, -2, -3), (-4, 5, -6), (-7, -8, 9)))
92
+ test.assertSequenceEqual(m + mat_cls((5, 5, 5) * 3), make_mat((4, 7, 8), (9, 0, 11), (12, 13, -4)))
93
+ test.assertSequenceEqual(m - mat_cls((5, 5, 5) * 3), make_mat((-6, -3, -2), (-1, -10, 1), (2, 3, -14)))
94
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(20, 25, 30))
95
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(20, 25, 30))
96
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(50, 25, 0))
97
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(50, 25, 0))
98
+
99
+ m = mat_cls(((2, 4, 6), (8, 10, 12), (14, 16, 18)))
100
+ test.assertSequenceEqual(m * wptype(2), make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
101
+ test.assertSequenceEqual(wptype(2) * m, make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
102
+ test.assertSequenceEqual(m / wptype(2), make_mat((1, 2, 3), (4, 5, 6), (7, 8, 9)))
103
+ test.assertSequenceEqual(wptype(5040) / m, make_mat((2520, 1260, 840), (630, 504, 420), (360, 315, 280)))
104
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(60, 150, 240))
105
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(60, 150, 240))
106
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(120, 150, 180))
107
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(120, 150, 180))
108
+
109
+
110
+ def test_negation(test, device, dtype, register_kernels=False):
111
+ rng = np.random.default_rng(123)
112
+
113
+ tol = {
114
+ np.float16: 1.0e-2,
115
+ np.float32: 1.0e-6,
116
+ np.float64: 1.0e-8,
117
+ }.get(dtype, 0)
118
+
119
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
120
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
121
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
122
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
123
+ mat55 = wp._src.types.matrix(shape=(5, 5), dtype=wptype)
124
+
125
+ output_select_kernel = get_select_kernel(wptype)
126
+
127
+ def check_mat_negation(
128
+ m2: wp.array(dtype=mat22),
129
+ m3: wp.array(dtype=mat33),
130
+ m4: wp.array(dtype=mat44),
131
+ m5: wp.array(dtype=mat55),
132
+ outcomponents: wp.array(dtype=wptype),
133
+ ):
134
+ mat2 = -m2[0]
135
+ mat3 = -m3[0]
136
+ mat4 = -m4[0]
137
+ mat5 = -m5[0]
138
+
139
+ # multiply outputs by 2 so we've got something to backpropagate:
140
+ idx = 0
141
+ for i in range(2):
142
+ for j in range(2):
143
+ outcomponents[idx] = wptype(2) * mat2[i, j]
144
+ idx = idx + 1
145
+
146
+ for i in range(3):
147
+ for j in range(3):
148
+ outcomponents[idx] = wptype(2) * mat3[i, j]
149
+ idx = idx + 1
150
+
151
+ for i in range(4):
152
+ for j in range(4):
153
+ outcomponents[idx] = wptype(2) * mat4[i, j]
154
+ idx = idx + 1
155
+
156
+ for i in range(5):
157
+ for j in range(5):
158
+ outcomponents[idx] = wptype(2) * mat5[i, j]
159
+ idx = idx + 1
160
+
161
+ kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
162
+
163
+ if register_kernels:
164
+ return
165
+
166
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
167
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
168
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
169
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
170
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
171
+
172
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
173
+
174
+ assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
175
+ assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
176
+ assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
177
+ assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
178
+
179
+ if dtype in np_float_types:
180
+ idx = 0
181
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
182
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
183
+ for i in range(dim):
184
+ for j in range(dim):
185
+ tape = wp.Tape()
186
+ with tape:
187
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
188
+ wp.launch(
189
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
190
+ )
191
+ tape.backward(loss=out)
192
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
193
+ expectedresult[i, j] = -2
194
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
195
+ tape.zero()
196
+ idx = idx + 1
197
+
198
+
199
+ def test_matmul(test, device, dtype, register_kernels=False):
200
+ rng = np.random.default_rng(123)
201
+
202
+ tol = {
203
+ np.float16: 5.0e-3,
204
+ np.float32: 1.0e-6,
205
+ np.float64: 1.0e-12,
206
+ }.get(dtype, 0)
207
+
208
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
209
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
210
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
211
+ mat23 = wp._src.types.matrix(shape=(2, 3), dtype=wptype)
212
+ mat32 = wp._src.types.matrix(shape=(3, 2), dtype=wptype)
213
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
214
+
215
+ output_select_kernel = get_select_kernel(wptype)
216
+
217
+ def check_mat_mul(
218
+ i23: wp.array(dtype=mat23),
219
+ i32: wp.array(dtype=mat32),
220
+ i44: wp.array(dtype=mat44),
221
+ o22: wp.array(dtype=mat22),
222
+ o33: wp.array(dtype=mat33),
223
+ o44: wp.array(dtype=mat44),
224
+ ):
225
+ i = wp.tid()
226
+ o22[i] = i23[i] @ i32[i]
227
+ o33[i] = i32[i] @ i23[i]
228
+ o44[i] = i44[i] @ i44[i]
229
+
230
+ kernel = getkernel(check_mat_mul, suffix=dtype.__name__)
231
+
232
+ if register_kernels:
233
+ return
234
+
235
+ test_adj = dtype in np_float_types
236
+
237
+ i23 = wp.array(randvals(rng, [1, 2, 3], dtype), dtype=mat23, requires_grad=test_adj, device=device)
238
+ i32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=test_adj, device=device)
239
+ i44 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=test_adj, device=device)
240
+ o22 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=test_adj, device=device)
241
+ o33 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=test_adj, device=device)
242
+ o44 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=test_adj, device=device)
243
+
244
+ tape = wp.Tape()
245
+ with tape:
246
+ wp.launch(
247
+ kernel,
248
+ dim=1,
249
+ inputs=[i23, i32, i44],
250
+ outputs=[o22, o33, o44],
251
+ device=device,
252
+ )
253
+
254
+ assert_np_equal(o22.numpy(), i23.numpy() @ i32.numpy(), tol=tol)
255
+ assert_np_equal(o33.numpy(), i32.numpy() @ i23.numpy(), tol=tol)
256
+ assert_np_equal(o44.numpy(), i44.numpy() @ i44.numpy(), tol=tol)
257
+
258
+ if test_adj:
259
+ o22.grad.assign([np.eye(2)])
260
+ o33.grad.assign([np.eye(3)])
261
+ o44.grad.assign([np.eye(4)])
262
+
263
+ tape.backward()
264
+
265
+ assert_np_equal(i23.grad.numpy(), 2.0 * i32.numpy().T, tol=tol)
266
+ assert_np_equal(i32.grad.numpy(), 2.0 * i23.numpy().T, tol=tol)
267
+ assert_np_equal(i44.grad.numpy(), 2.0 * i44.numpy().T, tol=tol)
268
+
269
+
270
+ def test_subtraction(test, device, dtype, register_kernels=False):
271
+ rng = np.random.default_rng(123)
272
+
273
+ tol = {
274
+ np.float16: 5.0e-3,
275
+ np.float32: 1.0e-6,
276
+ np.float64: 1.0e-8,
277
+ }.get(dtype, 0)
278
+
279
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
280
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
281
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
282
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
283
+ mat55 = wp._src.types.matrix(shape=(5, 5), dtype=wptype)
284
+
285
+ output_select_kernel = get_select_kernel(wptype)
286
+
287
+ def check_mat_sub(
288
+ s2: wp.array(dtype=mat22),
289
+ s3: wp.array(dtype=mat33),
290
+ s4: wp.array(dtype=mat44),
291
+ s5: wp.array(dtype=mat55),
292
+ v2: wp.array(dtype=mat22),
293
+ v3: wp.array(dtype=mat33),
294
+ v4: wp.array(dtype=mat44),
295
+ v5: wp.array(dtype=mat55),
296
+ outcomponents: wp.array(dtype=wptype),
297
+ ):
298
+ v2result = v2[0] - s2[0]
299
+ v3result = v3[0] - s3[0]
300
+ v4result = v4[0] - s4[0]
301
+ v5result = v5[0] - s5[0]
302
+
303
+ # multiply outputs by 2 so we've got something to backpropagate:
304
+ idx = 0
305
+ for i in range(2):
306
+ for j in range(2):
307
+ outcomponents[idx] = wptype(2) * v2result[i, j]
308
+ idx = idx + 1
309
+
310
+ for i in range(3):
311
+ for j in range(3):
312
+ outcomponents[idx] = wptype(2) * v3result[i, j]
313
+ idx = idx + 1
314
+
315
+ for i in range(4):
316
+ for j in range(4):
317
+ outcomponents[idx] = wptype(2) * v4result[i, j]
318
+ idx = idx + 1
319
+
320
+ for i in range(5):
321
+ for j in range(5):
322
+ outcomponents[idx] = wptype(2) * v5result[i, j]
323
+ idx = idx + 1
324
+
325
+ kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
326
+
327
+ if register_kernels:
328
+ return
329
+
330
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
331
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
332
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
333
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
334
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
335
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
336
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
337
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
338
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
339
+
340
+ wp.launch(
341
+ kernel,
342
+ dim=1,
343
+ inputs=[
344
+ s2,
345
+ s3,
346
+ s4,
347
+ s5,
348
+ v2,
349
+ v3,
350
+ v4,
351
+ v5,
352
+ ],
353
+ outputs=[outcomponents],
354
+ device=device,
355
+ )
356
+
357
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() - s2.numpy()).reshape(-1), tol=tol)
358
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() - s3.numpy()).reshape(-1), tol=tol)
359
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() - s4.numpy()).reshape(-1), tol=tol)
360
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() - s5.numpy()).reshape(-1), tol=10 * tol)
361
+
362
+ if dtype in np_float_types:
363
+ idx = 0
364
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
365
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
366
+ for i in range(dim):
367
+ for j in range(dim):
368
+ tape = wp.Tape()
369
+ with tape:
370
+ wp.launch(
371
+ kernel,
372
+ dim=1,
373
+ inputs=[s2, s3, s4, s5, v2, v3, v4, v5],
374
+ outputs=[outcomponents],
375
+ device=device,
376
+ )
377
+ wp.launch(
378
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
379
+ )
380
+ tape.backward(loss=out)
381
+ expected_result = np.zeros((dim, dim), dtype=dtype)
382
+ expected_result[i, j] = 2
383
+ assert_np_equal(tape.gradients[in2].numpy()[0], expected_result, tol=10 * tol)
384
+ expected_result[i, j] = -2
385
+ assert_np_equal(tape.gradients[in1].numpy()[0], expected_result, tol=10 * tol)
386
+ tape.zero()
387
+
388
+ idx = idx + 1
389
+
390
+
391
+ def test_determinant(test, device, dtype, register_kernels=False):
392
+ rng = np.random.default_rng(123)
393
+
394
+ tol = {
395
+ np.float16: 5.0e-3,
396
+ np.float32: 1.0e-6,
397
+ np.float64: 1.0e-8,
398
+ }.get(dtype, 0)
399
+
400
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
401
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
402
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
403
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
404
+
405
+ def check_mat_det(
406
+ v2: wp.array(dtype=mat22),
407
+ v3: wp.array(dtype=mat33),
408
+ v4: wp.array(dtype=mat44),
409
+ det2: wp.array(dtype=wptype),
410
+ det3: wp.array(dtype=wptype),
411
+ det4: wp.array(dtype=wptype),
412
+ ):
413
+ # multiply outputs by 2 so we've got something to backpropagate:
414
+ det2[0] = wptype(2) * wp.determinant(v2[0])
415
+ det3[0] = wptype(2) * wp.determinant(v3[0])
416
+ det4[0] = wptype(2) * wp.determinant(v4[0])
417
+
418
+ kernel = getkernel(check_mat_det, suffix=dtype.__name__)
419
+ if register_kernels:
420
+ return
421
+
422
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
423
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
424
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
425
+ det2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
426
+ det3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
427
+ det4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
428
+
429
+ tape = wp.Tape()
430
+ with tape:
431
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4], outputs=[det2, det3, det4], device=device)
432
+
433
+ if dtype in np_float_types:
434
+ assert_np_equal(det2.numpy()[0], 2 * np.linalg.det(v2.numpy()[0].astype(np.float64)), tol=100 * tol)
435
+ assert_np_equal(det3.numpy()[0], 2 * np.linalg.det(v3.numpy()[0].astype(np.float64)), tol=100 * tol)
436
+ assert_np_equal(det4.numpy()[0], 2 * np.linalg.det(v4.numpy()[0].astype(np.float64)), tol=420 * tol)
437
+ else:
438
+ assert_np_equal(det2.numpy()[0], 2 * np.around(np.linalg.det(v2.numpy()[0])).astype(int))
439
+ assert_np_equal(det3.numpy()[0], 2 * np.around(np.linalg.det(v3.numpy()[0])).astype(int))
440
+ assert_np_equal(det4.numpy()[0], 2 * np.around(np.linalg.det(v4.numpy()[0])).astype(int))
441
+
442
+ if dtype in np_float_types:
443
+ # determinant derivative formula is annoying so finite differences?
444
+ tape.backward(loss=det2)
445
+ v2grads = 1.0 * tape.gradients[v2].numpy()[0]
446
+ tape.zero()
447
+
448
+ tape.backward(loss=det3)
449
+ v3grads = 1.0 * tape.gradients[v3].numpy()[0]
450
+ tape.zero()
451
+
452
+ tape.backward(loss=det4)
453
+ v4grads = 1.0 * tape.gradients[v4].numpy()[0]
454
+ tape.zero()
455
+
456
+ # finite differences are also annoying hence the large tolerance...
457
+ # absolute nightmare in float16 too innit...
458
+ dx = 0.01 if dtype == np.float16 else 0.0001
459
+ fdtol = 2.0e-1 if dtype == np.float16 else 2.0e-3
460
+ for i in range(2):
461
+ for j in range(2):
462
+ v2test = v2.numpy()
463
+ v2test[0, i, j] += dx
464
+ wp.launch(
465
+ kernel,
466
+ dim=1,
467
+ inputs=[wp.array(v2test, dtype=v2.dtype, requires_grad=True, device=device), v3, v4],
468
+ outputs=[det2, det3, det4],
469
+ device=device,
470
+ )
471
+ dplus = det2.numpy()[0]
472
+ v2test[0, i, j] -= 2.0 * dx
473
+ wp.launch(
474
+ kernel,
475
+ dim=1,
476
+ inputs=[wp.array(v2test, dtype=v2.dtype, requires_grad=True, device=device), v3, v4],
477
+ outputs=[det2, det3, det4],
478
+ device=device,
479
+ )
480
+ dminus = det2.numpy()[0]
481
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v2grads[i, j] / dplus, tol=fdtol)
482
+
483
+ for i in range(3):
484
+ for j in range(3):
485
+ v3test = v3.numpy()
486
+ v3test[0, i, j] += dx
487
+ wp.launch(
488
+ kernel,
489
+ dim=1,
490
+ inputs=[v2, wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device), v4],
491
+ outputs=[det2, det3, det4],
492
+ device=device,
493
+ )
494
+ dplus = det3.numpy()[0]
495
+ v3test[0, i, j] -= 2.0 * dx
496
+ wp.launch(
497
+ kernel,
498
+ dim=1,
499
+ inputs=[v2, wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device), v4],
500
+ outputs=[det2, det3, det4],
501
+ device=device,
502
+ )
503
+ dminus = det3.numpy()[0]
504
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
505
+
506
+ for i in range(4):
507
+ for j in range(4):
508
+ v4test = v4.numpy()
509
+ v4test[0, i, j] += dx
510
+ wp.launch(
511
+ kernel,
512
+ dim=1,
513
+ inputs=[v2, v3, wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device)],
514
+ outputs=[det2, det3, det4],
515
+ device=device,
516
+ )
517
+ dplus = det4.numpy()[0]
518
+ v4test[0, i, j] -= 2.0 * dx
519
+ wp.launch(
520
+ kernel,
521
+ dim=1,
522
+ inputs=[v2, v3, wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device)],
523
+ outputs=[det2, det3, det4],
524
+ device=device,
525
+ )
526
+ dminus = det4.numpy()[0]
527
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
528
+
529
+
530
+ # Unused. Why?
531
+ # def test_get_diag(test, device, dtype, register_kernels=False):
532
+ # tol = {
533
+ # np.float16: 1.0e-3,
534
+ # np.float32: 1.0e-6,
535
+ # np.float64: 1.0e-8,
536
+ # }.get(dtype, 0)
537
+ #
538
+ # wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
539
+ # mat55 = wp._src.types.vector(shape=(5, 5), dtype=wptype)
540
+ #
541
+ # output_select_kernel = get_select_kernel(wptype)
542
+ #
543
+ # def check_mat_diag(
544
+ # m55: wp.array(dtype=mat55),
545
+ # outcomponents: wp.array(dtype=wptype),
546
+ # ):
547
+ # # multiply outputs by 2 so we've got something to backpropagate:
548
+ # vec5result = wptype(2) * wp.get_diag(m55[0])
549
+ #
550
+ # idx = 0
551
+ # for i in range(5):
552
+ # outcomponents[idx] = vec5result[i]
553
+ # idx = idx + 1
554
+ #
555
+ # kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
556
+ #
557
+ # if register_kernels:
558
+ # return
559
+ #
560
+ # m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
561
+ # outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
562
+ # out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
563
+ #
564
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
565
+ #
566
+ # assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
567
+ #
568
+ # if dtype in np_float_types:
569
+ # idx = 0
570
+ # for i in range(5):
571
+ # tape = wp.Tape()
572
+ # with tape:
573
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
574
+ # wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
575
+ # tape.backward(loss=out)
576
+ # expectedresult = np.zeros((5, 5), dtype=dtype)
577
+ # expectedresult[i, i] = 2
578
+ # assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
579
+ # tape.zero()
580
+ #
581
+ # idx = idx + 1
582
+
583
+
584
+ def test_inverse(test, device, dtype, register_kernels=False):
585
+ rng = np.random.default_rng(123)
586
+
587
+ tol = {
588
+ np.float16: 5.0e-2,
589
+ np.float32: 1.0e-5,
590
+ np.float64: 1.0e-8,
591
+ }.get(dtype, 0)
592
+
593
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
594
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
595
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
596
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
597
+
598
+ output_select_kernel = get_select_kernel(wptype)
599
+
600
+ def check_mat_inverse(
601
+ m2: wp.array(dtype=mat22),
602
+ m3: wp.array(dtype=mat33),
603
+ m4: wp.array(dtype=mat44),
604
+ outcomponents: wp.array(dtype=wptype),
605
+ ):
606
+ m2result = wp.inverse(m2[0])
607
+ m3result = wp.inverse(m3[0])
608
+ m4result = wp.inverse(m4[0])
609
+
610
+ # multiply outputs by 2 so we've got something to backpropagate:
611
+ idx = 0
612
+ for i in range(2):
613
+ for j in range(2):
614
+ outcomponents[idx] = wptype(2) * m2result[i, j]
615
+ idx = idx + 1
616
+
617
+ for i in range(3):
618
+ for j in range(3):
619
+ outcomponents[idx] = wptype(2) * m3result[i, j]
620
+ idx = idx + 1
621
+
622
+ for i in range(4):
623
+ for j in range(4):
624
+ outcomponents[idx] = wptype(2) * m4result[i, j]
625
+ idx = idx + 1
626
+
627
+ kernel = getkernel(check_mat_inverse, suffix=dtype.__name__)
628
+
629
+ if register_kernels:
630
+ return
631
+
632
+ m2 = wp.array(
633
+ 2 * (randvals(rng, [1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device
634
+ )
635
+ m3 = wp.array(
636
+ 2 * (randvals(rng, [1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device
637
+ )
638
+ m4 = wp.array(
639
+ 2 * (randvals(rng, [1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device
640
+ )
641
+
642
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4, dtype=wptype, requires_grad=True, device=device)
643
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
+
645
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
646
+
647
+ assert_np_equal(outcomponents.numpy()[:4], 2 * np.linalg.inv(m2.numpy()[0].astype(np.float64)), tol=tol)
648
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * np.linalg.inv(m3.numpy()[0].astype(np.float64)), tol=5 * tol)
649
+ assert_np_equal(outcomponents.numpy()[13:], 2 * np.linalg.inv(m4.numpy()[0].astype(np.float64)), tol=5 * tol)
650
+
651
+ if dtype in np_float_types:
652
+ # check gradients:
653
+ idx = 0
654
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
655
+ for dim, input in [(2, m2), (3, m3), (4, m4)]:
656
+ minv = np.linalg.inv(input.numpy()[0].astype(np.float64))
657
+ for i in range(dim):
658
+ for j in range(dim):
659
+ tape = wp.Tape()
660
+ with tape:
661
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
662
+ wp.launch(
663
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
664
+ )
665
+ tape.backward(loss=out)
666
+ d = np.zeros((dim, dim))
667
+ d[j, i] = 2
668
+ assert_np_equal(
669
+ tape.gradients[input].numpy()[0], -np.matmul(minv, np.matmul(d, minv)).T, tol=10 * tol
670
+ )
671
+ tape.zero()
672
+
673
+ idx = idx + 1
674
+
675
+ # let's check 2x2 using different formulae just for (in)sanity's sake:
676
+ m = m2.numpy()[0]
677
+
678
+ det = m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]
679
+ expected = 2 * np.array([[m[1, 1], -m[0, 1]], [-m[1, 0], m[0, 0]]], dtype=dtype) / det
680
+ assert_np_equal(expected, outcomponents.numpy()[:4], tol=tol)
681
+
682
+ # 0,0 component is this:
683
+ # 2 * m[1,1] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
684
+ assert_np_equal(2 * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[0], tol=tol)
685
+
686
+ tape = wp.Tape()
687
+ with tape:
688
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
689
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 0], outputs=[out], device=device)
690
+
691
+ if dtype in np_float_types:
692
+ tape.backward(loss=out)
693
+ g = tape.gradients[m2].numpy()[0]
694
+ assert_np_equal(-2 * m[1, 1] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
695
+ assert_np_equal(2 * m[1, 1] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
696
+ assert_np_equal(-2 * m[0, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
697
+ assert_np_equal(2 * m[1, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
698
+ tape.zero()
699
+
700
+ # 0,1 component is this:
701
+ # -2 * m[0,1] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
702
+ assert_np_equal(-2 * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[1], tol=tol)
703
+
704
+ tape = wp.Tape()
705
+ with tape:
706
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
707
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 1], outputs=[out], device=device)
708
+ if dtype in np_float_types:
709
+ tape.backward(loss=out)
710
+ g = tape.gradients[m2].numpy()[0]
711
+ assert_np_equal(2 * m[0, 1] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
712
+ assert_np_equal(-2 * m[0, 1] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
713
+ assert_np_equal(2 * m[0, 0] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
714
+ assert_np_equal(-2 * m[1, 1] * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
715
+ tape.zero()
716
+
717
+ # 1,0 component is this:
718
+ # -2 * m[1,0] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
719
+ assert_np_equal(-2 * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[2], tol=tol)
720
+
721
+ tape = wp.Tape()
722
+ with tape:
723
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
724
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 2], outputs=[out], device=device)
725
+
726
+ if dtype in np_float_types:
727
+ tape.backward(loss=out)
728
+ g = tape.gradients[m2].numpy()[0]
729
+ assert_np_equal(2 * m[1, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
730
+ assert_np_equal(-2 * m[0, 0] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
731
+ assert_np_equal(2 * m[0, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
732
+ assert_np_equal(-2 * m[1, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
733
+ tape.zero()
734
+
735
+ # 1,1 component is this:
736
+ # 2 * m[0,0] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
737
+ assert_np_equal(2 * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[3], tol=tol)
738
+
739
+ tape = wp.Tape()
740
+ with tape:
741
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
742
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 3], outputs=[out], device=device)
743
+
744
+ if dtype in np_float_types:
745
+ tape.backward(loss=out)
746
+ g = tape.gradients[m2].numpy()[0]
747
+ assert_np_equal(-2 * m[0, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
748
+ assert_np_equal(2 * m[0, 0] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
749
+ assert_np_equal(2 * m[0, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
750
+ assert_np_equal(-2 * m[0, 0] * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
751
+ tape.zero()
752
+
753
+
754
+ def test_svd(test, device, dtype, register_kernels=False):
755
+ rng = np.random.default_rng(123)
756
+
757
+ tol = {
758
+ np.float16: 1.0e-3,
759
+ np.float32: 1.0e-6,
760
+ np.float64: 1.0e-12,
761
+ }.get(dtype, 0)
762
+
763
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
764
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
765
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
766
+
767
+ def check_mat_svd(
768
+ m3: wp.array(dtype=mat33),
769
+ Uout: wp.array(dtype=mat33),
770
+ sigmaout: wp.array(dtype=vec3),
771
+ Vout: wp.array(dtype=mat33),
772
+ outcomponents: wp.array(dtype=wptype),
773
+ ):
774
+ U = mat33()
775
+ sigma = vec3()
776
+ V = mat33()
777
+
778
+ wp.svd3(m3[0], U, sigma, V)
779
+
780
+ Uout[0] = U
781
+ sigmaout[0] = sigma
782
+ Vout[0] = V
783
+
784
+ # multiply outputs by 2 so we've got something to backpropagate:
785
+ idx = 0
786
+ for i in range(3):
787
+ for j in range(3):
788
+ outcomponents[idx] = wptype(2) * U[i, j]
789
+ idx = idx + 1
790
+
791
+ for i in range(3):
792
+ outcomponents[idx] = wptype(2) * sigma[i]
793
+ idx = idx + 1
794
+
795
+ for i in range(3):
796
+ for j in range(3):
797
+ outcomponents[idx] = wptype(2) * V[i, j]
798
+ idx = idx + 1
799
+
800
+ kernel = getkernel(check_mat_svd, suffix=dtype.__name__)
801
+
802
+ output_select_kernel = get_select_kernel(wptype)
803
+
804
+ if register_kernels:
805
+ return
806
+
807
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
808
+
809
+ outcomponents = wp.zeros(2 * 3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
810
+ Uout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
811
+ sigmaout = wp.zeros(1, dtype=vec3, requires_grad=True, device=device)
812
+ Vout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
813
+
814
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
815
+
816
+ Uout_np = Uout.numpy()[0].astype(np.float64)
817
+ sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
818
+ Vout_np = Vout.numpy()[0].astype(np.float64)
819
+
820
+ assert_np_equal(
821
+ np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m3.numpy()[0].astype(np.float64), tol=30 * tol
822
+ )
823
+
824
+ if dtype == np.float16:
825
+ # I'm not even going to bother testing the gradients for float16
826
+ # because the rounding errors are terrible...
827
+ return
828
+
829
+ # check gradients:
830
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
831
+ idx = 0
832
+ for idx in range(3 * 3 + 3 + 3 * 3):
833
+ tape = wp.Tape()
834
+ with tape:
835
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
836
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
837
+ tape.backward(out)
838
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
839
+
840
+ tape.zero()
841
+
842
+ dx = 0.0001
843
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
844
+ for ii in range(3):
845
+ for jj in range(3):
846
+ m3test = 1.0 * m3.numpy()
847
+ m3test[0, ii, jj] += dx
848
+ wp.launch(
849
+ kernel,
850
+ dim=1,
851
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
852
+ outputs=[Uout, sigmaout, Vout, outcomponents],
853
+ device=device,
854
+ )
855
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
856
+ plusval = out.numpy()[0]
857
+
858
+ m3test = 1.0 * m3.numpy()
859
+ m3test[0, ii, jj] -= dx
860
+ wp.launch(
861
+ kernel,
862
+ dim=1,
863
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
864
+ outputs=[Uout, sigmaout, Vout, outcomponents],
865
+ device=device,
866
+ )
867
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
868
+ minusval = out.numpy()[0]
869
+
870
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
871
+
872
+
873
+ def test_svd_2D(test, device, dtype, register_kernels=False):
874
+ rng = np.random.default_rng(123)
875
+
876
+ tol = {
877
+ np.float16: 1.0e-3,
878
+ np.float32: 1.0e-6,
879
+ np.float64: 1.0e-12,
880
+ }.get(dtype, 0)
881
+
882
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
883
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
884
+ mat22 = wp._src.types.matrix(shape=(2, 2), dtype=wptype)
885
+
886
+ def check_mat_svd2(
887
+ m2: wp.array(dtype=mat22),
888
+ Uout: wp.array(dtype=mat22),
889
+ sigmaout: wp.array(dtype=vec2),
890
+ Vout: wp.array(dtype=mat22),
891
+ outcomponents: wp.array(dtype=wptype),
892
+ ):
893
+ tid = wp.tid()
894
+
895
+ U = mat22()
896
+ sigma = vec2()
897
+ V = mat22()
898
+
899
+ wp.svd2(m2[tid], U, sigma, V) # Assuming there's a 2D SVD kernel
900
+
901
+ Uout[tid] = U
902
+ sigmaout[tid] = sigma
903
+ Vout[tid] = V
904
+
905
+ # backprop test only for first input
906
+ if tid > 0:
907
+ return
908
+
909
+ # multiply outputs by 2 so we've got something to backpropagate:
910
+ idx = 0
911
+ for i in range(2):
912
+ for j in range(2):
913
+ outcomponents[idx] = wptype(2) * U[i, j]
914
+ idx = idx + 1
915
+
916
+ for i in range(2):
917
+ outcomponents[idx] = wptype(2) * sigma[i]
918
+ idx = idx + 1
919
+
920
+ for i in range(2):
921
+ for j in range(2):
922
+ outcomponents[idx] = wptype(2) * V[i, j]
923
+ idx = idx + 1
924
+
925
+ kernel = getkernel(check_mat_svd2, suffix=dtype.__name__)
926
+
927
+ output_select_kernel = get_select_kernel(wptype)
928
+
929
+ if register_kernels:
930
+ return
931
+
932
+ mats = np.concatenate(
933
+ (
934
+ randvals(rng, [24, 2, 2], dtype) + np.eye(2),
935
+ # rng unlikely to hit edge cases, build them manually
936
+ [
937
+ np.zeros((2, 2)),
938
+ np.eye(2),
939
+ 5.0 * np.eye(2),
940
+ np.array([[1.0, 0.0], [0.0, 0.0]]),
941
+ np.array([[0.0, 0.0], [0.0, 2.0]]),
942
+ np.array([[1.0, 1.0], [-1.0, -1.0]]),
943
+ np.array([[3.0, 0.0], [4.0, 5.0]]),
944
+ np.eye(2) + tol * np.array([[1.0, 1.0], [-1.0, -1.0]]),
945
+ ],
946
+ ),
947
+ axis=0,
948
+ )
949
+ M = len(mats)
950
+ m2 = wp.array(mats, dtype=mat22, requires_grad=True, device=device)
951
+
952
+ outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
953
+ Uout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
954
+ sigmaout = wp.zeros(M, dtype=vec2, requires_grad=True, device=device)
955
+ Vout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
956
+
957
+ wp.launch(kernel, dim=M, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
958
+
959
+ Uout_np = Uout.numpy().astype(np.float64)
960
+ sigmaout_np = sigmaout.numpy().astype(np.float64)
961
+ Vout_np = Vout.numpy().astype(np.float64)
962
+
963
+ USVt_np = Uout_np @ (sigmaout_np[..., None] * np.transpose(Vout_np, axes=(0, 2, 1)))
964
+
965
+ assert_np_equal(
966
+ Uout_np @ np.transpose(Uout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
967
+ )
968
+ assert_np_equal(
969
+ Vout_np @ np.transpose(Vout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
970
+ )
971
+ assert_np_equal(USVt_np, m2.numpy().astype(np.float64), tol=30 * tol)
972
+
973
+ if dtype == np.float16:
974
+ # Skip gradient check for float16 due to rounding errors
975
+ return
976
+
977
+ # Check gradients:
978
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
979
+ idx = 0
980
+ for idx in range(2 * 2 + 2 + 2 * 2):
981
+ tape = wp.Tape()
982
+ with tape:
983
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
984
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
985
+ tape.backward(out)
986
+ m2grads = 1.0 * tape.gradients[m2].numpy()[0]
987
+
988
+ tape.zero()
989
+
990
+ dx = 0.001
991
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
992
+ for ii in range(2):
993
+ for jj in range(2):
994
+ m2test = 1.0 * m2.numpy()
995
+ m2test[0, ii, jj] += dx
996
+ wp.launch(
997
+ kernel,
998
+ dim=1,
999
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1000
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1001
+ device=device,
1002
+ )
1003
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1004
+ plusval = out.numpy()[0]
1005
+
1006
+ m2test = 1.0 * m2.numpy()
1007
+ m2test[0, ii, jj] -= dx
1008
+ wp.launch(
1009
+ kernel,
1010
+ dim=1,
1011
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1012
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1013
+ device=device,
1014
+ )
1015
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1016
+ minusval = out.numpy()[0]
1017
+
1018
+ assert_np_equal((plusval - minusval) / (2 * dx), m2grads[ii, jj], tol=fdtol)
1019
+
1020
+
1021
+ def test_qr(test, device, dtype, register_kernels=False):
1022
+ rng = np.random.default_rng(123)
1023
+
1024
+ tol = {
1025
+ np.float16: 2.5e-3,
1026
+ np.float32: 1.0e-6,
1027
+ np.float64: 1.0e-12,
1028
+ }.get(dtype, 0)
1029
+
1030
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1031
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
1032
+
1033
+ def check_mat_qr(
1034
+ m3: wp.array(dtype=mat33),
1035
+ Qout: wp.array(dtype=mat33),
1036
+ Rout: wp.array(dtype=mat33),
1037
+ outcomponents: wp.array(dtype=wptype),
1038
+ ):
1039
+ Q = mat33()
1040
+ R = mat33()
1041
+
1042
+ wp.qr3(m3[0], Q, R)
1043
+
1044
+ Qout[0] = Q
1045
+ Rout[0] = R
1046
+
1047
+ # multiply outputs by 2 so we've got something to backpropagate:
1048
+ idx = 0
1049
+ for i in range(3):
1050
+ for j in range(3):
1051
+ outcomponents[idx] = wptype(2) * Q[i, j]
1052
+ idx = idx + 1
1053
+
1054
+ for i in range(3):
1055
+ for j in range(3):
1056
+ outcomponents[idx] = wptype(2) * R[i, j]
1057
+ idx = idx + 1
1058
+
1059
+ kernel = getkernel(check_mat_qr, suffix=dtype.__name__)
1060
+ output_select_kernel = get_select_kernel(wptype)
1061
+
1062
+ if register_kernels:
1063
+ return
1064
+
1065
+ m3 = wp.array(0.5 * (randvals(rng, [1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
1066
+
1067
+ outcomponents = wp.zeros(2 * 3 * 3, dtype=wptype, requires_grad=True, device=device)
1068
+ Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1069
+ Rout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1070
+
1071
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, Rout, outcomponents], device=device)
1072
+
1073
+ Qout_np = Qout.numpy()[0].astype(np.float64)
1074
+ Rout_np = Rout.numpy()[0].astype(np.float64)
1075
+
1076
+ # check it's actually a q and an r:
1077
+ assert_np_equal(np.matmul(Qout_np.T, Qout_np), np.eye(3, dtype=np.float64), tol=tol)
1078
+ assert_np_equal(Rout_np[1, [0]], np.zeros(1, dtype=np.float64), tol=tol)
1079
+ assert_np_equal(Rout_np[2, [0, 1]], np.zeros(2, dtype=np.float64), tol=tol)
1080
+
1081
+ # check it's a factorization:
1082
+ assert_np_equal(np.matmul(Qout_np, Rout_np), m3.numpy()[0].astype(np.float64), tol=30 * tol)
1083
+
1084
+ if dtype == np.float16:
1085
+ # I'm not even going to bother testing the gradients for float16
1086
+ # because the rounding errors are terrible...
1087
+ return
1088
+
1089
+ # check gradients:
1090
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1091
+ idx = 0
1092
+ for idx in range(len(outcomponents)):
1093
+ tape = wp.Tape()
1094
+ with tape:
1095
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, Rout, outcomponents], device=device)
1096
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1097
+ tape.backward(out)
1098
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
1099
+
1100
+ tape.zero()
1101
+
1102
+ dx = 0.0001
1103
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1104
+ for ii in range(3):
1105
+ for jj in range(3):
1106
+ m3test = 1.0 * m3.numpy()
1107
+ m3test[0, ii, jj] += dx
1108
+ wp.launch(
1109
+ kernel,
1110
+ dim=1,
1111
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1112
+ outputs=[Qout, Rout, outcomponents],
1113
+ device=device,
1114
+ )
1115
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1116
+ plusval = out.numpy()[0]
1117
+
1118
+ m3test = 1.0 * m3.numpy()
1119
+ m3test[0, ii, jj] -= dx
1120
+ wp.launch(
1121
+ kernel,
1122
+ dim=1,
1123
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1124
+ outputs=[Qout, Rout, outcomponents],
1125
+ device=device,
1126
+ )
1127
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1128
+ minusval = out.numpy()[0]
1129
+
1130
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1131
+
1132
+
1133
+ def test_eig(test, device, dtype, register_kernels=False):
1134
+ rng = np.random.default_rng(123)
1135
+
1136
+ tol = {
1137
+ np.float16: 4.0e-2,
1138
+ np.float32: 1.0e-5,
1139
+ np.float64: 1.0e-5,
1140
+ }.get(dtype, 0)
1141
+
1142
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1143
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1144
+ mat33 = wp._src.types.matrix(shape=(3, 3), dtype=wptype)
1145
+
1146
+ def check_mat_eig(
1147
+ m3: wp.array(dtype=mat33),
1148
+ Qout: wp.array(dtype=mat33),
1149
+ dout: wp.array(dtype=vec3),
1150
+ outcomponents: wp.array(dtype=wptype),
1151
+ ):
1152
+ Q = mat33()
1153
+ d = vec3()
1154
+
1155
+ wp.eig3(m3[0] + wp.transpose(m3[0]), Q, d)
1156
+
1157
+ Qout[0] = Q
1158
+ dout[0] = d
1159
+
1160
+ # multiply outputs by 2 so we've got something to backpropagate:
1161
+ idx = 0
1162
+ for i in range(3):
1163
+ for j in range(3):
1164
+ outcomponents[idx] = wptype(2) * Q[i, j]
1165
+ idx = idx + 1
1166
+
1167
+ for i in range(3):
1168
+ outcomponents[idx] = wptype(2) * d[i]
1169
+ idx = idx + 1
1170
+
1171
+ kernel = getkernel(check_mat_eig, suffix=dtype.__name__)
1172
+ output_select_kernel = get_select_kernel(wptype)
1173
+
1174
+ if register_kernels:
1175
+ return
1176
+
1177
+ m3_np = randvals(rng, [1, 3, 3], dtype) + np.eye(3, dtype=dtype)
1178
+ m3 = wp.array(m3_np, dtype=mat33, requires_grad=True, device=device)
1179
+
1180
+ outcomponents = wp.zeros(3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
1181
+ Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1182
+ dout = wp.zeros(1, dtype=vec3, requires_grad=True, device=device)
1183
+
1184
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, dout, outcomponents], device=device)
1185
+
1186
+ Qout_np = Qout.numpy()[0].astype(np.float64)
1187
+ dout_np = dout.numpy()[0].astype(np.float64)
1188
+ Dout_np = np.diag(dout_np)
1189
+
1190
+ # check Q is orthogonal:
1191
+ assert_np_equal(np.matmul(Qout_np.T, Qout_np), np.eye(3), tol=tol)
1192
+
1193
+ # check Q contains eigenvectors:
1194
+ assert_np_equal(np.matmul(Qout_np, np.matmul(Dout_np, Qout_np.T)), (m3_np[0] + m3_np[0].transpose()), tol=tol)
1195
+
1196
+ if dtype == np.float16:
1197
+ # I'm not even going to bother testing the gradients for float16
1198
+ # because the rounding errors are terrible...
1199
+ return
1200
+
1201
+ # check gradients:
1202
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1203
+ idx = 0
1204
+ for idx in range(len(outcomponents)):
1205
+ tape = wp.Tape()
1206
+ with tape:
1207
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, dout, outcomponents], device=device)
1208
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1209
+ tape.backward(out)
1210
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
1211
+
1212
+ tape.zero()
1213
+
1214
+ dx = 0.0001
1215
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1216
+ for ii in range(3):
1217
+ for jj in range(3):
1218
+ m3test = 1.0 * m3.numpy()
1219
+ m3test[0, ii, jj] += dx
1220
+ wp.launch(
1221
+ kernel,
1222
+ dim=1,
1223
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1224
+ outputs=[Qout, dout, outcomponents],
1225
+ device=device,
1226
+ )
1227
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1228
+ plusval = out.numpy()[0]
1229
+
1230
+ m3test = 1.0 * m3.numpy()
1231
+ m3test[0, ii, jj] -= dx
1232
+ wp.launch(
1233
+ kernel,
1234
+ dim=1,
1235
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1236
+ outputs=[Qout, dout, outcomponents],
1237
+ device=device,
1238
+ )
1239
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1240
+ minusval = out.numpy()[0]
1241
+
1242
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1243
+
1244
+
1245
+ def test_skew(test, device, dtype, register_kernels=False):
1246
+ rng = np.random.default_rng(123)
1247
+
1248
+ tol = {
1249
+ np.float16: 1.0e-3,
1250
+ np.float32: 1.0e-6,
1251
+ np.float64: 1.0e-8,
1252
+ }.get(dtype, 0)
1253
+
1254
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1255
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1256
+
1257
+ output_select_kernel = get_select_kernel(wptype)
1258
+
1259
+ def check_mat_skew(
1260
+ v3: wp.array(dtype=vec3),
1261
+ outcomponents: wp.array(dtype=wptype),
1262
+ ):
1263
+ m3result = wp.skew(v3[0])
1264
+
1265
+ # multiply outputs by 2 so we've got something to backpropagate:
1266
+ idx = 0
1267
+ for i in range(3):
1268
+ for j in range(3):
1269
+ outcomponents[idx] = wptype(2) * m3result[i, j]
1270
+ idx = idx + 1
1271
+
1272
+ kernel = getkernel(check_mat_skew, suffix=dtype.__name__)
1273
+
1274
+ if register_kernels:
1275
+ return
1276
+
1277
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1278
+
1279
+ outcomponents = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1280
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1281
+
1282
+ wp.launch(kernel, dim=1, inputs=[v3], outputs=[outcomponents], device=device)
1283
+
1284
+ # make sure it gives you a cross product matrix:
1285
+ crossprodmat = outcomponents.numpy().reshape(3, 3)
1286
+ assert_np_equal(
1287
+ np.matmul(crossprodmat, np.array([1, 0, 0])).reshape(-1),
1288
+ 2 * np.cross(v3.numpy()[0], np.array([1, 0, 0])),
1289
+ tol=tol,
1290
+ )
1291
+ assert_np_equal(
1292
+ np.matmul(crossprodmat, np.array([0, 1, 0])).reshape(-1),
1293
+ 2 * np.cross(v3.numpy()[0], np.array([0, 1, 0])),
1294
+ tol=tol,
1295
+ )
1296
+ assert_np_equal(
1297
+ np.matmul(crossprodmat, np.array([0, 0, 1])).reshape(-1),
1298
+ 2 * np.cross(v3.numpy()[0], np.array([0, 0, 1])),
1299
+ tol=tol,
1300
+ )
1301
+
1302
+ # check it another way:
1303
+ x0 = v3.numpy()[0, 0]
1304
+ x1 = v3.numpy()[0, 1]
1305
+ x2 = v3.numpy()[0, 2]
1306
+ crossprodmat_expected = np.array(
1307
+ [
1308
+ [0, -x2, x1],
1309
+ [x2, 0, -x0],
1310
+ [-x1, x0, 0],
1311
+ ],
1312
+ dtype=dtype,
1313
+ )
1314
+ assert_np_equal(crossprodmat, 2 * crossprodmat_expected, tol=tol)
1315
+
1316
+ if dtype in np_float_types:
1317
+ idx = 0
1318
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1319
+
1320
+ for i in range(3):
1321
+ for j in range(3):
1322
+ tape = wp.Tape()
1323
+ with tape:
1324
+ wp.launch(kernel, dim=1, inputs=[v3], outputs=[outcomponents], device=device)
1325
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1326
+ tape.backward(loss=out)
1327
+ if i == j:
1328
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.zeros(3))
1329
+ elif [i, j] == [0, 1]:
1330
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 0, -2]))
1331
+ elif [i, j] == [1, 0]:
1332
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 0, 2]))
1333
+ elif [i, j] == [0, 2]:
1334
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 2, 0]))
1335
+ elif [i, j] == [2, 0]:
1336
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, -2, 0]))
1337
+ elif [i, j] == [1, 2]:
1338
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([-2, 0, 0]))
1339
+ elif [i, j] == [2, 1]:
1340
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([2, 0, 0]))
1341
+ tape.zero()
1342
+
1343
+ idx = idx + 1
1344
+
1345
+
1346
+ def test_transform_point(test, device, dtype, register_kernels=False):
1347
+ rng = np.random.default_rng(123)
1348
+
1349
+ tol = {
1350
+ np.float16: 5.0e-3,
1351
+ np.float32: 1.0e-6,
1352
+ np.float64: 1.0e-8,
1353
+ }.get(dtype, 0)
1354
+
1355
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1356
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1357
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
1358
+
1359
+ output_select_kernel = get_select_kernel(wptype)
1360
+
1361
+ def check_mat_transform_point(
1362
+ v3: wp.array(dtype=vec3),
1363
+ m4: wp.array(dtype=mat44),
1364
+ outcomponents: wp.array(dtype=wptype),
1365
+ ):
1366
+ # multiply outputs by 2 so we've got something to backpropagate:
1367
+ presult = wptype(2) * wp.transform_point(m4[0], v3[0])
1368
+
1369
+ outcomponents[0] = presult[0]
1370
+ outcomponents[1] = presult[1]
1371
+ outcomponents[2] = presult[2]
1372
+
1373
+ kernel = getkernel(check_mat_transform_point, suffix=dtype.__name__)
1374
+
1375
+ if register_kernels:
1376
+ return
1377
+
1378
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1379
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1380
+
1381
+ outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1382
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1383
+
1384
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1385
+
1386
+ v3homog = np.ones(4, dtype=dtype)
1387
+ v3homog[:3] = v3.numpy()[0]
1388
+ assert_np_equal(outcomponents.numpy(), 2 * np.matmul(m4.numpy()[0], v3homog)[:3], tol=10 * tol)
1389
+
1390
+ if dtype in np_float_types:
1391
+ for j in range(3):
1392
+ tape = wp.Tape()
1393
+ with tape:
1394
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1395
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, j], outputs=[out], device=device)
1396
+ tape.backward(loss=out)
1397
+
1398
+ assert_np_equal(2 * m4.numpy()[0, j, :3], tape.gradients[v3].numpy(), tol=tol)
1399
+ expected = np.zeros((4, 4), dtype=dtype)
1400
+ expected[j, :3] = 2 * v3.numpy()
1401
+ expected[j, 3] = 2
1402
+ assert_np_equal(tape.gradients[m4].numpy(), expected, tol=tol)
1403
+
1404
+ tape.zero()
1405
+
1406
+
1407
+ def test_transform_vector(test, device, dtype, register_kernels=False):
1408
+ rng = np.random.default_rng(123)
1409
+
1410
+ tol = {
1411
+ np.float16: 5.0e-3,
1412
+ np.float32: 1.0e-6,
1413
+ np.float64: 1.0e-8,
1414
+ }.get(dtype, 0)
1415
+
1416
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1417
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1418
+ mat44 = wp._src.types.matrix(shape=(4, 4), dtype=wptype)
1419
+
1420
+ output_select_kernel = get_select_kernel(wptype)
1421
+
1422
+ def check_mat_transform_vector(
1423
+ v3: wp.array(dtype=vec3),
1424
+ m4: wp.array(dtype=mat44),
1425
+ outcomponents: wp.array(dtype=wptype),
1426
+ ):
1427
+ # multiply outputs by 2 so we've got something to backpropagate:
1428
+ presult = wptype(2) * wp.transform_vector(m4[0], v3[0])
1429
+
1430
+ outcomponents[0] = presult[0]
1431
+ outcomponents[1] = presult[1]
1432
+ outcomponents[2] = presult[2]
1433
+
1434
+ kernel = getkernel(check_mat_transform_vector, suffix=dtype.__name__)
1435
+
1436
+ if register_kernels:
1437
+ return
1438
+
1439
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1440
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1441
+
1442
+ outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1443
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1444
+
1445
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1446
+
1447
+ v3homog = np.zeros(4, dtype=dtype)
1448
+ v3homog[:3] = v3.numpy()[0]
1449
+ assert_np_equal(outcomponents.numpy(), 2 * np.matmul(m4.numpy()[0], v3homog)[:3], tol=10 * tol)
1450
+
1451
+ if dtype in np_float_types:
1452
+ for j in range(3):
1453
+ tape = wp.Tape()
1454
+ with tape:
1455
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1456
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, j], outputs=[out], device=device)
1457
+ tape.backward(loss=out)
1458
+
1459
+ assert_np_equal(2 * m4.numpy()[0, j, :3], tape.gradients[v3].numpy(), tol=tol)
1460
+ expected = np.zeros((4, 4), dtype=dtype)
1461
+ expected[j, :3] = 2 * v3.numpy()
1462
+ assert_np_equal(tape.gradients[m4].numpy(), expected, tol=tol)
1463
+
1464
+ tape.zero()
1465
+
1466
+
1467
+ @wp.kernel
1468
+ def test_matrix_mutation(expected: wp._src.types.matrix(shape=(10, 3), dtype=float)):
1469
+ m = wp.matrix(shape=(10, 3), dtype=float)
1470
+
1471
+ # test direct element indexing
1472
+ m[0, 0] = 1.0
1473
+ m[0][1] = 2.0
1474
+ m[0][2] = 3.0
1475
+
1476
+ # test setting rows
1477
+ for i in range(1, 10):
1478
+ m[i] = m[i - 1] + wp.vec3(1.0, 2.0, 3.0)
1479
+
1480
+ wp.expect_eq(m, expected)
1481
+
1482
+
1483
+ Mat23 = wp.mat((2, 3), dtype=wp.float16)
1484
+
1485
+
1486
+ @wp.kernel(module="unique")
1487
+ def matrix_len_kernel(
1488
+ m1: wp.mat22, m2: wp.mat((3, 3), float), m3: wp.mat((Any, Any), float), m4: Mat23, out: wp.array(dtype=int)
1489
+ ):
1490
+ length = wp.static(len(m1))
1491
+ wp.expect_eq(len(m1), 2)
1492
+ out[0] = len(m1)
1493
+
1494
+ length = len(m2)
1495
+ wp.expect_eq(wp.static(len(m2)), 3)
1496
+ out[1] = len(m2)
1497
+
1498
+ length = len(m3)
1499
+ wp.expect_eq(len(m3), 4)
1500
+ out[2] = wp.static(len(m3))
1501
+
1502
+ length = wp.static(len(m4))
1503
+ wp.expect_eq(wp.static(len(m4)), 2)
1504
+ out[3] = wp.static(len(m4))
1505
+
1506
+ foo = wp.mat22()
1507
+ length = len(foo)
1508
+ wp.expect_eq(len(foo), 2)
1509
+ out[4] = len(foo)
1510
+
1511
+
1512
+ def test_matrix_len(test, device):
1513
+ m1 = wp.mat22()
1514
+ m2 = wp.mat33()
1515
+ m3 = wp.mat44()
1516
+ m4 = Mat23()
1517
+ out = wp.empty(5, dtype=int, device=device)
1518
+ wp.launch(matrix_len_kernel, dim=(1,), inputs=(m1, m2, m3, m4), outputs=(out,), device=device)
1519
+
1520
+ test.assertEqual(out.numpy()[0], 2)
1521
+ test.assertEqual(out.numpy()[1], 3)
1522
+ test.assertEqual(out.numpy()[2], 4)
1523
+ test.assertEqual(out.numpy()[3], 2)
1524
+ test.assertEqual(out.numpy()[4], 2)
1525
+
1526
+ test.assertEqual(len(m1), 2)
1527
+ test.assertEqual(len(m2), 3)
1528
+ test.assertEqual(len(m3), 4)
1529
+ test.assertEqual(len(m4), 2)
1530
+
1531
+
1532
+ @wp.kernel
1533
+ def mat_extract_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=float)):
1534
+ tid = wp.tid()
1535
+
1536
+ a = x[tid]
1537
+ b = a[0, 0] + 2.0 * a[0, 1] + 3.0 * a[1, 0] + 4.0 * a[1, 1]
1538
+ y[tid] = b
1539
+
1540
+
1541
+ @wp.kernel
1542
+ def mat_extract_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.vec2)):
1543
+ tid = wp.tid()
1544
+
1545
+ a = x[tid]
1546
+ b = a[0] + 2.0 * a[1]
1547
+ y[tid] = b
1548
+
1549
+
1550
+ def test_mat_extract(test, device):
1551
+ # matrix element
1552
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1553
+ y = wp.zeros(1, dtype=float, requires_grad=True, device=device)
1554
+
1555
+ tape = wp.Tape()
1556
+ with tape:
1557
+ wp.launch(mat_extract_element, 1, inputs=[x], outputs=[y], device=device)
1558
+
1559
+ y.grad = wp.ones_like(y)
1560
+ tape.backward()
1561
+
1562
+ assert_np_equal(y.numpy(), np.array([10.0], dtype=float))
1563
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1564
+
1565
+ # matrix row
1566
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1567
+ y = wp.zeros(1, dtype=wp.vec2, requires_grad=True, device=device)
1568
+
1569
+ tape = wp.Tape()
1570
+ with tape:
1571
+ wp.launch(mat_extract_row, 1, inputs=[x], outputs=[y], device=device)
1572
+
1573
+ y.grad = wp.ones_like(y)
1574
+ tape.backward()
1575
+
1576
+ assert_np_equal(y.numpy(), np.array([[3.0, 3.0]], dtype=float))
1577
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1578
+
1579
+
1580
+ @wp.kernel
1581
+ def mat_assign_element(x: wp.array(dtype=float), y: wp.array(dtype=wp.mat22)):
1582
+ i = wp.tid()
1583
+
1584
+ a = wp.mat22()
1585
+ a[0, 0] = 1.0 * x[i]
1586
+ a[0, 1] = 2.0 * x[i]
1587
+ a[1, 0] = 3.0 * x[i]
1588
+ a[1, 1] = 4.0 * x[i]
1589
+
1590
+ y[i] = a
1591
+
1592
+
1593
+ @wp.kernel
1594
+ def mat_assign_row(x: wp.array(dtype=wp.vec2), y: wp.array(dtype=wp.mat22)):
1595
+ i = wp.tid()
1596
+
1597
+ a = wp.mat22()
1598
+ a[0] = 1.0 * x[i]
1599
+ a[1] = 2.0 * x[i]
1600
+
1601
+ y[i] = a
1602
+
1603
+
1604
+ def test_mat_assign(test, device):
1605
+ # matrix element
1606
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
1607
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1608
+
1609
+ tape = wp.Tape()
1610
+ with tape:
1611
+ wp.launch(mat_assign_element, 1, inputs=[x], outputs=[y], device=device)
1612
+
1613
+ y.grad = wp.ones_like(y)
1614
+ tape.backward()
1615
+
1616
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1617
+ assert_np_equal(x.grad.numpy(), np.array([10.0], dtype=float))
1618
+
1619
+ # matrix row
1620
+ x = wp.ones(1, dtype=wp.vec2, requires_grad=True, device=device)
1621
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1622
+
1623
+ tape = wp.Tape()
1624
+ with tape:
1625
+ wp.launch(mat_assign_row, 1, inputs=[x], outputs=[y], device=device)
1626
+
1627
+ y.grad = wp.ones_like(y)
1628
+ tape.backward()
1629
+
1630
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1631
+ assert_np_equal(x.grad.numpy(), np.array([[3.0, 3.0]], dtype=float))
1632
+
1633
+
1634
+ @wp.kernel
1635
+ def mat_array_extract_element(x: wp.array2d(dtype=wp.mat22), y: wp.array2d(dtype=float)):
1636
+ i, j = wp.tid()
1637
+ a = x[i, j][0, 0]
1638
+ b = x[i, j][0, 1]
1639
+ c = x[i, j][1, 0]
1640
+ d = x[i, j][1, 1]
1641
+ y[i, j] = 1.0 * a + 2.0 * b + 3.0 * c + 4.0 * d
1642
+
1643
+
1644
+ @wp.kernel
1645
+ def mat_array_extract_row(x: wp.array2d(dtype=wp.mat22), y: wp.array2d(dtype=wp.vec2)):
1646
+ i, j = wp.tid()
1647
+ a = x[i, j][0]
1648
+ b = x[i, j][1]
1649
+ y[i, j] = 1.0 * a + 2.0 * b
1650
+
1651
+
1652
+ def test_mat_array_extract(test, device):
1653
+ # matrix element
1654
+ x = wp.ones((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
1655
+ y = wp.zeros((1, 1), dtype=float, requires_grad=True, device=device)
1656
+
1657
+ tape = wp.Tape()
1658
+ with tape:
1659
+ wp.launch(mat_array_extract_element, (1, 1), inputs=[x], outputs=[y], device=device)
1660
+
1661
+ y.grad = wp.ones_like(y)
1662
+ tape.backward()
1663
+
1664
+ assert_np_equal(y.numpy(), np.array([[10.0]], dtype=float))
1665
+ assert_np_equal(x.grad.numpy(), np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float))
1666
+
1667
+ # matrix row
1668
+ x = wp.ones((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
1669
+ y = wp.zeros((1, 1), dtype=wp.vec2, requires_grad=True, device=device)
1670
+
1671
+ tape = wp.Tape()
1672
+ with tape:
1673
+ wp.launch(mat_array_extract_row, (1, 1), inputs=[x], outputs=[y], device=device)
1674
+
1675
+ y.grad = wp.ones_like(y)
1676
+ tape.backward()
1677
+
1678
+ assert_np_equal(y.numpy(), np.array([[[3.0, 3.0]]], dtype=float))
1679
+ assert_np_equal(x.grad.numpy(), np.array([[[[1.0, 1.0], [2.0, 2.0]]]], dtype=float))
1680
+
1681
+
1682
+ """ TODO: gradient propagation for in-place array assignment
1683
+ @wp.kernel
1684
+ def mat_array_assign_element(x: wp.array2d(dtype=float), y: wp.array2d(dtype=wp.mat22)):
1685
+ i, j = wp.tid()
1686
+
1687
+ y[i, j][0, 0] = 1.0 * x[i, j]
1688
+ y[i, j][0, 1] = 2.0 * x[i, j]
1689
+ y[i, j][1, 0] = 3.0 * x[i, j]
1690
+ y[i, j][1, 1] = 4.0 * x[i, j]
1691
+
1692
+
1693
+ @wp.kernel
1694
+ def mat_array_assign_row(x: wp.array2d(dtype=wp.vec3), y: wp.array2d(dtype=wp.mat(shape=(2, 3), dtype=float))):
1695
+ i, j = wp.tid()
1696
+
1697
+ y[i, j][0] = 1.0 * x[i, j]
1698
+ y[i, j][1] = 2.0 * x[i, j]
1699
+
1700
+
1701
+ def test_mat_array_assign(test, device):
1702
+ # matrix element
1703
+ x = wp.ones((1, 1), dtype=float, requires_grad=True, device=device)
1704
+ y = wp.zeros((1, 1), dtype=wp.mat22, requires_grad=True, device=device)
1705
+
1706
+ tape = wp.Tape()
1707
+ with tape:
1708
+ wp.launch(mat_array_assign_element, (1, 1), inputs=[x], outputs=[y], device=device)
1709
+
1710
+ y.grad = wp.ones_like(y)
1711
+ tape.backward()
1712
+
1713
+ assert_np_equal(y.numpy(), np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float))
1714
+ assert_np_equal(x.grad.numpy(), np.array([[10.0]], dtype=float))
1715
+
1716
+ # matrix row
1717
+ x = wp.ones((1, 1), dtype=wp.vec3, requires_grad=True, device=device)
1718
+ y = wp.zeros((1, 1), dtype=wp.mat(shape=(2, 3), dtype=float), requires_grad=True, device=device)
1719
+
1720
+ tape = wp.Tape()
1721
+ with tape:
1722
+ wp.launch(mat_array_assign_row, (1, 1), inputs=[x], outputs=[y], device=device)
1723
+
1724
+ y.grad = wp.ones_like(y)
1725
+ tape.backward()
1726
+
1727
+ assert_np_equal(y.numpy(), np.array([[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]], dtype=float))
1728
+ assert_np_equal(x.grad.numpy(), np.array([[[3.0, 3.0, 3.0]]], dtype=float))
1729
+ """
1730
+
1731
+
1732
+ @wp.kernel
1733
+ def mat_add_inplace_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1734
+ i = wp.tid()
1735
+
1736
+ a = wp.mat22()
1737
+ b = x[i]
1738
+
1739
+ a[0, 0] += 1.0 * b[0, 0]
1740
+ a[0, 1] += 2.0 * b[0, 1]
1741
+ a[1, 0] += 3.0 * b[1, 0]
1742
+ a[1, 1] += 4.0 * b[1, 1]
1743
+
1744
+ y[i] = a
1745
+
1746
+
1747
+ @wp.kernel
1748
+ def mat_add_inplace_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1749
+ i = wp.tid()
1750
+
1751
+ a = wp.mat22()
1752
+ b = x[i]
1753
+
1754
+ a[0] += 1.0 * b[0]
1755
+ a[1] += 2.0 * b[1]
1756
+
1757
+ y[i] = a
1758
+
1759
+
1760
+ def test_mat_add_inplace(test, device):
1761
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1762
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1763
+
1764
+ tape = wp.Tape()
1765
+ with tape:
1766
+ wp.launch(mat_add_inplace_element, 1, inputs=[x], outputs=[y], device=device)
1767
+
1768
+ y.grad = wp.ones_like(y)
1769
+ tape.backward()
1770
+
1771
+ assert_np_equal(y.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1772
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=float))
1773
+
1774
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1775
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1776
+
1777
+ tape = wp.Tape()
1778
+ with tape:
1779
+ wp.launch(mat_add_inplace_row, 1, inputs=[x], outputs=[y], device=device)
1780
+
1781
+ y.grad = wp.ones_like(y)
1782
+ tape.backward()
1783
+
1784
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1785
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [2.0, 2.0]]], dtype=float))
1786
+
1787
+
1788
+ @wp.kernel
1789
+ def mat_sub_inplace_element(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1790
+ i = wp.tid()
1791
+
1792
+ a = wp.mat22()
1793
+ b = x[i]
1794
+
1795
+ a[0, 0] -= 1.0 * b[0, 0]
1796
+ a[0, 1] -= 2.0 * b[0, 1]
1797
+ a[1, 0] -= 3.0 * b[1, 0]
1798
+ a[1, 1] -= 4.0 * b[1, 1]
1799
+
1800
+ y[i] = a
1801
+
1802
+
1803
+ @wp.kernel
1804
+ def mat_sub_inplace_row(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1805
+ i = wp.tid()
1806
+
1807
+ a = wp.mat22()
1808
+ b = x[i]
1809
+
1810
+ a[0] -= 1.0 * b[0]
1811
+ a[1] -= 2.0 * b[1]
1812
+
1813
+ y[i] = a
1814
+
1815
+
1816
+ def test_mat_sub_inplace(test, device):
1817
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1818
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1819
+
1820
+ tape = wp.Tape()
1821
+ with tape:
1822
+ wp.launch(mat_sub_inplace_element, 1, inputs=[x], outputs=[y], device=device)
1823
+
1824
+ y.grad = wp.ones_like(y)
1825
+ tape.backward()
1826
+
1827
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -2.0], [-3.0, -4.0]]], dtype=float))
1828
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -2.0], [-3.0, -4.0]]], dtype=float))
1829
+
1830
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1831
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1832
+
1833
+ tape = wp.Tape()
1834
+ with tape:
1835
+ wp.launch(mat_sub_inplace_row, 1, inputs=[x], outputs=[y], device=device)
1836
+
1837
+ y.grad = wp.ones_like(y)
1838
+ tape.backward()
1839
+
1840
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -1.0], [-2.0, -2.0]]], dtype=float))
1841
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -1.0], [-2.0, -2.0]]], dtype=float))
1842
+
1843
+
1844
+ @wp.kernel
1845
+ def mat_array_add_inplace(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1846
+ i = wp.tid()
1847
+
1848
+ y[i] += x[i]
1849
+
1850
+
1851
+ def test_mat_array_add_inplace(test, device):
1852
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1853
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1854
+
1855
+ tape = wp.Tape()
1856
+ with tape:
1857
+ wp.launch(mat_array_add_inplace, 1, inputs=[x], outputs=[y], device=device)
1858
+
1859
+ y.grad = wp.ones_like(y)
1860
+ tape.backward()
1861
+
1862
+ assert_np_equal(y.numpy(), np.array([[[1.0, 1.0], [1.0, 1.0]]], dtype=float))
1863
+ assert_np_equal(x.grad.numpy(), np.array([[[1.0, 1.0], [1.0, 1.0]]], dtype=float))
1864
+
1865
+
1866
+ @wp.kernel
1867
+ def mat_array_sub_inplace(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1868
+ i = wp.tid()
1869
+
1870
+ y[i] -= x[i]
1871
+
1872
+
1873
+ def test_mat_array_sub_inplace(test, device):
1874
+ x = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1875
+ y = wp.zeros(1, dtype=wp.mat22, requires_grad=True, device=device)
1876
+
1877
+ tape = wp.Tape()
1878
+ with tape:
1879
+ wp.launch(mat_array_sub_inplace, 1, inputs=[x], outputs=[y], device=device)
1880
+
1881
+ y.grad = wp.ones_like(y)
1882
+ tape.backward()
1883
+
1884
+ assert_np_equal(y.numpy(), np.array([[[-1.0, -1.0], [-1.0, -1.0]]], dtype=float))
1885
+ assert_np_equal(x.grad.numpy(), np.array([[[-1.0, -1.0], [-1.0, -1.0]]], dtype=float))
1886
+
1887
+
1888
+ @wp.kernel
1889
+ def scalar_mat_div(x: wp.array(dtype=wp.mat22), y: wp.array(dtype=wp.mat22)):
1890
+ i = wp.tid()
1891
+ y[i] = 1.0 / x[i]
1892
+
1893
+
1894
+ def test_scalar_mat_div(test, device):
1895
+ x = wp.array((wp.mat22(1.0, 2.0, 4.0, 8.0),), dtype=wp.mat22, requires_grad=True, device=device)
1896
+ y = wp.ones(1, dtype=wp.mat22, requires_grad=True, device=device)
1897
+
1898
+ tape = wp.Tape()
1899
+ with tape:
1900
+ wp.launch(scalar_mat_div, 1, inputs=(x,), outputs=(y,), device=device)
1901
+
1902
+ y.grad = wp.ones_like(y)
1903
+ tape.backward()
1904
+
1905
+ assert_np_equal(y.numpy(), np.array((((1.0, 0.5), (0.25, 0.125)),), dtype=float))
1906
+ assert_np_equal(x.grad.numpy(), np.array((((-1.0, -0.25), (-0.0625, -0.015625)),), dtype=float))
1907
+
1908
+
1909
+ def test_mat_from_rows_indexing_assign(test, device):
1910
+ @wp.func
1911
+ def fn():
1912
+ m = wp.matrix_from_rows(
1913
+ wp.vec2(1.0, 2.0),
1914
+ wp.vec2(3.0, 4.0),
1915
+ wp.vec2(5.0, 6.0),
1916
+ )
1917
+
1918
+ m[0] = wp.vec2(123.0, 234.0)
1919
+ m[1] *= 2.0
1920
+
1921
+ wp.expect_eq(m[0], wp.vec2(123.0, 234.0))
1922
+ wp.expect_eq(m[1], wp.vec2(6.0, 8.0))
1923
+ wp.expect_eq(m[2], wp.vec2(5.0, 6.0))
1924
+
1925
+ m[-1] = wp.vec2(123.0, 234.0)
1926
+ m[-2] *= 2.0
1927
+
1928
+ wp.expect_eq(m[-1], wp.vec2(123.0, 234.0))
1929
+ wp.expect_eq(m[-2], wp.vec2(12.0, 16.0))
1930
+ wp.expect_eq(m[-3], wp.vec2(123.0, 234.0))
1931
+
1932
+ m[0, 0] = 345.0
1933
+ m[1, 0] *= 2.0
1934
+
1935
+ wp.expect_eq(m[0, 0], 345.0)
1936
+ wp.expect_eq(m[0, 1], 234.0)
1937
+ wp.expect_eq(m[1, 0], 24.0)
1938
+ wp.expect_eq(m[1, 1], 16.0)
1939
+ wp.expect_eq(m[2, 0], 123.0)
1940
+ wp.expect_eq(m[2, 1], 234.0)
1941
+
1942
+ m[-1, -1] = 345.0
1943
+ m[-2, -1] *= 2.0
1944
+
1945
+ wp.expect_eq(m[-1, -1], 345.0)
1946
+ wp.expect_eq(m[-1, -2], 123.0)
1947
+ wp.expect_eq(m[-2, -1], 32.0)
1948
+ wp.expect_eq(m[-2, -2], 24.0)
1949
+ wp.expect_eq(m[-3, -1], 234.0)
1950
+ wp.expect_eq(m[-3, -2], 345.0)
1951
+
1952
+ m[0, 1] = 456.0
1953
+ m[1, 1] *= 2.0
1954
+
1955
+ wp.expect_eq(m[0][0], 345.0)
1956
+ wp.expect_eq(m[0][1], 456.0)
1957
+ wp.expect_eq(m[1][0], 24.0)
1958
+ wp.expect_eq(m[1][1], 64.0)
1959
+ wp.expect_eq(m[2][0], 123.0)
1960
+ wp.expect_eq(m[2][1], 345.0)
1961
+
1962
+ m[-1, -2] = 456.0
1963
+ m[-2, -2] *= 2.0
1964
+
1965
+ wp.expect_eq(m[-1][-1], 345.0)
1966
+ wp.expect_eq(m[-1][-2], 456.0)
1967
+ wp.expect_eq(m[-2][-1], 64.0)
1968
+ wp.expect_eq(m[-2][-2], 48.0)
1969
+ wp.expect_eq(m[-3][-1], 456.0)
1970
+ wp.expect_eq(m[-3][-2], 345.0)
1971
+
1972
+ @wp.kernel(module="unique")
1973
+ def kernel():
1974
+ fn()
1975
+
1976
+ wp.launch(kernel, 1, device=device)
1977
+ wp.synchronize()
1978
+ fn()
1979
+
1980
+
1981
+ def test_mat_from_cols_indexing_assign(test, device):
1982
+ @wp.func
1983
+ def fn():
1984
+ m = wp.matrix_from_cols(
1985
+ wp.vec2(1.0, 2.0),
1986
+ wp.vec2(3.0, 4.0),
1987
+ wp.vec2(5.0, 6.0),
1988
+ )
1989
+
1990
+ m[0] = wp.vec3(123.0, 234.0, 345.0)
1991
+ m[1] *= 2.0
1992
+
1993
+ wp.expect_eq(m[0], wp.vec3(123.0, 234.0, 345.0))
1994
+ wp.expect_eq(m[1], wp.vec3(4.0, 8.0, 12.0))
1995
+
1996
+ m[-1] = wp.vec3(123.0, 234.0, 345.0)
1997
+ m[-2] *= 2.0
1998
+
1999
+ wp.expect_eq(m[-1], wp.vec3(123.0, 234.0, 345.0))
2000
+ wp.expect_eq(m[-2], wp.vec3(246.0, 468.0, 690.0))
2001
+
2002
+ m[0, 0] = 456.0
2003
+ m[1, 0] *= 2.0
2004
+
2005
+ wp.expect_eq(m[0, 0], 456.0)
2006
+ wp.expect_eq(m[0, 1], 468.0)
2007
+ wp.expect_eq(m[0, 2], 690.0)
2008
+ wp.expect_eq(m[1, 0], 246.0)
2009
+ wp.expect_eq(m[1, 1], 234.0)
2010
+ wp.expect_eq(m[1, 2], 345.0)
2011
+
2012
+ m[-1, -1] = 456.0
2013
+ m[-2, -1] *= 2.0
2014
+
2015
+ wp.expect_eq(m[-1, -1], 456.0)
2016
+ wp.expect_eq(m[-1, -2], 234.0)
2017
+ wp.expect_eq(m[-1, -3], 246.0)
2018
+ wp.expect_eq(m[-2, -1], 1380.0)
2019
+ wp.expect_eq(m[-2, -2], 468.0)
2020
+ wp.expect_eq(m[-2, -3], 456.0)
2021
+
2022
+ m[0, 1] = 567.0
2023
+ m[1, 1] *= 2.0
2024
+
2025
+ wp.expect_eq(m[0][0], 456.0)
2026
+ wp.expect_eq(m[0][1], 567.0)
2027
+ wp.expect_eq(m[0][2], 1380.0)
2028
+ wp.expect_eq(m[1][0], 246.0)
2029
+ wp.expect_eq(m[1][1], 468.0)
2030
+ wp.expect_eq(m[1][2], 456.0)
2031
+
2032
+ m[-1, -2] = 567.0
2033
+ m[-2, -2] *= 2.0
2034
+
2035
+ wp.expect_eq(m[-1][-1], 456.0)
2036
+ wp.expect_eq(m[-1][-2], 567.0)
2037
+ wp.expect_eq(m[-1][-3], 246.0)
2038
+ wp.expect_eq(m[-2][-1], 1380.0)
2039
+ wp.expect_eq(m[-2][-2], 1134.0)
2040
+ wp.expect_eq(m[-2][-3], 456.0)
2041
+
2042
+ @wp.kernel(module="unique")
2043
+ def kernel():
2044
+ fn()
2045
+
2046
+ wp.launch(kernel, 1, device=device)
2047
+ wp.synchronize()
2048
+ fn()
2049
+
2050
+
2051
+ def test_mat_from_rows_slicing_assign(test, device):
2052
+ mat00 = wp.mat((0, 0), float)
2053
+ vec1 = wp.vec(1, float)
2054
+ vec2 = wp.vec(2, float)
2055
+ vec3 = wp.vec(3, float)
2056
+ vec4 = wp.vec(4, float)
2057
+
2058
+ @wp.func
2059
+ def fn():
2060
+ m = wp.matrix_from_rows(
2061
+ vec4(1.0, 2.0, 3.0, 4.0),
2062
+ vec4(5.0, 6.0, 7.0, 8.0),
2063
+ vec4(9.0, 10.0, 11.0, 12.0),
2064
+ vec4(13.0, 14.0, 15.0, 16.0),
2065
+ )
2066
+
2067
+ wp.expect_eq(
2068
+ m[:]
2069
+ == wp.matrix_from_rows(
2070
+ vec4(1.0, 2.0, 3.0, 4.0),
2071
+ vec4(5.0, 6.0, 7.0, 8.0),
2072
+ vec4(9.0, 10.0, 11.0, 12.0),
2073
+ vec4(13.0, 14.0, 15.0, 16.0),
2074
+ ),
2075
+ True,
2076
+ )
2077
+ wp.expect_eq(
2078
+ m[-123:123]
2079
+ == wp.matrix_from_rows(
2080
+ vec4(1.0, 2.0, 3.0, 4.0),
2081
+ vec4(5.0, 6.0, 7.0, 8.0),
2082
+ vec4(9.0, 10.0, 11.0, 12.0),
2083
+ vec4(13.0, 14.0, 15.0, 16.0),
2084
+ ),
2085
+ True,
2086
+ )
2087
+ wp.expect_eq(m[123:] == mat00(), True)
2088
+ wp.expect_eq(m[:-123] == mat00(), True)
2089
+ wp.expect_eq(
2090
+ m[::123]
2091
+ == wp.matrix_from_rows(
2092
+ vec4(1.0, 2.0, 3.0, 4.0),
2093
+ ),
2094
+ True,
2095
+ )
2096
+
2097
+ wp.expect_eq(
2098
+ m[1:]
2099
+ == wp.matrix_from_rows(
2100
+ vec4(5.0, 6.0, 7.0, 8.0),
2101
+ vec4(9.0, 10.0, 11.0, 12.0),
2102
+ vec4(13.0, 14.0, 15.0, 16.0),
2103
+ ),
2104
+ True,
2105
+ )
2106
+ wp.expect_eq(
2107
+ m[-2:]
2108
+ == wp.matrix_from_rows(
2109
+ vec4(9.0, 10.0, 11.0, 12.0),
2110
+ vec4(13.0, 14.0, 15.0, 16.0),
2111
+ ),
2112
+ True,
2113
+ )
2114
+ wp.expect_eq(
2115
+ m[:2]
2116
+ == wp.matrix_from_rows(
2117
+ vec4(1.0, 2.0, 3.0, 4.0),
2118
+ vec4(5.0, 6.0, 7.0, 8.0),
2119
+ ),
2120
+ True,
2121
+ )
2122
+ wp.expect_eq(
2123
+ m[:-1]
2124
+ == wp.matrix_from_rows(
2125
+ vec4(1.0, 2.0, 3.0, 4.0),
2126
+ vec4(5.0, 6.0, 7.0, 8.0),
2127
+ vec4(9.0, 10.0, 11.0, 12.0),
2128
+ ),
2129
+ True,
2130
+ )
2131
+ wp.expect_eq(
2132
+ m[::2]
2133
+ == wp.matrix_from_rows(
2134
+ vec4(1.0, 2.0, 3.0, 4.0),
2135
+ vec4(9.0, 10.0, 11.0, 12.0),
2136
+ ),
2137
+ True,
2138
+ )
2139
+ wp.expect_eq(
2140
+ m[1::2]
2141
+ == wp.matrix_from_rows(
2142
+ vec4(5.0, 6.0, 7.0, 8.0),
2143
+ vec4(13.0, 14.0, 15.0, 16.0),
2144
+ ),
2145
+ True,
2146
+ )
2147
+ wp.expect_eq(
2148
+ m[::-1]
2149
+ == wp.matrix_from_rows(
2150
+ vec4(13.0, 14.0, 15.0, 16.0),
2151
+ vec4(9.0, 10.0, 11.0, 12.0),
2152
+ vec4(5.0, 6.0, 7.0, 8.0),
2153
+ vec4(1.0, 2.0, 3.0, 4.0),
2154
+ ),
2155
+ True,
2156
+ )
2157
+ wp.expect_eq(
2158
+ m[::-2]
2159
+ == wp.matrix_from_rows(
2160
+ vec4(13.0, 14.0, 15.0, 16.0),
2161
+ vec4(5.0, 6.0, 7.0, 8.0),
2162
+ ),
2163
+ True,
2164
+ )
2165
+ wp.expect_eq(
2166
+ m[1::-2]
2167
+ == wp.matrix_from_rows(
2168
+ vec4(5.0, 6.0, 7.0, 8.0),
2169
+ ),
2170
+ True,
2171
+ )
2172
+
2173
+ wp.expect_eq(
2174
+ m[:, :]
2175
+ == wp.matrix_from_rows(
2176
+ vec4(1.0, 2.0, 3.0, 4.0),
2177
+ vec4(5.0, 6.0, 7.0, 8.0),
2178
+ vec4(9.0, 10.0, 11.0, 12.0),
2179
+ vec4(13.0, 14.0, 15.0, 16.0),
2180
+ ),
2181
+ True,
2182
+ )
2183
+ wp.expect_eq(
2184
+ m[:, 2:]
2185
+ == wp.matrix_from_rows(
2186
+ vec2(3.0, 4.0),
2187
+ vec2(7.0, 8.0),
2188
+ vec2(11.0, 12.0),
2189
+ vec2(15.0, 16.0),
2190
+ ),
2191
+ True,
2192
+ )
2193
+ wp.expect_eq(
2194
+ m[1:, 2:]
2195
+ == wp.matrix_from_rows(
2196
+ vec2(7.0, 8.0),
2197
+ vec2(11.0, 12.0),
2198
+ vec2(15.0, 16.0),
2199
+ ),
2200
+ True,
2201
+ )
2202
+ wp.expect_eq(
2203
+ m[-2:, 2:]
2204
+ == wp.matrix_from_rows(
2205
+ vec2(11.0, 12.0),
2206
+ vec2(15.0, 16.0),
2207
+ ),
2208
+ True,
2209
+ )
2210
+ wp.expect_eq(
2211
+ m[2:, -2:]
2212
+ == wp.matrix_from_rows(
2213
+ vec2(11.0, 12.0),
2214
+ vec2(15.0, 16.0),
2215
+ ),
2216
+ True,
2217
+ )
2218
+ wp.expect_eq(
2219
+ m[1:, :2]
2220
+ == wp.matrix_from_rows(
2221
+ vec2(5.0, 6.0),
2222
+ vec2(9.0, 10.0),
2223
+ vec2(13.0, 14.0),
2224
+ ),
2225
+ True,
2226
+ )
2227
+ wp.expect_eq(
2228
+ m[:1, 2:]
2229
+ == wp.matrix_from_rows(
2230
+ vec2(3.0, 4.0),
2231
+ ),
2232
+ True,
2233
+ )
2234
+ wp.expect_eq(
2235
+ m[::-1, :1]
2236
+ == wp.matrix_from_rows(
2237
+ vec1(13.0),
2238
+ vec1(9.0),
2239
+ vec1(5.0),
2240
+ vec1(1.0),
2241
+ ),
2242
+ True,
2243
+ )
2244
+ wp.expect_eq(
2245
+ m[:1, ::-1]
2246
+ == wp.matrix_from_rows(
2247
+ vec4(4.0, 3.0, 2.0, 1.0),
2248
+ ),
2249
+ True,
2250
+ )
2251
+ wp.expect_eq(
2252
+ m[:1:-1, 2::-1]
2253
+ == wp.matrix_from_rows(
2254
+ vec3(15.0, 14.0, 13.0),
2255
+ vec3(11.0, 10.0, 9.0),
2256
+ ),
2257
+ True,
2258
+ )
2259
+
2260
+ wp.expect_eq(m[:2, 0] == vec2(1.0, 5.0), True)
2261
+ wp.expect_eq(m[2:, 1] == vec2(10.0, 14.0), True)
2262
+ wp.expect_eq(m[0, :3] == vec3(1.0, 2.0, 3.0), True)
2263
+ wp.expect_eq(m[1, 1:] == vec3(6.0, 7.0, 8.0), True)
2264
+
2265
+ m[1:] = wp.matrix_from_rows(
2266
+ vec4(17.0, 18.0, 19.0, 20.0),
2267
+ vec4(21.0, 22.0, 23.0, 24.0),
2268
+ vec4(25.0, 26.0, 27.0, 28.0),
2269
+ )
2270
+ wp.expect_eq(
2271
+ m
2272
+ == wp.matrix_from_rows(
2273
+ vec4(1.0, 2.0, 3.0, 4.0),
2274
+ vec4(17.0, 18.0, 19.0, 20.0),
2275
+ vec4(21.0, 22.0, 23.0, 24.0),
2276
+ vec4(25.0, 26.0, 27.0, 28.0),
2277
+ ),
2278
+ True,
2279
+ )
2280
+
2281
+ m[-2:] = wp.matrix_from_rows(
2282
+ vec4(29.0, 30.0, 31.0, 32.0),
2283
+ vec4(33.0, 34.0, 35.0, 36.0),
2284
+ )
2285
+ wp.expect_eq(
2286
+ m
2287
+ == wp.matrix_from_rows(
2288
+ vec4(1.0, 2.0, 3.0, 4.0),
2289
+ vec4(17.0, 18.0, 19.0, 20.0),
2290
+ vec4(29.0, 30.0, 31.0, 32.0),
2291
+ vec4(33.0, 34.0, 35.0, 36.0),
2292
+ ),
2293
+ True,
2294
+ )
2295
+
2296
+ m[:2] = wp.matrix_from_rows(
2297
+ vec4(37.0, 38.0, 39.0, 40.0),
2298
+ vec4(41.0, 42.0, 43.0, 44.0),
2299
+ )
2300
+ wp.expect_eq(
2301
+ m
2302
+ == wp.matrix_from_rows(
2303
+ vec4(37.0, 38.0, 39.0, 40.0),
2304
+ vec4(41.0, 42.0, 43.0, 44.0),
2305
+ vec4(29.0, 30.0, 31.0, 32.0),
2306
+ vec4(33.0, 34.0, 35.0, 36.0),
2307
+ ),
2308
+ True,
2309
+ )
2310
+
2311
+ m[:-1] = wp.matrix_from_rows(
2312
+ vec4(45.0, 46.0, 47.0, 48.0),
2313
+ vec4(49.0, 50.0, 51.0, 52.0),
2314
+ vec4(53.0, 54.0, 55.0, 56.0),
2315
+ )
2316
+ wp.expect_eq(
2317
+ m
2318
+ == wp.matrix_from_rows(
2319
+ vec4(45.0, 46.0, 47.0, 48.0),
2320
+ vec4(49.0, 50.0, 51.0, 52.0),
2321
+ vec4(53.0, 54.0, 55.0, 56.0),
2322
+ vec4(33.0, 34.0, 35.0, 36.0),
2323
+ ),
2324
+ True,
2325
+ )
2326
+
2327
+ m[::2] = wp.matrix_from_rows(
2328
+ vec4(57.0, 58.0, 59.0, 60.0),
2329
+ vec4(61.0, 62.0, 63.0, 64.0),
2330
+ )
2331
+ wp.expect_eq(
2332
+ m
2333
+ == wp.matrix_from_rows(
2334
+ vec4(57.0, 58.0, 59.0, 60.0),
2335
+ vec4(49.0, 50.0, 51.0, 52.0),
2336
+ vec4(61.0, 62.0, 63.0, 64.0),
2337
+ vec4(33.0, 34.0, 35.0, 36.0),
2338
+ ),
2339
+ True,
2340
+ )
2341
+
2342
+ m[1::2] = wp.matrix_from_rows(
2343
+ vec4(65.0, 66.0, 67.0, 68.0),
2344
+ vec4(69.0, 70.0, 71.0, 72.0),
2345
+ )
2346
+ wp.expect_eq(
2347
+ m
2348
+ == wp.matrix_from_rows(
2349
+ vec4(57.0, 58.0, 59.0, 60.0),
2350
+ vec4(65.0, 66.0, 67.0, 68.0),
2351
+ vec4(61.0, 62.0, 63.0, 64.0),
2352
+ vec4(69.0, 70.0, 71.0, 72.0),
2353
+ ),
2354
+ True,
2355
+ )
2356
+
2357
+ m[::-1] = wp.matrix_from_rows(
2358
+ vec4(73.0, 74.0, 75.0, 76.0),
2359
+ vec4(77.0, 78.0, 79.0, 80.0),
2360
+ vec4(81.0, 82.0, 83.0, 84.0),
2361
+ vec4(85.0, 86.0, 87.0, 88.0),
2362
+ )
2363
+ wp.expect_eq(
2364
+ m
2365
+ == wp.matrix_from_rows(
2366
+ vec4(85.0, 86.0, 87.0, 88.0),
2367
+ vec4(81.0, 82.0, 83.0, 84.0),
2368
+ vec4(77.0, 78.0, 79.0, 80.0),
2369
+ vec4(73.0, 74.0, 75.0, 76.0),
2370
+ ),
2371
+ True,
2372
+ )
2373
+
2374
+ m[::-2] = wp.matrix_from_rows(
2375
+ vec4(89.0, 90.0, 91.0, 92.0),
2376
+ vec4(93.0, 94.0, 95.0, 96.0),
2377
+ )
2378
+ wp.expect_eq(
2379
+ m
2380
+ == wp.matrix_from_rows(
2381
+ vec4(85.0, 86.0, 87.0, 88.0),
2382
+ vec4(93.0, 94.0, 95.0, 96.0),
2383
+ vec4(77.0, 78.0, 79.0, 80.0),
2384
+ vec4(89.0, 90.0, 91.0, 92.0),
2385
+ ),
2386
+ True,
2387
+ )
2388
+
2389
+ m[1::-2] = wp.matrix_from_rows(
2390
+ vec4(97.0, 98.0, 99.0, 100.0),
2391
+ )
2392
+ wp.expect_eq(
2393
+ m
2394
+ == wp.matrix_from_rows(
2395
+ vec4(85.0, 86.0, 87.0, 88.0),
2396
+ vec4(97.0, 98.0, 99.0, 100.0),
2397
+ vec4(77.0, 78.0, 79.0, 80.0),
2398
+ vec4(89.0, 90.0, 91.0, 92.0),
2399
+ ),
2400
+ True,
2401
+ )
2402
+
2403
+ m[:, :] = wp.matrix_from_rows(
2404
+ vec4(101.0, 102.0, 103.0, 104.0),
2405
+ vec4(105.0, 106.0, 107.0, 108.0),
2406
+ vec4(109.0, 110.0, 111.0, 112.0),
2407
+ vec4(113.0, 114.0, 115.0, 116.0),
2408
+ )
2409
+ wp.expect_eq(
2410
+ m
2411
+ == wp.matrix_from_rows(
2412
+ vec4(101.0, 102.0, 103.0, 104.0),
2413
+ vec4(105.0, 106.0, 107.0, 108.0),
2414
+ vec4(109.0, 110.0, 111.0, 112.0),
2415
+ vec4(113.0, 114.0, 115.0, 116.0),
2416
+ ),
2417
+ True,
2418
+ )
2419
+
2420
+ m[:, 2:] = wp.matrix_from_rows(
2421
+ vec2(117.0, 118.0),
2422
+ vec2(119.0, 120.0),
2423
+ vec2(121.0, 122.0),
2424
+ vec2(123.0, 124.0),
2425
+ )
2426
+ wp.expect_eq(
2427
+ m
2428
+ == wp.matrix_from_rows(
2429
+ vec4(101.0, 102.0, 117.0, 118.0),
2430
+ vec4(105.0, 106.0, 119.0, 120.0),
2431
+ vec4(109.0, 110.0, 121.0, 122.0),
2432
+ vec4(113.0, 114.0, 123.0, 124.0),
2433
+ ),
2434
+ True,
2435
+ )
2436
+
2437
+ m[1:, 2:] = wp.matrix_from_rows(
2438
+ vec2(125.0, 126.0),
2439
+ vec2(127.0, 128.0),
2440
+ vec2(129.0, 130.0),
2441
+ )
2442
+ wp.expect_eq(
2443
+ m
2444
+ == wp.matrix_from_rows(
2445
+ vec4(101.0, 102.0, 117.0, 118.0),
2446
+ vec4(105.0, 106.0, 125.0, 126.0),
2447
+ vec4(109.0, 110.0, 127.0, 128.0),
2448
+ vec4(113.0, 114.0, 129.0, 130.0),
2449
+ ),
2450
+ True,
2451
+ )
2452
+
2453
+ m[-2:, 2:] = wp.matrix_from_rows(
2454
+ vec2(131.0, 132.0),
2455
+ vec2(133.0, 134.0),
2456
+ )
2457
+ wp.expect_eq(
2458
+ m
2459
+ == wp.matrix_from_rows(
2460
+ vec4(101.0, 102.0, 117.0, 118.0),
2461
+ vec4(105.0, 106.0, 125.0, 126.0),
2462
+ vec4(109.0, 110.0, 131.0, 132.0),
2463
+ vec4(113.0, 114.0, 133.0, 134.0),
2464
+ ),
2465
+ True,
2466
+ )
2467
+
2468
+ m[2:, -2:] = wp.matrix_from_rows(
2469
+ vec2(135.0, 136.0),
2470
+ vec2(137.0, 138.0),
2471
+ )
2472
+ wp.expect_eq(
2473
+ m
2474
+ == wp.matrix_from_rows(
2475
+ vec4(101.0, 102.0, 117.0, 118.0),
2476
+ vec4(105.0, 106.0, 125.0, 126.0),
2477
+ vec4(109.0, 110.0, 135.0, 136.0),
2478
+ vec4(113.0, 114.0, 137.0, 138.0),
2479
+ ),
2480
+ True,
2481
+ )
2482
+
2483
+ m[1:, :2] = wp.matrix_from_rows(
2484
+ vec2(139.0, 140.0),
2485
+ vec2(141.0, 142.0),
2486
+ vec2(143.0, 144.0),
2487
+ )
2488
+ wp.expect_eq(
2489
+ m
2490
+ == wp.matrix_from_rows(
2491
+ vec4(101.0, 102.0, 117.0, 118.0),
2492
+ vec4(139.0, 140.0, 125.0, 126.0),
2493
+ vec4(141.0, 142.0, 135.0, 136.0),
2494
+ vec4(143.0, 144.0, 137.0, 138.0),
2495
+ ),
2496
+ True,
2497
+ )
2498
+
2499
+ m[:1, 2:] = wp.matrix_from_rows(
2500
+ vec2(145.0, 146.0),
2501
+ )
2502
+ wp.expect_eq(
2503
+ m
2504
+ == wp.matrix_from_rows(
2505
+ vec4(101.0, 102.0, 145.0, 146.0),
2506
+ vec4(139.0, 140.0, 125.0, 126.0),
2507
+ vec4(141.0, 142.0, 135.0, 136.0),
2508
+ vec4(143.0, 144.0, 137.0, 138.0),
2509
+ ),
2510
+ True,
2511
+ )
2512
+
2513
+ m[:2, 0] = vec2(147.0, 148.0)
2514
+ wp.expect_eq(
2515
+ m
2516
+ == wp.matrix_from_rows(
2517
+ vec4(147.0, 102.0, 145.0, 146.0),
2518
+ vec4(148.0, 140.0, 125.0, 126.0),
2519
+ vec4(141.0, 142.0, 135.0, 136.0),
2520
+ vec4(143.0, 144.0, 137.0, 138.0),
2521
+ ),
2522
+ True,
2523
+ )
2524
+
2525
+ m[2:, 1] = vec2(149.0, 150.0)
2526
+ wp.expect_eq(
2527
+ m
2528
+ == wp.matrix_from_rows(
2529
+ vec4(147.0, 102.0, 145.0, 146.0),
2530
+ vec4(148.0, 140.0, 125.0, 126.0),
2531
+ vec4(141.0, 149.0, 135.0, 136.0),
2532
+ vec4(143.0, 150.0, 137.0, 138.0),
2533
+ ),
2534
+ True,
2535
+ )
2536
+
2537
+ m[0, :3] = vec3(151.0, 152.0, 153.0)
2538
+ wp.expect_eq(
2539
+ m
2540
+ == wp.matrix_from_rows(
2541
+ vec4(151.0, 152.0, 153.0, 146.0),
2542
+ vec4(148.0, 140.0, 125.0, 126.0),
2543
+ vec4(141.0, 149.0, 135.0, 136.0),
2544
+ vec4(143.0, 150.0, 137.0, 138.0),
2545
+ ),
2546
+ True,
2547
+ )
2548
+
2549
+ m[1, 1:] = vec3(154.0, 155.0, 156.0)
2550
+ wp.expect_eq(
2551
+ m
2552
+ == wp.matrix_from_rows(
2553
+ vec4(151.0, 152.0, 153.0, 146.0),
2554
+ vec4(148.0, 154.0, 155.0, 156.0),
2555
+ vec4(141.0, 149.0, 135.0, 136.0),
2556
+ vec4(143.0, 150.0, 137.0, 138.0),
2557
+ ),
2558
+ True,
2559
+ )
2560
+
2561
+ m[0, 2] = 157.0
2562
+ wp.expect_eq(
2563
+ m
2564
+ == wp.matrix_from_rows(
2565
+ vec4(151.0, 152.0, 157.0, 146.0),
2566
+ vec4(148.0, 154.0, 155.0, 156.0),
2567
+ vec4(141.0, 149.0, 135.0, 136.0),
2568
+ vec4(143.0, 150.0, 137.0, 138.0),
2569
+ ),
2570
+ True,
2571
+ )
2572
+
2573
+ m[3, 1:] += vec3(158.0, 159.0, 160.0)
2574
+ wp.expect_eq(
2575
+ m
2576
+ == wp.matrix_from_rows(
2577
+ vec4(151.0, 152.0, 157.0, 146.0),
2578
+ vec4(148.0, 154.0, 155.0, 156.0),
2579
+ vec4(141.0, 149.0, 135.0, 136.0),
2580
+ vec4(143.0, 308.0, 296.0, 298.0),
2581
+ ),
2582
+ True,
2583
+ )
2584
+
2585
+ m[2:, 1] += vec2(161.0, 162.0)
2586
+ wp.expect_eq(
2587
+ m
2588
+ == wp.matrix_from_rows(
2589
+ vec4(151.0, 152.0, 157.0, 146.0),
2590
+ vec4(148.0, 154.0, 155.0, 156.0),
2591
+ vec4(141.0, 310.0, 135.0, 136.0),
2592
+ vec4(143.0, 470.0, 296.0, 298.0),
2593
+ ),
2594
+ True,
2595
+ )
2596
+
2597
+ m[2:, 3] -= vec2(163.0, 164.0)
2598
+ wp.expect_eq(
2599
+ m
2600
+ == wp.matrix_from_rows(
2601
+ vec4(151.0, 152.0, 157.0, 146.0),
2602
+ vec4(148.0, 154.0, 155.0, 156.0),
2603
+ vec4(141.0, 310.0, 135.0, -27.0),
2604
+ vec4(143.0, 470.0, 296.0, 134.0),
2605
+ ),
2606
+ True,
2607
+ )
2608
+
2609
+ m[1, :3] -= vec3(165.0, 166.0, 167.0)
2610
+ wp.expect_eq(
2611
+ m
2612
+ == wp.matrix_from_rows(
2613
+ vec4(151.0, 152.0, 157.0, 146.0),
2614
+ vec4(-17.0, -12.0, -12.0, 156.0),
2615
+ vec4(141.0, 310.0, 135.0, -27.0),
2616
+ vec4(143.0, 470.0, 296.0, 134.0),
2617
+ ),
2618
+ True,
2619
+ )
2620
+
2621
+ m[:-2, 2:] *= 3.0
2622
+ wp.expect_eq(
2623
+ m
2624
+ == wp.matrix_from_rows(
2625
+ vec4(151.0, 152.0, 471.0, 438.0),
2626
+ vec4(-17.0, -12.0, -36.0, 468.0),
2627
+ vec4(141.0, 310.0, 135.0, -27.0),
2628
+ vec4(143.0, 470.0, 296.0, 134.0),
2629
+ ),
2630
+ True,
2631
+ )
2632
+
2633
+ m[-2:, 1] *= 4.0
2634
+ wp.expect_eq(
2635
+ m
2636
+ == wp.matrix_from_rows(
2637
+ vec4(151.0, 152.0, 471.0, 438.0),
2638
+ vec4(-17.0, -12.0, -36.0, 468.0),
2639
+ vec4(141.0, 1240.0, 135.0, -27.0),
2640
+ vec4(143.0, 1880.0, 296.0, 134.0),
2641
+ ),
2642
+ True,
2643
+ )
2644
+
2645
+ m[3, :1] *= 5.0
2646
+ wp.expect_eq(
2647
+ m
2648
+ == wp.matrix_from_rows(
2649
+ vec4(151.0, 152.0, 471.0, 438.0),
2650
+ vec4(-17.0, -12.0, -36.0, 468.0),
2651
+ vec4(141.0, 1240.0, 135.0, -27.0),
2652
+ vec4(715.0, 1880.0, 296.0, 134.0),
2653
+ ),
2654
+ True,
2655
+ )
2656
+
2657
+ m[:2, :2] /= 2.0
2658
+ wp.expect_eq(
2659
+ m
2660
+ == wp.matrix_from_rows(
2661
+ vec4(75.5, 76.0, 471.0, 438.0),
2662
+ vec4(-8.5, -6.0, -36.0, 468.0),
2663
+ vec4(141.0, 1240.0, 135.0, -27.0),
2664
+ vec4(715.0, 1880.0, 296.0, 134.0),
2665
+ ),
2666
+ True,
2667
+ )
2668
+
2669
+ m[3:, 3] /= 4.0
2670
+ wp.expect_eq(
2671
+ m
2672
+ == wp.matrix_from_rows(
2673
+ vec4(75.5, 76.0, 471.0, 438.0),
2674
+ vec4(-8.5, -6.0, -36.0, 468.0),
2675
+ vec4(141.0, 1240.0, 135.0, -27.0),
2676
+ vec4(715.0, 1880.0, 296.0, 33.5),
2677
+ ),
2678
+ True,
2679
+ )
2680
+
2681
+ m[0, :2] /= 4.0
2682
+ wp.expect_eq(
2683
+ m
2684
+ == wp.matrix_from_rows(
2685
+ vec4(18.875, 19.0, 471.0, 438.0),
2686
+ vec4(-8.5, -6.0, -36.0, 468.0),
2687
+ vec4(141.0, 1240.0, 135.0, -27.0),
2688
+ vec4(715.0, 1880.0, 296.0, 33.5),
2689
+ ),
2690
+ True,
2691
+ )
2692
+
2693
+ @wp.kernel(module="unique")
2694
+ def kernel():
2695
+ fn()
2696
+
2697
+ wp.launch(kernel, 1, device=device)
2698
+ wp.synchronize()
2699
+ fn()
2700
+
2701
+
2702
+ def test_mat_from_cols_slicing_assign(test, device):
2703
+ mat00 = wp.mat((0, 0), float)
2704
+ vec1 = wp.vec(1, float)
2705
+ vec2 = wp.vec(2, float)
2706
+ vec3 = wp.vec(3, float)
2707
+ vec4 = wp.vec(4, float)
2708
+
2709
+ @wp.func
2710
+ def fn():
2711
+ m = wp.matrix_from_cols(
2712
+ vec4(1.0, 2.0, 3.0, 4.0),
2713
+ vec4(5.0, 6.0, 7.0, 8.0),
2714
+ vec4(9.0, 10.0, 11.0, 12.0),
2715
+ vec4(13.0, 14.0, 15.0, 16.0),
2716
+ )
2717
+
2718
+ wp.expect_eq(
2719
+ m[:]
2720
+ == wp.matrix_from_rows(
2721
+ vec4(1.0, 5.0, 9.0, 13.0),
2722
+ vec4(2.0, 6.0, 10.0, 14.0),
2723
+ vec4(3.0, 7.0, 11.0, 15.0),
2724
+ vec4(4.0, 8.0, 12.0, 16.0),
2725
+ ),
2726
+ True,
2727
+ )
2728
+ wp.expect_eq(
2729
+ m[-123:123]
2730
+ == wp.matrix_from_rows(
2731
+ vec4(1.0, 5.0, 9.0, 13.0),
2732
+ vec4(2.0, 6.0, 10.0, 14.0),
2733
+ vec4(3.0, 7.0, 11.0, 15.0),
2734
+ vec4(4.0, 8.0, 12.0, 16.0),
2735
+ ),
2736
+ True,
2737
+ )
2738
+ wp.expect_eq(m[123:] == mat00(), True)
2739
+ wp.expect_eq(m[:-123] == mat00(), True)
2740
+ wp.expect_eq(
2741
+ m[::123]
2742
+ == wp.matrix_from_rows(
2743
+ vec4(1.0, 5.0, 9.0, 13.0),
2744
+ ),
2745
+ True,
2746
+ )
2747
+
2748
+ wp.expect_eq(
2749
+ m[1:]
2750
+ == wp.matrix_from_rows(
2751
+ vec4(2.0, 6.0, 10.0, 14.0),
2752
+ vec4(3.0, 7.0, 11.0, 15.0),
2753
+ vec4(4.0, 8.0, 12.0, 16.0),
2754
+ ),
2755
+ True,
2756
+ )
2757
+ wp.expect_eq(
2758
+ m[-2:]
2759
+ == wp.matrix_from_rows(
2760
+ vec4(3.0, 7.0, 11.0, 15.0),
2761
+ vec4(4.0, 8.0, 12.0, 16.0),
2762
+ ),
2763
+ True,
2764
+ )
2765
+ wp.expect_eq(
2766
+ m[:2]
2767
+ == wp.matrix_from_rows(
2768
+ vec4(1.0, 5.0, 9.0, 13.0),
2769
+ vec4(2.0, 6.0, 10.0, 14.0),
2770
+ ),
2771
+ True,
2772
+ )
2773
+ wp.expect_eq(
2774
+ m[:-1]
2775
+ == wp.matrix_from_rows(
2776
+ vec4(1.0, 5.0, 9.0, 13.0),
2777
+ vec4(2.0, 6.0, 10.0, 14.0),
2778
+ vec4(3.0, 7.0, 11.0, 15.0),
2779
+ ),
2780
+ True,
2781
+ )
2782
+ wp.expect_eq(
2783
+ m[::2]
2784
+ == wp.matrix_from_rows(
2785
+ vec4(1.0, 5.0, 9.0, 13.0),
2786
+ vec4(3.0, 7.0, 11.0, 15.0),
2787
+ ),
2788
+ True,
2789
+ )
2790
+ wp.expect_eq(
2791
+ m[1::2]
2792
+ == wp.matrix_from_rows(
2793
+ vec4(2.0, 6.0, 10.0, 14.0),
2794
+ vec4(4.0, 8.0, 12.0, 16.0),
2795
+ ),
2796
+ True,
2797
+ )
2798
+ wp.expect_eq(
2799
+ m[::-1]
2800
+ == wp.matrix_from_rows(
2801
+ vec4(4.0, 8.0, 12.0, 16.0),
2802
+ vec4(3.0, 7.0, 11.0, 15.0),
2803
+ vec4(2.0, 6.0, 10.0, 14.0),
2804
+ vec4(1.0, 5.0, 9.0, 13.0),
2805
+ ),
2806
+ True,
2807
+ )
2808
+ wp.expect_eq(
2809
+ m[::-2]
2810
+ == wp.matrix_from_rows(
2811
+ vec4(4.0, 8.0, 12.0, 16.0),
2812
+ vec4(2.0, 6.0, 10.0, 14.0),
2813
+ ),
2814
+ True,
2815
+ )
2816
+ wp.expect_eq(
2817
+ m[1::-2]
2818
+ == wp.matrix_from_rows(
2819
+ vec4(2.0, 6.0, 10.0, 14.0),
2820
+ ),
2821
+ True,
2822
+ )
2823
+
2824
+ wp.expect_eq(
2825
+ m[:, :]
2826
+ == wp.matrix_from_rows(
2827
+ vec4(1.0, 5.0, 9.0, 13.0),
2828
+ vec4(2.0, 6.0, 10.0, 14.0),
2829
+ vec4(3.0, 7.0, 11.0, 15.0),
2830
+ vec4(4.0, 8.0, 12.0, 16.0),
2831
+ ),
2832
+ True,
2833
+ )
2834
+ wp.expect_eq(
2835
+ m[:, 2:]
2836
+ == wp.matrix_from_rows(
2837
+ vec2(9.0, 13.0),
2838
+ vec2(10.0, 14.0),
2839
+ vec2(11.0, 15.0),
2840
+ vec2(12.0, 16.0),
2841
+ ),
2842
+ True,
2843
+ )
2844
+ wp.expect_eq(
2845
+ m[1:, 2:]
2846
+ == wp.matrix_from_rows(
2847
+ vec2(10.0, 14.0),
2848
+ vec2(11.0, 15.0),
2849
+ vec2(12.0, 16.0),
2850
+ ),
2851
+ True,
2852
+ )
2853
+ wp.expect_eq(
2854
+ m[-2:, 2:]
2855
+ == wp.matrix_from_rows(
2856
+ vec2(11.0, 15.0),
2857
+ vec2(12.0, 16.0),
2858
+ ),
2859
+ True,
2860
+ )
2861
+ wp.expect_eq(
2862
+ m[2:, -2:]
2863
+ == wp.matrix_from_rows(
2864
+ vec2(11.0, 15.0),
2865
+ vec2(12.0, 16.0),
2866
+ ),
2867
+ True,
2868
+ )
2869
+ wp.expect_eq(
2870
+ m[1:, :2]
2871
+ == wp.matrix_from_rows(
2872
+ vec2(2.0, 6.0),
2873
+ vec2(3.0, 7.0),
2874
+ vec2(4.0, 8.0),
2875
+ ),
2876
+ True,
2877
+ )
2878
+ wp.expect_eq(
2879
+ m[:1, 2:]
2880
+ == wp.matrix_from_rows(
2881
+ vec2(9.0, 13.0),
2882
+ ),
2883
+ True,
2884
+ )
2885
+ wp.expect_eq(
2886
+ m[::-1, :1]
2887
+ == wp.matrix_from_rows(
2888
+ vec1(4.0),
2889
+ vec1(3.0),
2890
+ vec1(2.0),
2891
+ vec1(1.0),
2892
+ ),
2893
+ True,
2894
+ )
2895
+ wp.expect_eq(
2896
+ m[:1, ::-1]
2897
+ == wp.matrix_from_rows(
2898
+ vec4(13.0, 9.0, 5.0, 1.0),
2899
+ ),
2900
+ True,
2901
+ )
2902
+ wp.expect_eq(
2903
+ m[:1:-1, 2::-1]
2904
+ == wp.matrix_from_rows(
2905
+ vec3(12.0, 8.0, 4.0),
2906
+ vec3(11.0, 7.0, 3.0),
2907
+ ),
2908
+ True,
2909
+ )
2910
+
2911
+ wp.expect_eq(m[:2, 0] == vec2(1.0, 2.0), True)
2912
+ wp.expect_eq(m[2:, 1] == vec2(7.0, 8.0), True)
2913
+ wp.expect_eq(m[0, :3] == vec3(1.0, 5.0, 9.0), True)
2914
+ wp.expect_eq(m[1, 1:] == vec3(6.0, 10.0, 14.0), True)
2915
+
2916
+ m[1:] = wp.matrix_from_cols(
2917
+ vec3(17.0, 18.0, 19.0),
2918
+ vec3(20.0, 21.0, 22.0),
2919
+ vec3(23.0, 24.0, 25.0),
2920
+ vec3(26.0, 27.0, 28.0),
2921
+ )
2922
+ wp.expect_eq(
2923
+ m
2924
+ == wp.matrix_from_rows(
2925
+ vec4(1.0, 5.0, 9.0, 13.0),
2926
+ vec4(17.0, 20.0, 23.0, 26.0),
2927
+ vec4(18.0, 21.0, 24.0, 27.0),
2928
+ vec4(19.0, 22.0, 25.0, 28.0),
2929
+ ),
2930
+ True,
2931
+ )
2932
+
2933
+ m[-2:] = wp.matrix_from_cols(
2934
+ vec2(29.0, 30.0),
2935
+ vec2(31.0, 32.0),
2936
+ vec2(33.0, 34.0),
2937
+ vec2(35.0, 36.0),
2938
+ )
2939
+ wp.expect_eq(
2940
+ m
2941
+ == wp.matrix_from_rows(
2942
+ vec4(1.0, 5.0, 9.0, 13.0),
2943
+ vec4(17.0, 20.0, 23.0, 26.0),
2944
+ vec4(29.0, 31.0, 33.0, 35.0),
2945
+ vec4(30.0, 32.0, 34.0, 36.0),
2946
+ ),
2947
+ True,
2948
+ )
2949
+
2950
+ m[:2] = wp.matrix_from_cols(
2951
+ vec2(37.0, 38.0),
2952
+ vec2(39.0, 40.0),
2953
+ vec2(41.0, 42.0),
2954
+ vec2(43.0, 44.0),
2955
+ )
2956
+ wp.expect_eq(
2957
+ m
2958
+ == wp.matrix_from_rows(
2959
+ vec4(37.0, 39.0, 41.0, 43.0),
2960
+ vec4(38.0, 40.0, 42.0, 44.0),
2961
+ vec4(29.0, 31.0, 33.0, 35.0),
2962
+ vec4(30.0, 32.0, 34.0, 36.0),
2963
+ ),
2964
+ True,
2965
+ )
2966
+
2967
+ m[:-1] = wp.matrix_from_cols(
2968
+ vec3(45.0, 46.0, 47.0),
2969
+ vec3(48.0, 49.0, 50.0),
2970
+ vec3(51.0, 52.0, 53.0),
2971
+ vec3(54.0, 55.0, 56.0),
2972
+ )
2973
+ wp.expect_eq(
2974
+ m
2975
+ == wp.matrix_from_rows(
2976
+ vec4(45.0, 48.0, 51.0, 54.0),
2977
+ vec4(46.0, 49.0, 52.0, 55.0),
2978
+ vec4(47.0, 50.0, 53.0, 56.0),
2979
+ vec4(30.0, 32.0, 34.0, 36.0),
2980
+ ),
2981
+ True,
2982
+ )
2983
+
2984
+ m[::2] = wp.matrix_from_cols(
2985
+ vec2(57.0, 58.0),
2986
+ vec2(59.0, 60.0),
2987
+ vec2(61.0, 62.0),
2988
+ vec2(63.0, 64.0),
2989
+ )
2990
+ wp.expect_eq(
2991
+ m
2992
+ == wp.matrix_from_rows(
2993
+ vec4(57.0, 59.0, 61.0, 63.0),
2994
+ vec4(46.0, 49.0, 52.0, 55.0),
2995
+ vec4(58.0, 60.0, 62.0, 64.0),
2996
+ vec4(30.0, 32.0, 34.0, 36.0),
2997
+ ),
2998
+ True,
2999
+ )
3000
+
3001
+ m[1::2] = wp.matrix_from_cols(
3002
+ vec2(65.0, 66.0),
3003
+ vec2(67.0, 68.0),
3004
+ vec2(69.0, 70.0),
3005
+ vec2(71.0, 72.0),
3006
+ )
3007
+ wp.expect_eq(
3008
+ m
3009
+ == wp.matrix_from_rows(
3010
+ vec4(57.0, 59.0, 61.0, 63.0),
3011
+ vec4(65.0, 67.0, 69.0, 71.0),
3012
+ vec4(58.0, 60.0, 62.0, 64.0),
3013
+ vec4(66.0, 68.0, 70.0, 72.0),
3014
+ ),
3015
+ True,
3016
+ )
3017
+
3018
+ m[::-1] = wp.matrix_from_cols(
3019
+ vec4(73.0, 74.0, 75.0, 76.0),
3020
+ vec4(77.0, 78.0, 79.0, 80.0),
3021
+ vec4(81.0, 82.0, 83.0, 84.0),
3022
+ vec4(85.0, 86.0, 87.0, 88.0),
3023
+ )
3024
+ wp.expect_eq(
3025
+ m
3026
+ == wp.matrix_from_rows(
3027
+ vec4(76.0, 80.0, 84.0, 88.0),
3028
+ vec4(75.0, 79.0, 83.0, 87.0),
3029
+ vec4(74.0, 78.0, 82.0, 86.0),
3030
+ vec4(73.0, 77.0, 81.0, 85.0),
3031
+ ),
3032
+ True,
3033
+ )
3034
+
3035
+ m[::-2] = wp.matrix_from_cols(
3036
+ vec2(89.0, 90.0),
3037
+ vec2(91.0, 92.0),
3038
+ vec2(93.0, 94.0),
3039
+ vec2(95.0, 96.0),
3040
+ )
3041
+ wp.expect_eq(
3042
+ m
3043
+ == wp.matrix_from_rows(
3044
+ vec4(76.0, 80.0, 84.0, 88.0),
3045
+ vec4(90.0, 92.0, 94.0, 96.0),
3046
+ vec4(74.0, 78.0, 82.0, 86.0),
3047
+ vec4(89.0, 91.0, 93.0, 95.0),
3048
+ ),
3049
+ True,
3050
+ )
3051
+
3052
+ m[1::-2] = wp.matrix_from_cols(
3053
+ vec1(97.0),
3054
+ vec1(98.0),
3055
+ vec1(99.0),
3056
+ vec1(100.0),
3057
+ )
3058
+ wp.expect_eq(
3059
+ m
3060
+ == wp.matrix_from_rows(
3061
+ vec4(76.0, 80.0, 84.0, 88.0),
3062
+ vec4(97.0, 98.0, 99.0, 100.0),
3063
+ vec4(74.0, 78.0, 82.0, 86.0),
3064
+ vec4(89.0, 91.0, 93.0, 95.0),
3065
+ ),
3066
+ True,
3067
+ )
3068
+
3069
+ m[:, :] = wp.matrix_from_cols(
3070
+ vec4(101.0, 102.0, 103.0, 104.0),
3071
+ vec4(105.0, 106.0, 107.0, 108.0),
3072
+ vec4(109.0, 110.0, 111.0, 112.0),
3073
+ vec4(113.0, 114.0, 115.0, 116.0),
3074
+ )
3075
+ wp.expect_eq(
3076
+ m
3077
+ == wp.matrix_from_rows(
3078
+ vec4(101.0, 105.0, 109.0, 113.0),
3079
+ vec4(102.0, 106.0, 110.0, 114.0),
3080
+ vec4(103.0, 107.0, 111.0, 115.0),
3081
+ vec4(104.0, 108.0, 112.0, 116.0),
3082
+ ),
3083
+ True,
3084
+ )
3085
+
3086
+ m[:, 2:] = wp.matrix_from_cols(
3087
+ vec4(117.0, 118.0, 119.0, 120.0),
3088
+ vec4(121.0, 122.0, 123.0, 124.0),
3089
+ )
3090
+ wp.expect_eq(
3091
+ m
3092
+ == wp.matrix_from_rows(
3093
+ vec4(101.0, 105.0, 117.0, 121.0),
3094
+ vec4(102.0, 106.0, 118.0, 122.0),
3095
+ vec4(103.0, 107.0, 119.0, 123.0),
3096
+ vec4(104.0, 108.0, 120.0, 124.0),
3097
+ ),
3098
+ True,
3099
+ )
3100
+
3101
+ m[1:, 2:] = wp.matrix_from_cols(
3102
+ vec3(125.0, 126.0, 127.0),
3103
+ vec3(128.0, 129.0, 130.0),
3104
+ )
3105
+ wp.expect_eq(
3106
+ m
3107
+ == wp.matrix_from_rows(
3108
+ vec4(101.0, 105.0, 117.0, 121.0),
3109
+ vec4(102.0, 106.0, 125.0, 128.0),
3110
+ vec4(103.0, 107.0, 126.0, 129.0),
3111
+ vec4(104.0, 108.0, 127.0, 130.0),
3112
+ ),
3113
+ True,
3114
+ )
3115
+
3116
+ m[-2:, 2:] = wp.matrix_from_cols(
3117
+ vec2(131.0, 132.0),
3118
+ vec2(133.0, 134.0),
3119
+ )
3120
+ wp.expect_eq(
3121
+ m
3122
+ == wp.matrix_from_rows(
3123
+ vec4(101.0, 105.0, 117.0, 121.0),
3124
+ vec4(102.0, 106.0, 125.0, 128.0),
3125
+ vec4(103.0, 107.0, 131.0, 133.0),
3126
+ vec4(104.0, 108.0, 132.0, 134.0),
3127
+ ),
3128
+ True,
3129
+ )
3130
+
3131
+ m[2:, -2:] = wp.matrix_from_cols(
3132
+ vec2(135.0, 136.0),
3133
+ vec2(137.0, 138.0),
3134
+ )
3135
+ wp.expect_eq(
3136
+ m
3137
+ == wp.matrix_from_rows(
3138
+ vec4(101.0, 105.0, 117.0, 121.0),
3139
+ vec4(102.0, 106.0, 125.0, 128.0),
3140
+ vec4(103.0, 107.0, 135.0, 137.0),
3141
+ vec4(104.0, 108.0, 136.0, 138.0),
3142
+ ),
3143
+ True,
3144
+ )
3145
+
3146
+ m[1:, :2] = wp.matrix_from_cols(
3147
+ vec3(139.0, 140.0, 141.0),
3148
+ vec3(142.0, 143.0, 144.0),
3149
+ )
3150
+ wp.expect_eq(
3151
+ m
3152
+ == wp.matrix_from_rows(
3153
+ vec4(101.0, 105.0, 117.0, 121.0),
3154
+ vec4(139.0, 142.0, 125.0, 128.0),
3155
+ vec4(140.0, 143.0, 135.0, 137.0),
3156
+ vec4(141.0, 144.0, 136.0, 138.0),
3157
+ ),
3158
+ True,
3159
+ )
3160
+
3161
+ m[:1, 2:] = wp.matrix_from_cols(
3162
+ vec1(145.0),
3163
+ vec1(146.0),
3164
+ )
3165
+ wp.expect_eq(
3166
+ m
3167
+ == wp.matrix_from_rows(
3168
+ vec4(101.0, 105.0, 145.0, 146.0),
3169
+ vec4(139.0, 142.0, 125.0, 128.0),
3170
+ vec4(140.0, 143.0, 135.0, 137.0),
3171
+ vec4(141.0, 144.0, 136.0, 138.0),
3172
+ ),
3173
+ True,
3174
+ )
3175
+
3176
+ m[:2, 0] = vec2(147.0, 148.0)
3177
+ wp.expect_eq(
3178
+ m
3179
+ == wp.matrix_from_rows(
3180
+ vec4(147.0, 105.0, 145.0, 146.0),
3181
+ vec4(148.0, 142.0, 125.0, 128.0),
3182
+ vec4(140.0, 143.0, 135.0, 137.0),
3183
+ vec4(141.0, 144.0, 136.0, 138.0),
3184
+ ),
3185
+ True,
3186
+ )
3187
+
3188
+ m[2:, 1] = vec2(149.0, 150.0)
3189
+ wp.expect_eq(
3190
+ m
3191
+ == wp.matrix_from_rows(
3192
+ vec4(147.0, 105.0, 145.0, 146.0),
3193
+ vec4(148.0, 142.0, 125.0, 128.0),
3194
+ vec4(140.0, 149.0, 135.0, 137.0),
3195
+ vec4(141.0, 150.0, 136.0, 138.0),
3196
+ ),
3197
+ True,
3198
+ )
3199
+
3200
+ m[0, :3] = vec3(151.0, 152.0, 153.0)
3201
+ wp.expect_eq(
3202
+ m
3203
+ == wp.matrix_from_rows(
3204
+ vec4(151.0, 152.0, 153.0, 146.0),
3205
+ vec4(148.0, 142.0, 125.0, 128.0),
3206
+ vec4(140.0, 149.0, 135.0, 137.0),
3207
+ vec4(141.0, 150.0, 136.0, 138.0),
3208
+ ),
3209
+ True,
3210
+ )
3211
+
3212
+ m[1, 1:] = vec3(154.0, 155.0, 156.0)
3213
+ wp.expect_eq(
3214
+ m
3215
+ == wp.matrix_from_rows(
3216
+ vec4(151.0, 152.0, 153.0, 146.0),
3217
+ vec4(148.0, 154.0, 155.0, 156.0),
3218
+ vec4(140.0, 149.0, 135.0, 137.0),
3219
+ vec4(141.0, 150.0, 136.0, 138.0),
3220
+ ),
3221
+ True,
3222
+ )
3223
+
3224
+ m[0, 2] = 157.0
3225
+ wp.expect_eq(
3226
+ m
3227
+ == wp.matrix_from_rows(
3228
+ vec4(151.0, 152.0, 157.0, 146.0),
3229
+ vec4(148.0, 154.0, 155.0, 156.0),
3230
+ vec4(140.0, 149.0, 135.0, 137.0),
3231
+ vec4(141.0, 150.0, 136.0, 138.0),
3232
+ ),
3233
+ True,
3234
+ )
3235
+
3236
+ m[3, 1:] += vec3(158.0, 159.0, 160.0)
3237
+ wp.expect_eq(
3238
+ m
3239
+ == wp.matrix_from_rows(
3240
+ vec4(151.0, 152.0, 157.0, 146.0),
3241
+ vec4(148.0, 154.0, 155.0, 156.0),
3242
+ vec4(140.0, 149.0, 135.0, 137.0),
3243
+ vec4(141.0, 308.0, 295.0, 298.0),
3244
+ ),
3245
+ True,
3246
+ )
3247
+
3248
+ m[2:, 1] += vec2(161.0, 162.0)
3249
+ wp.expect_eq(
3250
+ m
3251
+ == wp.matrix_from_rows(
3252
+ vec4(151.0, 152.0, 157.0, 146.0),
3253
+ vec4(148.0, 154.0, 155.0, 156.0),
3254
+ vec4(140.0, 310.0, 135.0, 137.0),
3255
+ vec4(141.0, 470.0, 295.0, 298.0),
3256
+ ),
3257
+ True,
3258
+ )
3259
+
3260
+ m[2:, 3] -= vec2(163.0, 164.0)
3261
+ wp.expect_eq(
3262
+ m
3263
+ == wp.matrix_from_rows(
3264
+ vec4(151.0, 152.0, 157.0, 146.0),
3265
+ vec4(148.0, 154.0, 155.0, 156.0),
3266
+ vec4(140.0, 310.0, 135.0, -26.0),
3267
+ vec4(141.0, 470.0, 295.0, 134.0),
3268
+ ),
3269
+ True,
3270
+ )
3271
+
3272
+ m[1, :3] -= vec3(165.0, 166.0, 167.0)
3273
+ wp.expect_eq(
3274
+ m
3275
+ == wp.matrix_from_rows(
3276
+ vec4(151.0, 152.0, 157.0, 146.0),
3277
+ vec4(-17.0, -12.0, -12.0, 156.0),
3278
+ vec4(140.0, 310.0, 135.0, -26.0),
3279
+ vec4(141.0, 470.0, 295.0, 134.0),
3280
+ ),
3281
+ True,
3282
+ )
3283
+
3284
+ m[:-2, 2:] *= 3.0
3285
+ wp.expect_eq(
3286
+ m
3287
+ == wp.matrix_from_rows(
3288
+ vec4(151.0, 152.0, 471.0, 438.0),
3289
+ vec4(-17.0, -12.0, -36.0, 468.0),
3290
+ vec4(140.0, 310.0, 135.0, -26.0),
3291
+ vec4(141.0, 470.0, 295.0, 134.0),
3292
+ ),
3293
+ True,
3294
+ )
3295
+
3296
+ m[-2:, 1] *= 4.0
3297
+ wp.expect_eq(
3298
+ m
3299
+ == wp.matrix_from_rows(
3300
+ vec4(151.0, 152.0, 471.0, 438.0),
3301
+ vec4(-17.0, -12.0, -36.0, 468.0),
3302
+ vec4(140.0, 1240.0, 135.0, -26.0),
3303
+ vec4(141.0, 1880.0, 295.0, 134.0),
3304
+ ),
3305
+ True,
3306
+ )
3307
+
3308
+ m[3, :1] *= 5.0
3309
+ wp.expect_eq(
3310
+ m
3311
+ == wp.matrix_from_rows(
3312
+ vec4(151.0, 152.0, 471.0, 438.0),
3313
+ vec4(-17.0, -12.0, -36.0, 468.0),
3314
+ vec4(140.0, 1240.0, 135.0, -26.0),
3315
+ vec4(705.0, 1880.0, 295.0, 134.0),
3316
+ ),
3317
+ True,
3318
+ )
3319
+
3320
+ m[:2, :2] /= 2.0
3321
+ wp.expect_eq(
3322
+ m
3323
+ == wp.matrix_from_rows(
3324
+ vec4(75.5, 76.0, 471.0, 438.0),
3325
+ vec4(-8.5, -6.0, -36.0, 468.0),
3326
+ vec4(140.0, 1240.0, 135.0, -26.0),
3327
+ vec4(705.0, 1880.0, 295.0, 134.0),
3328
+ ),
3329
+ True,
3330
+ )
3331
+
3332
+ m[3:, 3] /= 4.0
3333
+ wp.expect_eq(
3334
+ m
3335
+ == wp.matrix_from_rows(
3336
+ vec4(75.5, 76.0, 471.0, 438.0),
3337
+ vec4(-8.5, -6.0, -36.0, 468.0),
3338
+ vec4(140.0, 1240.0, 135.0, -26.0),
3339
+ vec4(705.0, 1880.0, 295.0, 33.5),
3340
+ ),
3341
+ True,
3342
+ )
3343
+
3344
+ m[0, :2] /= 4.0
3345
+ wp.expect_eq(
3346
+ m
3347
+ == wp.matrix_from_rows(
3348
+ vec4(18.875, 19.0, 471.0, 438.0),
3349
+ vec4(-8.5, -6.0, -36.0, 468.0),
3350
+ vec4(140.0, 1240.0, 135.0, -26.0),
3351
+ vec4(705.0, 1880.0, 295.0, 33.5),
3352
+ ),
3353
+ True,
3354
+ )
3355
+
3356
+ @wp.kernel(module="unique")
3357
+ def kernel():
3358
+ fn()
3359
+
3360
+ wp.launch(kernel, 1, device=device)
3361
+ wp.synchronize()
3362
+ fn()
3363
+
3364
+
3365
+ def test_mat_slicing_assign_backward(test, device):
3366
+ mat23 = wp.mat((2, 3), float)
3367
+
3368
+ @wp.kernel(module="unique")
3369
+ def kernel(
3370
+ arr_x: wp.array(dtype=wp.vec2),
3371
+ arr_y: wp.array(dtype=mat23),
3372
+ arr_z: wp.array(dtype=wp.mat44),
3373
+ ):
3374
+ i = wp.tid()
3375
+
3376
+ z = arr_z[i]
3377
+
3378
+ z[0, :2] = arr_x[i]
3379
+ z[:2, 1:] = arr_y[i]
3380
+
3381
+ z[:2, 3] += arr_x[i][:2]
3382
+ z[1:-1, :2] += arr_y[i][::-1, :-1]
3383
+
3384
+ z[2:, 3] -= arr_x[i][0:]
3385
+ z[3:, -1:] -= arr_y[i][:1, :1]
3386
+
3387
+ arr_z[i] = z
3388
+
3389
+ x = wp.ones(1, dtype=wp.vec2, requires_grad=True, device=device)
3390
+ y = wp.ones(1, dtype=mat23, requires_grad=True, device=device)
3391
+ z = wp.zeros(1, dtype=wp.mat44, requires_grad=True, device=device)
3392
+
3393
+ tape = wp.Tape()
3394
+ with tape:
3395
+ wp.launch(kernel, 1, inputs=(x, y), outputs=(z,), device=device)
3396
+
3397
+ z.grad = wp.ones_like(z)
3398
+ tape.backward()
3399
+
3400
+ assert_np_equal(
3401
+ z.numpy(),
3402
+ np.array(
3403
+ (
3404
+ (
3405
+ (1.0, 1.0, 1.0, 2.0),
3406
+ (1.0, 2.0, 1.0, 2.0),
3407
+ (1.0, 1.0, 0.0, -1.0),
3408
+ (0.0, 0.0, 0.0, -2.0),
3409
+ ),
3410
+ ),
3411
+ dtype=float,
3412
+ ),
3413
+ )
3414
+ assert_np_equal(x.grad.numpy(), np.array(((1.0, 1.0),), dtype=float))
3415
+ assert_np_equal(y.grad.numpy(), np.array((((1.0, 2.0, 1.0), (2.0, 2.0, 1.0)),), dtype=float))
3416
+
3417
+
3418
+ devices = get_test_devices()
3419
+
3420
+
3421
+ class TestMat(unittest.TestCase):
3422
+ def test_tpl_ops_with_anon(self):
3423
+ mat22f = wp.mat((2, 2), dtype=float)
3424
+
3425
+ m = wp.mat22f(1.0, 2.0, 3.0, 4.0)
3426
+ m += mat22f(2.0, 3.0, 4.0, 5.0)
3427
+ m -= mat22f(3.0, 4.0, 5.0, 6.0)
3428
+ self.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
3429
+
3430
+ m = mat22f(1.0, 2.0, 3.0, 4.0)
3431
+ m += wp.mat22f(2.0, 3.0, 4.0, 5.0)
3432
+ m -= wp.mat22f(3.0, 4.0, 5.0, 6.0)
3433
+ self.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
3434
+
3435
+
3436
+ mat103 = wp._src.types.matrix(shape=(10, 3), dtype=float)
3437
+ add_kernel_test(
3438
+ TestMat,
3439
+ test_matrix_mutation,
3440
+ dim=1,
3441
+ inputs=[
3442
+ mat103(
3443
+ 1.0, 2.0, 3.0,
3444
+ 2.0, 4.0, 6.0,
3445
+ 3.0, 6.0, 9.0,
3446
+ 4.0, 8.0, 12.0,
3447
+ 5.0, 10.0, 15.0,
3448
+ 6.0, 12.0, 18.0,
3449
+ 7.0, 14.0, 21.0,
3450
+ 8.0, 16.0, 24.0,
3451
+ 9.0, 18.0, 27.0,
3452
+ 10.0, 20.0, 30.0,
3453
+ )
3454
+ ],
3455
+ devices=devices,
3456
+ ) # fmt: skip
3457
+
3458
+ for dtype in np_signed_int_types + np_float_types:
3459
+ add_function_test_register_kernel(
3460
+ TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
3461
+ )
3462
+ add_function_test_register_kernel(
3463
+ TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
3464
+ )
3465
+ add_function_test_register_kernel(
3466
+ TestMat, f"test_matmul_{dtype.__name__}", test_matmul, devices=devices, dtype=dtype
3467
+ )
3468
+
3469
+ add_function_test(TestMat, "test_shape_mismatch", test_shape_mismatch, devices=devices)
3470
+
3471
+
3472
+ for dtype in np_float_types:
3473
+ add_function_test(
3474
+ TestMat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
3475
+ )
3476
+ add_function_test_register_kernel(
3477
+ TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
3478
+ )
3479
+ add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
3480
+ add_function_test_register_kernel(
3481
+ TestMat, f"test_svd_2D{dtype.__name__}", test_svd_2D, devices=devices, dtype=dtype
3482
+ )
3483
+ add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
3484
+ add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
3485
+ add_function_test_register_kernel(
3486
+ TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
3487
+ )
3488
+ add_function_test_register_kernel(
3489
+ TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
3490
+ )
3491
+ add_function_test_register_kernel(
3492
+ TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
3493
+ )
3494
+ add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
3495
+
3496
+ add_function_test(TestMat, "test_matrix_len", test_matrix_len, devices=devices)
3497
+ add_function_test(TestMat, "test_mat_extract", test_mat_extract, devices=devices)
3498
+ add_function_test(TestMat, "test_mat_assign", test_mat_assign, devices=devices)
3499
+ add_function_test(TestMat, "test_mat_array_extract", test_mat_array_extract, devices=devices)
3500
+ # add_function_test(TestMat, "test_mat_array_assign", test_mat_array_assign, devices=devices)
3501
+ add_function_test(TestMat, "test_mat_add_inplace", test_mat_add_inplace, devices=devices)
3502
+ add_function_test(TestMat, "test_mat_sub_inplace", test_mat_sub_inplace, devices=devices)
3503
+ add_function_test(TestMat, "test_mat_array_add_inplace", test_mat_array_add_inplace, devices=devices)
3504
+ add_function_test(TestMat, "test_mat_array_sub_inplace", test_mat_array_sub_inplace, devices=devices)
3505
+ add_function_test(TestMat, "test_scalar_mat_div", test_scalar_mat_div, devices=devices)
3506
+ add_function_test(TestMat, "test_mat_from_rows_indexing_assign", test_mat_from_rows_indexing_assign, devices=devices)
3507
+ add_function_test(TestMat, "test_mat_from_cols_indexing_assign", test_mat_from_cols_indexing_assign, devices=devices)
3508
+ add_function_test(TestMat, "test_mat_from_rows_slicing_assign", test_mat_from_rows_slicing_assign, devices=devices)
3509
+ add_function_test(TestMat, "test_mat_from_cols_slicing_assign", test_mat_from_cols_slicing_assign, devices=devices)
3510
+ add_function_test(TestMat, "test_mat_slicing_assign_backward", test_mat_slicing_assign_backward, devices=devices)
3511
+
3512
+
3513
+ if __name__ == "__main__":
3514
+ wp.clear_kernel_cache()
3515
+ unittest.main(verbosity=2, failfast=True)