warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2913 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_signed_int_types = [
24
+ np.int8,
25
+ np.int16,
26
+ np.int32,
27
+ np.int64,
28
+ np.byte,
29
+ ]
30
+
31
+ np_unsigned_int_types = [
32
+ np.uint8,
33
+ np.uint16,
34
+ np.uint32,
35
+ np.uint64,
36
+ np.ubyte,
37
+ ]
38
+
39
+ np_int_types = np_signed_int_types + np_unsigned_int_types
40
+
41
+ np_float_types = [np.float16, np.float32, np.float64]
42
+
43
+ np_scalar_types = np_int_types + np_float_types
44
+
45
+
46
+ def randvals(rng, shape, dtype):
47
+ if dtype in np_float_types:
48
+ return rng.standard_normal(size=shape).astype(dtype)
49
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
50
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
51
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
52
+
53
+
54
+ kernel_cache = {}
55
+
56
+
57
+ def getkernel(func, suffix=""):
58
+ key = func.__name__ + "_" + suffix
59
+ if key not in kernel_cache:
60
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
61
+ return kernel_cache[key]
62
+
63
+
64
+ def get_select_kernel(dtype):
65
+ def output_select_kernel_fn(
66
+ input: wp.array(dtype=dtype),
67
+ index: int,
68
+ out: wp.array(dtype=dtype),
69
+ ):
70
+ out[0] = input[index]
71
+
72
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
73
+
74
+
75
+ def test_arrays(test, device, dtype):
76
+ rng = np.random.default_rng(123)
77
+
78
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
79
+
80
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
81
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
82
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
83
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
84
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
85
+
86
+ v2_np = randvals(rng, [10, 2, 2], dtype)
87
+ v3_np = randvals(rng, [10, 3, 3], dtype)
88
+ v4_np = randvals(rng, [10, 4, 4], dtype)
89
+ v5_np = randvals(rng, [10, 5, 5], dtype)
90
+ v32_np = randvals(rng, [10, 3, 2], dtype)
91
+
92
+ v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
93
+ v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
94
+ v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
95
+ v5 = wp.array(v5_np, dtype=mat55, requires_grad=True, device=device)
96
+ v32 = wp.array(v32_np, dtype=mat32, requires_grad=True, device=device)
97
+
98
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
99
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
100
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
101
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
102
+ assert_np_equal(v32.numpy(), v32_np, tol=1.0e-6)
103
+
104
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
105
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
106
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
107
+
108
+ v2 = wp.array(v2_np, dtype=mat22, requires_grad=True, device=device)
109
+ v3 = wp.array(v3_np, dtype=mat33, requires_grad=True, device=device)
110
+ v4 = wp.array(v4_np, dtype=mat44, requires_grad=True, device=device)
111
+
112
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
113
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
114
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
115
+
116
+
117
+ def test_components(test, device, dtype):
118
+ # test accessing matrix components from Python - this is especially important
119
+ # for float16, which requires special handling internally
120
+
121
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
122
+ mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
123
+
124
+ m = mat23(1, 2, 3, 4, 5, 6)
125
+
126
+ # test __getitem__ for row vectors
127
+ r0 = m[0]
128
+ r1 = m[1]
129
+ test.assertEqual(r0[0], 1)
130
+ test.assertEqual(r0[1], 2)
131
+ test.assertEqual(r0[2], 3)
132
+ test.assertEqual(r1[0], 4)
133
+ test.assertEqual(r1[1], 5)
134
+ test.assertEqual(r1[2], 6)
135
+
136
+ # test __getitem__ for individual components
137
+ test.assertEqual(m[0, 0], 1)
138
+ test.assertEqual(m[0, 1], 2)
139
+ test.assertEqual(m[0, 2], 3)
140
+ test.assertEqual(m[1, 0], 4)
141
+ test.assertEqual(m[1, 1], 5)
142
+ test.assertEqual(m[1, 2], 6)
143
+
144
+ # test __setitem__ for row vectors
145
+ m[0] = [7, 8, 9]
146
+ m[1] = [10, 11, 12]
147
+ test.assertEqual(m[0, 0], 7)
148
+ test.assertEqual(m[0, 1], 8)
149
+ test.assertEqual(m[0, 2], 9)
150
+ test.assertEqual(m[1, 0], 10)
151
+ test.assertEqual(m[1, 1], 11)
152
+ test.assertEqual(m[1, 2], 12)
153
+
154
+ # test __setitem__ for individual components
155
+ m[0, 0] = 13
156
+ m[0, 1] = 14
157
+ m[0, 2] = 15
158
+ m[1, 0] = 16
159
+ m[1, 1] = 17
160
+ m[1, 2] = 18
161
+ test.assertEqual(m[0, 0], 13)
162
+ test.assertEqual(m[0, 1], 14)
163
+ test.assertEqual(m[0, 2], 15)
164
+ test.assertEqual(m[1, 0], 16)
165
+ test.assertEqual(m[1, 1], 17)
166
+ test.assertEqual(m[1, 2], 18)
167
+
168
+
169
+ def test_constants(test, device, dtype, register_kernels=False):
170
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
171
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
172
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
173
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
174
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
175
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
176
+
177
+ cm22 = wp.constant(mat22(22))
178
+ cm33 = wp.constant(mat33(33))
179
+ cm44 = wp.constant(mat44(44))
180
+ cm55 = wp.constant(mat55(55))
181
+ cm32 = wp.constant(mat32(32))
182
+
183
+ def check_matrix_constants():
184
+ wp.expect_eq(cm22, mat22(wptype(22)))
185
+ wp.expect_eq(cm33, mat33(wptype(33)))
186
+ wp.expect_eq(cm44, mat44(wptype(44)))
187
+ wp.expect_eq(cm55, mat55(wptype(55)))
188
+ wp.expect_eq(cm32, mat32(wptype(32)))
189
+
190
+ kernel = getkernel(check_matrix_constants, suffix=dtype.__name__)
191
+
192
+ if register_kernels:
193
+ return
194
+
195
+
196
+ def test_constructors(test, device, dtype, register_kernels=False):
197
+ rng = np.random.default_rng(123)
198
+
199
+ tol = {
200
+ np.float16: 1.0e-3,
201
+ np.float32: 1.0e-6,
202
+ np.float64: 1.0e-8,
203
+ }.get(dtype, 0)
204
+
205
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
206
+ vec2 = wp.types.vector(length=2, dtype=wptype)
207
+ vec3 = wp.types.vector(length=3, dtype=wptype)
208
+ vec4 = wp.types.vector(length=4, dtype=wptype)
209
+ vec5 = wp.types.vector(length=5, dtype=wptype)
210
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
211
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
212
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
213
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
214
+
215
+ output_select_kernel = get_select_kernel(wptype)
216
+
217
+ def check_scalar_mat_constructor(
218
+ input: wp.array(dtype=wptype),
219
+ outcomponents: wp.array(dtype=wptype),
220
+ ):
221
+ # multiply outputs by 2 so we've got something to backpropagate:
222
+ m2result = wptype(2) * mat22(input[0])
223
+ m3result = wptype(2) * mat33(input[0])
224
+ m4result = wptype(2) * mat44(input[0])
225
+ m5result = wptype(2) * mat55(input[0])
226
+
227
+ idx = 0
228
+ for i in range(2):
229
+ for j in range(2):
230
+ outcomponents[idx] = m2result[i, j]
231
+ idx = idx + 1
232
+
233
+ for i in range(3):
234
+ for j in range(3):
235
+ outcomponents[idx] = m3result[i, j]
236
+ idx = idx + 1
237
+
238
+ for i in range(4):
239
+ for j in range(4):
240
+ outcomponents[idx] = m4result[i, j]
241
+ idx = idx + 1
242
+
243
+ for i in range(5):
244
+ for j in range(5):
245
+ outcomponents[idx] = m5result[i, j]
246
+ idx = idx + 1
247
+
248
+ def check_component_mat_constructor(
249
+ input: wp.array(dtype=wptype),
250
+ outcomponents: wp.array(dtype=wptype),
251
+ ):
252
+ # multiply outputs by 2 so we've got something to backpropagate:
253
+ m2result = wptype(2) * mat22(input[0], input[1], input[2], input[3])
254
+ m3result = wptype(2) * mat33(
255
+ input[4],
256
+ input[5],
257
+ input[6],
258
+ input[7],
259
+ input[8],
260
+ input[9],
261
+ input[10],
262
+ input[11],
263
+ input[12],
264
+ )
265
+ m4result = wptype(2) * mat44(
266
+ input[13],
267
+ input[14],
268
+ input[15],
269
+ input[16],
270
+ input[17],
271
+ input[18],
272
+ input[19],
273
+ input[20],
274
+ input[21],
275
+ input[22],
276
+ input[23],
277
+ input[24],
278
+ input[25],
279
+ input[26],
280
+ input[27],
281
+ input[28],
282
+ )
283
+ m5result = wptype(2) * mat55(
284
+ input[29],
285
+ input[30],
286
+ input[31],
287
+ input[32],
288
+ input[33],
289
+ input[34],
290
+ input[35],
291
+ input[36],
292
+ input[37],
293
+ input[38],
294
+ input[39],
295
+ input[40],
296
+ input[41],
297
+ input[42],
298
+ input[43],
299
+ input[44],
300
+ input[45],
301
+ input[46],
302
+ input[47],
303
+ input[48],
304
+ input[49],
305
+ input[50],
306
+ input[51],
307
+ input[52],
308
+ input[53],
309
+ )
310
+
311
+ idx = 0
312
+ for i in range(2):
313
+ for j in range(2):
314
+ outcomponents[idx] = m2result[i, j]
315
+ idx = idx + 1
316
+
317
+ for i in range(3):
318
+ for j in range(3):
319
+ outcomponents[idx] = m3result[i, j]
320
+ idx = idx + 1
321
+
322
+ for i in range(4):
323
+ for j in range(4):
324
+ outcomponents[idx] = m4result[i, j]
325
+ idx = idx + 1
326
+
327
+ for i in range(5):
328
+ for j in range(5):
329
+ outcomponents[idx] = m5result[i, j]
330
+ idx = idx + 1
331
+
332
+ def check_vector_mat_constructor(
333
+ input: wp.array(dtype=wptype),
334
+ outcomponents: wp.array(dtype=wptype),
335
+ ):
336
+ # multiply outputs by 2 so we've got something to backpropagate:
337
+ m2result = wptype(2) * wp.matrix_from_cols(vec2(input[0], input[2]), vec2(input[1], input[3]))
338
+ m3result = wptype(2) * wp.matrix_from_cols(
339
+ vec3(input[4], input[7], input[10]),
340
+ vec3(input[5], input[8], input[11]),
341
+ vec3(input[6], input[9], input[12]),
342
+ )
343
+ m4result = wptype(2) * wp.matrix_from_cols(
344
+ vec4(input[13], input[17], input[21], input[25]),
345
+ vec4(input[14], input[18], input[22], input[26]),
346
+ vec4(input[15], input[19], input[23], input[27]),
347
+ vec4(input[16], input[20], input[24], input[28]),
348
+ )
349
+ m5result = wptype(2) * wp.matrix_from_cols(
350
+ vec5(input[29], input[34], input[39], input[44], input[49]),
351
+ vec5(input[30], input[35], input[40], input[45], input[50]),
352
+ vec5(input[31], input[36], input[41], input[46], input[51]),
353
+ vec5(input[32], input[37], input[42], input[47], input[52]),
354
+ vec5(input[33], input[38], input[43], input[48], input[53]),
355
+ )
356
+
357
+ idx = 0
358
+ for i in range(2):
359
+ for j in range(2):
360
+ outcomponents[idx] = m2result[i, j]
361
+ idx = idx + 1
362
+
363
+ for i in range(3):
364
+ for j in range(3):
365
+ outcomponents[idx] = m3result[i, j]
366
+ idx = idx + 1
367
+
368
+ for i in range(4):
369
+ for j in range(4):
370
+ outcomponents[idx] = m4result[i, j]
371
+ idx = idx + 1
372
+
373
+ for i in range(5):
374
+ for j in range(5):
375
+ outcomponents[idx] = m5result[i, j]
376
+ idx = idx + 1
377
+
378
+ kernel = getkernel(check_scalar_mat_constructor, suffix=dtype.__name__)
379
+ compkernel = getkernel(check_component_mat_constructor, suffix=dtype.__name__)
380
+ veckernel = getkernel(check_vector_mat_constructor, suffix=dtype.__name__)
381
+
382
+ if register_kernels:
383
+ return
384
+
385
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
386
+ val = input.numpy()[0]
387
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
388
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
389
+
390
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
391
+
392
+ assert_np_equal(outcomponents.numpy()[:4], 2 * val * np.ones(2 * 2), tol=tol)
393
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * val * np.ones(3 * 3), tol=tol)
394
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * val * np.ones(4 * 4), tol=tol)
395
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * val * np.ones(5 * 5), tol=tol)
396
+
397
+ if dtype in np_float_types:
398
+ for idx in range(len(outcomponents)):
399
+ tape = wp.Tape()
400
+ with tape:
401
+ wp.launch(kernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
402
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
403
+ tape.backward(loss=out)
404
+ test.assertEqual(tape.gradients[input].numpy()[0], 2)
405
+ tape.zero()
406
+
407
+ input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5], dtype), requires_grad=True, device=device)
408
+
409
+ wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
410
+ assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
411
+
412
+ if dtype in np_float_types:
413
+ for idx in range(len(outcomponents)):
414
+ tape = wp.Tape()
415
+ with tape:
416
+ wp.launch(compkernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
417
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
418
+ tape.backward(loss=out)
419
+ expectedgrads = np.zeros(len(input))
420
+ expectedgrads[idx] = 2
421
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
422
+ tape.zero()
423
+
424
+ wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
425
+ assert_np_equal(2 * input.numpy(), outcomponents.numpy(), tol=10 * tol)
426
+
427
+ if dtype in np_float_types:
428
+ for idx in range(len(outcomponents)):
429
+ tape = wp.Tape()
430
+ with tape:
431
+ wp.launch(veckernel, dim=1, inputs=[input], outputs=[outcomponents], device=device)
432
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
433
+ tape.backward(loss=out)
434
+ expectedgrads = np.zeros(len(input))
435
+ expectedgrads[idx] = 2
436
+ assert_np_equal(tape.gradients[input].numpy(), expectedgrads)
437
+ tape.zero()
438
+
439
+
440
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
441
+ rng = np.random.default_rng(123)
442
+
443
+ tol = {
444
+ np.float16: 5.0e-3,
445
+ np.float32: 1.0e-6,
446
+ np.float64: 1.0e-8,
447
+ }.get(dtype, 0)
448
+
449
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
450
+
451
+ def check_scalar_init(
452
+ input: wp.array(dtype=wptype),
453
+ output: wp.array(dtype=wptype),
454
+ ):
455
+ m2result = wp.matrix(input[0], shape=(2, 2))
456
+ m3result = wp.matrix(input[1], shape=(3, 3))
457
+ m4result = wp.matrix(input[2], shape=(4, 4))
458
+ m5result = wp.matrix(input[3], shape=(5, 5))
459
+ m32result = wp.matrix(input[4], shape=(3, 2))
460
+
461
+ idx = 0
462
+ for i in range(2):
463
+ for j in range(2):
464
+ output[idx] = wptype(2) * m2result[i, j]
465
+ idx = idx + 1
466
+ for i in range(3):
467
+ for j in range(3):
468
+ output[idx] = wptype(2) * m3result[i, j]
469
+ idx = idx + 1
470
+ for i in range(4):
471
+ for j in range(4):
472
+ output[idx] = wptype(2) * m4result[i, j]
473
+ idx = idx + 1
474
+ for i in range(5):
475
+ for j in range(5):
476
+ output[idx] = wptype(2) * m5result[i, j]
477
+ idx = idx + 1
478
+ for i in range(3):
479
+ for j in range(2):
480
+ output[idx] = wptype(2) * m32result[i, j]
481
+ idx = idx + 1
482
+
483
+ def check_component_init(
484
+ input: wp.array(dtype=wptype),
485
+ output: wp.array(dtype=wptype),
486
+ ):
487
+ m2result = wp.matrix(input[0], input[1], input[2], input[3], shape=(2, 2))
488
+ m3result = wp.matrix(
489
+ input[4], input[5], input[6], input[7], input[8], input[9], input[10], input[11], input[12], shape=(3, 3)
490
+ )
491
+ m4result = wp.matrix(
492
+ input[13],
493
+ input[14],
494
+ input[15],
495
+ input[16],
496
+ input[17],
497
+ input[18],
498
+ input[19],
499
+ input[20],
500
+ input[21],
501
+ input[22],
502
+ input[23],
503
+ input[24],
504
+ input[25],
505
+ input[26],
506
+ input[27],
507
+ input[28],
508
+ shape=(4, 4),
509
+ )
510
+ m5result = wp.matrix(
511
+ input[29],
512
+ input[30],
513
+ input[31],
514
+ input[32],
515
+ input[33],
516
+ input[34],
517
+ input[35],
518
+ input[36],
519
+ input[37],
520
+ input[38],
521
+ input[39],
522
+ input[40],
523
+ input[41],
524
+ input[42],
525
+ input[43],
526
+ input[44],
527
+ input[45],
528
+ input[46],
529
+ input[47],
530
+ input[48],
531
+ input[49],
532
+ input[50],
533
+ input[51],
534
+ input[52],
535
+ input[53],
536
+ shape=(5, 5),
537
+ )
538
+ m32result = wp.matrix(input[54], input[55], input[56], input[57], input[58], input[59], shape=(3, 2))
539
+
540
+ idx = 0
541
+ for i in range(2):
542
+ for j in range(2):
543
+ output[idx] = wptype(2) * m2result[i, j]
544
+ idx = idx + 1
545
+ for i in range(3):
546
+ for j in range(3):
547
+ output[idx] = wptype(2) * m3result[i, j]
548
+ idx = idx + 1
549
+ for i in range(4):
550
+ for j in range(4):
551
+ output[idx] = wptype(2) * m4result[i, j]
552
+ idx = idx + 1
553
+ for i in range(5):
554
+ for j in range(5):
555
+ output[idx] = wptype(2) * m5result[i, j]
556
+ idx = idx + 1
557
+ for i in range(3):
558
+ for j in range(2):
559
+ output[idx] = wptype(2) * m32result[i, j]
560
+ idx = idx + 1
561
+
562
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
563
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
564
+ output_select_kernel = get_select_kernel(wptype)
565
+
566
+ if register_kernels:
567
+ return
568
+
569
+ input = wp.array(randvals(rng, [5], dtype), requires_grad=True, device=device)
570
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
571
+
572
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
573
+
574
+ assert_np_equal(output.numpy()[:4], 2 * np.array([input.numpy()[0]] * 2 * 2), tol=1.0e-6)
575
+ assert_np_equal(output.numpy()[4:13], 2 * np.array([input.numpy()[1]] * 3 * 3), tol=1.0e-6)
576
+ assert_np_equal(output.numpy()[13:29], 2 * np.array([input.numpy()[2]] * 4 * 4), tol=1.0e-6)
577
+ assert_np_equal(output.numpy()[29:54], 2 * np.array([input.numpy()[3]] * 5 * 5), tol=1.0e-6)
578
+ assert_np_equal(output.numpy()[54:], 2 * np.array([input.numpy()[4]] * 3 * 2), tol=1.0e-6)
579
+
580
+ if dtype in np_float_types:
581
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
582
+ for i in range(len(output)):
583
+ tape = wp.Tape()
584
+ with tape:
585
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
586
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
587
+
588
+ tape.backward(loss=out)
589
+ expected = np.zeros_like(input.numpy())
590
+ if i < 4:
591
+ expected[0] = 2
592
+ elif i < 13:
593
+ expected[1] = 2
594
+ elif i < 29:
595
+ expected[2] = 2
596
+ elif i < 54:
597
+ expected[3] = 2
598
+ else:
599
+ expected[4] = 2
600
+
601
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
602
+
603
+ tape.reset()
604
+ tape.zero()
605
+
606
+ input = wp.array(randvals(rng, [2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2], dtype), requires_grad=True, device=device)
607
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2, dtype=wptype, requires_grad=True, device=device)
608
+
609
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
610
+
611
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
612
+
613
+ if dtype in np_float_types:
614
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
615
+ for i in range(len(output)):
616
+ tape = wp.Tape()
617
+ with tape:
618
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
619
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
620
+
621
+ tape.backward(loss=out)
622
+ expected = np.zeros_like(input.numpy())
623
+ expected[i] = 2
624
+
625
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
626
+
627
+ tape.reset()
628
+ tape.zero()
629
+
630
+
631
+ def test_identity(test, device, dtype, register_kernels=False):
632
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
633
+
634
+ def check_identity_mat(
635
+ output: wp.array(dtype=wptype),
636
+ ):
637
+ m2result = wp.identity(dtype=wptype, n=2)
638
+ m3result = wp.identity(dtype=wptype, n=3)
639
+ m4result = wp.identity(dtype=wptype, n=4)
640
+ m5result = wp.identity(dtype=wptype, n=5)
641
+
642
+ idx = 0
643
+ for i in range(2):
644
+ for j in range(2):
645
+ output[idx] = wptype(2) * m2result[i, j]
646
+ idx = idx + 1
647
+ for i in range(3):
648
+ for j in range(3):
649
+ output[idx] = wptype(2) * m3result[i, j]
650
+ idx = idx + 1
651
+ for i in range(4):
652
+ for j in range(4):
653
+ output[idx] = wptype(2) * m4result[i, j]
654
+ idx = idx + 1
655
+ for i in range(5):
656
+ for j in range(5):
657
+ output[idx] = wptype(2) * m5result[i, j]
658
+ idx = idx + 1
659
+
660
+ id_kernel = getkernel(check_identity_mat, suffix=dtype.__name__)
661
+
662
+ if register_kernels:
663
+ return
664
+
665
+ output = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
666
+ wp.launch(id_kernel, dim=1, inputs=[], outputs=[output], device=device)
667
+ assert_np_equal(output.numpy()[:4], 2 * np.eye(2), tol=1.0e-6)
668
+ assert_np_equal(output.numpy()[4:13], 2 * np.eye(3), tol=1.0e-6)
669
+ assert_np_equal(output.numpy()[13:29], 2 * np.eye(4), tol=1.0e-6)
670
+ assert_np_equal(output.numpy()[29:], 2 * np.eye(5), tol=1.0e-6)
671
+
672
+
673
+ def test_indexing(test, device, dtype, register_kernels=False):
674
+ rng = np.random.default_rng(123)
675
+
676
+ tol = {
677
+ np.float16: 1.0e-3,
678
+ np.float32: 1.0e-6,
679
+ np.float64: 1.0e-8,
680
+ }.get(dtype, 0)
681
+
682
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
683
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
684
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
685
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
686
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
687
+
688
+ output_select_kernel = get_select_kernel(wptype)
689
+
690
+ def check_mat_indexing(
691
+ m2: wp.array(dtype=mat22),
692
+ m3: wp.array(dtype=mat33),
693
+ m4: wp.array(dtype=mat44),
694
+ m5: wp.array(dtype=mat55),
695
+ outcomponents: wp.array(dtype=wptype),
696
+ ):
697
+ # multiply outputs by 2 so we've got something to backpropagate:
698
+ idx = 0
699
+ for i in range(2):
700
+ for j in range(2):
701
+ outcomponents[idx] = wptype(2) * m2[0][i, j]
702
+ idx = idx + 1
703
+
704
+ for i in range(3):
705
+ for j in range(3):
706
+ outcomponents[idx] = wptype(2) * m3[0][i, j]
707
+ idx = idx + 1
708
+
709
+ for i in range(4):
710
+ for j in range(4):
711
+ outcomponents[idx] = wptype(2) * m4[0][i, j]
712
+ idx = idx + 1
713
+
714
+ for i in range(5):
715
+ for j in range(5):
716
+ outcomponents[idx] = wptype(2) * m5[0][i, j]
717
+ idx = idx + 1
718
+
719
+ kernel = getkernel(check_mat_indexing, suffix=dtype.__name__)
720
+
721
+ if register_kernels:
722
+ return
723
+
724
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
725
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
726
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
727
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
728
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
729
+
730
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
731
+
732
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1), tol=tol)
733
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1), tol=tol)
734
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1), tol=tol)
735
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1), tol=tol)
736
+
737
+ if dtype in np_float_types:
738
+ idx = 0
739
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
740
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
741
+ for i in range(dim):
742
+ for j in range(dim):
743
+ tape = wp.Tape()
744
+ with tape:
745
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], outputs=[outcomponents], device=device)
746
+ wp.launch(
747
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
748
+ )
749
+ tape.backward(loss=out)
750
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
751
+ expectedresult[i, j] = 2
752
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
753
+ tape.zero()
754
+ idx = idx + 1
755
+
756
+
757
+ def test_equality(test, device, dtype, register_kernels=False):
758
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
759
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
760
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
761
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
762
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
763
+
764
+ def check_mat_equality():
765
+ wp.expect_eq(
766
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
767
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
768
+ )
769
+ wp.expect_neq(
770
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), -wptype(4.0)),
771
+ mat22(wptype(1.0), wptype(2.0), wptype(3.0), wptype(4.0)),
772
+ )
773
+
774
+ wp.expect_eq(
775
+ mat33(
776
+ wptype(1.0),
777
+ wptype(2.0),
778
+ wptype(3.0),
779
+ wptype(4.0),
780
+ wptype(5.0),
781
+ wptype(6.0),
782
+ wptype(7.0),
783
+ wptype(8.0),
784
+ wptype(9.0),
785
+ ),
786
+ mat33(
787
+ wptype(1.0),
788
+ wptype(2.0),
789
+ wptype(3.0),
790
+ wptype(4.0),
791
+ wptype(5.0),
792
+ wptype(6.0),
793
+ wptype(7.0),
794
+ wptype(8.0),
795
+ wptype(9.0),
796
+ ),
797
+ )
798
+ wp.expect_neq(
799
+ mat33(
800
+ wptype(1.0),
801
+ wptype(2.0),
802
+ wptype(3.0),
803
+ wptype(4.0),
804
+ wptype(5.0),
805
+ wptype(6.0),
806
+ wptype(7.0),
807
+ wptype(8.0),
808
+ wptype(9.0),
809
+ ),
810
+ mat33(
811
+ wptype(1.0),
812
+ wptype(2.0),
813
+ wptype(3.0),
814
+ -wptype(4.0),
815
+ wptype(5.0),
816
+ wptype(6.0),
817
+ wptype(7.0),
818
+ wptype(8.0),
819
+ wptype(9.0),
820
+ ),
821
+ )
822
+
823
+ wp.expect_eq(
824
+ mat44(
825
+ wptype(1.0),
826
+ wptype(2.0),
827
+ wptype(3.0),
828
+ wptype(4.0),
829
+ wptype(5.0),
830
+ wptype(6.0),
831
+ wptype(7.0),
832
+ wptype(8.0),
833
+ wptype(9.0),
834
+ wptype(10.0),
835
+ wptype(11.0),
836
+ wptype(12.0),
837
+ wptype(13.0),
838
+ wptype(14.0),
839
+ wptype(15.0),
840
+ wptype(16.0),
841
+ ),
842
+ mat44(
843
+ wptype(1.0),
844
+ wptype(2.0),
845
+ wptype(3.0),
846
+ wptype(4.0),
847
+ wptype(5.0),
848
+ wptype(6.0),
849
+ wptype(7.0),
850
+ wptype(8.0),
851
+ wptype(9.0),
852
+ wptype(10.0),
853
+ wptype(11.0),
854
+ wptype(12.0),
855
+ wptype(13.0),
856
+ wptype(14.0),
857
+ wptype(15.0),
858
+ wptype(16.0),
859
+ ),
860
+ )
861
+
862
+ wp.expect_neq(
863
+ mat44(
864
+ wptype(1.0),
865
+ wptype(2.0),
866
+ wptype(3.0),
867
+ wptype(4.0),
868
+ wptype(5.0),
869
+ wptype(6.0),
870
+ wptype(7.0),
871
+ wptype(8.0),
872
+ wptype(9.0),
873
+ wptype(10.0),
874
+ wptype(11.0),
875
+ wptype(12.0),
876
+ wptype(13.0),
877
+ wptype(14.0),
878
+ wptype(15.0),
879
+ wptype(16.0),
880
+ ),
881
+ mat44(
882
+ -wptype(1.0),
883
+ wptype(2.0),
884
+ wptype(3.0),
885
+ wptype(4.0),
886
+ wptype(5.0),
887
+ wptype(6.0),
888
+ wptype(7.0),
889
+ wptype(8.0),
890
+ wptype(9.0),
891
+ wptype(10.0),
892
+ wptype(11.0),
893
+ wptype(12.0),
894
+ wptype(13.0),
895
+ wptype(14.0),
896
+ wptype(15.0),
897
+ wptype(16.0),
898
+ ),
899
+ )
900
+
901
+ wp.expect_eq(
902
+ mat55(
903
+ wptype(1.0),
904
+ wptype(2.0),
905
+ wptype(3.0),
906
+ wptype(4.0),
907
+ wptype(5.0),
908
+ wptype(6.0),
909
+ wptype(7.0),
910
+ wptype(8.0),
911
+ wptype(9.0),
912
+ wptype(10.0),
913
+ wptype(11.0),
914
+ wptype(12.0),
915
+ wptype(13.0),
916
+ wptype(14.0),
917
+ wptype(15.0),
918
+ wptype(16.0),
919
+ wptype(17.0),
920
+ wptype(18.0),
921
+ wptype(19.0),
922
+ wptype(20.0),
923
+ wptype(21.0),
924
+ wptype(22.0),
925
+ wptype(23.0),
926
+ wptype(24.0),
927
+ wptype(25.0),
928
+ ),
929
+ mat55(
930
+ wptype(1.0),
931
+ wptype(2.0),
932
+ wptype(3.0),
933
+ wptype(4.0),
934
+ wptype(5.0),
935
+ wptype(6.0),
936
+ wptype(7.0),
937
+ wptype(8.0),
938
+ wptype(9.0),
939
+ wptype(10.0),
940
+ wptype(11.0),
941
+ wptype(12.0),
942
+ wptype(13.0),
943
+ wptype(14.0),
944
+ wptype(15.0),
945
+ wptype(16.0),
946
+ wptype(17.0),
947
+ wptype(18.0),
948
+ wptype(19.0),
949
+ wptype(20.0),
950
+ wptype(21.0),
951
+ wptype(22.0),
952
+ wptype(23.0),
953
+ wptype(24.0),
954
+ wptype(25.0),
955
+ ),
956
+ )
957
+
958
+ wp.expect_neq(
959
+ mat55(
960
+ wptype(1.0),
961
+ wptype(2.0),
962
+ wptype(3.0),
963
+ wptype(4.0),
964
+ wptype(5.0),
965
+ wptype(6.0),
966
+ wptype(7.0),
967
+ wptype(8.0),
968
+ wptype(9.0),
969
+ wptype(10.0),
970
+ wptype(11.0),
971
+ wptype(12.0),
972
+ wptype(13.0),
973
+ wptype(14.0),
974
+ wptype(15.0),
975
+ wptype(16.0),
976
+ wptype(17.0),
977
+ wptype(18.0),
978
+ wptype(19.0),
979
+ wptype(20.0),
980
+ wptype(21.0),
981
+ wptype(22.0),
982
+ wptype(23.0),
983
+ wptype(24.0),
984
+ wptype(25.0),
985
+ ),
986
+ mat55(
987
+ wptype(1.0),
988
+ wptype(2.0),
989
+ wptype(3.0),
990
+ wptype(4.0),
991
+ wptype(5.0),
992
+ wptype(6.0),
993
+ wptype(7.0),
994
+ wptype(8.0),
995
+ wptype(9.0),
996
+ wptype(10.0),
997
+ wptype(11.0),
998
+ wptype(12.0),
999
+ wptype(13.0),
1000
+ wptype(14.0),
1001
+ wptype(15.0),
1002
+ wptype(16.0),
1003
+ -wptype(17.0),
1004
+ wptype(18.0),
1005
+ wptype(19.0),
1006
+ wptype(20.0),
1007
+ wptype(21.0),
1008
+ wptype(22.0),
1009
+ wptype(23.0),
1010
+ wptype(24.0),
1011
+ wptype(25.0),
1012
+ ),
1013
+ )
1014
+
1015
+ kernel = getkernel(check_mat_equality, suffix=dtype.__name__)
1016
+
1017
+ if register_kernels:
1018
+ return
1019
+
1020
+ wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
1021
+
1022
+
1023
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
1024
+ rng = np.random.default_rng(123)
1025
+
1026
+ tol = {
1027
+ np.float16: 1.0e-2,
1028
+ np.float32: 1.0e-6,
1029
+ np.float64: 1.0e-8,
1030
+ }.get(dtype, 0)
1031
+
1032
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1033
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1034
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1035
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1036
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1037
+
1038
+ output_select_kernel = get_select_kernel(wptype)
1039
+
1040
+ def check_mat_scalar_mul(
1041
+ s: wp.array(dtype=wptype),
1042
+ m2: wp.array(dtype=mat22),
1043
+ m3: wp.array(dtype=mat33),
1044
+ m4: wp.array(dtype=mat44),
1045
+ m5: wp.array(dtype=mat55),
1046
+ outcomponents: wp.array(dtype=wptype),
1047
+ outcomponents_rightmul: wp.array(dtype=wptype),
1048
+ ):
1049
+ m2result = s[0] * m2[0]
1050
+ m3result = s[0] * m3[0]
1051
+ m4result = s[0] * m4[0]
1052
+ m5result = s[0] * m5[0]
1053
+
1054
+ m2resultright = m2[0] * s[0]
1055
+ m3resultright = m3[0] * s[0]
1056
+ m4resultright = m4[0] * s[0]
1057
+ m5resultright = m5[0] * s[0]
1058
+
1059
+ m2result_2 = s[0] * m2[0]
1060
+ m3result_2 = s[0] * m3[0]
1061
+ m4result_2 = s[0] * m4[0]
1062
+ m5result_2 = s[0] * m5[0]
1063
+
1064
+ m2resultright_2 = m2[0] * s[0]
1065
+ m3resultright_2 = m3[0] * s[0]
1066
+ m4resultright_2 = m4[0] * s[0]
1067
+ m5resultright_2 = m5[0] * s[0]
1068
+
1069
+ # multiply outputs by 2 so we've got something to backpropagate:
1070
+ idx = 0
1071
+ for i in range(2):
1072
+ for j in range(2):
1073
+ outcomponents[idx] = wptype(2) * m2result[i, j]
1074
+ outcomponents_rightmul[idx] = wptype(2) * m2resultright[i, j]
1075
+ idx = idx + 1
1076
+
1077
+ for i in range(3):
1078
+ for j in range(3):
1079
+ outcomponents[idx] = wptype(2) * m3result[i, j]
1080
+ outcomponents_rightmul[idx] = wptype(2) * m3resultright[i, j]
1081
+ idx = idx + 1
1082
+
1083
+ for i in range(4):
1084
+ for j in range(4):
1085
+ outcomponents[idx] = wptype(2) * m4result[i, j]
1086
+ outcomponents_rightmul[idx] = wptype(2) * m4resultright[i, j]
1087
+ idx = idx + 1
1088
+
1089
+ for i in range(5):
1090
+ for j in range(5):
1091
+ outcomponents[idx] = wptype(2) * m5result[i, j]
1092
+ outcomponents_rightmul[idx] = wptype(2) * m5resultright[i, j]
1093
+ idx = idx + 1
1094
+
1095
+ for i in range(2):
1096
+ for j in range(2):
1097
+ outcomponents[idx] = wptype(2) * m2result_2[i, j]
1098
+ outcomponents_rightmul[idx] = wptype(2) * m2resultright_2[i, j]
1099
+ idx = idx + 1
1100
+
1101
+ for i in range(3):
1102
+ for j in range(3):
1103
+ outcomponents[idx] = wptype(2) * m3result_2[i, j]
1104
+ outcomponents_rightmul[idx] = wptype(2) * m3resultright_2[i, j]
1105
+ idx = idx + 1
1106
+
1107
+ for i in range(4):
1108
+ for j in range(4):
1109
+ outcomponents[idx] = wptype(2) * m4result_2[i, j]
1110
+ outcomponents_rightmul[idx] = wptype(2) * m4resultright_2[i, j]
1111
+ idx = idx + 1
1112
+
1113
+ for i in range(5):
1114
+ for j in range(5):
1115
+ outcomponents[idx] = wptype(2) * m5result_2[i, j]
1116
+ outcomponents_rightmul[idx] = wptype(2) * m5resultright_2[i, j]
1117
+ idx = idx + 1
1118
+
1119
+ kernel = getkernel(check_mat_scalar_mul, suffix=dtype.__name__)
1120
+
1121
+ if register_kernels:
1122
+ return
1123
+
1124
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1125
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1126
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1127
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1128
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1129
+ outcomponents = wp.zeros(2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device)
1130
+ outcomponents_rightmul = wp.zeros(
1131
+ 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5), dtype=wptype, requires_grad=True, device=device
1132
+ )
1133
+
1134
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents, outcomponents_rightmul], device=device)
1135
+
1136
+ sval = s.numpy()[0]
1137
+ assert_np_equal(outcomponents.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1138
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1139
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1140
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1141
+
1142
+ assert_np_equal(outcomponents_rightmul.numpy()[:4], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1143
+ assert_np_equal(outcomponents_rightmul.numpy()[4:13], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1144
+ assert_np_equal(outcomponents_rightmul.numpy()[13:29], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1145
+ assert_np_equal(outcomponents_rightmul.numpy()[29:54], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1146
+
1147
+ assert_np_equal(outcomponents.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1148
+ assert_np_equal(outcomponents.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1149
+ assert_np_equal(outcomponents.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1150
+ assert_np_equal(outcomponents.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1151
+
1152
+ assert_np_equal(outcomponents_rightmul.numpy()[54:58], 2 * sval * m2.numpy().reshape(-1), tol=tol)
1153
+ assert_np_equal(outcomponents_rightmul.numpy()[58:67], 2 * sval * m3.numpy().reshape(-1), tol=10 * tol)
1154
+ assert_np_equal(outcomponents_rightmul.numpy()[67:83], 2 * sval * m4.numpy().reshape(-1), tol=10 * tol)
1155
+ assert_np_equal(outcomponents_rightmul.numpy()[83:108], 2 * sval * m5.numpy().reshape(-1), tol=10 * tol)
1156
+
1157
+ if dtype in np_float_types:
1158
+ idx = 0
1159
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1160
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
1161
+ for i in range(dim):
1162
+ for j in range(dim):
1163
+ # test left mul gradient:
1164
+ tape = wp.Tape()
1165
+ with tape:
1166
+ wp.launch(
1167
+ kernel,
1168
+ dim=1,
1169
+ inputs=[s, m2, m3, m4, m5],
1170
+ outputs=[outcomponents, outcomponents_rightmul],
1171
+ device=device,
1172
+ )
1173
+ wp.launch(
1174
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1175
+ )
1176
+ tape.backward(loss=out)
1177
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1178
+ expectedresult[i, j] = 2 * sval
1179
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1180
+ assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1181
+ tape.zero()
1182
+
1183
+ # test right mul gradient:
1184
+ tape = wp.Tape()
1185
+ with tape:
1186
+ wp.launch(
1187
+ kernel,
1188
+ dim=1,
1189
+ inputs=[s, m2, m3, m4, m5],
1190
+ outputs=[outcomponents, outcomponents_rightmul],
1191
+ device=device,
1192
+ )
1193
+ wp.launch(
1194
+ output_select_kernel,
1195
+ dim=1,
1196
+ inputs=[outcomponents_rightmul, idx],
1197
+ outputs=[out],
1198
+ device=device,
1199
+ )
1200
+ tape.backward(loss=out)
1201
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1202
+ expectedresult[i, j] = 2 * sval
1203
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
1204
+ assert_np_equal(tape.gradients[s].numpy()[0], 2 * input.numpy()[0, i, j], tol=10 * tol)
1205
+ tape.zero()
1206
+
1207
+ idx = idx + 1
1208
+
1209
+
1210
+ def test_matvec_multiplication(test, device, dtype, register_kernels=False):
1211
+ rng = np.random.default_rng(123)
1212
+
1213
+ tol = {
1214
+ np.float16: 2.0e-2,
1215
+ np.float32: 5.0e-6,
1216
+ np.float64: 1.0e-8,
1217
+ }.get(dtype, 0)
1218
+
1219
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1220
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1221
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1222
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1223
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1224
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1225
+
1226
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1227
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1228
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1229
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1230
+
1231
+ output_select_kernel = get_select_kernel(wptype)
1232
+
1233
+ def check_mat_vec_mul(
1234
+ v2: wp.array(dtype=vec2),
1235
+ v3: wp.array(dtype=vec3),
1236
+ v4: wp.array(dtype=vec4),
1237
+ v5: wp.array(dtype=vec5),
1238
+ v32: wp.array(dtype=vec2),
1239
+ m2: wp.array(dtype=mat22),
1240
+ m3: wp.array(dtype=mat33),
1241
+ m4: wp.array(dtype=mat44),
1242
+ m5: wp.array(dtype=mat55),
1243
+ m32: wp.array(dtype=mat32),
1244
+ outcomponents: wp.array(dtype=wptype),
1245
+ ):
1246
+ v2result = m2[0] * v2[0]
1247
+ v3result = m3[0] * v3[0]
1248
+ v4result = m4[0] * v4[0]
1249
+ v5result = m5[0] * v5[0]
1250
+ v32result = m32[0] * v32[0]
1251
+ v2result_2 = m2[0] @ v2[0]
1252
+ v3result_2 = m3[0] @ v3[0]
1253
+ v4result_2 = m4[0] @ v4[0]
1254
+ v5result_2 = m5[0] @ v5[0]
1255
+ v32result_2 = m32[0] @ v32[0]
1256
+
1257
+ idx = 0
1258
+
1259
+ # multiply outputs by 2 so we've got something to backpropagate:
1260
+ for i in range(2):
1261
+ outcomponents[idx] = wptype(2) * v2result[i]
1262
+ idx = idx + 1
1263
+
1264
+ for i in range(3):
1265
+ outcomponents[idx] = wptype(2) * v3result[i]
1266
+ idx = idx + 1
1267
+
1268
+ for i in range(4):
1269
+ outcomponents[idx] = wptype(2) * v4result[i]
1270
+ idx = idx + 1
1271
+
1272
+ for i in range(5):
1273
+ outcomponents[idx] = wptype(2) * v5result[i]
1274
+ idx = idx + 1
1275
+
1276
+ for i in range(3):
1277
+ outcomponents[idx] = wptype(2) * v32result[i]
1278
+ idx = idx + 1
1279
+
1280
+ for i in range(2):
1281
+ outcomponents[idx] = wptype(2) * v2result_2[i]
1282
+ idx = idx + 1
1283
+
1284
+ for i in range(3):
1285
+ outcomponents[idx] = wptype(2) * v3result_2[i]
1286
+ idx = idx + 1
1287
+
1288
+ for i in range(4):
1289
+ outcomponents[idx] = wptype(2) * v4result_2[i]
1290
+ idx = idx + 1
1291
+
1292
+ for i in range(5):
1293
+ outcomponents[idx] = wptype(2) * v5result_2[i]
1294
+ idx = idx + 1
1295
+
1296
+ for i in range(3):
1297
+ outcomponents[idx] = wptype(2) * v32result_2[i]
1298
+ idx = idx + 1
1299
+
1300
+ kernel = getkernel(check_mat_vec_mul, suffix=dtype.__name__)
1301
+
1302
+ if register_kernels:
1303
+ return
1304
+
1305
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1306
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1307
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1308
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1309
+ v32 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1310
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1311
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1312
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1313
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1314
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1315
+ outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1316
+
1317
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1318
+
1319
+ assert_np_equal(outcomponents.numpy()[:2], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1320
+ assert_np_equal(outcomponents.numpy()[2:5], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1321
+ assert_np_equal(outcomponents.numpy()[5:9], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1322
+ assert_np_equal(outcomponents.numpy()[9:14], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1323
+ assert_np_equal(outcomponents.numpy()[14:17], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1324
+ assert_np_equal(outcomponents.numpy()[17:19], 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1325
+ assert_np_equal(outcomponents.numpy()[19:22], 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1326
+ assert_np_equal(outcomponents.numpy()[22:26], 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=5 * tol)
1327
+ assert_np_equal(outcomponents.numpy()[26:31], 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=5 * tol)
1328
+ assert_np_equal(outcomponents.numpy()[31:34], 2 * np.matmul(m32.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1329
+
1330
+ if dtype in np_float_types:
1331
+ idx = 0
1332
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1333
+ for dim, invec, inmat in [(2, v2, m2), (3, v3, m3), (4, v4, m4), (5, v5, m5), (3, v32, m32)]:
1334
+ for i in range(dim):
1335
+ tape = wp.Tape()
1336
+ with tape:
1337
+ wp.launch(
1338
+ kernel,
1339
+ dim=1,
1340
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1341
+ outputs=[outcomponents],
1342
+ device=device,
1343
+ )
1344
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1345
+ tape.backward(loss=out)
1346
+
1347
+ assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, i, :], tol=2 * tol)
1348
+ expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1349
+ expectedresult[i, :] = 2 * invec.numpy()[0]
1350
+ assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1351
+
1352
+ tape.zero()
1353
+
1354
+ idx = idx + 1
1355
+
1356
+
1357
+ def test_vecmat_multiplication(test, device, dtype, register_kernels=False):
1358
+ rng = np.random.default_rng(123)
1359
+
1360
+ tol = {
1361
+ np.float16: 2.0e-2,
1362
+ np.float32: 5.0e-6,
1363
+ np.float64: 1.0e-8,
1364
+ }.get(dtype, 0)
1365
+
1366
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1367
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1368
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1369
+ mat23 = wp.types.matrix(shape=(2, 3), dtype=wptype)
1370
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1371
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1372
+
1373
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1374
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1375
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1376
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1377
+
1378
+ output_select_kernel = get_select_kernel(wptype)
1379
+
1380
+ def check_vec_mat_mul(
1381
+ v2: wp.array(dtype=vec2),
1382
+ v3: wp.array(dtype=vec3),
1383
+ v4: wp.array(dtype=vec4),
1384
+ v5: wp.array(dtype=vec5),
1385
+ v32: wp.array(dtype=vec2),
1386
+ m2: wp.array(dtype=mat22),
1387
+ m3: wp.array(dtype=mat33),
1388
+ m4: wp.array(dtype=mat44),
1389
+ m5: wp.array(dtype=mat55),
1390
+ m23: wp.array(dtype=mat23),
1391
+ outcomponents: wp.array(dtype=wptype),
1392
+ ):
1393
+ v2result = v2[0] * m2[0]
1394
+ v3result = v3[0] * m3[0]
1395
+ v4result = v4[0] * m4[0]
1396
+ v5result = v5[0] * m5[0]
1397
+ v32result = v32[0] * m23[0]
1398
+ v2result_2 = v2[0] @ m2[0]
1399
+ v3result_2 = v3[0] @ m3[0]
1400
+ v4result_2 = v4[0] @ m4[0]
1401
+ v5result_2 = v5[0] @ m5[0]
1402
+ v32result_2 = v32[0] @ m23[0]
1403
+
1404
+ idx = 0
1405
+
1406
+ # multiply outputs by 2 so we've got something to backpropagate:
1407
+ for i in range(2):
1408
+ outcomponents[idx] = wptype(2) * v2result[i]
1409
+ idx = idx + 1
1410
+
1411
+ for i in range(3):
1412
+ outcomponents[idx] = wptype(2) * v3result[i]
1413
+ idx = idx + 1
1414
+
1415
+ for i in range(4):
1416
+ outcomponents[idx] = wptype(2) * v4result[i]
1417
+ idx = idx + 1
1418
+
1419
+ for i in range(5):
1420
+ outcomponents[idx] = wptype(2) * v5result[i]
1421
+ idx = idx + 1
1422
+
1423
+ for i in range(3):
1424
+ outcomponents[idx] = wptype(2) * v32result[i]
1425
+ idx = idx + 1
1426
+
1427
+ for i in range(2):
1428
+ outcomponents[idx] = wptype(2) * v2result_2[i]
1429
+ idx = idx + 1
1430
+
1431
+ for i in range(3):
1432
+ outcomponents[idx] = wptype(2) * v3result_2[i]
1433
+ idx = idx + 1
1434
+
1435
+ for i in range(4):
1436
+ outcomponents[idx] = wptype(2) * v4result_2[i]
1437
+ idx = idx + 1
1438
+
1439
+ for i in range(5):
1440
+ outcomponents[idx] = wptype(2) * v5result_2[i]
1441
+ idx = idx + 1
1442
+
1443
+ for i in range(3):
1444
+ outcomponents[idx] = wptype(2) * v32result_2[i]
1445
+ idx = idx + 1
1446
+
1447
+ kernel = getkernel(check_vec_mat_mul, suffix=dtype.__name__)
1448
+
1449
+ if register_kernels:
1450
+ return
1451
+
1452
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1453
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
1454
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
1455
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
1456
+ v32 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
1457
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1458
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1459
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1460
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1461
+ m23 = wp.array(randvals(rng, [1, 2, 3], dtype), dtype=mat23, requires_grad=True, device=device)
1462
+ outcomponents = wp.zeros(2 * (2 + 3 + 4 + 5 + 3), dtype=wptype, requires_grad=True, device=device)
1463
+
1464
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m23], outputs=[outcomponents], device=device)
1465
+
1466
+ outcomponents_np = outcomponents.numpy()
1467
+
1468
+ assert_np_equal(outcomponents_np[:2], 2 * np.matmul(v2.numpy()[0], m2.numpy()[0]), tol=tol)
1469
+ assert_np_equal(outcomponents_np[2:5], 2 * np.matmul(v3.numpy()[0], m3.numpy()[0]), tol=tol)
1470
+ assert_np_equal(outcomponents_np[5:9], 2 * np.matmul(v4.numpy()[0], m4.numpy()[0]), tol=5 * tol)
1471
+ assert_np_equal(outcomponents_np[9:14], 2 * np.matmul(v5.numpy()[0], m5.numpy()[0]), tol=5 * tol)
1472
+ assert_np_equal(outcomponents_np[14:17], 2 * np.matmul(v32.numpy()[0], m23.numpy()[0]), tol=5 * tol)
1473
+ assert_np_equal(outcomponents_np[17:19], 2 * np.matmul(v2.numpy()[0], m2.numpy()[0]), tol=tol)
1474
+ assert_np_equal(outcomponents_np[19:22], 2 * np.matmul(v3.numpy()[0], m3.numpy()[0]), tol=tol)
1475
+ assert_np_equal(outcomponents_np[22:26], 2 * np.matmul(v4.numpy()[0], m4.numpy()[0]), tol=5 * tol)
1476
+ assert_np_equal(outcomponents_np[26:31], 2 * np.matmul(v5.numpy()[0], m5.numpy()[0]), tol=5 * tol)
1477
+ assert_np_equal(outcomponents_np[31:34], 2 * np.matmul(v32.numpy()[0], m23.numpy()[0]), tol=5 * tol)
1478
+
1479
+ if dtype in np_float_types:
1480
+ idx = 0
1481
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1482
+ for dim, inmat, invec in [(2, m2, v2), (3, m3, v3), (4, m4, v4), (5, m5, v5), (3, m23, v32)]:
1483
+ for i in range(dim):
1484
+ tape = wp.Tape()
1485
+ with tape:
1486
+ wp.launch(
1487
+ kernel,
1488
+ dim=1,
1489
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m23],
1490
+ outputs=[outcomponents],
1491
+ device=device,
1492
+ )
1493
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1494
+ tape.backward(loss=out)
1495
+
1496
+ assert_np_equal(tape.gradients[invec].numpy()[0], 2 * inmat.numpy()[0, :, i], tol=2 * tol)
1497
+ expectedresult = np.zeros(inmat.dtype._shape_, dtype=dtype)
1498
+ expectedresult[:, i] = 2 * invec.numpy()[0]
1499
+ assert_np_equal(tape.gradients[inmat].numpy()[0], expectedresult, tol=2 * tol)
1500
+
1501
+ tape.zero()
1502
+
1503
+ idx = idx + 1
1504
+
1505
+
1506
+ def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1507
+ rng = np.random.default_rng(123)
1508
+
1509
+ tol = {
1510
+ np.float16: 2.0e-2,
1511
+ np.float32: 5.0e-6,
1512
+ np.float64: 5.0e-7,
1513
+ }.get(dtype, 0)
1514
+
1515
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1516
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1517
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1518
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
1519
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1520
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1521
+
1522
+ output_select_kernel = get_select_kernel(wptype)
1523
+
1524
+ def check_mat_mat_mul(
1525
+ a2: wp.array(dtype=mat22),
1526
+ a3: wp.array(dtype=mat33),
1527
+ a4: wp.array(dtype=mat44),
1528
+ a5: wp.array(dtype=mat55),
1529
+ a32: wp.array(dtype=mat32),
1530
+ b2: wp.array(dtype=mat22),
1531
+ b3: wp.array(dtype=mat33),
1532
+ b4: wp.array(dtype=mat44),
1533
+ b5: wp.array(dtype=mat55),
1534
+ b32: wp.array(dtype=mat32),
1535
+ outcomponents: wp.array(dtype=wptype),
1536
+ ):
1537
+ c2result = b2[0] * a2[0]
1538
+ c3result = b3[0] * a3[0]
1539
+ c4result = b4[0] * a4[0]
1540
+ c5result = b5[0] * a5[0]
1541
+ c32result = b32[0] * a2[0]
1542
+ c32result2 = b3[0] * a32[0]
1543
+ c2result_2 = b2[0] @ a2[0]
1544
+ c3result_2 = b3[0] @ a3[0]
1545
+ c4result_2 = b4[0] @ a4[0]
1546
+ c5result_2 = b5[0] @ a5[0]
1547
+ c32result_2 = b32[0] @ a2[0]
1548
+ c32result2_2 = b3[0] @ a32[0]
1549
+
1550
+ # multiply outputs by 2 so we've got something to backpropagate:
1551
+ idx = 0
1552
+ for i in range(2):
1553
+ for j in range(2):
1554
+ outcomponents[idx] = wptype(2) * c2result[i, j]
1555
+ idx = idx + 1
1556
+
1557
+ for i in range(3):
1558
+ for j in range(3):
1559
+ outcomponents[idx] = wptype(2) * c3result[i, j]
1560
+ idx = idx + 1
1561
+
1562
+ for i in range(4):
1563
+ for j in range(4):
1564
+ outcomponents[idx] = wptype(2) * c4result[i, j]
1565
+ idx = idx + 1
1566
+
1567
+ for i in range(5):
1568
+ for j in range(5):
1569
+ outcomponents[idx] = wptype(2) * c5result[i, j]
1570
+ idx = idx + 1
1571
+
1572
+ for i in range(3):
1573
+ for j in range(2):
1574
+ outcomponents[idx] = wptype(2) * c32result[i, j]
1575
+ idx = idx + 1
1576
+
1577
+ for i in range(3):
1578
+ for j in range(2):
1579
+ outcomponents[idx] = wptype(2) * c32result2[i, j]
1580
+ idx = idx + 1
1581
+
1582
+ for i in range(2):
1583
+ for j in range(2):
1584
+ outcomponents[idx] = wptype(2) * c2result_2[i, j]
1585
+ idx = idx + 1
1586
+
1587
+ for i in range(3):
1588
+ for j in range(3):
1589
+ outcomponents[idx] = wptype(2) * c3result_2[i, j]
1590
+ idx = idx + 1
1591
+
1592
+ for i in range(4):
1593
+ for j in range(4):
1594
+ outcomponents[idx] = wptype(2) * c4result_2[i, j]
1595
+ idx = idx + 1
1596
+
1597
+ for i in range(5):
1598
+ for j in range(5):
1599
+ outcomponents[idx] = wptype(2) * c5result_2[i, j]
1600
+ idx = idx + 1
1601
+
1602
+ for i in range(3):
1603
+ for j in range(2):
1604
+ outcomponents[idx] = wptype(2) * c32result_2[i, j]
1605
+ idx = idx + 1
1606
+
1607
+ for i in range(3):
1608
+ for j in range(2):
1609
+ outcomponents[idx] = wptype(2) * c32result2_2[i, j]
1610
+ idx = idx + 1
1611
+
1612
+ kernel = getkernel(check_mat_mat_mul, suffix=dtype.__name__)
1613
+
1614
+ if register_kernels:
1615
+ return
1616
+
1617
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1618
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1619
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1620
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1621
+ v32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1622
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1623
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1624
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1625
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1626
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
1627
+ outcomponents = wp.zeros(
1628
+ 2 * (2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 3 * 2 + 3 * 2), dtype=wptype, requires_grad=True, device=device
1629
+ )
1630
+
1631
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
1632
+
1633
+ outcomponents_np = outcomponents.numpy()
1634
+
1635
+ assert_np_equal(outcomponents_np[:4].reshape((2, 2)), 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1636
+ assert_np_equal(outcomponents_np[4:13].reshape((3, 3)), 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1637
+ assert_np_equal(outcomponents_np[13:29].reshape((4, 4)), 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1638
+ assert_np_equal(outcomponents_np[29:54].reshape((5, 5)), 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1639
+ assert_np_equal(outcomponents_np[54:60].reshape((3, 2)), 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol)
1640
+ assert_np_equal(outcomponents_np[60:66].reshape((3, 2)), 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol)
1641
+ assert_np_equal(outcomponents_np[66:70].reshape((2, 2)), 2 * np.matmul(m2.numpy()[0], v2.numpy()[0]), tol=tol)
1642
+ assert_np_equal(outcomponents_np[70:79].reshape((3, 3)), 2 * np.matmul(m3.numpy()[0], v3.numpy()[0]), tol=tol)
1643
+ assert_np_equal(outcomponents_np[79:95].reshape((4, 4)), 2 * np.matmul(m4.numpy()[0], v4.numpy()[0]), tol=2 * tol)
1644
+ assert_np_equal(outcomponents_np[95:120].reshape((5, 5)), 2 * np.matmul(m5.numpy()[0], v5.numpy()[0]), tol=10 * tol)
1645
+ assert_np_equal(
1646
+ outcomponents_np[120:126].reshape((3, 2)), 2 * np.matmul(m32.numpy()[0], v2.numpy()[0]), tol=5 * tol
1647
+ )
1648
+ assert_np_equal(
1649
+ outcomponents_np[126:132].reshape((3, 2)), 2 * np.matmul(m3.numpy()[0], v32.numpy()[0]), tol=5 * tol
1650
+ )
1651
+
1652
+ if dtype in np_float_types:
1653
+ idx = 0
1654
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1655
+ for v, m in [(v2, m2), (v3, m3), (v4, m4), (v5, m5), (v2, m32), (v32, m3)]:
1656
+ rows, cols = m.dtype._shape_[0], v.dtype._shape_[1]
1657
+ for i in range(rows):
1658
+ for j in range(cols):
1659
+ tape = wp.Tape()
1660
+ with tape:
1661
+ wp.launch(
1662
+ kernel,
1663
+ dim=1,
1664
+ inputs=[v2, v3, v4, v5, v32, m2, m3, m4, m5, m32],
1665
+ outputs=[outcomponents],
1666
+ device=device,
1667
+ )
1668
+ wp.launch(
1669
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1670
+ )
1671
+ tape.backward(loss=out)
1672
+
1673
+ expected = np.zeros(v.dtype._shape_, dtype=dtype)
1674
+ expected[:, j] = 2 * m.numpy()[0, i, :]
1675
+ assert_np_equal(tape.gradients[v].numpy()[0], expected, tol=10 * tol)
1676
+
1677
+ expected = np.zeros(m.dtype._shape_, dtype=dtype)
1678
+ expected[i, :] = 2 * v.numpy()[0, :, j]
1679
+ assert_np_equal(tape.gradients[m].numpy()[0], expected, tol=10 * tol)
1680
+
1681
+ tape.zero()
1682
+ idx = idx + 1
1683
+
1684
+
1685
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1686
+ rng = np.random.default_rng(123)
1687
+
1688
+ tol = {
1689
+ np.float16: 5.0e-2,
1690
+ np.float32: 1.0e-6,
1691
+ np.float64: 1.0e-8,
1692
+ }.get(dtype, 0)
1693
+
1694
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1695
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1696
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1697
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1698
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1699
+
1700
+ output_select_kernel = get_select_kernel(wptype)
1701
+
1702
+ def check_mat_cw_mul(
1703
+ s2: wp.array(dtype=mat22),
1704
+ s3: wp.array(dtype=mat33),
1705
+ s4: wp.array(dtype=mat44),
1706
+ s5: wp.array(dtype=mat55),
1707
+ v2: wp.array(dtype=mat22),
1708
+ v3: wp.array(dtype=mat33),
1709
+ v4: wp.array(dtype=mat44),
1710
+ v5: wp.array(dtype=mat55),
1711
+ outcomponents: wp.array(dtype=wptype),
1712
+ ):
1713
+ v2result = wptype(2) * wp.cw_mul(v2[0], s2[0])
1714
+ v3result = wptype(2) * wp.cw_mul(v3[0], s3[0])
1715
+ v4result = wptype(2) * wp.cw_mul(v4[0], s4[0])
1716
+ v5result = wptype(2) * wp.cw_mul(v5[0], s5[0])
1717
+
1718
+ # multiply outputs by 2 so we've got something to backpropagate:
1719
+ idx = 0
1720
+ for i in range(2):
1721
+ for j in range(2):
1722
+ outcomponents[idx] = v2result[i, j]
1723
+ idx = idx + 1
1724
+
1725
+ for i in range(3):
1726
+ for j in range(3):
1727
+ outcomponents[idx] = v3result[i, j]
1728
+ idx = idx + 1
1729
+
1730
+ for i in range(4):
1731
+ for j in range(4):
1732
+ outcomponents[idx] = v4result[i, j]
1733
+ idx = idx + 1
1734
+
1735
+ for i in range(5):
1736
+ for j in range(5):
1737
+ outcomponents[idx] = v5result[i, j]
1738
+ idx = idx + 1
1739
+
1740
+ kernel = getkernel(check_mat_cw_mul, suffix=dtype.__name__)
1741
+
1742
+ if register_kernels:
1743
+ return
1744
+
1745
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1746
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1747
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1748
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1749
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1750
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1751
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1752
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1753
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1754
+
1755
+ wp.launch(
1756
+ kernel,
1757
+ dim=1,
1758
+ inputs=[
1759
+ s2,
1760
+ s3,
1761
+ s4,
1762
+ s5,
1763
+ v2,
1764
+ v3,
1765
+ v4,
1766
+ v5,
1767
+ ],
1768
+ outputs=[outcomponents],
1769
+ device=device,
1770
+ )
1771
+
1772
+ outcomponents_np = outcomponents.numpy()
1773
+
1774
+ assert_np_equal(outcomponents_np[:4], 2 * (v2.numpy() * s2.numpy()).reshape(-1), tol=50 * tol)
1775
+ assert_np_equal(outcomponents_np[4:13], 2 * (v3.numpy() * s3.numpy()).reshape(-1), tol=50 * tol)
1776
+ assert_np_equal(outcomponents_np[13:29], 2 * (v4.numpy() * s4.numpy()).reshape(-1), tol=50 * tol)
1777
+ assert_np_equal(outcomponents_np[29:54], 2 * (v5.numpy() * s5.numpy()).reshape(-1), tol=50 * tol)
1778
+
1779
+ if dtype in np_float_types:
1780
+ idx = 0
1781
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1782
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1783
+ for i in range(dim):
1784
+ for j in range(dim):
1785
+ tape = wp.Tape()
1786
+ with tape:
1787
+ wp.launch(
1788
+ kernel,
1789
+ dim=1,
1790
+ inputs=[
1791
+ s2,
1792
+ s3,
1793
+ s4,
1794
+ s5,
1795
+ v2,
1796
+ v3,
1797
+ v4,
1798
+ v5,
1799
+ ],
1800
+ outputs=[outcomponents],
1801
+ device=device,
1802
+ )
1803
+ wp.launch(
1804
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1805
+ )
1806
+ tape.backward(loss=out)
1807
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1808
+ expectedresult[i, j] = 2 * in1.numpy()[0][i, j]
1809
+ assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=5 * tol)
1810
+ expectedresult[i, j] = 2 * in2.numpy()[0][i, j]
1811
+ assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=5 * tol)
1812
+ tape.zero()
1813
+
1814
+ idx = idx + 1
1815
+
1816
+
1817
+ def test_cw_division(test, device, dtype, register_kernels=False):
1818
+ rng = np.random.default_rng(123)
1819
+
1820
+ tol = {
1821
+ np.float16: 1.0e-2,
1822
+ np.float32: 1.0e-6,
1823
+ np.float64: 1.0e-8,
1824
+ }.get(dtype, 0)
1825
+
1826
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1827
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1828
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
1829
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
1830
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
1831
+
1832
+ output_select_kernel = get_select_kernel(wptype)
1833
+
1834
+ def check_mat_cw_div(
1835
+ s2: wp.array(dtype=mat22),
1836
+ s3: wp.array(dtype=mat33),
1837
+ s4: wp.array(dtype=mat44),
1838
+ s5: wp.array(dtype=mat55),
1839
+ v2: wp.array(dtype=mat22),
1840
+ v3: wp.array(dtype=mat33),
1841
+ v4: wp.array(dtype=mat44),
1842
+ v5: wp.array(dtype=mat55),
1843
+ outcomponents: wp.array(dtype=wptype),
1844
+ ):
1845
+ v2result = wptype(2) * wp.cw_div(v2[0], s2[0])
1846
+ v3result = wptype(2) * wp.cw_div(v3[0], s3[0])
1847
+ v4result = wptype(2) * wp.cw_div(v4[0], s4[0])
1848
+ v5result = wptype(2) * wp.cw_div(v5[0], s5[0])
1849
+
1850
+ # multiply outputs by 2 so we've got something to backpropagate:
1851
+ idx = 0
1852
+ for i in range(2):
1853
+ for j in range(2):
1854
+ outcomponents[idx] = v2result[i, j]
1855
+ idx = idx + 1
1856
+
1857
+ for i in range(3):
1858
+ for j in range(3):
1859
+ outcomponents[idx] = v3result[i, j]
1860
+ idx = idx + 1
1861
+
1862
+ for i in range(4):
1863
+ for j in range(4):
1864
+ outcomponents[idx] = v4result[i, j]
1865
+ idx = idx + 1
1866
+
1867
+ for i in range(5):
1868
+ for j in range(5):
1869
+ outcomponents[idx] = v5result[i, j]
1870
+ idx = idx + 1
1871
+
1872
+ kernel = getkernel(check_mat_cw_div, suffix=dtype.__name__)
1873
+
1874
+ if register_kernels:
1875
+ return
1876
+
1877
+ s2 = randvals(rng, [1, 2, 2], dtype)
1878
+ s3 = randvals(rng, [1, 3, 3], dtype)
1879
+ s4 = randvals(rng, [1, 4, 4], dtype)
1880
+ s5 = randvals(rng, [1, 5, 5], dtype)
1881
+
1882
+ # set denominators to 1 if their magnitudes are small
1883
+ # to prevent divide by zero, or overflows if we're testing
1884
+ # float16:
1885
+ s2[np.abs(s2) < 1.0e-2] = 1
1886
+ s3[np.abs(s3) < 1.0e-2] = 1
1887
+ s4[np.abs(s4) < 1.0e-2] = 1
1888
+ s5[np.abs(s5) < 1.0e-2] = 1
1889
+
1890
+ s2 = wp.array(s2, dtype=mat22, requires_grad=True, device=device)
1891
+ s3 = wp.array(s3, dtype=mat33, requires_grad=True, device=device)
1892
+ s4 = wp.array(s4, dtype=mat44, requires_grad=True, device=device)
1893
+ s5 = wp.array(s5, dtype=mat55, requires_grad=True, device=device)
1894
+
1895
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
1896
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
1897
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
1898
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
1899
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
1900
+
1901
+ wp.launch(
1902
+ kernel,
1903
+ dim=1,
1904
+ inputs=[
1905
+ s2,
1906
+ s3,
1907
+ s4,
1908
+ s5,
1909
+ v2,
1910
+ v3,
1911
+ v4,
1912
+ v5,
1913
+ ],
1914
+ outputs=[outcomponents],
1915
+ device=device,
1916
+ )
1917
+
1918
+ if dtype in np_float_types:
1919
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() / s2.numpy()).reshape(-1), tol=50 * tol)
1920
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() / s3.numpy()).reshape(-1), tol=50 * tol)
1921
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() / s4.numpy()).reshape(-1), tol=50 * tol)
1922
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() / s5.numpy()).reshape(-1), tol=50 * tol)
1923
+ else:
1924
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() // s2.numpy()).reshape(-1), tol=50 * tol)
1925
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() // s3.numpy()).reshape(-1), tol=50 * tol)
1926
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() // s4.numpy()).reshape(-1), tol=50 * tol)
1927
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() // s5.numpy()).reshape(-1), tol=50 * tol)
1928
+
1929
+ if dtype in np_float_types:
1930
+ idx = 0
1931
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1932
+ for dim, s, v in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
1933
+ for i in range(dim):
1934
+ for j in range(dim):
1935
+ tape = wp.Tape()
1936
+ with tape:
1937
+ wp.launch(
1938
+ kernel,
1939
+ dim=1,
1940
+ inputs=[
1941
+ s2,
1942
+ s3,
1943
+ s4,
1944
+ s5,
1945
+ v2,
1946
+ v3,
1947
+ v4,
1948
+ v5,
1949
+ ],
1950
+ outputs=[outcomponents],
1951
+ device=device,
1952
+ )
1953
+ wp.launch(
1954
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
1955
+ )
1956
+ tape.backward(loss=out)
1957
+
1958
+ # y = v/s
1959
+ # dy/dv = 1.0/s
1960
+ # dy/ds = -v/s^2
1961
+
1962
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
1963
+ expectedresult[i, j] = 2.0 / (s.numpy()[0, i, j])
1964
+ assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=50 * tol)
1965
+ expectedresult[i, j] = -2.0 * v.numpy()[0, i, j] / (s.numpy()[0, i, j] ** 2)
1966
+ assert_np_equal(
1967
+ tape.gradients[s].numpy()[0], expectedresult, tol=abs(outcomponents.numpy()[idx]) * 50 * tol
1968
+ )
1969
+ tape.zero()
1970
+
1971
+ idx = idx + 1
1972
+
1973
+
1974
+ def test_outer_product(test, device, dtype, register_kernels=False):
1975
+ rng = np.random.default_rng(123)
1976
+
1977
+ tol = {
1978
+ np.float16: 5.0e-3,
1979
+ np.float32: 1.0e-6,
1980
+ np.float64: 1.0e-8,
1981
+ }.get(dtype, 0)
1982
+
1983
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1984
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1985
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1986
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1987
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1988
+
1989
+ output_select_kernel = get_select_kernel(wptype)
1990
+
1991
+ def check_mat_outer_product(
1992
+ s2: wp.array(dtype=vec2),
1993
+ s3: wp.array(dtype=vec3),
1994
+ s4: wp.array(dtype=vec4),
1995
+ s5: wp.array(dtype=vec5),
1996
+ v2: wp.array(dtype=vec2),
1997
+ v3: wp.array(dtype=vec3),
1998
+ v4: wp.array(dtype=vec4),
1999
+ v5: wp.array(dtype=vec5),
2000
+ outcomponents: wp.array(dtype=wptype),
2001
+ ):
2002
+ m22result = wptype(2) * wp.outer(s2[0], v2[0])
2003
+ m33result = wptype(2) * wp.outer(s3[0], v3[0])
2004
+ m44result = wptype(2) * wp.outer(s4[0], v4[0])
2005
+ m55result = wptype(2) * wp.outer(s5[0], v5[0])
2006
+ m25result = wptype(2) * wp.outer(s2[0], v5[0])
2007
+
2008
+ # multiply outputs by 2 so we've got something to backpropagate:
2009
+ idx = 0
2010
+ for i in range(2):
2011
+ for j in range(2):
2012
+ outcomponents[idx] = m22result[i, j]
2013
+ idx = idx + 1
2014
+
2015
+ for i in range(3):
2016
+ for j in range(3):
2017
+ outcomponents[idx] = m33result[i, j]
2018
+ idx = idx + 1
2019
+
2020
+ for i in range(4):
2021
+ for j in range(4):
2022
+ outcomponents[idx] = m44result[i, j]
2023
+ idx = idx + 1
2024
+
2025
+ for i in range(5):
2026
+ for j in range(5):
2027
+ outcomponents[idx] = m55result[i, j]
2028
+ idx = idx + 1
2029
+
2030
+ for i in range(2):
2031
+ for j in range(5):
2032
+ outcomponents[idx] = m25result[i, j]
2033
+ idx = idx + 1
2034
+
2035
+ kernel = getkernel(check_mat_outer_product, suffix=dtype.__name__)
2036
+
2037
+ if register_kernels:
2038
+ return
2039
+
2040
+ s2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2041
+ s3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2042
+ s4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2043
+ s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2044
+ v2 = wp.array(randvals(rng, [1, 2], dtype), dtype=vec2, requires_grad=True, device=device)
2045
+ v3 = wp.array(randvals(rng, [1, 3], dtype), dtype=vec3, requires_grad=True, device=device)
2046
+ v4 = wp.array(randvals(rng, [1, 4], dtype), dtype=vec4, requires_grad=True, device=device)
2047
+ v5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2048
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 5, dtype=wptype, requires_grad=True, device=device)
2049
+
2050
+ wp.launch(kernel, dim=1, inputs=[s2, s3, s4, s5, v2, v3, v4, v5], outputs=[outcomponents], device=device)
2051
+
2052
+ outcomponents_np = outcomponents.numpy()
2053
+
2054
+ assert_np_equal(outcomponents_np[:4].reshape((2, 2)), 2 * s2.numpy()[0, :, None] * v2.numpy()[0, None, :], tol=tol)
2055
+ assert_np_equal(
2056
+ outcomponents_np[4:13].reshape((3, 3)), 2 * s3.numpy()[0, :, None] * v3.numpy()[0, None, :], tol=10 * tol
2057
+ )
2058
+ assert_np_equal(
2059
+ outcomponents_np[13:29].reshape((4, 4)), 2 * s4.numpy()[0, :, None] * v4.numpy()[0, None, :], tol=10 * tol
2060
+ )
2061
+ assert_np_equal(
2062
+ outcomponents_np[29:54].reshape((5, 5)), 2 * s5.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol
2063
+ )
2064
+ assert_np_equal(
2065
+ outcomponents_np[54:].reshape(2, 5), 2 * s2.numpy()[0, :, None] * v5.numpy()[0, None, :], tol=10 * tol
2066
+ )
2067
+
2068
+ if dtype in np_float_types:
2069
+ idx = 0
2070
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2071
+ for s, v in [(s2, v2), (s3, v3), (s4, v4), (s5, v5), (s2, v5)]:
2072
+ rows = s.dtype._length_
2073
+ cols = v.dtype._length_
2074
+ for i in range(rows):
2075
+ for j in range(cols):
2076
+ tape = wp.Tape()
2077
+ with tape:
2078
+ wp.launch(
2079
+ kernel,
2080
+ dim=1,
2081
+ inputs=[
2082
+ s2,
2083
+ s3,
2084
+ s4,
2085
+ s5,
2086
+ v2,
2087
+ v3,
2088
+ v4,
2089
+ v5,
2090
+ ],
2091
+ outputs=[outcomponents],
2092
+ device=device,
2093
+ )
2094
+ wp.launch(
2095
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2096
+ )
2097
+ tape.backward(loss=out)
2098
+
2099
+ # this component's gonna be s_i * v_j, so its s gradient is gonna be nozero
2100
+ # at the ith component and its v gradient will be nonzero at the jth component:
2101
+
2102
+ expectedresult = np.zeros((rows), dtype=dtype)
2103
+ expectedresult[i] = 2 * v.numpy()[0, j]
2104
+ assert_np_equal(tape.gradients[s].numpy()[0], expectedresult, tol=10 * tol)
2105
+
2106
+ expectedresult = np.zeros((cols), dtype=dtype)
2107
+ expectedresult[j] = 2 * s.numpy()[0, i]
2108
+ assert_np_equal(tape.gradients[v].numpy()[0], expectedresult, tol=10 * tol)
2109
+ tape.zero()
2110
+
2111
+ idx = idx + 1
2112
+
2113
+
2114
+ def test_transpose(test, device, dtype, register_kernels=False):
2115
+ rng = np.random.default_rng(123)
2116
+
2117
+ tol = {
2118
+ np.float16: 1.0e-2,
2119
+ np.float32: 1.0e-6,
2120
+ np.float64: 1.0e-8,
2121
+ }.get(dtype, 0)
2122
+
2123
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2124
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2125
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2126
+ mat32 = wp.types.matrix(shape=(3, 2), dtype=wptype)
2127
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2128
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2129
+
2130
+ output_select_kernel = get_select_kernel(wptype)
2131
+
2132
+ def check_mat_transpose(
2133
+ m2: wp.array(dtype=mat22),
2134
+ m3: wp.array(dtype=mat33),
2135
+ m4: wp.array(dtype=mat44),
2136
+ m5: wp.array(dtype=mat55),
2137
+ m32: wp.array(dtype=mat32),
2138
+ outcomponents: wp.array(dtype=wptype),
2139
+ ):
2140
+ # multiply outputs by 2 so we've got something to backpropagate:
2141
+ mat2 = wptype(2) * wp.transpose(m2[0])
2142
+ mat3 = wptype(2) * wp.transpose(m3[0])
2143
+ mat4 = wptype(2) * wp.transpose(m4[0])
2144
+ mat5 = wptype(2) * wp.transpose(m5[0])
2145
+ mat32 = wptype(2) * wp.transpose(m32[0])
2146
+
2147
+ idx = 0
2148
+ for i in range(2):
2149
+ for j in range(2):
2150
+ outcomponents[idx] = mat2[i, j]
2151
+ idx = idx + 1
2152
+
2153
+ for i in range(3):
2154
+ for j in range(3):
2155
+ outcomponents[idx] = mat3[i, j]
2156
+ idx = idx + 1
2157
+
2158
+ for i in range(4):
2159
+ for j in range(4):
2160
+ outcomponents[idx] = mat4[i, j]
2161
+ idx = idx + 1
2162
+
2163
+ for i in range(5):
2164
+ for j in range(5):
2165
+ outcomponents[idx] = mat5[i, j]
2166
+ idx = idx + 1
2167
+
2168
+ for i in range(2):
2169
+ for j in range(3):
2170
+ outcomponents[idx] = mat32[i, j]
2171
+ idx = idx + 1
2172
+
2173
+ kernel = getkernel(check_mat_transpose, suffix=dtype.__name__)
2174
+
2175
+ if register_kernels:
2176
+ return
2177
+
2178
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2179
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2180
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2181
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2182
+ m32 = wp.array(randvals(rng, [1, 3, 2], dtype), dtype=mat32, requires_grad=True, device=device)
2183
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 2 * 3, dtype=wptype, requires_grad=True, device=device)
2184
+
2185
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
2186
+
2187
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy()[0].T.reshape(-1), tol=tol)
2188
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy()[0].T.reshape(-1), tol=tol)
2189
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy()[0].T.reshape(-1), tol=tol)
2190
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy()[0].T.reshape(-1), tol=tol)
2191
+ assert_np_equal(outcomponents.numpy()[54:], 2 * m32.numpy()[0].T.reshape(-1), tol=tol)
2192
+
2193
+ if dtype in np_float_types:
2194
+ idx = 0
2195
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2196
+ for input in [m2, m3, m4, m5]:
2197
+ for i in range(input.dtype._shape_[0]):
2198
+ for j in range(input.dtype._shape_[1]):
2199
+ tape = wp.Tape()
2200
+ with tape:
2201
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5, m32], outputs=[outcomponents], device=device)
2202
+ wp.launch(
2203
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2204
+ )
2205
+ tape.backward(loss=out)
2206
+ expectedresult = np.zeros((input.dtype._shape_[1], input.dtype._shape_[0]), dtype=dtype)
2207
+ expectedresult[j, i] = 2
2208
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult)
2209
+ tape.zero()
2210
+ idx = idx + 1
2211
+
2212
+
2213
+ def test_scalar_division(test, device, dtype, register_kernels=False):
2214
+ rng = np.random.default_rng(123)
2215
+
2216
+ tol = {
2217
+ np.float16: 1.0e-2,
2218
+ np.float32: 1.0e-6,
2219
+ np.float64: 1.0e-8,
2220
+ }.get(dtype, 0)
2221
+
2222
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2223
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2224
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2225
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2226
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2227
+
2228
+ output_select_kernel = get_select_kernel(wptype)
2229
+
2230
+ def check_mat_scalar_div(
2231
+ s: wp.array(dtype=wptype),
2232
+ m2: wp.array(dtype=mat22),
2233
+ m3: wp.array(dtype=mat33),
2234
+ m4: wp.array(dtype=mat44),
2235
+ m5: wp.array(dtype=mat55),
2236
+ outcomponents: wp.array(dtype=wptype),
2237
+ ):
2238
+ m2result = m2[0] / s[0]
2239
+ m3result = m3[0] / s[0]
2240
+ m4result = m4[0] / s[0]
2241
+ m5result = m5[0] / s[0]
2242
+
2243
+ # multiply outputs by 2 so we've got something to backpropagate:
2244
+ idx = 0
2245
+ for i in range(2):
2246
+ for j in range(2):
2247
+ outcomponents[idx] = wptype(2) * m2result[i, j]
2248
+ idx = idx + 1
2249
+
2250
+ for i in range(3):
2251
+ for j in range(3):
2252
+ outcomponents[idx] = wptype(2) * m3result[i, j]
2253
+ idx = idx + 1
2254
+
2255
+ for i in range(4):
2256
+ for j in range(4):
2257
+ outcomponents[idx] = wptype(2) * m4result[i, j]
2258
+ idx = idx + 1
2259
+
2260
+ for i in range(5):
2261
+ for j in range(5):
2262
+ outcomponents[idx] = wptype(2) * m5result[i, j]
2263
+ idx = idx + 1
2264
+
2265
+ kernel = getkernel(check_mat_scalar_div, suffix=dtype.__name__)
2266
+
2267
+ if register_kernels:
2268
+ return
2269
+
2270
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
2271
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2272
+ m3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2273
+ m4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2274
+ m5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2275
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2276
+
2277
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2278
+
2279
+ sval = s.numpy()[0]
2280
+ if dtype in np_float_types:
2281
+ assert_np_equal(outcomponents.numpy()[:4], 2 * m2.numpy().reshape(-1) / sval, tol=tol)
2282
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * m3.numpy().reshape(-1) / sval, tol=10 * tol)
2283
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * m4.numpy().reshape(-1) / sval, tol=10 * tol)
2284
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * m5.numpy().reshape(-1) / sval, tol=10 * tol)
2285
+ else:
2286
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (m2.numpy().reshape(-1) // sval), tol=tol)
2287
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (m3.numpy().reshape(-1) // sval), tol=10 * tol)
2288
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (m4.numpy().reshape(-1) // sval), tol=10 * tol)
2289
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (m5.numpy().reshape(-1) // sval), tol=10 * tol)
2290
+
2291
+ if dtype in np_float_types:
2292
+ idx = 0
2293
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2294
+ for dim, input in [(2, m2), (3, m3), (4, m4), (5, m5)]:
2295
+ for i in range(dim):
2296
+ for j in range(dim):
2297
+ tape = wp.Tape()
2298
+ with tape:
2299
+ wp.launch(kernel, dim=1, inputs=[s, m2, m3, m4, m5], outputs=[outcomponents], device=device)
2300
+ wp.launch(
2301
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2302
+ )
2303
+ tape.backward(loss=out)
2304
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
2305
+ expectedresult[i, j] = 2.0 / sval
2306
+ assert_np_equal(tape.gradients[input].numpy()[0], expectedresult, tol=10 * tol)
2307
+ assert_np_equal(
2308
+ tape.gradients[s].numpy()[0], -2 * input.numpy()[0, i, j] / (sval * sval), tol=10 * tol
2309
+ )
2310
+ tape.zero()
2311
+
2312
+ idx = idx + 1
2313
+
2314
+
2315
+ def test_addition(test, device, dtype, register_kernels=False):
2316
+ rng = np.random.default_rng(123)
2317
+
2318
+ tol = {
2319
+ np.float16: 2.0e-2,
2320
+ np.float32: 5.0e-6,
2321
+ np.float64: 1.0e-8,
2322
+ }.get(dtype, 0)
2323
+
2324
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2325
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2326
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2327
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2328
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2329
+
2330
+ output_select_kernel = get_select_kernel(wptype)
2331
+
2332
+ def check_mat_add(
2333
+ s2: wp.array(dtype=mat22),
2334
+ s3: wp.array(dtype=mat33),
2335
+ s4: wp.array(dtype=mat44),
2336
+ s5: wp.array(dtype=mat55),
2337
+ v2: wp.array(dtype=mat22),
2338
+ v3: wp.array(dtype=mat33),
2339
+ v4: wp.array(dtype=mat44),
2340
+ v5: wp.array(dtype=mat55),
2341
+ outcomponents: wp.array(dtype=wptype),
2342
+ ):
2343
+ v2result = v2[0] + s2[0]
2344
+ v3result = v3[0] + s3[0]
2345
+ v4result = v4[0] + s4[0]
2346
+ v5result = v5[0] + s5[0]
2347
+
2348
+ # multiply outputs by 2 so we've got something to backpropagate:
2349
+ idx = 0
2350
+ for i in range(2):
2351
+ for j in range(2):
2352
+ outcomponents[idx] = wptype(2) * v2result[i, j]
2353
+ idx = idx + 1
2354
+
2355
+ for i in range(3):
2356
+ for j in range(3):
2357
+ outcomponents[idx] = wptype(2) * v3result[i, j]
2358
+ idx = idx + 1
2359
+
2360
+ for i in range(4):
2361
+ for j in range(4):
2362
+ outcomponents[idx] = wptype(2) * v4result[i, j]
2363
+ idx = idx + 1
2364
+
2365
+ for i in range(5):
2366
+ for j in range(5):
2367
+ outcomponents[idx] = wptype(2) * v5result[i, j]
2368
+ idx = idx + 1
2369
+
2370
+ kernel = getkernel(check_mat_add, suffix=dtype.__name__)
2371
+
2372
+ if register_kernels:
2373
+ return
2374
+
2375
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2376
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2377
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2378
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2379
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2380
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2381
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2382
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2383
+ outcomponents = wp.zeros(2 * 2 + 3 * 3 + 4 * 4 + 5 * 5, dtype=wptype, requires_grad=True, device=device)
2384
+
2385
+ wp.launch(
2386
+ kernel,
2387
+ dim=1,
2388
+ inputs=[
2389
+ s2,
2390
+ s3,
2391
+ s4,
2392
+ s5,
2393
+ v2,
2394
+ v3,
2395
+ v4,
2396
+ v5,
2397
+ ],
2398
+ outputs=[outcomponents],
2399
+ device=device,
2400
+ )
2401
+
2402
+ assert_np_equal(outcomponents.numpy()[:4], 2 * (v2.numpy() + s2.numpy()).reshape(-1), tol=tol)
2403
+ assert_np_equal(outcomponents.numpy()[4:13], 2 * (v3.numpy() + s3.numpy()).reshape(-1), tol=tol)
2404
+ assert_np_equal(outcomponents.numpy()[13:29], 2 * (v4.numpy() + s4.numpy()).reshape(-1), tol=tol)
2405
+ assert_np_equal(outcomponents.numpy()[29:54], 2 * (v5.numpy() + s5.numpy()).reshape(-1), tol=tol)
2406
+
2407
+ if dtype in np_float_types:
2408
+ idx = 0
2409
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2410
+ for dim, in1, in2 in [(2, s2, v2), (3, s3, v3), (4, s4, v4), (5, s5, v5)]:
2411
+ for i in range(dim):
2412
+ for j in range(dim):
2413
+ tape = wp.Tape()
2414
+ with tape:
2415
+ wp.launch(
2416
+ kernel,
2417
+ dim=1,
2418
+ inputs=[
2419
+ s2,
2420
+ s3,
2421
+ s4,
2422
+ s5,
2423
+ v2,
2424
+ v3,
2425
+ v4,
2426
+ v5,
2427
+ ],
2428
+ outputs=[outcomponents],
2429
+ device=device,
2430
+ )
2431
+ wp.launch(
2432
+ output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device
2433
+ )
2434
+ tape.backward(loss=out)
2435
+ expectedresult = np.zeros((dim, dim), dtype=dtype)
2436
+ expectedresult[i, j] = 2
2437
+ assert_np_equal(tape.gradients[in2].numpy()[0], expectedresult, tol=10 * tol)
2438
+ expectedresult[i, j] = 2
2439
+ assert_np_equal(tape.gradients[in1].numpy()[0], expectedresult, tol=10 * tol)
2440
+ tape.zero()
2441
+
2442
+ idx = idx + 1
2443
+
2444
+
2445
+ def test_ddot(test, device, dtype, register_kernels=False):
2446
+ rng = np.random.default_rng(123)
2447
+
2448
+ tol = {
2449
+ np.float16: 5.0e-3,
2450
+ np.float32: 1.0e-6,
2451
+ np.float64: 1.0e-8,
2452
+ }.get(dtype, 0)
2453
+
2454
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2455
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2456
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2457
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2458
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2459
+
2460
+ def check_mat_dot(
2461
+ s2: wp.array(dtype=mat22),
2462
+ s3: wp.array(dtype=mat33),
2463
+ s4: wp.array(dtype=mat44),
2464
+ s5: wp.array(dtype=mat55),
2465
+ v2: wp.array(dtype=mat22),
2466
+ v3: wp.array(dtype=mat33),
2467
+ v4: wp.array(dtype=mat44),
2468
+ v5: wp.array(dtype=mat55),
2469
+ dot2: wp.array(dtype=wptype),
2470
+ dot3: wp.array(dtype=wptype),
2471
+ dot4: wp.array(dtype=wptype),
2472
+ dot5: wp.array(dtype=wptype),
2473
+ ):
2474
+ # multiply outputs by 2 so we've got something to backpropagate:
2475
+ dot2[0] = wptype(2) * wp.ddot(v2[0], s2[0])
2476
+ dot3[0] = wptype(2) * wp.ddot(v3[0], s3[0])
2477
+ dot4[0] = wptype(2) * wp.ddot(v4[0], s4[0])
2478
+ dot5[0] = wptype(2) * wp.ddot(v5[0], s5[0])
2479
+
2480
+ kernel = getkernel(check_mat_dot, suffix=dtype.__name__)
2481
+
2482
+ if register_kernels:
2483
+ return
2484
+
2485
+ s2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2486
+ s3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2487
+ s4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2488
+ s5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2489
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2490
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2491
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2492
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2493
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2494
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2495
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2496
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2497
+
2498
+ tape = wp.Tape()
2499
+ with tape:
2500
+ wp.launch(
2501
+ kernel,
2502
+ dim=1,
2503
+ inputs=[
2504
+ s2,
2505
+ s3,
2506
+ s4,
2507
+ s5,
2508
+ v2,
2509
+ v3,
2510
+ v4,
2511
+ v5,
2512
+ ],
2513
+ outputs=[dot2, dot3, dot4, dot5],
2514
+ device=device,
2515
+ )
2516
+
2517
+ assert_np_equal(dot2.numpy()[0], 2 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
2518
+ assert_np_equal(dot3.numpy()[0], 2 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
2519
+ assert_np_equal(dot4.numpy()[0], 2 * (v4.numpy() * s4.numpy()).sum(), tol=50 * tol)
2520
+ assert_np_equal(dot5.numpy()[0], 2 * (v5.numpy() * s5.numpy()).sum(), tol=200 * tol)
2521
+
2522
+ if dtype in np_float_types:
2523
+ tape.backward(loss=dot2)
2524
+ sgrads = tape.gradients[s2].numpy()[0]
2525
+ expected_grads = 2.0 * v2.numpy()[0]
2526
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2527
+
2528
+ vgrads = tape.gradients[v2].numpy()[0]
2529
+ expected_grads = 2.0 * s2.numpy()[0]
2530
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2531
+
2532
+ tape.zero()
2533
+
2534
+ tape.backward(loss=dot3)
2535
+ sgrads = tape.gradients[s3].numpy()[0]
2536
+ expected_grads = 2.0 * v3.numpy()[0]
2537
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2538
+
2539
+ vgrads = tape.gradients[v3].numpy()[0]
2540
+ expected_grads = 2.0 * s3.numpy()[0]
2541
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2542
+
2543
+ tape.zero()
2544
+
2545
+ tape.backward(loss=dot4)
2546
+ sgrads = tape.gradients[s4].numpy()[0]
2547
+ expected_grads = 2.0 * v4.numpy()[0]
2548
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2549
+
2550
+ vgrads = tape.gradients[v4].numpy()[0]
2551
+ expected_grads = 2.0 * s4.numpy()[0]
2552
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2553
+
2554
+ tape.zero()
2555
+
2556
+ tape.backward(loss=dot5)
2557
+ sgrads = tape.gradients[s5].numpy()[0]
2558
+ expected_grads = 2.0 * v5.numpy()[0]
2559
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
2560
+
2561
+ vgrads = tape.gradients[v5].numpy()[0]
2562
+ expected_grads = 2.0 * s5.numpy()[0]
2563
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
2564
+
2565
+ tape.zero()
2566
+
2567
+
2568
+ def test_trace(test, device, dtype, register_kernels=False):
2569
+ rng = np.random.default_rng(123)
2570
+
2571
+ tol = {
2572
+ np.float16: 1.0e-3,
2573
+ np.float32: 1.0e-6,
2574
+ np.float64: 1.0e-8,
2575
+ }.get(dtype, 0)
2576
+
2577
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2578
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2579
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2580
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2581
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2582
+
2583
+ def check_mat_trace(
2584
+ v2: wp.array(dtype=mat22),
2585
+ v3: wp.array(dtype=mat33),
2586
+ v4: wp.array(dtype=mat44),
2587
+ v5: wp.array(dtype=mat55),
2588
+ tr2: wp.array(dtype=wptype),
2589
+ tr3: wp.array(dtype=wptype),
2590
+ tr4: wp.array(dtype=wptype),
2591
+ tr5: wp.array(dtype=wptype),
2592
+ ):
2593
+ # multiply outputs by 2 so we've got something to backpropagate:
2594
+ tr2[0] = wptype(2) * wp.trace(v2[0])
2595
+ tr3[0] = wptype(2) * wp.trace(v3[0])
2596
+ tr4[0] = wptype(2) * wp.trace(v4[0])
2597
+ tr5[0] = wptype(2) * wp.trace(v5[0])
2598
+
2599
+ kernel = getkernel(check_mat_trace, suffix=dtype.__name__)
2600
+
2601
+ if register_kernels:
2602
+ return
2603
+
2604
+ v2 = wp.array(randvals(rng, [1, 2, 2], dtype), dtype=mat22, requires_grad=True, device=device)
2605
+ v3 = wp.array(randvals(rng, [1, 3, 3], dtype), dtype=mat33, requires_grad=True, device=device)
2606
+ v4 = wp.array(randvals(rng, [1, 4, 4], dtype), dtype=mat44, requires_grad=True, device=device)
2607
+ v5 = wp.array(randvals(rng, [1, 5, 5], dtype), dtype=mat55, requires_grad=True, device=device)
2608
+ tr2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2609
+ tr3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2610
+ tr4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2611
+ tr5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2612
+
2613
+ tape = wp.Tape()
2614
+ with tape:
2615
+ wp.launch(
2616
+ kernel,
2617
+ dim=1,
2618
+ inputs=[
2619
+ v2,
2620
+ v3,
2621
+ v4,
2622
+ v5,
2623
+ ],
2624
+ outputs=[
2625
+ tr2,
2626
+ tr3,
2627
+ tr4,
2628
+ tr5,
2629
+ ],
2630
+ device=device,
2631
+ )
2632
+
2633
+ assert_np_equal(tr2.numpy()[0], 2 * np.trace(v2.numpy()[0]), tol=10 * tol)
2634
+ assert_np_equal(tr3.numpy()[0], 2 * np.trace(v3.numpy()[0]), tol=10 * tol)
2635
+ assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2636
+ assert_np_equal(tr4.numpy()[0], 2 * np.trace(v4.numpy()[0]), tol=200 * tol)
2637
+
2638
+ if dtype in np_float_types:
2639
+ tape.backward(loss=tr2)
2640
+ vgrads = tape.gradients[v2].numpy()[0]
2641
+ assert_np_equal(vgrads, 2.0 * np.eye(2), tol=10 * tol)
2642
+ tape.zero()
2643
+
2644
+ tape.backward(loss=tr3)
2645
+ vgrads = tape.gradients[v3].numpy()[0]
2646
+ assert_np_equal(vgrads, 2.0 * np.eye(3), tol=10 * tol)
2647
+ tape.zero()
2648
+
2649
+ tape.backward(loss=tr4)
2650
+ vgrads = tape.gradients[v4].numpy()[0]
2651
+ assert_np_equal(vgrads, 2.0 * np.eye(4), tol=10 * tol)
2652
+ tape.zero()
2653
+
2654
+ tape.backward(loss=tr5)
2655
+ vgrads = tape.gradients[v5].numpy()[0]
2656
+ assert_np_equal(vgrads, 2.0 * np.eye(5), tol=10 * tol)
2657
+ tape.zero()
2658
+
2659
+
2660
+ def test_diag(test, device, dtype, register_kernels=False):
2661
+ rng = np.random.default_rng(123)
2662
+
2663
+ tol = {
2664
+ np.float16: 1.0e-3,
2665
+ np.float32: 1.0e-6,
2666
+ np.float64: 1.0e-8,
2667
+ }.get(dtype, 0)
2668
+
2669
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2670
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2671
+
2672
+ output_select_kernel = get_select_kernel(wptype)
2673
+
2674
+ def check_mat_diag(
2675
+ s5: wp.array(dtype=vec5),
2676
+ outcomponents: wp.array(dtype=wptype),
2677
+ ):
2678
+ # multiply outputs by 2 so we've got something to backpropagate:
2679
+ m55result = wptype(2) * wp.diag(s5[0])
2680
+
2681
+ idx = 0
2682
+ for i in range(5):
2683
+ for j in range(5):
2684
+ outcomponents[idx] = m55result[i, j]
2685
+ idx = idx + 1
2686
+
2687
+ kernel = getkernel(check_mat_diag, suffix=dtype.__name__)
2688
+
2689
+ if register_kernels:
2690
+ return
2691
+
2692
+ s5 = wp.array(randvals(rng, [1, 5], dtype), dtype=vec5, requires_grad=True, device=device)
2693
+ outcomponents = wp.zeros(5 * 5, dtype=wptype, requires_grad=True, device=device)
2694
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2695
+
2696
+ wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2697
+
2698
+ assert_np_equal(outcomponents.reshape((5, 5)).numpy(), 2 * np.diag(s5.numpy()[0]), tol=tol)
2699
+
2700
+ if dtype in np_float_types:
2701
+ idx = 0
2702
+ for i in range(5):
2703
+ for j in range(5):
2704
+ tape = wp.Tape()
2705
+ with tape:
2706
+ wp.launch(kernel, dim=1, inputs=[s5], outputs=[outcomponents], device=device)
2707
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
2708
+ tape.backward(loss=out)
2709
+ expectedresult = np.zeros(5, dtype=dtype)
2710
+ if i == j:
2711
+ expectedresult[i] = 2
2712
+ assert_np_equal(tape.gradients[s5].numpy()[0], expectedresult, tol=10 * tol)
2713
+ tape.zero()
2714
+
2715
+ idx = idx + 1
2716
+
2717
+
2718
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
2719
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2720
+
2721
+ # matrix types
2722
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
2723
+ mat33 = wp.types.matrix(shape=(3, 3), dtype=wptype)
2724
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
2725
+ mat55 = wp.types.matrix(shape=(5, 5), dtype=wptype)
2726
+
2727
+ # matrix types equivalent to the above
2728
+ mat22_equiv = wp.types.matrix(shape=(2, 2), dtype=wptype)
2729
+ mat33_equiv = wp.types.matrix(shape=(3, 3), dtype=wptype)
2730
+ mat44_equiv = wp.types.matrix(shape=(4, 4), dtype=wptype)
2731
+ mat55_equiv = wp.types.matrix(shape=(5, 5), dtype=wptype)
2732
+
2733
+ # declare kernel with original types
2734
+ def check_equivalence(
2735
+ m2: mat22,
2736
+ m3: mat33,
2737
+ m4: mat44,
2738
+ m5: mat55,
2739
+ ):
2740
+ wp.expect_eq(m2, mat22(wptype(42)))
2741
+ wp.expect_eq(m3, mat33(wptype(43)))
2742
+ wp.expect_eq(m4, mat44(wptype(44)))
2743
+ wp.expect_eq(m5, mat55(wptype(45)))
2744
+
2745
+ wp.expect_eq(m2, mat22_equiv(wptype(42)))
2746
+ wp.expect_eq(m3, mat33_equiv(wptype(43)))
2747
+ wp.expect_eq(m4, mat44_equiv(wptype(44)))
2748
+ wp.expect_eq(m5, mat55_equiv(wptype(45)))
2749
+
2750
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
2751
+
2752
+ if register_kernels:
2753
+ return
2754
+
2755
+ # call kernel with equivalent types
2756
+ m2 = mat22_equiv(42)
2757
+ m3 = mat33_equiv(43)
2758
+ m4 = mat44_equiv(44)
2759
+ m5 = mat55_equiv(45)
2760
+
2761
+ wp.launch(kernel, dim=1, inputs=[m2, m3, m4, m5], device=device)
2762
+
2763
+
2764
+ def test_conversions(test, device, dtype, register_kernels=False):
2765
+ def check_matrices_equal(
2766
+ m0: wp.mat22,
2767
+ m1: wp.mat22,
2768
+ m2: wp.mat22,
2769
+ m3: wp.mat22,
2770
+ m4: wp.mat22,
2771
+ m5: wp.mat22,
2772
+ m6: wp.mat22,
2773
+ ):
2774
+ wp.expect_eq(m1, m0)
2775
+ wp.expect_eq(m2, m0)
2776
+ wp.expect_eq(m3, m0)
2777
+ wp.expect_eq(m4, m0)
2778
+ wp.expect_eq(m5, m0)
2779
+ wp.expect_eq(m6, m0)
2780
+
2781
+ kernel = getkernel(check_matrices_equal, suffix=dtype.__name__)
2782
+
2783
+ if register_kernels:
2784
+ return
2785
+
2786
+ m0 = wp.mat22(1, 2, 3, 4)
2787
+
2788
+ # test explicit conversions - constructing matrices from different containers
2789
+ m1 = wp.mat22(((1, 2), (3, 4))) # nested tuples
2790
+ m2 = wp.mat22([[1, 2], [3, 4]]) # nested lists
2791
+ m3 = wp.mat22(np.array([[1, 2], [3, 4]], dtype=dtype)) # 2d array
2792
+ m4 = wp.mat22((1, 2, 3, 4)) # flat tuple
2793
+ m5 = wp.mat22([1, 2, 3, 4]) # flat list
2794
+ m6 = wp.mat22(np.array([1, 2, 3, 4], dtype=dtype)) # 1d array
2795
+
2796
+ wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
2797
+
2798
+ # test implicit conversions - passing different containers as matrices to wp.launch()
2799
+ m1 = ((1, 2), (3, 4)) # nested tuples
2800
+ m2 = [[1, 2], [3, 4]] # nested lists
2801
+ m3 = np.array([[1, 2], [3, 4]], dtype=dtype) # 2d array
2802
+ m4 = (1, 2, 3, 4) # flat tuple
2803
+ m5 = [1, 2, 3, 4] # flat list
2804
+ m6 = np.array([1, 2, 3, 4], dtype=dtype) # 1d array
2805
+
2806
+ wp.launch(kernel, dim=1, inputs=[m0, m1, m2, m3, m4, m5, m6], device=device)
2807
+
2808
+
2809
+ devices = get_test_devices()
2810
+
2811
+
2812
+ class TestMatScalarOps(unittest.TestCase):
2813
+ pass
2814
+
2815
+
2816
+ for dtype in np_scalar_types:
2817
+ add_function_test(TestMatScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2818
+ add_function_test(TestMatScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2819
+ add_function_test_register_kernel(
2820
+ TestMatScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2821
+ )
2822
+ add_function_test_register_kernel(
2823
+ TestMatScalarOps,
2824
+ f"test_anon_type_instance_{dtype.__name__}",
2825
+ test_anon_type_instance,
2826
+ devices=devices,
2827
+ dtype=dtype,
2828
+ )
2829
+ add_function_test_register_kernel(
2830
+ TestMatScalarOps, f"test_identity_{dtype.__name__}", test_identity, devices=devices, dtype=dtype
2831
+ )
2832
+ add_function_test_register_kernel(
2833
+ TestMatScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2834
+ )
2835
+ add_function_test_register_kernel(
2836
+ TestMatScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2837
+ )
2838
+ add_function_test_register_kernel(
2839
+ TestMatScalarOps,
2840
+ f"test_scalar_multiplication_{dtype.__name__}",
2841
+ test_scalar_multiplication,
2842
+ devices=devices,
2843
+ dtype=dtype,
2844
+ )
2845
+ add_function_test_register_kernel(
2846
+ TestMatScalarOps,
2847
+ f"test_matvec_multiplication_{dtype.__name__}",
2848
+ test_matvec_multiplication,
2849
+ devices=devices,
2850
+ dtype=dtype,
2851
+ )
2852
+ add_function_test_register_kernel(
2853
+ TestMatScalarOps,
2854
+ f"test_vecmat_multiplication_{dtype.__name__}",
2855
+ test_vecmat_multiplication,
2856
+ devices=devices,
2857
+ dtype=dtype,
2858
+ )
2859
+ add_function_test_register_kernel(
2860
+ TestMatScalarOps,
2861
+ f"test_matmat_multiplication_{dtype.__name__}",
2862
+ test_matmat_multiplication,
2863
+ devices=devices,
2864
+ dtype=dtype,
2865
+ )
2866
+ add_function_test_register_kernel(
2867
+ TestMatScalarOps,
2868
+ f"test_cw_multiplication_{dtype.__name__}",
2869
+ test_cw_multiplication,
2870
+ devices=devices,
2871
+ dtype=dtype,
2872
+ )
2873
+ add_function_test_register_kernel(
2874
+ TestMatScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2875
+ )
2876
+ add_function_test_register_kernel(
2877
+ TestMatScalarOps, f"test_outer_product_{dtype.__name__}", test_outer_product, devices=devices, dtype=dtype
2878
+ )
2879
+ add_function_test_register_kernel(
2880
+ TestMatScalarOps, f"test_transpose_{dtype.__name__}", test_transpose, devices=devices, dtype=dtype
2881
+ )
2882
+ add_function_test_register_kernel(
2883
+ TestMatScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2884
+ )
2885
+ add_function_test_register_kernel(
2886
+ TestMatScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2887
+ )
2888
+ add_function_test_register_kernel(
2889
+ TestMatScalarOps, f"test_ddot_{dtype.__name__}", test_ddot, devices=devices, dtype=dtype
2890
+ )
2891
+ add_function_test_register_kernel(
2892
+ TestMatScalarOps, f"test_trace_{dtype.__name__}", test_trace, devices=devices, dtype=dtype
2893
+ )
2894
+ add_function_test_register_kernel(
2895
+ TestMatScalarOps, f"test_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
2896
+ )
2897
+ add_function_test_register_kernel(
2898
+ TestMatScalarOps, f"test_get_diag_{dtype.__name__}", test_diag, devices=devices, dtype=dtype
2899
+ )
2900
+ add_function_test_register_kernel(
2901
+ TestMatScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2902
+ )
2903
+ add_function_test_register_kernel(
2904
+ TestMatScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2905
+ )
2906
+ add_function_test_register_kernel(
2907
+ TestMatScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=None, dtype=dtype
2908
+ )
2909
+
2910
+
2911
+ if __name__ == "__main__":
2912
+ wp.clear_kernel_cache()
2913
+ unittest.main(verbosity=2, failfast=True)