warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/tests/test_mat.py ADDED
@@ -0,0 +1,2089 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ np_signed_int_types = [np.int8, np.int16, np.int32, np.int64, np.byte]
25
+ np_float_types = [np.float16, np.float32, np.float64]
26
+
27
+
28
+ def randvals(rng, shape, dtype):
29
+ if dtype in np_float_types:
30
+ return rng.standard_normal(size=shape).astype(dtype)
31
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
32
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
33
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
34
+
35
+
36
+ kernel_cache = {}
37
+
38
+
39
+ def getkernel(func, suffix=""):
40
+ key = func.__name__ + "_" + suffix
41
+ if key not in kernel_cache:
42
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
43
+ return kernel_cache[key]
44
+
45
+
46
+ def get_select_kernel(dtype):
47
+ def output_select_kernel_fn(input: wp.array(dtype=dtype), index: int, out: wp.array(dtype=dtype)):
48
+ out[0] = input[index]
49
+
50
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
51
+
52
+
53
+ def test_anon_constructor_error_shape_arg_missing(test, device):
54
+ @wp.kernel
55
+ def kernel():
56
+ wp.matrix(1.0, 2.0, 3.0)
57
+
58
+ with test.assertRaisesRegex(
59
+ RuntimeError,
60
+ r"the `shape` argument must be specified when initializing a matrix by value$",
61
+ ):
62
+ wp.launch(kernel, dim=1, inputs=[], device=device)
63
+
64
+
65
+ def test_anon_constructor_error_shape_mismatch(test, device):
66
+ @wp.kernel
67
+ def kernel():
68
+ wp.matrix(wp.matrix(shape=(1, 2), dtype=float), shape=(3, 4), dtype=float)
69
+
70
+ with test.assertRaisesRegex(
71
+ RuntimeError,
72
+ r"incompatible matrix of shape \(3, 4\) given when copy constructing a matrix of shape \(1, 2\)$",
73
+ ):
74
+ wp.launch(kernel, dim=1, inputs=[], device=device)
75
+
76
+
77
+ def test_anon_constructor_error_type_mismatch(test, device):
78
+ @wp.kernel
79
+ def kernel():
80
+ wp.matrix(1.0, shape=(3, 2), dtype=wp.float16)
81
+
82
+ with test.assertRaisesRegex(
83
+ RuntimeError,
84
+ r"the value used to fill this matrix is expected to be of the type `float16`$",
85
+ ):
86
+ wp.launch(kernel, dim=1, inputs=[], device=device)
87
+
88
+
89
+ def test_anon_constructor_error_invalid_arg_count(test, device):
90
+ @wp.kernel
91
+ def kernel():
92
+ wp.matrix(1.0, 2.0, 3.0, shape=(2, 2), dtype=float)
93
+
94
+ with test.assertRaisesRegex(
95
+ RuntimeError,
96
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
97
+ ):
98
+ wp.launch(kernel, dim=1, inputs=[], device=device)
99
+
100
+
101
+ def test_anon_xform_constructor_error_type_mismatch(test, device):
102
+ @wp.kernel
103
+ def kernel():
104
+ wp.matrix(wp.vec3(1.0, 2.0, 3.0), wp.quat(0.0, 0.0, 0.0, 1.0), wp.vec3(2.0, 2.0, 2.0), wp.float64)
105
+
106
+ with test.assertRaisesRegex(
107
+ RuntimeError,
108
+ r"all values used to initialize this transformation matrix are expected to be of the type `float64`$",
109
+ ):
110
+ wp.launch(
111
+ kernel,
112
+ dim=1,
113
+ inputs=[],
114
+ device=device,
115
+ )
116
+
117
+
118
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
119
+ @wp.kernel
120
+ def kernel():
121
+ wp.mat33(wp.mat22(1.0, 2.0, 3.0, 4.0))
122
+
123
+ with test.assertRaisesRegex(
124
+ RuntimeError,
125
+ r"incompatible matrix of shape \(3, 3\) given when copy constructing a matrix of shape \(2, 2\)$",
126
+ ):
127
+ wp.launch(kernel, dim=1, inputs=[], device=device)
128
+
129
+
130
+ def test_tpl_constructor_error_invalid_arg_count(test, device):
131
+ @wp.kernel
132
+ def kernel():
133
+ wp.mat22(1.0, 2.0, 3.0)
134
+
135
+ with test.assertRaisesRegex(
136
+ RuntimeError,
137
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
138
+ ):
139
+ wp.launch(kernel, dim=1, inputs=[], device=device)
140
+
141
+
142
+ def test_py_arithmetic_ops(test, device, dtype):
143
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
144
+
145
+ def make_mat(*args):
146
+ if wptype in wp.types.int_types:
147
+ # Cast to the correct integer type to simulate wrapping.
148
+ return tuple(tuple(wptype._type_(x).value for x in row) for row in args)
149
+
150
+ return args
151
+
152
+ def make_vec(*args):
153
+ if wptype in wp.types.int_types:
154
+ # Cast to the correct integer type to simulate wrapping.
155
+ return tuple(wptype._type_(x).value for x in args)
156
+
157
+ return args
158
+
159
+ mat_cls = wp.mat((3, 3), wptype)
160
+ vec_cls = wp.vec(3, wptype)
161
+
162
+ m = mat_cls(((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
163
+ test.assertSequenceEqual(+m, make_mat((-1, 2, 3), (4, -5, 6), (7, 8, -9)))
164
+ test.assertSequenceEqual(-m, make_mat((1, -2, -3), (-4, 5, -6), (-7, -8, 9)))
165
+ test.assertSequenceEqual(m + mat_cls((5, 5, 5) * 3), make_mat((4, 7, 8), (9, 0, 11), (12, 13, -4)))
166
+ test.assertSequenceEqual(m - mat_cls((5, 5, 5) * 3), make_mat((-6, -3, -2), (-1, -10, 1), (2, 3, -14)))
167
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(20, 25, 30))
168
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(20, 25, 30))
169
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(50, 25, 0))
170
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(50, 25, 0))
171
+
172
+ m = mat_cls(((2, 4, 6), (8, 10, 12), (14, 16, 18)))
173
+ test.assertSequenceEqual(m * wptype(2), make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
174
+ test.assertSequenceEqual(wptype(2) * m, make_mat((4, 8, 12), (16, 20, 24), (28, 32, 36)))
175
+ test.assertSequenceEqual(m / wptype(2), make_mat((1, 2, 3), (4, 5, 6), (7, 8, 9)))
176
+ test.assertSequenceEqual(wptype(5040) / m, make_mat((2520, 1260, 840), (630, 504, 420), (360, 315, 280)))
177
+ test.assertSequenceEqual(m * vec_cls(5, 5, 5), make_vec(60, 150, 240))
178
+ test.assertSequenceEqual(m @ vec_cls(5, 5, 5), make_vec(60, 150, 240))
179
+ test.assertSequenceEqual(vec_cls(5, 5, 5) * m, make_vec(120, 150, 180))
180
+ test.assertSequenceEqual(vec_cls(5, 5, 5) @ m, make_vec(120, 150, 180))
181
+
182
+
183
+ def test_quat_constructor(test, device, dtype, register_kernels=False):
184
+ rng = np.random.default_rng(123)
185
+
186
+ tol = {
187
+ np.float16: 1.0e-3,
188
+ np.float32: 1.0e-6,
189
+ np.float64: 1.0e-8,
190
+ }.get(dtype, 0)
191
+
192
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
193
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
194
+ vec4 = wp.types.vector(length=4, dtype=wptype)
195
+ vec3 = wp.types.vector(length=3, dtype=wptype)
196
+ quat = wp.types.quaternion(dtype=wptype)
197
+
198
+ output_select_kernel = get_select_kernel(wptype)
199
+
200
+ def check_mat_quat_constructor(
201
+ p: wp.array(dtype=vec3),
202
+ r: wp.array(dtype=quat),
203
+ s: wp.array(dtype=vec3),
204
+ outcomponents: wp.array(dtype=wptype),
205
+ outcomponents_alt: wp.array(dtype=wptype),
206
+ ):
207
+ m = mat44(p[0], r[0], s[0])
208
+
209
+ R = wp.transpose(wp.quat_to_matrix(r[0]))
210
+ c0 = s[0][0] * R[0]
211
+ c1 = s[0][1] * R[1]
212
+ c2 = s[0][2] * R[2]
213
+ m_alt = wp.matrix_from_cols(
214
+ vec4(c0[0], c0[1], c0[2], wptype(0.0)),
215
+ vec4(c1[0], c1[1], c1[2], wptype(0.0)),
216
+ vec4(c2[0], c2[1], c2[2], wptype(0.0)),
217
+ vec4(p[0][0], p[0][1], p[0][2], wptype(1.0)),
218
+ )
219
+
220
+ idx = 0
221
+ for i in range(4):
222
+ for j in range(4):
223
+ outcomponents[idx] = m[i, j]
224
+ outcomponents_alt[idx] = m_alt[i, j]
225
+ idx = idx + 1
226
+
227
+ kernel = getkernel(check_mat_quat_constructor, suffix=dtype.__name__)
228
+
229
+ if register_kernels:
230
+ return
231
+
232
+ # translation:
233
+ p = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
234
+
235
+ # generate a normalized quaternion for the rotation:
236
+ r = rng.standard_normal(size=(1, 4))
237
+ r /= np.linalg.norm(r)
238
+ r = wp.array(r.astype(dtype), dtype=quat, requires_grad=True, device=device)
239
+
240
+ # scale:
241
+ s = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
242
+
243
+ # just going to generate the matrix using the constructor, then
244
+ # more manually, and make sure the values/gradients are the same:
245
+ outcomponents = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
246
+ outcomponents_alt = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
247
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
248
+ assert_np_equal(outcomponents.numpy(), outcomponents_alt.numpy(), tol=1.0e-6)
249
+
250
+ idx = 0
251
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
252
+ out_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
253
+ for _i in range(4):
254
+ for _j in range(4):
255
+ tape = wp.Tape()
256
+ with tape:
257
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
258
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
259
+ wp.launch(
260
+ output_select_kernel, dim=1, inputs=[outcomponents_alt, idx], outputs=[out_alt], device=device
261
+ )
262
+
263
+ tape.backward(loss=out)
264
+ p_grad = 1.0 * tape.gradients[p].numpy()[0]
265
+ r_grad = 1.0 * tape.gradients[r].numpy()[0]
266
+ s_grad = 1.0 * tape.gradients[s].numpy()[0]
267
+ tape.zero()
268
+
269
+ tape.backward(loss=out_alt)
270
+ p_grad_alt = 1.0 * tape.gradients[p].numpy()[0]
271
+ r_grad_alt = 1.0 * tape.gradients[r].numpy()[0]
272
+ s_grad_alt = 1.0 * tape.gradients[s].numpy()[0]
273
+ tape.zero()
274
+
275
+ assert_np_equal(p_grad, p_grad_alt, tol=tol)
276
+ assert_np_equal(r_grad, r_grad_alt, tol=tol)
277
+ assert_np_equal(s_grad, s_grad_alt, tol=tol)
278
+
279
+ idx = idx + 1
280
+
281
+
282
+ def test_negation(test, device, dtype, register_kernels=False):
283
+ rng = np.random.default_rng(123)
284
+
285
+ tol = {
286
+ np.float16: 1.0e-2,
287
+ np.float32: 1.0e-6,
288
+ np.float64: 1.0e-8,
289
+ }.get(dtype, 0)
290
+
291
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
292
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
293
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
294
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
295
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
296
+
297
+ output_select_kernel = get_select_kernel(wptype)
298
+
299
+ def check_mat_negation(
300
+ m2: wp.array(dtype=mat22),
301
+ m3: wp.array(dtype=mat33),
302
+ m4: wp.array(dtype=mat44),
303
+ m5: wp.array(dtype=mat55),
304
+ outcomponents: wp.array(dtype=wptype),
305
+ ):
306
+ mat2 = -m2[0]
307
+ mat3 = -m3[0]
308
+ mat4 = -m4[0]
309
+ mat5 = -m5[0]
310
+
311
+ # multiply outputs by 2 so we've got something to backpropagate:
312
+ idx = 0
313
+ for i in range(2):
314
+ for j in range(2):
315
+ outcomponents[idx] = wptype(2) * mat2[i, j]
316
+ idx = idx + 1
317
+
318
+ for i in range(3):
319
+ for j in range(3):
320
+ outcomponents[idx] = wptype(2) * mat3[i, j]
321
+ idx = idx + 1
322
+
323
+ for i in range(4):
324
+ for j in range(4):
325
+ outcomponents[idx] = wptype(2) * mat4[i, j]
326
+ idx = idx + 1
327
+
328
+ for i in range(5):
329
+ for j in range(5):
330
+ outcomponents[idx] = wptype(2) * mat5[i, j]
331
+ idx = idx + 1
332
+
333
+ kernel = getkernel(check_mat_negation, suffix=dtype.__name__)
334
+
335
+ if register_kernels:
336
+ return
337
+
338
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
339
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
340
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
341
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
342
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
343
+
344
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
345
+
346
+ assert_np_equal(outcomponents.numpy()[:4], -2 * m2.numpy().reshape(-1), tol=tol)
347
+ assert_np_equal(outcomponents.numpy()[4:13], -2 * m3.numpy().reshape(-1), tol=tol)
348
+ assert_np_equal(outcomponents.numpy()[13:29], -2 * m4.numpy().reshape(-1), tol=tol)
349
+ assert_np_equal(outcomponents.numpy()[29:54], -2 * m5.numpy().reshape(-1), tol=tol)
350
+
351
+ if dtype in np_float_types:
352
+ idx = 0
353
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
354
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
355
+ for i in range(dim):
356
+ for j in range(dim):
357
+ tape = wp.Tape()
358
+ with tape:
359
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
360
+ wp.launch(
361
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
362
+ )
363
+ tape.backward(loss=out)
364
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
365
+ expectedresult[i, j] = -2
366
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
367
+ tape.zero()
368
+ idx = idx + 1
369
+
370
+
371
+ def test_matmul(test, device, dtype, register_kernels=False):
372
+ rng = np.random.default_rng(123)
373
+
374
+ tol = {
375
+ np.float16: 5.0e-3,
376
+ np.float32: 1.0e-6,
377
+ np.float64: 1.0e-12,
378
+ }.get(dtype, 0)
379
+
380
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
381
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
382
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
383
+ mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
384
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
385
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
386
+
387
+ output_select_kernel = get_select_kernel(wptype)
388
+
389
+ def check_mat_mul(
390
+ i23: wp.array(dtype=mat23),
391
+ i32: wp.array(dtype=mat32),
392
+ i44: wp.array(dtype=mat44),
393
+ o22: wp.array(dtype=mat22),
394
+ o33: wp.array(dtype=mat33),
395
+ o44: wp.array(dtype=mat44),
396
+ ):
397
+ i = wp.tid()
398
+ o22[i] = i23[i] @ i32[i]
399
+ o33[i] = i32[i] @ i23[i]
400
+ o44[i] = i44[i] @ i44[i]
401
+
402
+ kernel = getkernel(check_mat_mul, suffix=dtype.__name__)
403
+
404
+ if register_kernels:
405
+ return
406
+
407
+ test_adj = dtype in np_float_types
408
+
409
+ i23 = wp.array(randvals(rng, [1, 2, 3], dtype), dtype=mat23, requires_grad=test_adj, device=device)
410
+ i32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=test_adj, device=device)
411
+ i44 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=test_adj, device=device)
412
+ o22 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=test_adj, device=device)
413
+ o33 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=test_adj, device=device)
414
+ o44 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=test_adj, device=device)
415
+
416
+ tape = wp.Tape()
417
+ with tape:
418
+ wp.launch(
419
+ kernel,
420
+ dim=1,
421
+ inputs=[i23, i32, i44],
422
+ outputs=[o22, o33, o44],
423
+ device=device,
424
+ )
425
+
426
+ assert_np_equal(o22.numpy(), i23.numpy() @ i32.numpy(), tol=tol)
427
+ assert_np_equal(o33.numpy(), i32.numpy() @ i23.numpy(), tol=tol)
428
+ assert_np_equal(o44.numpy(), i44.numpy() @ i44.numpy(), tol=tol)
429
+
430
+ if test_adj:
431
+ o22.grad.assign([np.eye(2)])
432
+ o33.grad.assign([np.eye(3)])
433
+ o44.grad.assign([np.eye(4)])
434
+
435
+ tape.backward()
436
+
437
+ assert_np_equal(i23.grad.numpy(), 2.0 * i32.numpy().T, tol=tol)
438
+ assert_np_equal(i32.grad.numpy(), 2.0 * i23.numpy().T, tol=tol)
439
+ assert_np_equal(i44.grad.numpy(), 2.0 * i44.numpy().T, tol=tol)
440
+
441
+
442
+ def test_subtraction(test, device, dtype, register_kernels=False):
443
+ rng = np.random.default_rng(123)
444
+
445
+ tol = {
446
+ np.float16: 5.0e-3,
447
+ np.float32: 1.0e-6,
448
+ np.float64: 1.0e-8,
449
+ }.get(dtype, 0)
450
+
451
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
452
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
453
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
454
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
455
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
456
+
457
+ output_select_kernel = get_select_kernel(wptype)
458
+
459
+ def check_mat_sub(
460
+ s2: wp.array(dtype=mat22),
461
+ s3: wp.array(dtype=mat33),
462
+ s4: wp.array(dtype=mat44),
463
+ s5: wp.array(dtype=mat55),
464
+ v2: wp.array(dtype=mat22),
465
+ v3: wp.array(dtype=mat33),
466
+ v4: wp.array(dtype=mat44),
467
+ v5: wp.array(dtype=mat55),
468
+ outcomponents: wp.array(dtype=wptype),
469
+ ):
470
+ v2result = v2[0] - s2[0]
471
+ v3result = v3[0] - s3[0]
472
+ v4result = v4[0] - s4[0]
473
+ v5result = v5[0] - s5[0]
474
+
475
+ # multiply outputs by 2 so we've got something to backpropagate:
476
+ idx = 0
477
+ for i in range(2):
478
+ for j in range(2):
479
+ outcomponents[idx] = wptype(2) * v2result[i, j]
480
+ idx = idx + 1
481
+
482
+ for i in range(3):
483
+ for j in range(3):
484
+ outcomponents[idx] = wptype(2) * v3result[i, j]
485
+ idx = idx + 1
486
+
487
+ for i in range(4):
488
+ for j in range(4):
489
+ outcomponents[idx] = wptype(2) * v4result[i, j]
490
+ idx = idx + 1
491
+
492
+ for i in range(5):
493
+ for j in range(5):
494
+ outcomponents[idx] = wptype(2) * v5result[i, j]
495
+ idx = idx + 1
496
+
497
+ kernel = getkernel(check_mat_sub, suffix=dtype.__name__)
498
+
499
+ if register_kernels:
500
+ return
501
+
502
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
503
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
504
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
505
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
506
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
507
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
508
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
509
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
510
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
511
+
512
+ wp.launch(
513
+ kernel,
514
+ dim=1,
515
+ inputs=[
516
+ s2,
517
+ s3,
518
+ s4,
519
+ s5,
520
+ v2,
521
+ v3,
522
+ v4,
523
+ v5,
524
+ ],
525
+ outputs=[outcomponents],
526
+ device=device,
527
+ )
528
+
529
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() - s2.numpy()).reshape(-1), tol=tol)
530
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() - s3.numpy()).reshape(-1), tol=tol)
531
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() - s4.numpy()).reshape(-1), tol=tol)
532
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() - s5.numpy()).reshape(-1), tol=10 * tol)
533
+
534
+ if dtype in np_float_types:
535
+ idx = 0
536
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
537
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
538
+ for i in range(dim):
539
+ for j in range(dim):
540
+ tape = wp.Tape()
541
+ with tape:
542
+ wp.launch(
543
+ kernel,
544
+ dim=1,
545
+ inputs=[s2, s3, s4, s5, v2, v3, v4, v5],
546
+ outputs=[outcomponents],
547
+ device=device,
548
+ )
549
+ wp.launch(
550
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
551
+ )
552
+ tape.backward(loss=out)
553
+ expected_result = np.zeros((dim, dim), dtype=dtype)
554
+ expected_result[i, j] = 2
555
+ assert_np_equal(tape.gradients[in2].numpy()[0], expected_result, tol=10 * tol)
556
+ expected_result[i, j] = -2
557
+ assert_np_equal(tape.gradients[in1].numpy()[0], expected_result, tol=10 * tol)
558
+ tape.zero()
559
+
560
+ idx = idx + 1
561
+
562
+
563
+ def test_determinant(test, device, dtype, register_kernels=False):
564
+ rng = np.random.default_rng(123)
565
+
566
+ tol = {
567
+ np.float16: 5.0e-3,
568
+ np.float32: 1.0e-6,
569
+ np.float64: 1.0e-8,
570
+ }.get(dtype, 0)
571
+
572
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
573
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
574
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
575
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
576
+
577
+ def check_mat_det(
578
+ v2: wp.array(dtype=mat22),
579
+ v3: wp.array(dtype=mat33),
580
+ v4: wp.array(dtype=mat44),
581
+ det2: wp.array(dtype=wptype),
582
+ det3: wp.array(dtype=wptype),
583
+ det4: wp.array(dtype=wptype),
584
+ ):
585
+ # multiply outputs by 2 so we've got something to backpropagate:
586
+ det2[0] = wptype(2) * wp.determinant(v2[0])
587
+ det3[0] = wptype(2) * wp.determinant(v3[0])
588
+ det4[0] = wptype(2) * wp.determinant(v4[0])
589
+
590
+ kernel = getkernel(check_mat_det, suffix=dtype.__name__)
591
+ if register_kernels:
592
+ return
593
+
594
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
595
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
596
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
597
+ det2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
598
+ det3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
599
+ det4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
600
+
601
+ tape = wp.Tape()
602
+ with tape:
603
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4], outputs=[det2, det3, det4], device=device)
604
+
605
+ if dtype in np_float_types:
606
+ assert_np_equal(det2.numpy()[0], 2 * np.linalg.det(v2.numpy()[0].astype(np.float64)), tol=100 * tol)
607
+ assert_np_equal(det3.numpy()[0], 2 * np.linalg.det(v3.numpy()[0].astype(np.float64)), tol=100 * tol)
608
+ assert_np_equal(det4.numpy()[0], 2 * np.linalg.det(v4.numpy()[0].astype(np.float64)), tol=420 * tol)
609
+ else:
610
+ assert_np_equal(det2.numpy()[0], 2 * np.around(np.linalg.det(v2.numpy()[0])).astype(int))
611
+ assert_np_equal(det3.numpy()[0], 2 * np.around(np.linalg.det(v3.numpy()[0])).astype(int))
612
+ assert_np_equal(det4.numpy()[0], 2 * np.around(np.linalg.det(v4.numpy()[0])).astype(int))
613
+
614
+ if dtype in np_float_types:
615
+ # determinant derivative formula is annoying so finite differences?
616
+ tape.backward(loss=det2)
617
+ v2grads = 1.0 * tape.gradients[v2].numpy()[0]
618
+ tape.zero()
619
+
620
+ tape.backward(loss=det3)
621
+ v3grads = 1.0 * tape.gradients[v3].numpy()[0]
622
+ tape.zero()
623
+
624
+ tape.backward(loss=det4)
625
+ v4grads = 1.0 * tape.gradients[v4].numpy()[0]
626
+ tape.zero()
627
+
628
+ # finite differences are also annoying hence the large tolerance...
629
+ # absolute nightmare in float16 too innit...
630
+ dx = 0.01 if dtype == np.float16 else 0.0001
631
+ fdtol = 2.0e-1 if dtype == np.float16 else 2.0e-3
632
+ for i in range(2):
633
+ for j in range(2):
634
+ v2test = v2.numpy()
635
+ v2test[0, i, j] += dx
636
+ wp.launch(
637
+ kernel,
638
+ dim=1,
639
+ inputs=[wp.array(v2test, dtype=v2.dtype, requires_grad=True, device=device), v3, v4],
640
+ outputs=[det2, det3, det4],
641
+ device=device,
642
+ )
643
+ dplus = det2.numpy()[0]
644
+ v2test[0, i, j] -= 2.0 * dx
645
+ wp.launch(
646
+ kernel,
647
+ dim=1,
648
+ inputs=[wp.array(v2test, dtype=v2.dtype, requires_grad=True, device=device), v3, v4],
649
+ outputs=[det2, det3, det4],
650
+ device=device,
651
+ )
652
+ dminus = det2.numpy()[0]
653
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v2grads[i, j] / dplus, tol=fdtol)
654
+
655
+ for i in range(3):
656
+ for j in range(3):
657
+ v3test = v3.numpy()
658
+ v3test[0, i, j] += dx
659
+ wp.launch(
660
+ kernel,
661
+ dim=1,
662
+ inputs=[v2, wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device), v4],
663
+ outputs=[det2, det3, det4],
664
+ device=device,
665
+ )
666
+ dplus = det3.numpy()[0]
667
+ v3test[0, i, j] -= 2.0 * dx
668
+ wp.launch(
669
+ kernel,
670
+ dim=1,
671
+ inputs=[v2, wp.array(v3test, dtype=v3.dtype, requires_grad=True, device=device), v4],
672
+ outputs=[det2, det3, det4],
673
+ device=device,
674
+ )
675
+ dminus = det3.numpy()[0]
676
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v3grads[i, j] / dplus, tol=fdtol)
677
+
678
+ for i in range(4):
679
+ for j in range(4):
680
+ v4test = v4.numpy()
681
+ v4test[0, i, j] += dx
682
+ wp.launch(
683
+ kernel,
684
+ dim=1,
685
+ inputs=[v2, v3, wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device)],
686
+ outputs=[det2, det3, det4],
687
+ device=device,
688
+ )
689
+ dplus = det4.numpy()[0]
690
+ v4test[0, i, j] -= 2.0 * dx
691
+ wp.launch(
692
+ kernel,
693
+ dim=1,
694
+ inputs=[v2, v3, wp.array(v4test, dtype=v4.dtype, requires_grad=True, device=device)],
695
+ outputs=[det2, det3, det4],
696
+ device=device,
697
+ )
698
+ dminus = det4.numpy()[0]
699
+ assert_np_equal((dplus - dminus) / (2.0 * dx * dplus), v4grads[i, j] / dplus, tol=fdtol)
700
+
701
+
702
+ # Unused. Why?
703
+ # def test_get_diag(test, device, dtype, register_kernels=False):
704
+ # tol = {
705
+ # np.float16: 1.0e-3,
706
+ # np.float32: 1.0e-6,
707
+ # np.float64: 1.0e-8,
708
+ # }.get(dtype, 0)
709
+ #
710
+ # wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
711
+ # mat55 = wp.types.vector(shape=(5, 5), dtype=wptype)
712
+ #
713
+ # output_select_kernel = get_select_kernel(wptype)
714
+ #
715
+ # def check_mat_diag(
716
+ # m55: wp.array(dtype=mat55),
717
+ # outcomponents: wp.array(dtype=wptype),
718
+ # ):
719
+ # # multiply outputs by 2 so we've got something to backpropagate:
720
+ # vec5result = wptype(2) * wp.get_diag(m55[0])
721
+ #
722
+ # idx = 0
723
+ # for i in range(5):
724
+ # outcomponents[idx] = vec5result[i]
725
+ # idx = idx + 1
726
+ #
727
+ # kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
728
+ #
729
+ # if register_kernels:
730
+ # return
731
+ #
732
+ # m55 = wp.array(randvals((1, 5, 5), dtype), dtype=mat55, requires_grad=True, device=device)
733
+ # outcomponents = wp.zeros(5, dtype=wptype, requires_grad=True, device=device)
734
+ # out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
735
+ #
736
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
737
+ #
738
+ # assert_np_equal(outcomponents.numpy(), 2 * np.diag(m55.numpy()[0]), tol=tol)
739
+ #
740
+ # if dtype in np_float_types:
741
+ # idx = 0
742
+ # for i in range(5):
743
+ # tape = wp.Tape()
744
+ # with tape:
745
+ # wp.launch(kernel, dim=1, inputs=[m55], outputs=[outcomponents], device=device)
746
+ # wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
747
+ # tape.backward(loss=out)
748
+ # expectedresult = np.zeros((5, 5), dtype=dtype)
749
+ # expectedresult[i, i] = 2
750
+ # assert_np_equal(tape.gradients[m55].numpy()[0], expectedresult, tol=10 * tol)
751
+ # tape.zero()
752
+ #
753
+ # idx = idx + 1
754
+
755
+
756
+ def test_inverse(test, device, dtype, register_kernels=False):
757
+ rng = np.random.default_rng(123)
758
+
759
+ tol = {
760
+ np.float16: 5.0e-2,
761
+ np.float32: 1.0e-5,
762
+ np.float64: 1.0e-8,
763
+ }.get(dtype, 0)
764
+
765
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
766
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
767
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
768
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
769
+
770
+ output_select_kernel = get_select_kernel(wptype)
771
+
772
+ def check_mat_inverse(
773
+ m2: wp.array(dtype=mat22),
774
+ m3: wp.array(dtype=mat33),
775
+ m4: wp.array(dtype=mat44),
776
+ outcomponents: wp.array(dtype=wptype),
777
+ ):
778
+ m2result = wp.inverse(m2[0])
779
+ m3result = wp.inverse(m3[0])
780
+ m4result = wp.inverse(m4[0])
781
+
782
+ # multiply outputs by 2 so we've got something to backpropagate:
783
+ idx = 0
784
+ for i in range(2):
785
+ for j in range(2):
786
+ outcomponents[idx] = wptype(2) * m2result[i, j]
787
+ idx = idx + 1
788
+
789
+ for i in range(3):
790
+ for j in range(3):
791
+ outcomponents[idx] = wptype(2) * m3result[i, j]
792
+ idx = idx + 1
793
+
794
+ for i in range(4):
795
+ for j in range(4):
796
+ outcomponents[idx] = wptype(2) * m4result[i, j]
797
+ idx = idx + 1
798
+
799
+ kernel = getkernel(check_mat_inverse, suffix=dtype.__name__)
800
+
801
+ if register_kernels:
802
+ return
803
+
804
+ m2 = wp.array(
805
+ 2 * (randvals(rng, [1, 2, 2], dtype) + 0.2 * np.eye(2)), dtype=mat22, requires_grad=True, device=device
806
+ )
807
+ m3 = wp.array(
808
+ 2 * (randvals(rng, [1, 3, 3], dtype) + 0.2 * np.eye(3)), dtype=mat33, requires_grad=True, device=device
809
+ )
810
+ m4 = wp.array(
811
+ 2 * (randvals(rng, [1, 4, 4], dtype) + 0.2 * np.eye(4)), dtype=mat44, requires_grad=True, device=device
812
+ )
813
+
814
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4, dtype=wptype, requires_grad=True, device=device)
815
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
816
+
817
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
818
+
819
+ assert_np_equal(outcomponents.numpy()[:4], 2 * np.linalg.inv(m2.numpy()[0].astype(np.float64)), tol=tol)
820
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * np.linalg.inv(m3.numpy()[0].astype(np.float64)), tol=5 * tol)
821
+ assert_np_equal(outcomponents.numpy()[13:], 2 * np.linalg.inv(m4.numpy()[0].astype(np.float64)), tol=5 * tol)
822
+
823
+ if dtype in np_float_types:
824
+ # check gradients:
825
+ idx = 0
826
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
827
+ for dim, input in [(2, m2), (3, m3), (4, m4)]:
828
+ minv = np.linalg.inv(input.numpy()[0].astype(np.float64))
829
+ for i in range(dim):
830
+ for j in range(dim):
831
+ tape = wp.Tape()
832
+ with tape:
833
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
834
+ wp.launch(
835
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
836
+ )
837
+ tape.backward(loss=out)
838
+ d = np.zeros((dim, dim))
839
+ d[j, i] = 2
840
+ assert_np_equal(
841
+ tape.gradients[input].numpy()[0], -np.matmul(minv, np.matmul(d, minv)).T, tol=10 * tol
842
+ )
843
+ tape.zero()
844
+
845
+ idx = idx + 1
846
+
847
+ # let's check 2x2 using different formulae just for (in)sanity's sake:
848
+ m = m2.numpy()[0]
849
+
850
+ det = m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]
851
+ expected = 2 * np.array([[m[1, 1], -m[0, 1]], [-m[1, 0], m[0, 0]]], dtype=dtype) / det
852
+ assert_np_equal(expected, outcomponents.numpy()[:4], tol=tol)
853
+
854
+ # 0,0 component is this:
855
+ # 2 * m[1,1] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
856
+ assert_np_equal(2 * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[0], tol=tol)
857
+
858
+ tape = wp.Tape()
859
+ with tape:
860
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
861
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 0], outputs=[out], device=device)
862
+
863
+ if dtype in np_float_types:
864
+ tape.backward(loss=out)
865
+ g = tape.gradients[m2].numpy()[0]
866
+ assert_np_equal(-2 * m[1, 1] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
867
+ assert_np_equal(2 * m[1, 1] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
868
+ assert_np_equal(-2 * m[0, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
869
+ assert_np_equal(2 * m[1, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
870
+ tape.zero()
871
+
872
+ # 0,1 component is this:
873
+ # -2 * m[0,1] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
874
+ assert_np_equal(-2 * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[1], tol=tol)
875
+
876
+ tape = wp.Tape()
877
+ with tape:
878
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
879
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 1], outputs=[out], device=device)
880
+ if dtype in np_float_types:
881
+ tape.backward(loss=out)
882
+ g = tape.gradients[m2].numpy()[0]
883
+ assert_np_equal(2 * m[0, 1] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
884
+ assert_np_equal(-2 * m[0, 1] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
885
+ assert_np_equal(2 * m[0, 0] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
886
+ assert_np_equal(-2 * m[1, 1] * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
887
+ tape.zero()
888
+
889
+ # 1,0 component is this:
890
+ # -2 * m[1,0] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
891
+ assert_np_equal(-2 * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[2], tol=tol)
892
+
893
+ tape = wp.Tape()
894
+ with tape:
895
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
896
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 2], outputs=[out], device=device)
897
+
898
+ if dtype in np_float_types:
899
+ tape.backward(loss=out)
900
+ g = tape.gradients[m2].numpy()[0]
901
+ assert_np_equal(2 * m[1, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
902
+ assert_np_equal(-2 * m[0, 0] * m[1, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
903
+ assert_np_equal(2 * m[0, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
904
+ assert_np_equal(-2 * m[1, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
905
+ tape.zero()
906
+
907
+ # 1,1 component is this:
908
+ # 2 * m[0,0] / (m[0,0]*m[1,1] - m[1,0] * m[0,1])
909
+ assert_np_equal(2 * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]), outcomponents.numpy()[3], tol=tol)
910
+
911
+ tape = wp.Tape()
912
+ with tape:
913
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4], outputs=[outcomponents], device=device)
914
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, 3], outputs=[out], device=device)
915
+
916
+ if dtype in np_float_types:
917
+ tape.backward(loss=out)
918
+ g = tape.gradients[m2].numpy()[0]
919
+ assert_np_equal(-2 * m[0, 1] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 0], tol=tol)
920
+ assert_np_equal(2 * m[0, 0] * m[0, 1] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 0], tol=tol)
921
+ assert_np_equal(2 * m[0, 0] * m[1, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[0, 1], tol=tol)
922
+ assert_np_equal(-2 * m[0, 0] * m[0, 0] / (m[0, 0] * m[1, 1] - m[1, 0] * m[0, 1]) ** 2, g[1, 1], tol=tol)
923
+ tape.zero()
924
+
925
+
926
+ def test_svd(test, device, dtype, register_kernels=False):
927
+ rng = np.random.default_rng(123)
928
+
929
+ tol = {
930
+ np.float16: 1.0e-3,
931
+ np.float32: 1.0e-6,
932
+ np.float64: 1.0e-12,
933
+ }.get(dtype, 0)
934
+
935
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
936
+ vec3 = wp.types.vector(length=3, dtype=wptype)
937
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
938
+
939
+ def check_mat_svd(
940
+ m3: wp.array(dtype=mat33),
941
+ Uout: wp.array(dtype=mat33),
942
+ sigmaout: wp.array(dtype=vec3),
943
+ Vout: wp.array(dtype=mat33),
944
+ outcomponents: wp.array(dtype=wptype),
945
+ ):
946
+ U = mat33()
947
+ sigma = vec3()
948
+ V = mat33()
949
+
950
+ wp.svd3(m3[0], U, sigma, V)
951
+
952
+ Uout[0] = U
953
+ sigmaout[0] = sigma
954
+ Vout[0] = V
955
+
956
+ # multiply outputs by 2 so we've got something to backpropagate:
957
+ idx = 0
958
+ for i in range(3):
959
+ for j in range(3):
960
+ outcomponents[idx] = wptype(2) * U[i, j]
961
+ idx = idx + 1
962
+
963
+ for i in range(3):
964
+ outcomponents[idx] = wptype(2) * sigma[i]
965
+ idx = idx + 1
966
+
967
+ for i in range(3):
968
+ for j in range(3):
969
+ outcomponents[idx] = wptype(2) * V[i, j]
970
+ idx = idx + 1
971
+
972
+ kernel = getkernel(check_mat_svd, suffix=dtype.__name__)
973
+
974
+ output_select_kernel = get_select_kernel(wptype)
975
+
976
+ if register_kernels:
977
+ return
978
+
979
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype) + np.eye(3), dtype=mat33, requires_grad=True, device=device)
980
+
981
+ outcomponents = wp.zeros(2 * 3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
982
+ Uout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
983
+ sigmaout = wp.zeros(1, dtype=vec3, requires_grad=True, device=device)
984
+ Vout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
985
+
986
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
987
+
988
+ Uout_np = Uout.numpy()[0].astype(np.float64)
989
+ sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
990
+ Vout_np = Vout.numpy()[0].astype(np.float64)
991
+
992
+ assert_np_equal(
993
+ np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m3.numpy()[0].astype(np.float64), tol=30 * tol
994
+ )
995
+
996
+ if dtype == np.float16:
997
+ # I'm not even going to bother testing the gradients for float16
998
+ # because the rounding errors are terrible...
999
+ return
1000
+
1001
+ # check gradients:
1002
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1003
+ idx = 0
1004
+ for idx in range(3 * 3 + 3 + 3 * 3):
1005
+ tape = wp.Tape()
1006
+ with tape:
1007
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1008
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1009
+ tape.backward(out)
1010
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
1011
+
1012
+ tape.zero()
1013
+
1014
+ dx = 0.0001
1015
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1016
+ for ii in range(3):
1017
+ for jj in range(3):
1018
+ m3test = 1.0 * m3.numpy()
1019
+ m3test[0, ii, jj] += dx
1020
+ wp.launch(
1021
+ kernel,
1022
+ dim=1,
1023
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1024
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1025
+ device=device,
1026
+ )
1027
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1028
+ plusval = out.numpy()[0]
1029
+
1030
+ m3test = 1.0 * m3.numpy()
1031
+ m3test[0, ii, jj] -= dx
1032
+ wp.launch(
1033
+ kernel,
1034
+ dim=1,
1035
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1036
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1037
+ device=device,
1038
+ )
1039
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1040
+ minusval = out.numpy()[0]
1041
+
1042
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1043
+
1044
+
1045
+ def test_svd_2D(test, device, dtype, register_kernels=False):
1046
+ rng = np.random.default_rng(123)
1047
+
1048
+ tol = {
1049
+ np.float16: 1.0e-3,
1050
+ np.float32: 1.0e-6,
1051
+ np.float64: 1.0e-12,
1052
+ }.get(dtype, 0)
1053
+
1054
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1055
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1056
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1057
+
1058
+ def check_mat_svd2(
1059
+ m2: wp.array(dtype=mat22),
1060
+ Uout: wp.array(dtype=mat22),
1061
+ sigmaout: wp.array(dtype=vec2),
1062
+ Vout: wp.array(dtype=mat22),
1063
+ outcomponents: wp.array(dtype=wptype),
1064
+ ):
1065
+ U = mat22()
1066
+ sigma = vec2()
1067
+ V = mat22()
1068
+
1069
+ wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel
1070
+
1071
+ Uout[0] = U
1072
+ sigmaout[0] = sigma
1073
+ Vout[0] = V
1074
+
1075
+ # multiply outputs by 2 so we've got something to backpropagate:
1076
+ idx = 0
1077
+ for i in range(2):
1078
+ for j in range(2):
1079
+ outcomponents[idx] = wptype(2) * U[i, j]
1080
+ idx = idx + 1
1081
+
1082
+ for i in range(2):
1083
+ outcomponents[idx] = wptype(2) * sigma[i]
1084
+ idx = idx + 1
1085
+
1086
+ for i in range(2):
1087
+ for j in range(2):
1088
+ outcomponents[idx] = wptype(2) * V[i, j]
1089
+ idx = idx + 1
1090
+
1091
+ kernel = getkernel(check_mat_svd2, suffix=dtype.__name__)
1092
+
1093
+ output_select_kernel = get_select_kernel(wptype)
1094
+
1095
+ if register_kernels:
1096
+ return
1097
+
1098
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)
1099
+
1100
+ outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
1101
+ Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1102
+ sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
1103
+ Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1104
+
1105
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1106
+
1107
+ Uout_np = Uout.numpy()[0].astype(np.float64)
1108
+ sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
1109
+ Vout_np = Vout.numpy()[0].astype(np.float64)
1110
+
1111
+ assert_np_equal(
1112
+ np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
1113
+ )
1114
+
1115
+ if dtype == np.float16:
1116
+ # Skip gradient check for float16 due to rounding errors
1117
+ return
1118
+
1119
+ # Check gradients:
1120
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1121
+ idx = 0
1122
+ for idx in range(2 * 2 + 2 + 2 * 2):
1123
+ tape = wp.Tape()
1124
+ with tape:
1125
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1126
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1127
+ tape.backward(out)
1128
+ m2grads = 1.0 * tape.gradients[m2].numpy()[0]
1129
+
1130
+ tape.zero()
1131
+
1132
+ dx = 0.0001
1133
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1134
+ for ii in range(2):
1135
+ for jj in range(2):
1136
+ m2test = 1.0 * m2.numpy()
1137
+ m2test[0, ii, jj] += dx
1138
+ wp.launch(
1139
+ kernel,
1140
+ dim=1,
1141
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1142
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1143
+ device=device,
1144
+ )
1145
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1146
+ plusval = out.numpy()[0]
1147
+
1148
+ m2test = 1.0 * m2.numpy()
1149
+ m2test[0, ii, jj] -= dx
1150
+ wp.launch(
1151
+ kernel,
1152
+ dim=1,
1153
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1154
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1155
+ device=device,
1156
+ )
1157
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1158
+ minusval = out.numpy()[0]
1159
+
1160
+ assert_np_equal((plusval - minusval) / (2 * dx), m2grads[ii, jj], tol=fdtol)
1161
+
1162
+
1163
+ def test_qr(test, device, dtype, register_kernels=False):
1164
+ rng = np.random.default_rng(123)
1165
+
1166
+ tol = {
1167
+ np.float16: 2.0e-3,
1168
+ np.float32: 1.0e-6,
1169
+ np.float64: 1.0e-6,
1170
+ }.get(dtype, 0)
1171
+
1172
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1173
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1174
+
1175
+ def check_mat_qr(
1176
+ m3: wp.array(dtype=mat33),
1177
+ Qout: wp.array(dtype=mat33),
1178
+ Rout: wp.array(dtype=mat33),
1179
+ outcomponents: wp.array(dtype=wptype),
1180
+ ):
1181
+ Q = mat33()
1182
+ R = mat33()
1183
+
1184
+ wp.qr3(m3[0], Q, R)
1185
+
1186
+ Qout[0] = Q
1187
+ Rout[0] = R
1188
+
1189
+ # multiply outputs by 2 so we've got something to backpropagate:
1190
+ idx = 0
1191
+ for i in range(3):
1192
+ for j in range(3):
1193
+ outcomponents[idx] = wptype(2) * Q[i, j]
1194
+ idx = idx + 1
1195
+
1196
+ for i in range(3):
1197
+ for j in range(3):
1198
+ outcomponents[idx] = wptype(2) * R[i, j]
1199
+ idx = idx + 1
1200
+
1201
+ kernel = getkernel(check_mat_qr, suffix=dtype.__name__)
1202
+ output_select_kernel = get_select_kernel(wptype)
1203
+
1204
+ if register_kernels:
1205
+ return
1206
+
1207
+ m3 = wp.array(0.5 * (randvals(rng, [1, 3, 3], dtype) + np.eye(3)), dtype=mat33, requires_grad=True, device=device)
1208
+
1209
+ outcomponents = wp.zeros(2 * 3 * 3, dtype=wptype, requires_grad=True, device=device)
1210
+ Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1211
+ Rout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1212
+
1213
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, Rout, outcomponents], device=device)
1214
+
1215
+ Qout_np = Qout.numpy()[0].astype(np.float64)
1216
+ Rout_np = Rout.numpy()[0].astype(np.float64)
1217
+
1218
+ # check it's actually a q and an r:
1219
+ assert_np_equal(np.matmul(Qout_np.T, Qout_np), np.eye(3, dtype=np.float64), tol=tol)
1220
+ assert_np_equal(Rout_np[1, [0]], np.zeros(1, dtype=np.float64), tol=tol)
1221
+ assert_np_equal(Rout_np[2, [0, 1]], np.zeros(2, dtype=np.float64), tol=tol)
1222
+
1223
+ # check it's a factorization:
1224
+ assert_np_equal(np.matmul(Qout_np, Rout_np), m3.numpy()[0].astype(np.float64), tol=30 * tol)
1225
+
1226
+ if dtype == np.float16:
1227
+ # I'm not even going to bother testing the gradients for float16
1228
+ # because the rounding errors are terrible...
1229
+ return
1230
+
1231
+ # check gradients:
1232
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1233
+ idx = 0
1234
+ for idx in range(len(outcomponents)):
1235
+ tape = wp.Tape()
1236
+ with tape:
1237
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, Rout, outcomponents], device=device)
1238
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1239
+ tape.backward(out)
1240
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
1241
+
1242
+ tape.zero()
1243
+
1244
+ dx = 0.0001
1245
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1246
+ for ii in range(3):
1247
+ for jj in range(3):
1248
+ m3test = 1.0 * m3.numpy()
1249
+ m3test[0, ii, jj] += dx
1250
+ wp.launch(
1251
+ kernel,
1252
+ dim=1,
1253
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1254
+ outputs=[Qout, Rout, outcomponents],
1255
+ device=device,
1256
+ )
1257
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1258
+ plusval = out.numpy()[0]
1259
+
1260
+ m3test = 1.0 * m3.numpy()
1261
+ m3test[0, ii, jj] -= dx
1262
+ wp.launch(
1263
+ kernel,
1264
+ dim=1,
1265
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1266
+ outputs=[Qout, Rout, outcomponents],
1267
+ device=device,
1268
+ )
1269
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1270
+ minusval = out.numpy()[0]
1271
+
1272
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1273
+
1274
+
1275
+ def test_eig(test, device, dtype, register_kernels=False):
1276
+ rng = np.random.default_rng(123)
1277
+
1278
+ tol = {
1279
+ np.float16: 4.0e-2,
1280
+ np.float32: 1.0e-5,
1281
+ np.float64: 1.0e-5,
1282
+ }.get(dtype, 0)
1283
+
1284
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1285
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1286
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1287
+
1288
+ def check_mat_eig(
1289
+ m3: wp.array(dtype=mat33),
1290
+ Qout: wp.array(dtype=mat33),
1291
+ dout: wp.array(dtype=vec3),
1292
+ outcomponents: wp.array(dtype=wptype),
1293
+ ):
1294
+ Q = mat33()
1295
+ d = vec3()
1296
+
1297
+ wp.eig3(m3[0] + wp.transpose(m3[0]), Q, d)
1298
+
1299
+ Qout[0] = Q
1300
+ dout[0] = d
1301
+
1302
+ # multiply outputs by 2 so we've got something to backpropagate:
1303
+ idx = 0
1304
+ for i in range(3):
1305
+ for j in range(3):
1306
+ outcomponents[idx] = wptype(2) * Q[i, j]
1307
+ idx = idx + 1
1308
+
1309
+ for i in range(3):
1310
+ outcomponents[idx] = wptype(2) * d[i]
1311
+ idx = idx + 1
1312
+
1313
+ kernel = getkernel(check_mat_eig, suffix=dtype.__name__)
1314
+ output_select_kernel = get_select_kernel(wptype)
1315
+
1316
+ if register_kernels:
1317
+ return
1318
+
1319
+ m3_np = randvals(rng, [1, 3, 3], dtype) + np.eye(3, dtype=dtype)
1320
+ m3 = wp.array(m3_np, dtype=mat33, requires_grad=True, device=device)
1321
+
1322
+ outcomponents = wp.zeros(3 * 3 + 3, dtype=wptype, requires_grad=True, device=device)
1323
+ Qout = wp.zeros(1, dtype=mat33, requires_grad=True, device=device)
1324
+ dout = wp.zeros(1, dtype=vec3, requires_grad=True, device=device)
1325
+
1326
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, dout, outcomponents], device=device)
1327
+
1328
+ Qout_np = Qout.numpy()[0].astype(np.float64)
1329
+ dout_np = dout.numpy()[0].astype(np.float64)
1330
+ Dout_np = np.diag(dout_np)
1331
+
1332
+ # check Q is orthogonal:
1333
+ assert_np_equal(np.matmul(Qout_np.T, Qout_np), np.eye(3), tol=tol)
1334
+
1335
+ # check Q contains eigenvectors:
1336
+ assert_np_equal(np.matmul(Qout_np, np.matmul(Dout_np, Qout_np.T)), (m3_np[0] + m3_np[0].transpose()), tol=tol)
1337
+
1338
+ if dtype == np.float16:
1339
+ # I'm not even going to bother testing the gradients for float16
1340
+ # because the rounding errors are terrible...
1341
+ return
1342
+
1343
+ # check gradients:
1344
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1345
+ idx = 0
1346
+ for idx in range(len(outcomponents)):
1347
+ tape = wp.Tape()
1348
+ with tape:
1349
+ wp.launch(kernel, dim=1, inputs=[m3], outputs=[Qout, dout, outcomponents], device=device)
1350
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1351
+ tape.backward(out)
1352
+ m3grads = 1.0 * tape.gradients[m3].numpy()[0]
1353
+
1354
+ tape.zero()
1355
+
1356
+ dx = 0.0001
1357
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1358
+ for ii in range(3):
1359
+ for jj in range(3):
1360
+ m3test = 1.0 * m3.numpy()
1361
+ m3test[0, ii, jj] += dx
1362
+ wp.launch(
1363
+ kernel,
1364
+ dim=1,
1365
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1366
+ outputs=[Qout, dout, outcomponents],
1367
+ device=device,
1368
+ )
1369
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1370
+ plusval = out.numpy()[0]
1371
+
1372
+ m3test = 1.0 * m3.numpy()
1373
+ m3test[0, ii, jj] -= dx
1374
+ wp.launch(
1375
+ kernel,
1376
+ dim=1,
1377
+ inputs=[wp.array(m3test, dtype=mat33, device=device)],
1378
+ outputs=[Qout, dout, outcomponents],
1379
+ device=device,
1380
+ )
1381
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1382
+ minusval = out.numpy()[0]
1383
+
1384
+ assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1385
+
1386
+
1387
+ def test_skew(test, device, dtype, register_kernels=False):
1388
+ rng = np.random.default_rng(123)
1389
+
1390
+ tol = {
1391
+ np.float16: 1.0e-3,
1392
+ np.float32: 1.0e-6,
1393
+ np.float64: 1.0e-8,
1394
+ }.get(dtype, 0)
1395
+
1396
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1397
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1398
+
1399
+ output_select_kernel = get_select_kernel(wptype)
1400
+
1401
+ def check_mat_skew(
1402
+ v3: wp.array(dtype=vec3),
1403
+ outcomponents: wp.array(dtype=wptype),
1404
+ ):
1405
+ m3result = wp.skew(v3[0])
1406
+
1407
+ # multiply outputs by 2 so we've got something to backpropagate:
1408
+ idx = 0
1409
+ for i in range(3):
1410
+ for j in range(3):
1411
+ outcomponents[idx] = wptype(2) * m3result[i, j]
1412
+ idx = idx + 1
1413
+
1414
+ kernel = getkernel(check_mat_skew, suffix=dtype.__name__)
1415
+
1416
+ if register_kernels:
1417
+ return
1418
+
1419
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1420
+
1421
+ outcomponents = wp.zeros(3 * 3, dtype=wptype, requires_grad=True, device=device)
1422
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1423
+
1424
+ wp.launch(kernel, dim=1, inputs=[v3], outputs=[outcomponents], device=device)
1425
+
1426
+ # make sure it gives you a cross product matrix:
1427
+ crossprodmat = outcomponents.numpy().reshape(3, 3)
1428
+ v = np.array([1, 0, 0])
1429
+ assert_np_equal(
1430
+ np.matmul(crossprodmat, np.array([1, 0, 0])).reshape(-1),
1431
+ 2 * np.cross(v3.numpy()[0], np.array([1, 0, 0])),
1432
+ tol=tol,
1433
+ )
1434
+ assert_np_equal(
1435
+ np.matmul(crossprodmat, np.array([0, 1, 0])).reshape(-1),
1436
+ 2 * np.cross(v3.numpy()[0], np.array([0, 1, 0])),
1437
+ tol=tol,
1438
+ )
1439
+ assert_np_equal(
1440
+ np.matmul(crossprodmat, np.array([0, 0, 1])).reshape(-1),
1441
+ 2 * np.cross(v3.numpy()[0], np.array([0, 0, 1])),
1442
+ tol=tol,
1443
+ )
1444
+
1445
+ # check it another way:
1446
+ x0 = v3.numpy()[0, 0]
1447
+ x1 = v3.numpy()[0, 1]
1448
+ x2 = v3.numpy()[0, 2]
1449
+ crossprodmat_expected = np.array(
1450
+ [
1451
+ [0, -x2, x1],
1452
+ [x2, 0, -x0],
1453
+ [-x1, x0, 0],
1454
+ ],
1455
+ dtype=dtype,
1456
+ )
1457
+ assert_np_equal(crossprodmat, 2 * crossprodmat_expected, tol=tol)
1458
+
1459
+ if dtype in np_float_types:
1460
+ idx = 0
1461
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1462
+
1463
+ for i in range(3):
1464
+ for j in range(3):
1465
+ tape = wp.Tape()
1466
+ with tape:
1467
+ wp.launch(kernel, dim=1, inputs=[v3], outputs=[outcomponents], device=device)
1468
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1469
+ tape.backward(loss=out)
1470
+ if i == j:
1471
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.zeros(3))
1472
+ elif [i, j] == [0, 1]:
1473
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 0, -2]))
1474
+ elif [i, j] == [1, 0]:
1475
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 0, 2]))
1476
+ elif [i, j] == [0, 2]:
1477
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, 2, 0]))
1478
+ elif [i, j] == [2, 0]:
1479
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([0, -2, 0]))
1480
+ elif [i, j] == [1, 2]:
1481
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([-2, 0, 0]))
1482
+ elif [i, j] == [2, 1]:
1483
+ assert_np_equal(tape.gradients[v3].numpy()[0], np.array([2, 0, 0]))
1484
+ tape.zero()
1485
+
1486
+ idx = idx + 1
1487
+
1488
+
1489
+ def test_transform_point(test, device, dtype, register_kernels=False):
1490
+ rng = np.random.default_rng(123)
1491
+
1492
+ tol = {
1493
+ np.float16: 5.0e-3,
1494
+ np.float32: 1.0e-6,
1495
+ np.float64: 1.0e-8,
1496
+ }.get(dtype, 0)
1497
+
1498
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1499
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1500
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1501
+
1502
+ output_select_kernel = get_select_kernel(wptype)
1503
+
1504
+ def check_mat_transform_point(
1505
+ v3: wp.array(dtype=vec3),
1506
+ m4: wp.array(dtype=mat44),
1507
+ outcomponents: wp.array(dtype=wptype),
1508
+ ):
1509
+ # multiply outputs by 2 so we've got something to backpropagate:
1510
+ presult = wptype(2) * wp.transform_point(m4[0], v3[0])
1511
+
1512
+ outcomponents[0] = presult[0]
1513
+ outcomponents[1] = presult[1]
1514
+ outcomponents[2] = presult[2]
1515
+
1516
+ kernel = getkernel(check_mat_transform_point, suffix=dtype.__name__)
1517
+
1518
+ if register_kernels:
1519
+ return
1520
+
1521
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1522
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1523
+
1524
+ outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1525
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1526
+
1527
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1528
+
1529
+ v3homog = np.ones(4, dtype=dtype)
1530
+ v3homog[:3] = v3.numpy()[0]
1531
+ assert_np_equal(outcomponents.numpy(), 2 * np.matmul(m4.numpy()[0], v3homog)[:3], tol=10 * tol)
1532
+
1533
+ if dtype in np_float_types:
1534
+ for j in range(3):
1535
+ tape = wp.Tape()
1536
+ with tape:
1537
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1538
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, j], outputs=[out], device=device)
1539
+ tape.backward(loss=out)
1540
+
1541
+ assert_np_equal(2 * m4.numpy()[0, j, :3], tape.gradients[v3].numpy(), tol=tol)
1542
+ expected = np.zeros((4, 4), dtype=dtype)
1543
+ expected[j, :3] = 2 * v3.numpy()
1544
+ expected[j, 3] = 2
1545
+ assert_np_equal(tape.gradients[m4].numpy(), expected, tol=tol)
1546
+
1547
+ tape.zero()
1548
+
1549
+
1550
+ def test_transform_vector(test, device, dtype, register_kernels=False):
1551
+ rng = np.random.default_rng(123)
1552
+
1553
+ tol = {
1554
+ np.float16: 5.0e-3,
1555
+ np.float32: 1.0e-6,
1556
+ np.float64: 1.0e-8,
1557
+ }.get(dtype, 0)
1558
+
1559
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1560
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1561
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1562
+
1563
+ output_select_kernel = get_select_kernel(wptype)
1564
+
1565
+ def check_mat_transform_vector(
1566
+ v3: wp.array(dtype=vec3),
1567
+ m4: wp.array(dtype=mat44),
1568
+ outcomponents: wp.array(dtype=wptype),
1569
+ ):
1570
+ # multiply outputs by 2 so we've got something to backpropagate:
1571
+ presult = wptype(2) * wp.transform_vector(m4[0], v3[0])
1572
+
1573
+ outcomponents[0] = presult[0]
1574
+ outcomponents[1] = presult[1]
1575
+ outcomponents[2] = presult[2]
1576
+
1577
+ kernel = getkernel(check_mat_transform_vector, suffix=dtype.__name__)
1578
+
1579
+ if register_kernels:
1580
+ return
1581
+
1582
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1583
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1584
+
1585
+ outcomponents = wp.zeros(3, dtype=wptype, requires_grad=True, device=device)
1586
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1587
+
1588
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1589
+
1590
+ v3homog = np.zeros(4, dtype=dtype)
1591
+ v3homog[:3] = v3.numpy()[0]
1592
+ assert_np_equal(outcomponents.numpy(), 2 * np.matmul(m4.numpy()[0], v3homog)[:3], tol=10 * tol)
1593
+
1594
+ if dtype in np_float_types:
1595
+ for j in range(3):
1596
+ tape = wp.Tape()
1597
+ with tape:
1598
+ wp.launch(kernel, dim=1, inputs=[v3, m4], outputs=[outcomponents], device=device)
1599
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, j], outputs=[out], device=device)
1600
+ tape.backward(loss=out)
1601
+
1602
+ assert_np_equal(2 * m4.numpy()[0, j, :3], tape.gradients[v3].numpy(), tol=tol)
1603
+ expected = np.zeros((4, 4), dtype=dtype)
1604
+ expected[j, :3] = 2 * v3.numpy()
1605
+ assert_np_equal(tape.gradients[m4].numpy(), expected, tol=tol)
1606
+
1607
+ tape.zero()
1608
+
1609
+
1610
+ def test_matrix_assign_inplace(test, device, dtype, register_kernels=False):
1611
+ np_type = np.dtype(dtype)
1612
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
1613
+
1614
+ vec2 = wp.types.vector(length=2, dtype=wp_type)
1615
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wp_type)
1616
+
1617
+ def mattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=mat22)):
1618
+ tid = wp.tid()
1619
+
1620
+ t = a[tid]
1621
+ t[0, 0] = x[tid]
1622
+ a[tid] = t
1623
+
1624
+ def mattest_in_register(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1625
+ i, j = wp.tid()
1626
+
1627
+ a = mat22(wp_type(0.0))
1628
+ a[0] = y[i]
1629
+ a[1, 1] = wp_type(3.0)
1630
+ x[i, j] = a
1631
+
1632
+ kernel_read_write_store = getkernel(mattest_read_write_store, suffix=dtype.__name__)
1633
+ kernel_in_register = getkernel(mattest_in_register, suffix=dtype.__name__)
1634
+
1635
+ if register_kernels:
1636
+ return
1637
+
1638
+ a = wp.ones(1, dtype=mat22, device=device, requires_grad=True)
1639
+ x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1640
+
1641
+ tape = wp.Tape()
1642
+ with tape:
1643
+ wp.launch(kernel_read_write_store, dim=1, inputs=[x, a], device=device)
1644
+
1645
+ tape.backward(grads={a: wp.ones_like(a, requires_grad=False)})
1646
+
1647
+ assert_np_equal(a.numpy(), np.array([[[2.0, 1.0], [1.0, 1.0]]], dtype=np_type))
1648
+ assert_np_equal(x.grad.numpy(), np.array([1.0], dtype=np_type))
1649
+
1650
+ tape.reset()
1651
+
1652
+ x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1653
+ y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1654
+
1655
+ with tape:
1656
+ wp.launch(kernel_in_register, dim=(1, 1), inputs=[x, y], device=device)
1657
+
1658
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1659
+
1660
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 1.0], [0.0, 3.0]]]], dtype=np_type))
1661
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 1.0]], dtype=np_type))
1662
+
1663
+
1664
+ # Test matrix constructors using explicit type (float16)
1665
+ # note that these tests are specifically not using generics / closure
1666
+ # args to create kernels dynamically (like the rest of this file)
1667
+ # as those use different code paths to resolve arg types which
1668
+ # has lead to regressions.
1669
+ @wp.kernel
1670
+ def test_constructors_explicit_precision():
1671
+ # construction for custom matrix types
1672
+ eye = wp.identity(dtype=wp.float16, n=2)
1673
+ zeros = wp.matrix(shape=(2, 2), dtype=wp.float16)
1674
+ custom = wp.matrix(wp.float16(0.0), wp.float16(1.0), wp.float16(2.0), wp.float16(3.0), shape=(2, 2))
1675
+
1676
+ for i in range(2):
1677
+ for j in range(2):
1678
+ if i == j:
1679
+ wp.expect_eq(eye[i, j], wp.float16(1.0))
1680
+ else:
1681
+ wp.expect_eq(eye[i, j], wp.float16(0.0))
1682
+
1683
+ wp.expect_eq(zeros[i, j], wp.float16(0.0))
1684
+ wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
1685
+
1686
+
1687
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
1688
+
1689
+
1690
+ @wp.kernel
1691
+ def test_matrix_constructor_value_func():
1692
+ a = wp.mat22()
1693
+ b = wp.matrix(a, shape=(2, 2))
1694
+ c = mat32d()
1695
+ d = mat32d(c, shape=(3, 2))
1696
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
1697
+
1698
+
1699
+ @wp.kernel
1700
+ def test_matrix_from_vecs():
1701
+ m1 = wp.matrix_from_cols(
1702
+ wp.vec3(1.0, 2.0, 3.0),
1703
+ wp.vec3(4.0, 5.0, 6.0),
1704
+ wp.vec3(7.0, 8.0, 9.0),
1705
+ )
1706
+ wp.expect_eq(m1[0, 0], 1.0)
1707
+ wp.expect_eq(m1[0, 1], 4.0)
1708
+ wp.expect_eq(m1[0, 2], 7.0)
1709
+ wp.expect_eq(m1[1, 0], 2.0)
1710
+ wp.expect_eq(m1[1, 1], 5.0)
1711
+ wp.expect_eq(m1[1, 2], 8.0)
1712
+ wp.expect_eq(m1[2, 0], 3.0)
1713
+ wp.expect_eq(m1[2, 1], 6.0)
1714
+ wp.expect_eq(m1[2, 2], 9.0)
1715
+
1716
+ m2 = wp.matrix_from_rows(
1717
+ wp.vec3(1.0, 2.0, 3.0),
1718
+ wp.vec3(4.0, 5.0, 6.0),
1719
+ wp.vec3(7.0, 8.0, 9.0),
1720
+ )
1721
+ wp.expect_eq(m2[0, 0], 1.0)
1722
+ wp.expect_eq(m2[0, 1], 2.0)
1723
+ wp.expect_eq(m2[0, 2], 3.0)
1724
+ wp.expect_eq(m2[1, 0], 4.0)
1725
+ wp.expect_eq(m2[1, 1], 5.0)
1726
+ wp.expect_eq(m2[1, 2], 6.0)
1727
+ wp.expect_eq(m2[2, 0], 7.0)
1728
+ wp.expect_eq(m2[2, 1], 8.0)
1729
+ wp.expect_eq(m2[2, 2], 9.0)
1730
+
1731
+ m3 = wp.matrix_from_cols(
1732
+ wp.vec3(1.0, 2.0, 3.0),
1733
+ wp.vec3(4.0, 5.0, 6.0),
1734
+ )
1735
+ wp.expect_eq(m3[0, 0], 1.0)
1736
+ wp.expect_eq(m3[0, 1], 4.0)
1737
+ wp.expect_eq(m3[1, 0], 2.0)
1738
+ wp.expect_eq(m3[1, 1], 5.0)
1739
+ wp.expect_eq(m3[2, 0], 3.0)
1740
+ wp.expect_eq(m3[2, 1], 6.0)
1741
+
1742
+ m4 = wp.matrix_from_rows(
1743
+ wp.vec3(1.0, 2.0, 3.0),
1744
+ wp.vec3(4.0, 5.0, 6.0),
1745
+ )
1746
+ wp.expect_eq(m4[0, 0], 1.0)
1747
+ wp.expect_eq(m4[0, 1], 2.0)
1748
+ wp.expect_eq(m4[0, 2], 3.0)
1749
+ wp.expect_eq(m4[1, 0], 4.0)
1750
+ wp.expect_eq(m4[1, 1], 5.0)
1751
+ wp.expect_eq(m4[1, 2], 6.0)
1752
+
1753
+
1754
+ # Same as above but with a default (float/int) type
1755
+ # which tests some different code paths that
1756
+ # need to ensure types are correctly canonicalized
1757
+ # during codegen
1758
+ @wp.kernel
1759
+ def test_constructors_default_precision():
1760
+ # construction for default (float) matrix types
1761
+ eye = wp.identity(dtype=float, n=2)
1762
+ zeros = wp.matrix(shape=(2, 2), dtype=float)
1763
+ custom = wp.matrix(0.0, 1.0, 2.0, 3.0, shape=(2, 2))
1764
+
1765
+ for i in range(2):
1766
+ for j in range(2):
1767
+ if i == j:
1768
+ wp.expect_eq(eye[i, j], 1.0)
1769
+ else:
1770
+ wp.expect_eq(eye[i, j], 0.0)
1771
+
1772
+ wp.expect_eq(zeros[i, j], 0.0)
1773
+ wp.expect_eq(custom[i, j], float(i) * 2.0 + float(j))
1774
+
1775
+
1776
+ @wp.kernel
1777
+ def test_matrix_mutation(expected: wp.types.matrix(shape=(10, 3), dtype=float)):
1778
+ m = wp.matrix(shape=(10, 3), dtype=float)
1779
+
1780
+ # test direct element indexing
1781
+ m[0, 0] = 1.0
1782
+ m[0, 1] = 2.0
1783
+ m[0, 2] = 3.0
1784
+
1785
+ # The nested indexing (matrix->vector->scalar) below does not
1786
+ # currently modify m because m[0] returns row vector by
1787
+ # value rather than reference, this is different from NumPy
1788
+ # which always returns by ref. Not clear how we can support
1789
+ # this as well as auto-diff.
1790
+
1791
+ # m[0][1] = 2.0
1792
+ # m[0][2] = 3.0
1793
+
1794
+ # test setting rows
1795
+ for i in range(1, 10):
1796
+ m[i] = m[i - 1] + wp.vec3(1.0, 2.0, 3.0)
1797
+
1798
+ wp.expect_eq(m, expected)
1799
+
1800
+
1801
+ # NOTE: Compile tile is highly sensitive to shape so we use small values now
1802
+ CONSTANT_SHAPE_ROWS = wp.constant(2)
1803
+ CONSTANT_SHAPE_COLS = wp.constant(2)
1804
+
1805
+
1806
+ # tests that we can use global constants in shape keyword argument
1807
+ # for matrix constructor
1808
+ @wp.kernel
1809
+ def test_constructors_constant_shape():
1810
+ m = wp.matrix(shape=(CONSTANT_SHAPE_ROWS, CONSTANT_SHAPE_COLS), dtype=float)
1811
+
1812
+ for i in range(CONSTANT_SHAPE_ROWS):
1813
+ for j in range(CONSTANT_SHAPE_COLS):
1814
+ m[i, j] = float(i * j)
1815
+
1816
+
1817
+ Mat23 = wp.mat((2, 3), dtype=wp.float16)
1818
+
1819
+
1820
+ @wp.kernel
1821
+ def matrix_len_kernel(
1822
+ m1: wp.mat22, m2: wp.mat((3, 3), float), m3: wp.mat((Any, Any), float), m4: Mat23, out: wp.array(dtype=int)
1823
+ ):
1824
+ length = wp.static(len(m1))
1825
+ wp.expect_eq(len(m1), 2)
1826
+ out[0] = len(m1)
1827
+
1828
+ length = len(m2)
1829
+ wp.expect_eq(wp.static(len(m2)), 3)
1830
+ out[1] = len(m2)
1831
+
1832
+ length = len(m3)
1833
+ wp.expect_eq(len(m3), 4)
1834
+ out[2] = wp.static(len(m3))
1835
+
1836
+ length = wp.static(len(m4))
1837
+ wp.expect_eq(wp.static(len(m4)), 2)
1838
+ out[3] = wp.static(len(m4))
1839
+
1840
+ foo = wp.mat22()
1841
+ length = len(foo)
1842
+ wp.expect_eq(len(foo), 2)
1843
+ out[4] = len(foo)
1844
+
1845
+
1846
+ def test_matrix_len(test, device):
1847
+ m1 = wp.mat22()
1848
+ m2 = wp.mat33()
1849
+ m3 = wp.mat44()
1850
+ m4 = Mat23()
1851
+ out = wp.empty(5, dtype=int, device=device)
1852
+ wp.launch(matrix_len_kernel, dim=(1,), inputs=(m1, m2, m3, m4), outputs=(out,), device=device)
1853
+
1854
+ test.assertEqual(out.numpy()[0], 2)
1855
+ test.assertEqual(out.numpy()[1], 3)
1856
+ test.assertEqual(out.numpy()[2], 4)
1857
+ test.assertEqual(out.numpy()[3], 2)
1858
+ test.assertEqual(out.numpy()[4], 2)
1859
+
1860
+ test.assertEqual(len(m1), 2)
1861
+ test.assertEqual(len(m2), 3)
1862
+ test.assertEqual(len(m3), 4)
1863
+ test.assertEqual(len(m4), 2)
1864
+
1865
+
1866
+ @wp.kernel
1867
+ def matrix_augassign_kernel(
1868
+ a: wp.array(dtype=wp.mat22),
1869
+ b: wp.array(dtype=wp.mat22),
1870
+ x: wp.array(dtype=wp.vec2),
1871
+ c: wp.array(dtype=wp.mat22),
1872
+ d: wp.array(dtype=wp.mat22),
1873
+ y: wp.array(dtype=wp.vec2),
1874
+ ):
1875
+ i = wp.tid()
1876
+
1877
+ m1 = wp.mat22()
1878
+ m2 = b[i]
1879
+ v2 = x[i]
1880
+
1881
+ m1[0] += v2
1882
+ m1[1, 0] += m2[1, 0]
1883
+ m1[1, 1] += m2[1, 1]
1884
+
1885
+ a[i] = m1
1886
+
1887
+ m3 = wp.mat22()
1888
+ m4 = d[i]
1889
+ v4 = y[i]
1890
+
1891
+ m3[0] -= v4
1892
+ m3[1, 0] -= m4[1, 0]
1893
+ m3[1, 1] -= m4[1, 1]
1894
+
1895
+ c[i] = m3
1896
+
1897
+
1898
+ def test_matrix_augassign(test, device):
1899
+ N = 1
1900
+
1901
+ a = wp.zeros(N, dtype=wp.mat22, requires_grad=True, device=device)
1902
+ b = wp.ones(N, dtype=wp.mat22, requires_grad=True, device=device)
1903
+ x = wp.ones(N, dtype=wp.vec2, requires_grad=True, device=device)
1904
+
1905
+ c = wp.zeros(N, dtype=wp.mat22, requires_grad=True, device=device)
1906
+ d = wp.ones(N, dtype=wp.mat22, requires_grad=True, device=device)
1907
+ y = wp.ones(N, dtype=wp.vec2, requires_grad=True, device=device)
1908
+
1909
+ tape = wp.Tape()
1910
+ with tape:
1911
+ wp.launch(matrix_augassign_kernel, N, inputs=[a, b, x, c, d, y], device=device)
1912
+
1913
+ tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
1914
+
1915
+ assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
1916
+ assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
1917
+ assert_np_equal(b.grad.numpy(), np.array([[[0, 0], [1, 1]]], dtype=float))
1918
+ assert_np_equal(x.grad.numpy(), np.array([[1, 1]], dtype=float))
1919
+
1920
+ assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
1921
+ assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
1922
+ assert_np_equal(d.grad.numpy(), np.array([[[0, 0], [-1, -1]]], dtype=float))
1923
+ assert_np_equal(y.grad.numpy(), np.array([[-1, -1]], dtype=float))
1924
+
1925
+
1926
+ def test_matrix_assign_copy(test, device):
1927
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1928
+ try:
1929
+ wp.config.enable_vector_component_overwrites = True
1930
+
1931
+ @wp.kernel
1932
+ def mat_in_register_overwrite(x: wp.array2d(dtype=wp.mat22), y: wp.array(dtype=wp.vec2)):
1933
+ i, j = wp.tid()
1934
+
1935
+ a = wp.mat22()
1936
+ a[0] = y[i]
1937
+ a[0, 1] = 3.0
1938
+ x[i, j] = a
1939
+
1940
+ x = wp.zeros((1, 1), dtype=wp.mat22, device=device, requires_grad=True)
1941
+ y = wp.ones(1, dtype=wp.vec2, device=device, requires_grad=True)
1942
+
1943
+ tape = wp.Tape()
1944
+ with tape:
1945
+ wp.launch(mat_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1946
+
1947
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1948
+
1949
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=float))
1950
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=float))
1951
+
1952
+ finally:
1953
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1954
+
1955
+
1956
+ devices = get_test_devices()
1957
+
1958
+
1959
+ class TestMat(unittest.TestCase):
1960
+ def test_tpl_ops_with_anon(self):
1961
+ mat22f = wp.mat((2, 2), dtype=float)
1962
+
1963
+ m = wp.mat22f(1.0, 2.0, 3.0, 4.0)
1964
+ m += mat22f(2.0, 3.0, 4.0, 5.0)
1965
+ m -= mat22f(3.0, 4.0, 5.0, 6.0)
1966
+ self.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
1967
+
1968
+ m = mat22f(1.0, 2.0, 3.0, 4.0)
1969
+ m += wp.mat22f(2.0, 3.0, 4.0, 5.0)
1970
+ m -= wp.mat22f(3.0, 4.0, 5.0, 6.0)
1971
+ self.assertSequenceEqual(m, ((0.0, 1.0), (2.0, 3.0)))
1972
+
1973
+
1974
+ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=devices)
1975
+ add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1976
+ add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1977
+ add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
1978
+ add_kernel_test(TestMat, test_matrix_from_vecs, dim=1, devices=devices)
1979
+
1980
+ mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1981
+ add_kernel_test(
1982
+ TestMat,
1983
+ test_matrix_mutation,
1984
+ dim=1,
1985
+ inputs=[
1986
+ mat103(
1987
+ 1.0, 2.0, 3.0,
1988
+ 2.0, 4.0, 6.0,
1989
+ 3.0, 6.0, 9.0,
1990
+ 4.0, 8.0, 12.0,
1991
+ 5.0, 10.0, 15.0,
1992
+ 6.0, 12.0, 18.0,
1993
+ 7.0, 14.0, 21.0,
1994
+ 8.0, 16.0, 24.0,
1995
+ 9.0, 18.0, 27.0,
1996
+ 10.0, 20.0, 30.0,
1997
+ )
1998
+ ],
1999
+ devices=devices,
2000
+ ) # fmt: skip
2001
+
2002
+ for dtype in np_signed_int_types + np_float_types:
2003
+ add_function_test_register_kernel(
2004
+ TestMat, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
2005
+ )
2006
+ add_function_test_register_kernel(
2007
+ TestMat, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
2008
+ )
2009
+ add_function_test_register_kernel(
2010
+ TestMat, f"test_matmul_{dtype.__name__}", test_matmul, devices=devices, dtype=dtype
2011
+ )
2012
+
2013
+ add_function_test(
2014
+ TestMat,
2015
+ "test_anon_constructor_error_shape_arg_missing",
2016
+ test_anon_constructor_error_shape_arg_missing,
2017
+ devices=devices,
2018
+ )
2019
+ add_function_test(
2020
+ TestMat, "test_anon_constructor_error_shape_mismatch", test_anon_constructor_error_shape_mismatch, devices=devices
2021
+ )
2022
+ add_function_test(
2023
+ TestMat, "test_anon_constructor_error_type_mismatch", test_anon_constructor_error_type_mismatch, devices=devices
2024
+ )
2025
+ add_function_test(
2026
+ TestMat,
2027
+ "test_anon_constructor_error_invalid_arg_count",
2028
+ test_anon_constructor_error_invalid_arg_count,
2029
+ devices=devices,
2030
+ )
2031
+ add_function_test(
2032
+ TestMat,
2033
+ "test_anon_xform_constructor_error_type_mismatch",
2034
+ test_anon_xform_constructor_error_type_mismatch,
2035
+ devices=devices,
2036
+ )
2037
+ add_function_test(
2038
+ TestMat,
2039
+ "test_tpl_constructor_error_incompatible_sizes",
2040
+ test_tpl_constructor_error_incompatible_sizes,
2041
+ devices=devices,
2042
+ )
2043
+ add_function_test(
2044
+ TestMat,
2045
+ "test_tpl_constructor_error_invalid_arg_count",
2046
+ test_tpl_constructor_error_invalid_arg_count,
2047
+ devices=devices,
2048
+ )
2049
+
2050
+ for dtype in np_float_types:
2051
+ add_function_test(
2052
+ TestMat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2053
+ )
2054
+ add_function_test_register_kernel(
2055
+ TestMat, f"test_quat_constructor_{dtype.__name__}", test_quat_constructor, devices=devices, dtype=dtype
2056
+ )
2057
+ add_function_test_register_kernel(
2058
+ TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
2059
+ )
2060
+ add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
2061
+ add_function_test_register_kernel(
2062
+ TestMat, f"test_svd_2D{dtype.__name__}", test_svd_2D, devices=devices, dtype=dtype
2063
+ )
2064
+ add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
2065
+ add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
2066
+ add_function_test_register_kernel(
2067
+ TestMat, f"test_transform_point_{dtype.__name__}", test_transform_point, devices=devices, dtype=dtype
2068
+ )
2069
+ add_function_test_register_kernel(
2070
+ TestMat, f"test_transform_vector_{dtype.__name__}", test_transform_vector, devices=devices, dtype=dtype
2071
+ )
2072
+ add_function_test_register_kernel(
2073
+ TestMat, f"test_determinant_{dtype.__name__}", test_determinant, devices=devices, dtype=dtype
2074
+ )
2075
+ add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
2076
+ add_function_test_register_kernel(
2077
+ TestMat,
2078
+ f"test_matrix_assign_inplace_{dtype.__name__}",
2079
+ test_matrix_assign_inplace,
2080
+ devices=devices,
2081
+ dtype=dtype,
2082
+ )
2083
+ add_function_test(TestMat, "test_matrix_len", test_matrix_len, devices=devices)
2084
+ add_function_test(TestMat, "test_matrix_augassign", test_matrix_augassign, devices=devices)
2085
+ add_function_test(TestMat, "test_matrix_assign_copy", test_matrix_assign_copy, devices=devices)
2086
+
2087
+ if __name__ == "__main__":
2088
+ wp.clear_kernel_cache()
2089
+ unittest.main(verbosity=2, failfast=True)