warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2327 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_signed_int_types = [
24
+ np.int8,
25
+ np.int16,
26
+ np.int32,
27
+ np.int64,
28
+ np.byte,
29
+ ]
30
+
31
+ np_unsigned_int_types = [
32
+ np.uint8,
33
+ np.uint16,
34
+ np.uint32,
35
+ np.uint64,
36
+ np.ubyte,
37
+ ]
38
+
39
+ np_int_types = np_signed_int_types + np_unsigned_int_types
40
+
41
+ np_float_types = [np.float16, np.float32, np.float64]
42
+
43
+ np_scalar_types = np_int_types + np_float_types
44
+
45
+
46
+ def randvals(rng, shape, dtype):
47
+ if dtype in np_float_types:
48
+ return rng.standard_normal(size=shape).astype(dtype)
49
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
50
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
51
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
52
+
53
+
54
+ kernel_cache = {}
55
+
56
+
57
+ def getkernel(func, suffix=""):
58
+ key = func.__name__ + "_" + suffix
59
+ if key not in kernel_cache:
60
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
61
+ return kernel_cache[key]
62
+
63
+
64
+ def get_select_kernel(dtype):
65
+ def output_select_kernel_fn(
66
+ input: wp.array(dtype=dtype),
67
+ index: int,
68
+ out: wp.array(dtype=dtype),
69
+ ):
70
+ out[0] = input[index]
71
+
72
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
73
+
74
+
75
+ def get_select_kernel2(dtype):
76
+ def output_select_kernel2_fn(
77
+ input: wp.array(dtype=dtype, ndim=2),
78
+ index0: int,
79
+ index1: int,
80
+ out: wp.array(dtype=dtype),
81
+ ):
82
+ out[0] = input[index0, index1]
83
+
84
+ return getkernel(output_select_kernel2_fn, suffix=dtype.__name__)
85
+
86
+
87
+ def test_arrays(test, device, dtype):
88
+ rng = np.random.default_rng(123)
89
+
90
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
91
+ vec2 = wp.types.vector(length=2, dtype=wptype)
92
+ vec3 = wp.types.vector(length=3, dtype=wptype)
93
+ vec4 = wp.types.vector(length=4, dtype=wptype)
94
+ vec5 = wp.types.vector(length=5, dtype=wptype)
95
+
96
+ v2_np = randvals(rng, (10, 2), dtype)
97
+ v3_np = randvals(rng, (10, 3), dtype)
98
+ v4_np = randvals(rng, (10, 4), dtype)
99
+ v5_np = randvals(rng, (10, 5), dtype)
100
+
101
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
102
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
103
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
104
+ v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
105
+
106
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
107
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
108
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
109
+ assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
110
+
111
+ vec2 = wp.types.vector(length=2, dtype=wptype)
112
+ vec3 = wp.types.vector(length=3, dtype=wptype)
113
+ vec4 = wp.types.vector(length=4, dtype=wptype)
114
+
115
+ v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
116
+ v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
117
+ v4 = wp.array(v4_np, dtype=vec4, requires_grad=True, device=device)
118
+
119
+ assert_np_equal(v2.numpy(), v2_np, tol=1.0e-6)
120
+ assert_np_equal(v3.numpy(), v3_np, tol=1.0e-6)
121
+ assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
122
+
123
+
124
+ def test_components(test, device, dtype):
125
+ # test accessing vector components from Python - this is especially important
126
+ # for float16, which requires special handling internally
127
+
128
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
129
+ vec3 = wp.types.vector(length=3, dtype=wptype)
130
+
131
+ v = vec3(1, 2, 3)
132
+
133
+ # test __getitem__ for individual components
134
+ test.assertEqual(v[0], 1)
135
+ test.assertEqual(v[1], 2)
136
+ test.assertEqual(v[2], 3)
137
+
138
+ # test __getitem__ for slices
139
+ s = v[:]
140
+ test.assertEqual(s[0], 1)
141
+ test.assertEqual(s[1], 2)
142
+ test.assertEqual(s[2], 3)
143
+
144
+ s = v[1:]
145
+ test.assertEqual(s[0], 2)
146
+ test.assertEqual(s[1], 3)
147
+
148
+ s = v[:2]
149
+ test.assertEqual(s[0], 1)
150
+ test.assertEqual(s[1], 2)
151
+
152
+ s = v[::2]
153
+ test.assertEqual(s[0], 1)
154
+ test.assertEqual(s[1], 3)
155
+
156
+ # test __setitem__ for individual components
157
+ v[0] = 4
158
+ v[1] = 5
159
+ v[2] = 6
160
+ test.assertEqual(v[0], 4)
161
+ test.assertEqual(v[1], 5)
162
+ test.assertEqual(v[2], 6)
163
+
164
+ # test __setitem__ for slices
165
+ v[:] = [7, 8, 9]
166
+ test.assertEqual(v[0], 7)
167
+ test.assertEqual(v[1], 8)
168
+ test.assertEqual(v[2], 9)
169
+
170
+ v[1:] = [10, 11]
171
+ test.assertEqual(v[0], 7)
172
+ test.assertEqual(v[1], 10)
173
+ test.assertEqual(v[2], 11)
174
+
175
+ v[:2] = [12, 13]
176
+ test.assertEqual(v[0], 12)
177
+ test.assertEqual(v[1], 13)
178
+ test.assertEqual(v[2], 11)
179
+
180
+ v[::2] = [14, 15]
181
+ test.assertEqual(v[0], 14)
182
+ test.assertEqual(v[1], 13)
183
+ test.assertEqual(v[2], 15)
184
+
185
+
186
+ def test_py_arithmetic_ops(test, device, dtype):
187
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
188
+
189
+ def make_vec(*args):
190
+ if wptype in wp.types.int_types:
191
+ # Cast to the correct integer type to simulate wrapping.
192
+ return tuple(wptype._type_(x).value for x in args)
193
+
194
+ return args
195
+
196
+ vec_cls = wp.vec(3, wptype)
197
+
198
+ v = vec_cls(1, -2, 3)
199
+ test.assertSequenceEqual(+v, make_vec(1, -2, 3))
200
+ test.assertSequenceEqual(-v, make_vec(-1, 2, -3))
201
+ test.assertSequenceEqual(v + vec_cls(5, 5, 5), make_vec(6, 3, 8))
202
+ test.assertSequenceEqual(v - vec_cls(5, 5, 5), make_vec(-4, -7, -2))
203
+ test.assertSequenceEqual(v % vec_cls(2, 2, 2), make_vec(1, 0, 1))
204
+
205
+ v = vec_cls(2, 4, 6)
206
+ test.assertSequenceEqual(v * wptype(2), make_vec(4, 8, 12))
207
+ test.assertSequenceEqual(wptype(2) * v, make_vec(4, 8, 12))
208
+ test.assertSequenceEqual(v / wptype(2), make_vec(1, 2, 3))
209
+ test.assertSequenceEqual(wptype(24) / v, make_vec(12, 6, 4))
210
+ test.assertSequenceEqual(v % vec_cls(3, 3, 3), make_vec(2, 1, 0))
211
+
212
+
213
+ def test_constructors(test, device, dtype, register_kernels=False):
214
+ rng = np.random.default_rng(123)
215
+
216
+ tol = {
217
+ np.float16: 5.0e-3,
218
+ np.float32: 1.0e-6,
219
+ np.float64: 1.0e-8,
220
+ }.get(dtype, 0)
221
+
222
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
223
+ vec2 = wp.types.vector(length=2, dtype=wptype)
224
+ vec3 = wp.types.vector(length=3, dtype=wptype)
225
+ vec4 = wp.types.vector(length=4, dtype=wptype)
226
+ vec5 = wp.types.vector(length=5, dtype=wptype)
227
+
228
+ def check_scalar_constructor(
229
+ input: wp.array(dtype=wptype),
230
+ v2: wp.array(dtype=vec2),
231
+ v3: wp.array(dtype=vec3),
232
+ v4: wp.array(dtype=vec4),
233
+ v5: wp.array(dtype=vec5),
234
+ v20: wp.array(dtype=wptype),
235
+ v21: wp.array(dtype=wptype),
236
+ v30: wp.array(dtype=wptype),
237
+ v31: wp.array(dtype=wptype),
238
+ v32: wp.array(dtype=wptype),
239
+ v40: wp.array(dtype=wptype),
240
+ v41: wp.array(dtype=wptype),
241
+ v42: wp.array(dtype=wptype),
242
+ v43: wp.array(dtype=wptype),
243
+ v50: wp.array(dtype=wptype),
244
+ v51: wp.array(dtype=wptype),
245
+ v52: wp.array(dtype=wptype),
246
+ v53: wp.array(dtype=wptype),
247
+ v54: wp.array(dtype=wptype),
248
+ ):
249
+ v2result = vec2(input[0])
250
+ v3result = vec3(input[0])
251
+ v4result = vec4(input[0])
252
+ v5result = vec5(input[0])
253
+
254
+ v2[0] = v2result
255
+ v3[0] = v3result
256
+ v4[0] = v4result
257
+ v5[0] = v5result
258
+
259
+ # multiply outputs by 2 so we've got something to backpropagate
260
+ v20[0] = wptype(2) * v2result[0]
261
+ v21[0] = wptype(2) * v2result[1]
262
+
263
+ v30[0] = wptype(2) * v3result[0]
264
+ v31[0] = wptype(2) * v3result[1]
265
+ v32[0] = wptype(2) * v3result[2]
266
+
267
+ v40[0] = wptype(2) * v4result[0]
268
+ v41[0] = wptype(2) * v4result[1]
269
+ v42[0] = wptype(2) * v4result[2]
270
+ v43[0] = wptype(2) * v4result[3]
271
+
272
+ v50[0] = wptype(2) * v5result[0]
273
+ v51[0] = wptype(2) * v5result[1]
274
+ v52[0] = wptype(2) * v5result[2]
275
+ v53[0] = wptype(2) * v5result[3]
276
+ v54[0] = wptype(2) * v5result[4]
277
+
278
+ def check_vector_constructors(
279
+ input: wp.array(dtype=wptype),
280
+ v2: wp.array(dtype=vec2),
281
+ v3: wp.array(dtype=vec3),
282
+ v4: wp.array(dtype=vec4),
283
+ v5: wp.array(dtype=vec5),
284
+ v20: wp.array(dtype=wptype),
285
+ v21: wp.array(dtype=wptype),
286
+ v30: wp.array(dtype=wptype),
287
+ v31: wp.array(dtype=wptype),
288
+ v32: wp.array(dtype=wptype),
289
+ v40: wp.array(dtype=wptype),
290
+ v41: wp.array(dtype=wptype),
291
+ v42: wp.array(dtype=wptype),
292
+ v43: wp.array(dtype=wptype),
293
+ v50: wp.array(dtype=wptype),
294
+ v51: wp.array(dtype=wptype),
295
+ v52: wp.array(dtype=wptype),
296
+ v53: wp.array(dtype=wptype),
297
+ v54: wp.array(dtype=wptype),
298
+ ):
299
+ v2result = vec2(input[0], input[1])
300
+ v3result = vec3(input[2], input[3], input[4])
301
+ v4result = vec4(input[5], input[6], input[7], input[8])
302
+ v5result = vec5(input[9], input[10], input[11], input[12], input[13])
303
+
304
+ v2[0] = v2result
305
+ v3[0] = v3result
306
+ v4[0] = v4result
307
+ v5[0] = v5result
308
+
309
+ # multiply the output by 2 so we've got something to backpropagate:
310
+ v20[0] = wptype(2) * v2result[0]
311
+ v21[0] = wptype(2) * v2result[1]
312
+
313
+ v30[0] = wptype(2) * v3result[0]
314
+ v31[0] = wptype(2) * v3result[1]
315
+ v32[0] = wptype(2) * v3result[2]
316
+
317
+ v40[0] = wptype(2) * v4result[0]
318
+ v41[0] = wptype(2) * v4result[1]
319
+ v42[0] = wptype(2) * v4result[2]
320
+ v43[0] = wptype(2) * v4result[3]
321
+
322
+ v50[0] = wptype(2) * v5result[0]
323
+ v51[0] = wptype(2) * v5result[1]
324
+ v52[0] = wptype(2) * v5result[2]
325
+ v53[0] = wptype(2) * v5result[3]
326
+ v54[0] = wptype(2) * v5result[4]
327
+
328
+ vec_kernel = getkernel(check_vector_constructors, suffix=dtype.__name__)
329
+ kernel = getkernel(check_scalar_constructor, suffix=dtype.__name__)
330
+
331
+ if register_kernels:
332
+ return
333
+
334
+ input = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
335
+ v2 = wp.zeros(1, dtype=vec2, device=device)
336
+ v3 = wp.zeros(1, dtype=vec3, device=device)
337
+ v4 = wp.zeros(1, dtype=vec4, device=device)
338
+ v5 = wp.zeros(1, dtype=vec5, device=device)
339
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
340
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
341
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
342
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
343
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
344
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
345
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
346
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
347
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
348
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
349
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
350
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
351
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
352
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
353
+
354
+ tape = wp.Tape()
355
+ with tape:
356
+ wp.launch(
357
+ kernel,
358
+ dim=1,
359
+ inputs=[input],
360
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
361
+ device=device,
362
+ )
363
+
364
+ if dtype in np_float_types:
365
+ for l in [v20, v21]:
366
+ tape.backward(loss=l)
367
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
368
+ tape.zero()
369
+
370
+ for l in [v30, v31, v32]:
371
+ tape.backward(loss=l)
372
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
373
+ tape.zero()
374
+
375
+ for l in [v40, v41, v42, v43]:
376
+ tape.backward(loss=l)
377
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
378
+ tape.zero()
379
+
380
+ for l in [v50, v51, v52, v53, v54]:
381
+ tape.backward(loss=l)
382
+ test.assertEqual(tape.gradients[input].numpy()[0], 2.0)
383
+ tape.zero()
384
+
385
+ val = input.numpy()[0]
386
+ assert_np_equal(v2.numpy()[0], np.array([val, val]), tol=1.0e-6)
387
+ assert_np_equal(v3.numpy()[0], np.array([val, val, val]), tol=1.0e-6)
388
+ assert_np_equal(v4.numpy()[0], np.array([val, val, val, val]), tol=1.0e-6)
389
+ assert_np_equal(v5.numpy()[0], np.array([val, val, val, val, val]), tol=1.0e-6)
390
+
391
+ assert_np_equal(v20.numpy()[0], 2 * val, tol=1.0e-6)
392
+ assert_np_equal(v21.numpy()[0], 2 * val, tol=1.0e-6)
393
+ assert_np_equal(v30.numpy()[0], 2 * val, tol=1.0e-6)
394
+ assert_np_equal(v31.numpy()[0], 2 * val, tol=1.0e-6)
395
+ assert_np_equal(v32.numpy()[0], 2 * val, tol=1.0e-6)
396
+ assert_np_equal(v40.numpy()[0], 2 * val, tol=1.0e-6)
397
+ assert_np_equal(v41.numpy()[0], 2 * val, tol=1.0e-6)
398
+ assert_np_equal(v42.numpy()[0], 2 * val, tol=1.0e-6)
399
+ assert_np_equal(v43.numpy()[0], 2 * val, tol=1.0e-6)
400
+ assert_np_equal(v50.numpy()[0], 2 * val, tol=1.0e-6)
401
+ assert_np_equal(v51.numpy()[0], 2 * val, tol=1.0e-6)
402
+ assert_np_equal(v52.numpy()[0], 2 * val, tol=1.0e-6)
403
+ assert_np_equal(v53.numpy()[0], 2 * val, tol=1.0e-6)
404
+ assert_np_equal(v54.numpy()[0], 2 * val, tol=1.0e-6)
405
+
406
+ input = wp.array(randvals(rng, [14], dtype), requires_grad=True, device=device)
407
+ tape = wp.Tape()
408
+ with tape:
409
+ wp.launch(
410
+ vec_kernel,
411
+ dim=1,
412
+ inputs=[input],
413
+ outputs=[v2, v3, v4, v5, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
414
+ device=device,
415
+ )
416
+
417
+ if dtype in np_float_types:
418
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
419
+ tape.backward(loss=l)
420
+ grad = tape.gradients[input].numpy()
421
+ expected_grad = np.zeros_like(grad)
422
+ expected_grad[i] = 2
423
+ assert_np_equal(grad, expected_grad, tol=tol)
424
+ tape.zero()
425
+
426
+ assert_np_equal(v2.numpy()[0, 0], input.numpy()[0], tol=tol)
427
+ assert_np_equal(v2.numpy()[0, 1], input.numpy()[1], tol=tol)
428
+ assert_np_equal(v3.numpy()[0, 0], input.numpy()[2], tol=tol)
429
+ assert_np_equal(v3.numpy()[0, 1], input.numpy()[3], tol=tol)
430
+ assert_np_equal(v3.numpy()[0, 2], input.numpy()[4], tol=tol)
431
+ assert_np_equal(v4.numpy()[0, 0], input.numpy()[5], tol=tol)
432
+ assert_np_equal(v4.numpy()[0, 1], input.numpy()[6], tol=tol)
433
+ assert_np_equal(v4.numpy()[0, 2], input.numpy()[7], tol=tol)
434
+ assert_np_equal(v4.numpy()[0, 3], input.numpy()[8], tol=tol)
435
+ assert_np_equal(v5.numpy()[0, 0], input.numpy()[9], tol=tol)
436
+ assert_np_equal(v5.numpy()[0, 1], input.numpy()[10], tol=tol)
437
+ assert_np_equal(v5.numpy()[0, 2], input.numpy()[11], tol=tol)
438
+ assert_np_equal(v5.numpy()[0, 3], input.numpy()[12], tol=tol)
439
+ assert_np_equal(v5.numpy()[0, 4], input.numpy()[13], tol=tol)
440
+
441
+ assert_np_equal(v20.numpy()[0], 2 * input.numpy()[0], tol=tol)
442
+ assert_np_equal(v21.numpy()[0], 2 * input.numpy()[1], tol=tol)
443
+ assert_np_equal(v30.numpy()[0], 2 * input.numpy()[2], tol=tol)
444
+ assert_np_equal(v31.numpy()[0], 2 * input.numpy()[3], tol=tol)
445
+ assert_np_equal(v32.numpy()[0], 2 * input.numpy()[4], tol=tol)
446
+ assert_np_equal(v40.numpy()[0], 2 * input.numpy()[5], tol=tol)
447
+ assert_np_equal(v41.numpy()[0], 2 * input.numpy()[6], tol=tol)
448
+ assert_np_equal(v42.numpy()[0], 2 * input.numpy()[7], tol=tol)
449
+ assert_np_equal(v43.numpy()[0], 2 * input.numpy()[8], tol=tol)
450
+ assert_np_equal(v50.numpy()[0], 2 * input.numpy()[9], tol=tol)
451
+ assert_np_equal(v51.numpy()[0], 2 * input.numpy()[10], tol=tol)
452
+ assert_np_equal(v52.numpy()[0], 2 * input.numpy()[11], tol=tol)
453
+ assert_np_equal(v53.numpy()[0], 2 * input.numpy()[12], tol=tol)
454
+ assert_np_equal(v54.numpy()[0], 2 * input.numpy()[13], tol=tol)
455
+
456
+
457
+ def test_anon_type_instance(test, device, dtype, register_kernels=False):
458
+ rng = np.random.default_rng(123)
459
+
460
+ tol = {
461
+ np.float16: 5.0e-3,
462
+ np.float32: 1.0e-6,
463
+ np.float64: 1.0e-8,
464
+ }.get(dtype, 0)
465
+
466
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
467
+
468
+ def check_scalar_init(
469
+ input: wp.array(dtype=wptype),
470
+ output: wp.array(dtype=wptype),
471
+ ):
472
+ v2result = wp.vector(input[0], length=2)
473
+ v3result = wp.vector(input[1], length=3)
474
+ v4result = wp.vector(input[2], length=4)
475
+ v5result = wp.vector(input[3], length=5)
476
+
477
+ idx = 0
478
+ for i in range(2):
479
+ output[idx] = wptype(2) * v2result[i]
480
+ idx = idx + 1
481
+ for i in range(3):
482
+ output[idx] = wptype(2) * v3result[i]
483
+ idx = idx + 1
484
+ for i in range(4):
485
+ output[idx] = wptype(2) * v4result[i]
486
+ idx = idx + 1
487
+ for i in range(5):
488
+ output[idx] = wptype(2) * v5result[i]
489
+ idx = idx + 1
490
+
491
+ def check_component_init(
492
+ input: wp.array(dtype=wptype),
493
+ output: wp.array(dtype=wptype),
494
+ ):
495
+ v2result = wp.vector(input[0], input[1])
496
+ v3result = wp.vector(input[2], input[3], input[4])
497
+ v4result = wp.vector(input[5], input[6], input[7], input[8])
498
+ v5result = wp.vector(input[9], input[10], input[11], input[12], input[13])
499
+
500
+ idx = 0
501
+ for i in range(2):
502
+ output[idx] = wptype(2) * v2result[i]
503
+ idx = idx + 1
504
+ for i in range(3):
505
+ output[idx] = wptype(2) * v3result[i]
506
+ idx = idx + 1
507
+ for i in range(4):
508
+ output[idx] = wptype(2) * v4result[i]
509
+ idx = idx + 1
510
+ for i in range(5):
511
+ output[idx] = wptype(2) * v5result[i]
512
+ idx = idx + 1
513
+
514
+ scalar_kernel = getkernel(check_scalar_init, suffix=dtype.__name__)
515
+ component_kernel = getkernel(check_component_init, suffix=dtype.__name__)
516
+ output_select_kernel = get_select_kernel(wptype)
517
+
518
+ if register_kernels:
519
+ return
520
+
521
+ input = wp.array(randvals(rng, [4], dtype), requires_grad=True, device=device)
522
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
523
+
524
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
525
+
526
+ assert_np_equal(output.numpy()[:2], 2 * np.array([input.numpy()[0]] * 2), tol=1.0e-6)
527
+ assert_np_equal(output.numpy()[2:5], 2 * np.array([input.numpy()[1]] * 3), tol=1.0e-6)
528
+ assert_np_equal(output.numpy()[5:9], 2 * np.array([input.numpy()[2]] * 4), tol=1.0e-6)
529
+ assert_np_equal(output.numpy()[9:], 2 * np.array([input.numpy()[3]] * 5), tol=1.0e-6)
530
+
531
+ if dtype in np_float_types:
532
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
+ for i in range(len(output)):
534
+ tape = wp.Tape()
535
+ with tape:
536
+ wp.launch(scalar_kernel, dim=1, inputs=[input], outputs=[output], device=device)
537
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
538
+
539
+ tape.backward(loss=out)
540
+ expected = np.zeros_like(input.numpy())
541
+ if i < 2:
542
+ expected[0] = 2
543
+ elif i < 5:
544
+ expected[1] = 2
545
+ elif i < 9:
546
+ expected[2] = 2
547
+ else:
548
+ expected[3] = 2
549
+
550
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
551
+
552
+ tape.reset()
553
+ tape.zero()
554
+
555
+ input = wp.array(randvals(rng, [2 + 3 + 4 + 5], dtype), requires_grad=True, device=device)
556
+ output = wp.zeros(2 + 3 + 4 + 5, dtype=wptype, requires_grad=True, device=device)
557
+
558
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
559
+
560
+ assert_np_equal(output.numpy(), 2 * input.numpy(), tol=1.0e-6)
561
+
562
+ if dtype in np_float_types:
563
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
564
+ for i in range(len(output)):
565
+ tape = wp.Tape()
566
+ with tape:
567
+ wp.launch(component_kernel, dim=1, inputs=[input], outputs=[output], device=device)
568
+ wp.launch(output_select_kernel, dim=1, inputs=[output, i], outputs=[out], device=device)
569
+
570
+ tape.backward(loss=out)
571
+ expected = np.zeros_like(input.numpy())
572
+ expected[i] = 2
573
+
574
+ assert_np_equal(tape.gradients[input].numpy(), expected, tol=tol)
575
+
576
+ tape.reset()
577
+ tape.zero()
578
+
579
+
580
+ def test_indexing(test, device, dtype, register_kernels=False):
581
+ rng = np.random.default_rng(123)
582
+
583
+ tol = {
584
+ np.float16: 5.0e-3,
585
+ np.float32: 1.0e-6,
586
+ np.float64: 1.0e-8,
587
+ }.get(dtype, 0)
588
+
589
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
590
+ vec2 = wp.types.vector(length=2, dtype=wptype)
591
+ vec3 = wp.types.vector(length=3, dtype=wptype)
592
+ vec4 = wp.types.vector(length=4, dtype=wptype)
593
+ vec5 = wp.types.vector(length=5, dtype=wptype)
594
+
595
+ def check_indexing(
596
+ v2: wp.array(dtype=vec2),
597
+ v3: wp.array(dtype=vec3),
598
+ v4: wp.array(dtype=vec4),
599
+ v5: wp.array(dtype=vec5),
600
+ v20: wp.array(dtype=wptype),
601
+ v21: wp.array(dtype=wptype),
602
+ v30: wp.array(dtype=wptype),
603
+ v31: wp.array(dtype=wptype),
604
+ v32: wp.array(dtype=wptype),
605
+ v40: wp.array(dtype=wptype),
606
+ v41: wp.array(dtype=wptype),
607
+ v42: wp.array(dtype=wptype),
608
+ v43: wp.array(dtype=wptype),
609
+ v50: wp.array(dtype=wptype),
610
+ v51: wp.array(dtype=wptype),
611
+ v52: wp.array(dtype=wptype),
612
+ v53: wp.array(dtype=wptype),
613
+ v54: wp.array(dtype=wptype),
614
+ ):
615
+ # multiply outputs by 2 so we've got something to backpropagate:
616
+ v20[0] = wptype(2) * v2[0][0]
617
+ v21[0] = wptype(2) * v2[0][1]
618
+
619
+ v30[0] = wptype(2) * v3[0][0]
620
+ v31[0] = wptype(2) * v3[0][1]
621
+ v32[0] = wptype(2) * v3[0][2]
622
+
623
+ v40[0] = wptype(2) * v4[0][0]
624
+ v41[0] = wptype(2) * v4[0][1]
625
+ v42[0] = wptype(2) * v4[0][2]
626
+ v43[0] = wptype(2) * v4[0][3]
627
+
628
+ v50[0] = wptype(2) * v5[0][0]
629
+ v51[0] = wptype(2) * v5[0][1]
630
+ v52[0] = wptype(2) * v5[0][2]
631
+ v53[0] = wptype(2) * v5[0][3]
632
+ v54[0] = wptype(2) * v5[0][4]
633
+
634
+ kernel = getkernel(check_indexing, suffix=dtype.__name__)
635
+
636
+ if register_kernels:
637
+ return
638
+
639
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
640
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
641
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
642
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
643
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
644
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
645
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
646
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
647
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
648
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
649
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
650
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
651
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
652
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
653
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
654
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
655
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
656
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
657
+
658
+ tape = wp.Tape()
659
+ with tape:
660
+ wp.launch(
661
+ kernel,
662
+ dim=1,
663
+ inputs=[v2, v3, v4, v5],
664
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
665
+ device=device,
666
+ )
667
+
668
+ if dtype in np_float_types:
669
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
670
+ tape.backward(loss=l)
671
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
672
+ expected_grads = np.zeros_like(allgrads)
673
+ expected_grads[i] = 2
674
+ assert_np_equal(allgrads, expected_grads, tol=tol)
675
+ tape.zero()
676
+
677
+ assert_np_equal(v20.numpy()[0], 2.0 * v2.numpy()[0, 0], tol=tol)
678
+ assert_np_equal(v21.numpy()[0], 2.0 * v2.numpy()[0, 1], tol=tol)
679
+ assert_np_equal(v30.numpy()[0], 2.0 * v3.numpy()[0, 0], tol=tol)
680
+ assert_np_equal(v31.numpy()[0], 2.0 * v3.numpy()[0, 1], tol=tol)
681
+ assert_np_equal(v32.numpy()[0], 2.0 * v3.numpy()[0, 2], tol=tol)
682
+ assert_np_equal(v40.numpy()[0], 2.0 * v4.numpy()[0, 0], tol=tol)
683
+ assert_np_equal(v41.numpy()[0], 2.0 * v4.numpy()[0, 1], tol=tol)
684
+ assert_np_equal(v42.numpy()[0], 2.0 * v4.numpy()[0, 2], tol=tol)
685
+ assert_np_equal(v43.numpy()[0], 2.0 * v4.numpy()[0, 3], tol=tol)
686
+ assert_np_equal(v50.numpy()[0], 2.0 * v5.numpy()[0, 0], tol=tol)
687
+ assert_np_equal(v51.numpy()[0], 2.0 * v5.numpy()[0, 1], tol=tol)
688
+ assert_np_equal(v52.numpy()[0], 2.0 * v5.numpy()[0, 2], tol=tol)
689
+ assert_np_equal(v53.numpy()[0], 2.0 * v5.numpy()[0, 3], tol=tol)
690
+ assert_np_equal(v54.numpy()[0], 2.0 * v5.numpy()[0, 4], tol=tol)
691
+
692
+
693
+ def test_equality(test, device, dtype, register_kernels=False):
694
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
695
+ vec2 = wp.types.vector(length=2, dtype=wptype)
696
+ vec3 = wp.types.vector(length=3, dtype=wptype)
697
+ vec4 = wp.types.vector(length=4, dtype=wptype)
698
+ vec5 = wp.types.vector(length=5, dtype=wptype)
699
+
700
+ def check_unsigned_equality(
701
+ v20: wp.array(dtype=vec2),
702
+ v21: wp.array(dtype=vec2),
703
+ v22: wp.array(dtype=vec2),
704
+ v30: wp.array(dtype=vec3),
705
+ v40: wp.array(dtype=vec4),
706
+ v50: wp.array(dtype=vec5),
707
+ ):
708
+ wp.expect_eq(v20[0], v20[0])
709
+ wp.expect_neq(v21[0], v20[0])
710
+ wp.expect_neq(v22[0], v20[0])
711
+ wp.expect_eq(v30[0], v30[0])
712
+ wp.expect_eq(v40[0], v40[0])
713
+ wp.expect_eq(v50[0], v50[0])
714
+
715
+ def check_signed_equality(
716
+ v30: wp.array(dtype=vec3),
717
+ v31: wp.array(dtype=vec3),
718
+ v32: wp.array(dtype=vec3),
719
+ v33: wp.array(dtype=vec3),
720
+ v40: wp.array(dtype=vec4),
721
+ v41: wp.array(dtype=vec4),
722
+ v42: wp.array(dtype=vec4),
723
+ v43: wp.array(dtype=vec4),
724
+ v44: wp.array(dtype=vec4),
725
+ v50: wp.array(dtype=vec5),
726
+ v51: wp.array(dtype=vec5),
727
+ v52: wp.array(dtype=vec5),
728
+ v53: wp.array(dtype=vec5),
729
+ v54: wp.array(dtype=vec5),
730
+ v55: wp.array(dtype=vec5),
731
+ ):
732
+ wp.expect_neq(v31[0], v30[0])
733
+ wp.expect_neq(v32[0], v30[0])
734
+ wp.expect_neq(v33[0], v30[0])
735
+ wp.expect_neq(v41[0], v40[0])
736
+ wp.expect_neq(v42[0], v40[0])
737
+ wp.expect_neq(v43[0], v40[0])
738
+ wp.expect_neq(v44[0], v40[0])
739
+ wp.expect_neq(v51[0], v50[0])
740
+ wp.expect_neq(v52[0], v50[0])
741
+ wp.expect_neq(v53[0], v50[0])
742
+ wp.expect_neq(v54[0], v50[0])
743
+ wp.expect_neq(v55[0], v50[0])
744
+
745
+ unsigned_kernel = getkernel(check_unsigned_equality, suffix=dtype.__name__)
746
+ signed_kernel = getkernel(check_signed_equality, suffix=dtype.__name__)
747
+
748
+ if register_kernels:
749
+ return
750
+
751
+ v20 = wp.array([1.0, 2.0], dtype=vec2, requires_grad=True, device=device)
752
+ v21 = wp.array([1.0, 3.0], dtype=vec2, requires_grad=True, device=device)
753
+ v22 = wp.array([3.0, 2.0], dtype=vec2, requires_grad=True, device=device)
754
+
755
+ v30 = wp.array([1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
756
+ v40 = wp.array([1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
757
+ v50 = wp.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
758
+
759
+ wp.launch(
760
+ unsigned_kernel,
761
+ dim=1,
762
+ inputs=[
763
+ v20,
764
+ v21,
765
+ v22,
766
+ v30,
767
+ v40,
768
+ v50,
769
+ ],
770
+ outputs=[],
771
+ device=device,
772
+ )
773
+
774
+ if dtype not in np_unsigned_int_types:
775
+ v31 = wp.array([-1.0, 2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
776
+ v32 = wp.array([1.0, -2.0, 3.0], dtype=vec3, requires_grad=True, device=device)
777
+ v33 = wp.array([1.0, 2.0, -3.0], dtype=vec3, requires_grad=True, device=device)
778
+
779
+ v41 = wp.array([-1.0, 2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
780
+ v42 = wp.array([1.0, -2.0, 3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
781
+ v43 = wp.array([1.0, 2.0, -3.0, 4.0], dtype=vec4, requires_grad=True, device=device)
782
+ v44 = wp.array([1.0, 2.0, 3.0, -4.0], dtype=vec4, requires_grad=True, device=device)
783
+
784
+ v51 = wp.array([-1.0, 2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
785
+ v52 = wp.array([1.0, -2.0, 3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
786
+ v53 = wp.array([1.0, 2.0, -3.0, 4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
787
+ v54 = wp.array([1.0, 2.0, 3.0, -4.0, 5.0], dtype=vec5, requires_grad=True, device=device)
788
+ v55 = wp.array([1.0, 2.0, 3.0, 4.0, -5.0], dtype=vec5, requires_grad=True, device=device)
789
+
790
+ wp.launch(
791
+ signed_kernel,
792
+ dim=1,
793
+ inputs=[
794
+ v30,
795
+ v31,
796
+ v32,
797
+ v33,
798
+ v40,
799
+ v41,
800
+ v42,
801
+ v43,
802
+ v44,
803
+ v50,
804
+ v51,
805
+ v52,
806
+ v53,
807
+ v54,
808
+ v55,
809
+ ],
810
+ outputs=[],
811
+ device=device,
812
+ )
813
+
814
+
815
+ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
816
+ rng = np.random.default_rng(123)
817
+
818
+ tol = {
819
+ np.float16: 5.0e-3,
820
+ np.float32: 1.0e-6,
821
+ np.float64: 1.0e-8,
822
+ }.get(dtype, 0)
823
+
824
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
825
+ vec2 = wp.types.vector(length=2, dtype=wptype)
826
+ vec3 = wp.types.vector(length=3, dtype=wptype)
827
+ vec4 = wp.types.vector(length=4, dtype=wptype)
828
+ vec5 = wp.types.vector(length=5, dtype=wptype)
829
+
830
+ def check_mul(
831
+ s: wp.array(dtype=wptype),
832
+ v2: wp.array(dtype=vec2),
833
+ v3: wp.array(dtype=vec3),
834
+ v4: wp.array(dtype=vec4),
835
+ v5: wp.array(dtype=vec5),
836
+ v20: wp.array(dtype=wptype),
837
+ v21: wp.array(dtype=wptype),
838
+ v30: wp.array(dtype=wptype),
839
+ v31: wp.array(dtype=wptype),
840
+ v32: wp.array(dtype=wptype),
841
+ v40: wp.array(dtype=wptype),
842
+ v41: wp.array(dtype=wptype),
843
+ v42: wp.array(dtype=wptype),
844
+ v43: wp.array(dtype=wptype),
845
+ v50: wp.array(dtype=wptype),
846
+ v51: wp.array(dtype=wptype),
847
+ v52: wp.array(dtype=wptype),
848
+ v53: wp.array(dtype=wptype),
849
+ v54: wp.array(dtype=wptype),
850
+ ):
851
+ v2result = s[0] * v2[0]
852
+ v3result = s[0] * v3[0]
853
+ v4result = s[0] * v4[0]
854
+ v5result = s[0] * v5[0]
855
+
856
+ # multiply outputs by 2 so we've got something to backpropagate:
857
+ v20[0] = wptype(2) * v2result[0]
858
+ v21[0] = wptype(2) * v2result[1]
859
+
860
+ v30[0] = wptype(2) * v3result[0]
861
+ v31[0] = wptype(2) * v3result[1]
862
+ v32[0] = wptype(2) * v3result[2]
863
+
864
+ v40[0] = wptype(2) * v4result[0]
865
+ v41[0] = wptype(2) * v4result[1]
866
+ v42[0] = wptype(2) * v4result[2]
867
+ v43[0] = wptype(2) * v4result[3]
868
+
869
+ v50[0] = wptype(2) * v5result[0]
870
+ v51[0] = wptype(2) * v5result[1]
871
+ v52[0] = wptype(2) * v5result[2]
872
+ v53[0] = wptype(2) * v5result[3]
873
+ v54[0] = wptype(2) * v5result[4]
874
+
875
+ kernel = getkernel(check_mul, suffix=dtype.__name__)
876
+
877
+ if register_kernels:
878
+ return
879
+
880
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
881
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
882
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
883
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
884
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
885
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
886
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
887
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
888
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
889
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
890
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
891
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
892
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
893
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
894
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
895
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
896
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
897
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
898
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
899
+ tape = wp.Tape()
900
+ with tape:
901
+ wp.launch(
902
+ kernel,
903
+ dim=1,
904
+ inputs=[
905
+ s,
906
+ v2,
907
+ v3,
908
+ v4,
909
+ v5,
910
+ ],
911
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
912
+ device=device,
913
+ )
914
+
915
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
916
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
917
+
918
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
919
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
920
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
921
+
922
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
923
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
924
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
925
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
926
+
927
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
928
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
929
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
930
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
931
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
932
+
933
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
934
+
935
+ if dtype in np_float_types:
936
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
937
+ tape.backward(loss=l)
938
+ sgrad = tape.gradients[s].numpy()[0]
939
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
940
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
941
+ expected_grads = np.zeros_like(allgrads)
942
+ expected_grads[i] = s.numpy()[0] * 2
943
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
944
+ tape.zero()
945
+
946
+
947
+ def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=False):
948
+ rng = np.random.default_rng(123)
949
+
950
+ tol = {
951
+ np.float16: 5.0e-3,
952
+ np.float32: 1.0e-6,
953
+ np.float64: 1.0e-8,
954
+ }.get(dtype, 0)
955
+
956
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
957
+ vec2 = wp.types.vector(length=2, dtype=wptype)
958
+ vec3 = wp.types.vector(length=3, dtype=wptype)
959
+ vec4 = wp.types.vector(length=4, dtype=wptype)
960
+ vec5 = wp.types.vector(length=5, dtype=wptype)
961
+
962
+ def check_rightmul(
963
+ s: wp.array(dtype=wptype),
964
+ v2: wp.array(dtype=vec2),
965
+ v3: wp.array(dtype=vec3),
966
+ v4: wp.array(dtype=vec4),
967
+ v5: wp.array(dtype=vec5),
968
+ v20: wp.array(dtype=wptype),
969
+ v21: wp.array(dtype=wptype),
970
+ v30: wp.array(dtype=wptype),
971
+ v31: wp.array(dtype=wptype),
972
+ v32: wp.array(dtype=wptype),
973
+ v40: wp.array(dtype=wptype),
974
+ v41: wp.array(dtype=wptype),
975
+ v42: wp.array(dtype=wptype),
976
+ v43: wp.array(dtype=wptype),
977
+ v50: wp.array(dtype=wptype),
978
+ v51: wp.array(dtype=wptype),
979
+ v52: wp.array(dtype=wptype),
980
+ v53: wp.array(dtype=wptype),
981
+ v54: wp.array(dtype=wptype),
982
+ ):
983
+ v2result = v2[0] * s[0]
984
+ v3result = v3[0] * s[0]
985
+ v4result = v4[0] * s[0]
986
+ v5result = v5[0] * s[0]
987
+
988
+ # multiply outputs by 2 so we've got something to backpropagate:
989
+ v20[0] = wptype(2) * v2result[0]
990
+ v21[0] = wptype(2) * v2result[1]
991
+
992
+ v30[0] = wptype(2) * v3result[0]
993
+ v31[0] = wptype(2) * v3result[1]
994
+ v32[0] = wptype(2) * v3result[2]
995
+
996
+ v40[0] = wptype(2) * v4result[0]
997
+ v41[0] = wptype(2) * v4result[1]
998
+ v42[0] = wptype(2) * v4result[2]
999
+ v43[0] = wptype(2) * v4result[3]
1000
+
1001
+ v50[0] = wptype(2) * v5result[0]
1002
+ v51[0] = wptype(2) * v5result[1]
1003
+ v52[0] = wptype(2) * v5result[2]
1004
+ v53[0] = wptype(2) * v5result[3]
1005
+ v54[0] = wptype(2) * v5result[4]
1006
+
1007
+ kernel = getkernel(check_rightmul, suffix=dtype.__name__)
1008
+
1009
+ if register_kernels:
1010
+ return
1011
+
1012
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1013
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1014
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1015
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1016
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1017
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1018
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1019
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1020
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1021
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1022
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1023
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1024
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1025
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1026
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1027
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1028
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1029
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1030
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1031
+ tape = wp.Tape()
1032
+ with tape:
1033
+ wp.launch(
1034
+ kernel,
1035
+ dim=1,
1036
+ inputs=[
1037
+ s,
1038
+ v2,
1039
+ v3,
1040
+ v4,
1041
+ v5,
1042
+ ],
1043
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1044
+ device=device,
1045
+ )
1046
+
1047
+ assert_np_equal(v20.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 0], tol=tol)
1048
+ assert_np_equal(v21.numpy()[0], 2 * s.numpy()[0] * v2.numpy()[0, 1], tol=tol)
1049
+
1050
+ assert_np_equal(v30.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 0], tol=10 * tol)
1051
+ assert_np_equal(v31.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 1], tol=10 * tol)
1052
+ assert_np_equal(v32.numpy()[0], 2 * s.numpy()[0] * v3.numpy()[0, 2], tol=10 * tol)
1053
+
1054
+ assert_np_equal(v40.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 0], tol=10 * tol)
1055
+ assert_np_equal(v41.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 1], tol=10 * tol)
1056
+ assert_np_equal(v42.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 2], tol=10 * tol)
1057
+ assert_np_equal(v43.numpy()[0], 2 * s.numpy()[0] * v4.numpy()[0, 3], tol=10 * tol)
1058
+
1059
+ assert_np_equal(v50.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 0], tol=10 * tol)
1060
+ assert_np_equal(v51.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 1], tol=10 * tol)
1061
+ assert_np_equal(v52.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 2], tol=10 * tol)
1062
+ assert_np_equal(v53.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 3], tol=10 * tol)
1063
+ assert_np_equal(v54.numpy()[0], 2 * s.numpy()[0] * v5.numpy()[0, 4], tol=10 * tol)
1064
+
1065
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1066
+
1067
+ if dtype in np_float_types:
1068
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43]):
1069
+ tape.backward(loss=l)
1070
+ sgrad = tape.gradients[s].numpy()[0]
1071
+ assert_np_equal(sgrad, 2 * incmps[i], tol=10 * tol)
1072
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4]])
1073
+ expected_grads = np.zeros_like(allgrads)
1074
+ expected_grads[i] = s.numpy()[0] * 2
1075
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1076
+ tape.zero()
1077
+
1078
+
1079
+ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1080
+ rng = np.random.default_rng(123)
1081
+
1082
+ tol = {
1083
+ np.float16: 5.0e-3,
1084
+ np.float32: 1.0e-6,
1085
+ np.float64: 1.0e-8,
1086
+ }.get(dtype, 0)
1087
+
1088
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1089
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1090
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1091
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1092
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1093
+
1094
+ def check_cw_mul(
1095
+ s2: wp.array(dtype=vec2),
1096
+ s3: wp.array(dtype=vec3),
1097
+ s4: wp.array(dtype=vec4),
1098
+ s5: wp.array(dtype=vec5),
1099
+ v2: wp.array(dtype=vec2),
1100
+ v3: wp.array(dtype=vec3),
1101
+ v4: wp.array(dtype=vec4),
1102
+ v5: wp.array(dtype=vec5),
1103
+ v20: wp.array(dtype=wptype),
1104
+ v21: wp.array(dtype=wptype),
1105
+ v30: wp.array(dtype=wptype),
1106
+ v31: wp.array(dtype=wptype),
1107
+ v32: wp.array(dtype=wptype),
1108
+ v40: wp.array(dtype=wptype),
1109
+ v41: wp.array(dtype=wptype),
1110
+ v42: wp.array(dtype=wptype),
1111
+ v43: wp.array(dtype=wptype),
1112
+ v50: wp.array(dtype=wptype),
1113
+ v51: wp.array(dtype=wptype),
1114
+ v52: wp.array(dtype=wptype),
1115
+ v53: wp.array(dtype=wptype),
1116
+ v54: wp.array(dtype=wptype),
1117
+ ):
1118
+ v2result = wp.cw_mul(s2[0], v2[0])
1119
+ v3result = wp.cw_mul(s3[0], v3[0])
1120
+ v4result = wp.cw_mul(s4[0], v4[0])
1121
+ v5result = wp.cw_mul(s5[0], v5[0])
1122
+
1123
+ v20[0] = wptype(2) * v2result[0]
1124
+ v21[0] = wptype(2) * v2result[1]
1125
+
1126
+ v30[0] = wptype(2) * v3result[0]
1127
+ v31[0] = wptype(2) * v3result[1]
1128
+ v32[0] = wptype(2) * v3result[2]
1129
+
1130
+ v40[0] = wptype(2) * v4result[0]
1131
+ v41[0] = wptype(2) * v4result[1]
1132
+ v42[0] = wptype(2) * v4result[2]
1133
+ v43[0] = wptype(2) * v4result[3]
1134
+
1135
+ v50[0] = wptype(2) * v5result[0]
1136
+ v51[0] = wptype(2) * v5result[1]
1137
+ v52[0] = wptype(2) * v5result[2]
1138
+ v53[0] = wptype(2) * v5result[3]
1139
+ v54[0] = wptype(2) * v5result[4]
1140
+
1141
+ kernel = getkernel(check_cw_mul, suffix=dtype.__name__)
1142
+
1143
+ if register_kernels:
1144
+ return
1145
+
1146
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1147
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1148
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1149
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1150
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1151
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1152
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1153
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1154
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1155
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1156
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1157
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1158
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1159
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1160
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1161
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1162
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1163
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1164
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1165
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1166
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1167
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1168
+ tape = wp.Tape()
1169
+ with tape:
1170
+ wp.launch(
1171
+ kernel,
1172
+ dim=1,
1173
+ inputs=[
1174
+ s2,
1175
+ s3,
1176
+ s4,
1177
+ s5,
1178
+ v2,
1179
+ v3,
1180
+ v4,
1181
+ v5,
1182
+ ],
1183
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1184
+ device=device,
1185
+ )
1186
+
1187
+ assert_np_equal(v20.numpy()[0], 2 * s2.numpy()[0, 0] * v2.numpy()[0, 0], tol=10 * tol)
1188
+ assert_np_equal(v21.numpy()[0], 2 * s2.numpy()[0, 1] * v2.numpy()[0, 1], tol=10 * tol)
1189
+
1190
+ assert_np_equal(v30.numpy()[0], 2 * s3.numpy()[0, 0] * v3.numpy()[0, 0], tol=10 * tol)
1191
+ assert_np_equal(v31.numpy()[0], 2 * s3.numpy()[0, 1] * v3.numpy()[0, 1], tol=10 * tol)
1192
+ assert_np_equal(v32.numpy()[0], 2 * s3.numpy()[0, 2] * v3.numpy()[0, 2], tol=10 * tol)
1193
+
1194
+ assert_np_equal(v40.numpy()[0], 2 * s4.numpy()[0, 0] * v4.numpy()[0, 0], tol=10 * tol)
1195
+ assert_np_equal(v41.numpy()[0], 2 * s4.numpy()[0, 1] * v4.numpy()[0, 1], tol=10 * tol)
1196
+ assert_np_equal(v42.numpy()[0], 2 * s4.numpy()[0, 2] * v4.numpy()[0, 2], tol=10 * tol)
1197
+ assert_np_equal(v43.numpy()[0], 2 * s4.numpy()[0, 3] * v4.numpy()[0, 3], tol=10 * tol)
1198
+
1199
+ assert_np_equal(v50.numpy()[0], 2 * s5.numpy()[0, 0] * v5.numpy()[0, 0], tol=10 * tol)
1200
+ assert_np_equal(v51.numpy()[0], 2 * s5.numpy()[0, 1] * v5.numpy()[0, 1], tol=10 * tol)
1201
+ assert_np_equal(v52.numpy()[0], 2 * s5.numpy()[0, 2] * v5.numpy()[0, 2], tol=10 * tol)
1202
+ assert_np_equal(v53.numpy()[0], 2 * s5.numpy()[0, 3] * v5.numpy()[0, 3], tol=10 * tol)
1203
+ assert_np_equal(v54.numpy()[0], 2 * s5.numpy()[0, 4] * v5.numpy()[0, 4], tol=10 * tol)
1204
+
1205
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1206
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1207
+
1208
+ if dtype in np_float_types:
1209
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1210
+ tape.backward(loss=l)
1211
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1212
+ expected_grads = np.zeros_like(sgrads)
1213
+ expected_grads[i] = incmps[i] * 2
1214
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1215
+
1216
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1217
+ expected_grads = np.zeros_like(allgrads)
1218
+ expected_grads[i] = scmps[i] * 2
1219
+ assert_np_equal(allgrads, expected_grads, tol=10 * tol)
1220
+
1221
+ tape.zero()
1222
+
1223
+
1224
+ def test_scalar_division(test, device, dtype, register_kernels=False):
1225
+ rng = np.random.default_rng(123)
1226
+
1227
+ tol = {
1228
+ np.float16: 5.0e-3,
1229
+ np.float32: 1.0e-6,
1230
+ np.float64: 1.0e-8,
1231
+ }.get(dtype, 0)
1232
+
1233
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1234
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1235
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1236
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1237
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1238
+
1239
+ def check_div(
1240
+ s: wp.array(dtype=wptype),
1241
+ v2: wp.array(dtype=vec2),
1242
+ v3: wp.array(dtype=vec3),
1243
+ v4: wp.array(dtype=vec4),
1244
+ v5: wp.array(dtype=vec5),
1245
+ v20: wp.array(dtype=wptype),
1246
+ v21: wp.array(dtype=wptype),
1247
+ v30: wp.array(dtype=wptype),
1248
+ v31: wp.array(dtype=wptype),
1249
+ v32: wp.array(dtype=wptype),
1250
+ v40: wp.array(dtype=wptype),
1251
+ v41: wp.array(dtype=wptype),
1252
+ v42: wp.array(dtype=wptype),
1253
+ v43: wp.array(dtype=wptype),
1254
+ v50: wp.array(dtype=wptype),
1255
+ v51: wp.array(dtype=wptype),
1256
+ v52: wp.array(dtype=wptype),
1257
+ v53: wp.array(dtype=wptype),
1258
+ v54: wp.array(dtype=wptype),
1259
+ ):
1260
+ v2result = v2[0] / s[0]
1261
+ v3result = v3[0] / s[0]
1262
+ v4result = v4[0] / s[0]
1263
+ v5result = v5[0] / s[0]
1264
+
1265
+ v20[0] = wptype(2) * v2result[0]
1266
+ v21[0] = wptype(2) * v2result[1]
1267
+
1268
+ v30[0] = wptype(2) * v3result[0]
1269
+ v31[0] = wptype(2) * v3result[1]
1270
+ v32[0] = wptype(2) * v3result[2]
1271
+
1272
+ v40[0] = wptype(2) * v4result[0]
1273
+ v41[0] = wptype(2) * v4result[1]
1274
+ v42[0] = wptype(2) * v4result[2]
1275
+ v43[0] = wptype(2) * v4result[3]
1276
+
1277
+ v50[0] = wptype(2) * v5result[0]
1278
+ v51[0] = wptype(2) * v5result[1]
1279
+ v52[0] = wptype(2) * v5result[2]
1280
+ v53[0] = wptype(2) * v5result[3]
1281
+ v54[0] = wptype(2) * v5result[4]
1282
+
1283
+ kernel = getkernel(check_div, suffix=dtype.__name__)
1284
+
1285
+ if register_kernels:
1286
+ return
1287
+
1288
+ s = wp.array(randvals(rng, [1], dtype), requires_grad=True, device=device)
1289
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1290
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1291
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1292
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1293
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1294
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1295
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1296
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1297
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1298
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1299
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1300
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1301
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1302
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1303
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1304
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1305
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1306
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1307
+ tape = wp.Tape()
1308
+ with tape:
1309
+ wp.launch(
1310
+ kernel,
1311
+ dim=1,
1312
+ inputs=[
1313
+ s,
1314
+ v2,
1315
+ v3,
1316
+ v4,
1317
+ v5,
1318
+ ],
1319
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1320
+ device=device,
1321
+ )
1322
+
1323
+ if dtype in np_int_types:
1324
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // (s.numpy()[0])), tol=tol)
1325
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // (s.numpy()[0])), tol=tol)
1326
+
1327
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1328
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1329
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1330
+
1331
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1332
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1333
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1334
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1335
+
1336
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // (s.numpy()[0])), tol=10 * tol)
1337
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // (s.numpy()[0])), tol=10 * tol)
1338
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // (s.numpy()[0])), tol=10 * tol)
1339
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // (s.numpy()[0])), tol=10 * tol)
1340
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // (s.numpy()[0])), tol=10 * tol)
1341
+
1342
+ else:
1343
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / (s.numpy()[0]), tol=tol)
1344
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / (s.numpy()[0]), tol=tol)
1345
+
1346
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1347
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1348
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1349
+
1350
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1351
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1352
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1353
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1354
+
1355
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / (s.numpy()[0]), tol=10 * tol)
1356
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / (s.numpy()[0]), tol=10 * tol)
1357
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / (s.numpy()[0]), tol=10 * tol)
1358
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / (s.numpy()[0]), tol=10 * tol)
1359
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / (s.numpy()[0]), tol=10 * tol)
1360
+
1361
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1362
+
1363
+ if dtype in np_float_types:
1364
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1365
+ tape.backward(loss=l)
1366
+ sgrad = tape.gradients[s].numpy()[0]
1367
+
1368
+ # d/ds v/s = -v/s^2
1369
+ assert_np_equal(sgrad, -2 * incmps[i] / (s.numpy()[0] * s.numpy()[0]), tol=10 * tol)
1370
+
1371
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1372
+ expected_grads = np.zeros_like(allgrads)
1373
+ expected_grads[i] = 2 / s.numpy()[0]
1374
+
1375
+ # d/dv v/s = 1/s
1376
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1377
+ tape.zero()
1378
+
1379
+
1380
+ def test_cw_division(test, device, dtype, register_kernels=False):
1381
+ rng = np.random.default_rng(123)
1382
+
1383
+ tol = {
1384
+ np.float16: 1.0e-2,
1385
+ np.float32: 1.0e-6,
1386
+ np.float64: 1.0e-8,
1387
+ }.get(dtype, 0)
1388
+
1389
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1390
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1391
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1392
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1393
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1394
+
1395
+ def check_cw_div(
1396
+ s2: wp.array(dtype=vec2),
1397
+ s3: wp.array(dtype=vec3),
1398
+ s4: wp.array(dtype=vec4),
1399
+ s5: wp.array(dtype=vec5),
1400
+ v2: wp.array(dtype=vec2),
1401
+ v3: wp.array(dtype=vec3),
1402
+ v4: wp.array(dtype=vec4),
1403
+ v5: wp.array(dtype=vec5),
1404
+ v20: wp.array(dtype=wptype),
1405
+ v21: wp.array(dtype=wptype),
1406
+ v30: wp.array(dtype=wptype),
1407
+ v31: wp.array(dtype=wptype),
1408
+ v32: wp.array(dtype=wptype),
1409
+ v40: wp.array(dtype=wptype),
1410
+ v41: wp.array(dtype=wptype),
1411
+ v42: wp.array(dtype=wptype),
1412
+ v43: wp.array(dtype=wptype),
1413
+ v50: wp.array(dtype=wptype),
1414
+ v51: wp.array(dtype=wptype),
1415
+ v52: wp.array(dtype=wptype),
1416
+ v53: wp.array(dtype=wptype),
1417
+ v54: wp.array(dtype=wptype),
1418
+ ):
1419
+ v2result = wp.cw_div(v2[0], s2[0])
1420
+ v3result = wp.cw_div(v3[0], s3[0])
1421
+ v4result = wp.cw_div(v4[0], s4[0])
1422
+ v5result = wp.cw_div(v5[0], s5[0])
1423
+
1424
+ v20[0] = wptype(2) * v2result[0]
1425
+ v21[0] = wptype(2) * v2result[1]
1426
+
1427
+ v30[0] = wptype(2) * v3result[0]
1428
+ v31[0] = wptype(2) * v3result[1]
1429
+ v32[0] = wptype(2) * v3result[2]
1430
+
1431
+ v40[0] = wptype(2) * v4result[0]
1432
+ v41[0] = wptype(2) * v4result[1]
1433
+ v42[0] = wptype(2) * v4result[2]
1434
+ v43[0] = wptype(2) * v4result[3]
1435
+
1436
+ v50[0] = wptype(2) * v5result[0]
1437
+ v51[0] = wptype(2) * v5result[1]
1438
+ v52[0] = wptype(2) * v5result[2]
1439
+ v53[0] = wptype(2) * v5result[3]
1440
+ v54[0] = wptype(2) * v5result[4]
1441
+
1442
+ kernel = getkernel(check_cw_div, suffix=dtype.__name__)
1443
+
1444
+ if register_kernels:
1445
+ return
1446
+
1447
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1448
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1449
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1450
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1451
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1452
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1453
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1454
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1455
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1456
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1457
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1458
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1459
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1460
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1461
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1462
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1463
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1464
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1465
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1466
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1467
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1468
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1469
+ tape = wp.Tape()
1470
+ with tape:
1471
+ wp.launch(
1472
+ kernel,
1473
+ dim=1,
1474
+ inputs=[
1475
+ s2,
1476
+ s3,
1477
+ s4,
1478
+ s5,
1479
+ v2,
1480
+ v3,
1481
+ v4,
1482
+ v5,
1483
+ ],
1484
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1485
+ device=device,
1486
+ )
1487
+
1488
+ if dtype in np_int_types:
1489
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] // s2.numpy()[0, 0]), tol=tol)
1490
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] // s2.numpy()[0, 1]), tol=tol)
1491
+
1492
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] // s3.numpy()[0, 0]), tol=tol)
1493
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] // s3.numpy()[0, 1]), tol=tol)
1494
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] // s3.numpy()[0, 2]), tol=tol)
1495
+
1496
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] // s4.numpy()[0, 0]), tol=tol)
1497
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] // s4.numpy()[0, 1]), tol=tol)
1498
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] // s4.numpy()[0, 2]), tol=tol)
1499
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] // s4.numpy()[0, 3]), tol=tol)
1500
+
1501
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] // s5.numpy()[0, 0]), tol=tol)
1502
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] // s5.numpy()[0, 1]), tol=tol)
1503
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] // s5.numpy()[0, 2]), tol=tol)
1504
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] // s5.numpy()[0, 3]), tol=tol)
1505
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] // s5.numpy()[0, 4]), tol=tol)
1506
+ else:
1507
+ assert_np_equal(v20.numpy()[0], 2 * v2.numpy()[0, 0] / s2.numpy()[0, 0], tol=tol)
1508
+ assert_np_equal(v21.numpy()[0], 2 * v2.numpy()[0, 1] / s2.numpy()[0, 1], tol=tol)
1509
+
1510
+ assert_np_equal(v30.numpy()[0], 2 * v3.numpy()[0, 0] / s3.numpy()[0, 0], tol=tol)
1511
+ assert_np_equal(v31.numpy()[0], 2 * v3.numpy()[0, 1] / s3.numpy()[0, 1], tol=tol)
1512
+ assert_np_equal(v32.numpy()[0], 2 * v3.numpy()[0, 2] / s3.numpy()[0, 2], tol=tol)
1513
+
1514
+ assert_np_equal(v40.numpy()[0], 2 * v4.numpy()[0, 0] / s4.numpy()[0, 0], tol=tol)
1515
+ assert_np_equal(v41.numpy()[0], 2 * v4.numpy()[0, 1] / s4.numpy()[0, 1], tol=tol)
1516
+ assert_np_equal(v42.numpy()[0], 2 * v4.numpy()[0, 2] / s4.numpy()[0, 2], tol=tol)
1517
+ assert_np_equal(v43.numpy()[0], 2 * v4.numpy()[0, 3] / s4.numpy()[0, 3], tol=tol)
1518
+
1519
+ assert_np_equal(v50.numpy()[0], 2 * v5.numpy()[0, 0] / s5.numpy()[0, 0], tol=tol)
1520
+ assert_np_equal(v51.numpy()[0], 2 * v5.numpy()[0, 1] / s5.numpy()[0, 1], tol=tol)
1521
+ assert_np_equal(v52.numpy()[0], 2 * v5.numpy()[0, 2] / s5.numpy()[0, 2], tol=tol)
1522
+ assert_np_equal(v53.numpy()[0], 2 * v5.numpy()[0, 3] / s5.numpy()[0, 3], tol=tol)
1523
+ assert_np_equal(v54.numpy()[0], 2 * v5.numpy()[0, 4] / s5.numpy()[0, 4], tol=tol)
1524
+
1525
+ if dtype in np_float_types:
1526
+ incmps = np.concatenate([v.numpy()[0] for v in [v2, v3, v4, v5]])
1527
+ scmps = np.concatenate([v.numpy()[0] for v in [s2, s3, s4, s5]])
1528
+
1529
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1530
+ tape.backward(loss=l)
1531
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1532
+ expected_grads = np.zeros_like(sgrads)
1533
+
1534
+ # d/ds v/s = -v/s^2
1535
+ expected_grads[i] = -incmps[i] * 2 / (scmps[i] * scmps[i])
1536
+ assert_np_equal(sgrads, expected_grads, tol=20 * tol)
1537
+
1538
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1539
+ expected_grads = np.zeros_like(allgrads)
1540
+
1541
+ # d/dv v/s = 1/s
1542
+ expected_grads[i] = 2 / scmps[i]
1543
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1544
+
1545
+ tape.zero()
1546
+
1547
+
1548
+ def test_addition(test, device, dtype, register_kernels=False):
1549
+ rng = np.random.default_rng(123)
1550
+
1551
+ tol = {
1552
+ np.float16: 5.0e-3,
1553
+ np.float32: 1.0e-6,
1554
+ np.float64: 1.0e-8,
1555
+ }.get(dtype, 0)
1556
+
1557
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1558
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1559
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1560
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1561
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1562
+
1563
+ def check_add(
1564
+ s2: wp.array(dtype=vec2),
1565
+ s3: wp.array(dtype=vec3),
1566
+ s4: wp.array(dtype=vec4),
1567
+ s5: wp.array(dtype=vec5),
1568
+ v2: wp.array(dtype=vec2),
1569
+ v3: wp.array(dtype=vec3),
1570
+ v4: wp.array(dtype=vec4),
1571
+ v5: wp.array(dtype=vec5),
1572
+ v20: wp.array(dtype=wptype),
1573
+ v21: wp.array(dtype=wptype),
1574
+ v30: wp.array(dtype=wptype),
1575
+ v31: wp.array(dtype=wptype),
1576
+ v32: wp.array(dtype=wptype),
1577
+ v40: wp.array(dtype=wptype),
1578
+ v41: wp.array(dtype=wptype),
1579
+ v42: wp.array(dtype=wptype),
1580
+ v43: wp.array(dtype=wptype),
1581
+ v50: wp.array(dtype=wptype),
1582
+ v51: wp.array(dtype=wptype),
1583
+ v52: wp.array(dtype=wptype),
1584
+ v53: wp.array(dtype=wptype),
1585
+ v54: wp.array(dtype=wptype),
1586
+ ):
1587
+ v2result = v2[0] + s2[0]
1588
+ v3result = v3[0] + s3[0]
1589
+ v4result = v4[0] + s4[0]
1590
+ v5result = v5[0] + s5[0]
1591
+
1592
+ v20[0] = wptype(2) * v2result[0]
1593
+ v21[0] = wptype(2) * v2result[1]
1594
+
1595
+ v30[0] = wptype(2) * v3result[0]
1596
+ v31[0] = wptype(2) * v3result[1]
1597
+ v32[0] = wptype(2) * v3result[2]
1598
+
1599
+ v40[0] = wptype(2) * v4result[0]
1600
+ v41[0] = wptype(2) * v4result[1]
1601
+ v42[0] = wptype(2) * v4result[2]
1602
+ v43[0] = wptype(2) * v4result[3]
1603
+
1604
+ v50[0] = wptype(2) * v5result[0]
1605
+ v51[0] = wptype(2) * v5result[1]
1606
+ v52[0] = wptype(2) * v5result[2]
1607
+ v53[0] = wptype(2) * v5result[3]
1608
+ v54[0] = wptype(2) * v5result[4]
1609
+
1610
+ kernel = getkernel(check_add, suffix=dtype.__name__)
1611
+
1612
+ if register_kernels:
1613
+ return
1614
+
1615
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1616
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1617
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1618
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1619
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1620
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1621
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1622
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1623
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1624
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1625
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1626
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1627
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1628
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1629
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1630
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1631
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1632
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1633
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1634
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1635
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1636
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1637
+ tape = wp.Tape()
1638
+ with tape:
1639
+ wp.launch(
1640
+ kernel,
1641
+ dim=1,
1642
+ inputs=[
1643
+ s2,
1644
+ s3,
1645
+ s4,
1646
+ s5,
1647
+ v2,
1648
+ v3,
1649
+ v4,
1650
+ v5,
1651
+ ],
1652
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
1653
+ device=device,
1654
+ )
1655
+
1656
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] + s2.numpy()[0, 0]), tol=tol)
1657
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] + s2.numpy()[0, 1]), tol=tol)
1658
+
1659
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] + s3.numpy()[0, 0]), tol=tol)
1660
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] + s3.numpy()[0, 1]), tol=tol)
1661
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] + s3.numpy()[0, 2]), tol=tol)
1662
+
1663
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] + s4.numpy()[0, 0]), tol=tol)
1664
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] + s4.numpy()[0, 1]), tol=tol)
1665
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] + s4.numpy()[0, 2]), tol=tol)
1666
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] + s4.numpy()[0, 3]), tol=tol)
1667
+
1668
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] + s5.numpy()[0, 0]), tol=tol)
1669
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] + s5.numpy()[0, 1]), tol=tol)
1670
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] + s5.numpy()[0, 2]), tol=tol)
1671
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] + s5.numpy()[0, 3]), tol=tol)
1672
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] + s5.numpy()[0, 4]), tol=2 * tol)
1673
+
1674
+ if dtype in np_float_types:
1675
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
1676
+ tape.backward(loss=l)
1677
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
1678
+ expected_grads = np.zeros_like(sgrads)
1679
+
1680
+ expected_grads[i] = 2
1681
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1682
+
1683
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
1684
+ assert_np_equal(allgrads, expected_grads, tol=tol)
1685
+
1686
+ tape.zero()
1687
+
1688
+
1689
+ def test_dotproduct(test, device, dtype, register_kernels=False):
1690
+ rng = np.random.default_rng(123)
1691
+
1692
+ tol = {
1693
+ np.float16: 1.0e-2,
1694
+ np.float32: 1.0e-6,
1695
+ np.float64: 1.0e-8,
1696
+ }.get(dtype, 0)
1697
+
1698
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1699
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1700
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1701
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1702
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1703
+
1704
+ def check_dot(
1705
+ s2: wp.array(dtype=vec2),
1706
+ s3: wp.array(dtype=vec3),
1707
+ s4: wp.array(dtype=vec4),
1708
+ s5: wp.array(dtype=vec5),
1709
+ v2: wp.array(dtype=vec2),
1710
+ v3: wp.array(dtype=vec3),
1711
+ v4: wp.array(dtype=vec4),
1712
+ v5: wp.array(dtype=vec5),
1713
+ dot2: wp.array(dtype=wptype),
1714
+ dot3: wp.array(dtype=wptype),
1715
+ dot4: wp.array(dtype=wptype),
1716
+ dot5: wp.array(dtype=wptype),
1717
+ ):
1718
+ dot2[0] = wptype(2) * wp.dot(v2[0], s2[0])
1719
+ dot3[0] = wptype(2) * wp.dot(v3[0], s3[0])
1720
+ dot4[0] = wptype(2) * wp.dot(v4[0], s4[0])
1721
+ dot5[0] = wptype(2) * wp.dot(v5[0], s5[0])
1722
+
1723
+ kernel = getkernel(check_dot, suffix=dtype.__name__)
1724
+
1725
+ if register_kernels:
1726
+ return
1727
+
1728
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1729
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1730
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1731
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1732
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1733
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1734
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1735
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1736
+ dot2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1737
+ dot3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1738
+ dot4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1739
+ dot5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1740
+ tape = wp.Tape()
1741
+ with tape:
1742
+ wp.launch(
1743
+ kernel,
1744
+ dim=1,
1745
+ inputs=[
1746
+ s2,
1747
+ s3,
1748
+ s4,
1749
+ s5,
1750
+ v2,
1751
+ v3,
1752
+ v4,
1753
+ v5,
1754
+ ],
1755
+ outputs=[dot2, dot3, dot4, dot5],
1756
+ device=device,
1757
+ )
1758
+
1759
+ assert_np_equal(dot2.numpy()[0], 2.0 * (v2.numpy() * s2.numpy()).sum(), tol=10 * tol)
1760
+ assert_np_equal(dot3.numpy()[0], 2.0 * (v3.numpy() * s3.numpy()).sum(), tol=10 * tol)
1761
+ assert_np_equal(dot4.numpy()[0], 2.0 * (v4.numpy() * s4.numpy()).sum(), tol=10 * tol)
1762
+ assert_np_equal(dot5.numpy()[0], 2.0 * (v5.numpy() * s5.numpy()).sum(), tol=10 * tol)
1763
+
1764
+ if dtype in np_float_types:
1765
+ tape.backward(loss=dot2)
1766
+ sgrads = tape.gradients[s2].numpy()[0]
1767
+ expected_grads = 2.0 * v2.numpy()[0]
1768
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1769
+
1770
+ vgrads = tape.gradients[v2].numpy()[0]
1771
+ expected_grads = 2.0 * s2.numpy()[0]
1772
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1773
+
1774
+ tape.zero()
1775
+
1776
+ tape.backward(loss=dot3)
1777
+ sgrads = tape.gradients[s3].numpy()[0]
1778
+ expected_grads = 2.0 * v3.numpy()[0]
1779
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1780
+
1781
+ vgrads = tape.gradients[v3].numpy()[0]
1782
+ expected_grads = 2.0 * s3.numpy()[0]
1783
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1784
+
1785
+ tape.zero()
1786
+
1787
+ tape.backward(loss=dot4)
1788
+ sgrads = tape.gradients[s4].numpy()[0]
1789
+ expected_grads = 2.0 * v4.numpy()[0]
1790
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1791
+
1792
+ vgrads = tape.gradients[v4].numpy()[0]
1793
+ expected_grads = 2.0 * s4.numpy()[0]
1794
+ assert_np_equal(vgrads, expected_grads, tol=tol)
1795
+
1796
+ tape.zero()
1797
+
1798
+ tape.backward(loss=dot5)
1799
+ sgrads = tape.gradients[s5].numpy()[0]
1800
+ expected_grads = 2.0 * v5.numpy()[0]
1801
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
1802
+
1803
+ vgrads = tape.gradients[v5].numpy()[0]
1804
+ expected_grads = 2.0 * s5.numpy()[0]
1805
+ assert_np_equal(vgrads, expected_grads, tol=10 * tol)
1806
+
1807
+ tape.zero()
1808
+
1809
+
1810
+ def test_modulo(test, device, dtype, register_kernels=False):
1811
+ rng = np.random.default_rng(123)
1812
+
1813
+ tol = {
1814
+ np.float16: 1.0e-2,
1815
+ np.float32: 1.0e-6,
1816
+ np.float64: 1.0e-8,
1817
+ }.get(dtype, 0)
1818
+
1819
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1820
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1821
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1822
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1823
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1824
+
1825
+ def check_mod(
1826
+ s2: wp.array(dtype=vec2),
1827
+ s3: wp.array(dtype=vec3),
1828
+ s4: wp.array(dtype=vec4),
1829
+ s5: wp.array(dtype=vec5),
1830
+ v2: wp.array(dtype=vec2),
1831
+ v3: wp.array(dtype=vec3),
1832
+ v4: wp.array(dtype=vec4),
1833
+ v5: wp.array(dtype=vec5),
1834
+ v20: wp.array(dtype=wptype),
1835
+ v21: wp.array(dtype=wptype),
1836
+ v30: wp.array(dtype=wptype),
1837
+ v31: wp.array(dtype=wptype),
1838
+ v32: wp.array(dtype=wptype),
1839
+ v40: wp.array(dtype=wptype),
1840
+ v41: wp.array(dtype=wptype),
1841
+ v42: wp.array(dtype=wptype),
1842
+ v43: wp.array(dtype=wptype),
1843
+ v50: wp.array(dtype=wptype),
1844
+ v51: wp.array(dtype=wptype),
1845
+ v52: wp.array(dtype=wptype),
1846
+ v53: wp.array(dtype=wptype),
1847
+ v54: wp.array(dtype=wptype),
1848
+ ):
1849
+ v20[0] = (wptype(2) * wp.mod(v2[0], s2[0]))[0]
1850
+ v21[0] = (wptype(2) * wp.mod(v2[0], s2[0]))[1]
1851
+
1852
+ v30[0] = (wptype(2) * wp.mod(v3[0], s3[0]))[0]
1853
+ v31[0] = (wptype(2) * wp.mod(v3[0], s3[0]))[1]
1854
+ v32[0] = (wptype(2) * wp.mod(v3[0], s3[0]))[2]
1855
+
1856
+ v40[0] = (wptype(2) * wp.mod(v4[0], s4[0]))[0]
1857
+ v41[0] = (wptype(2) * wp.mod(v4[0], s4[0]))[1]
1858
+ v42[0] = (wptype(2) * wp.mod(v4[0], s4[0]))[2]
1859
+ v43[0] = (wptype(2) * wp.mod(v4[0], s4[0]))[3]
1860
+
1861
+ v50[0] = (wptype(2) * wp.mod(v5[0], s5[0]))[0]
1862
+ v51[0] = (wptype(2) * wp.mod(v5[0], s5[0]))[1]
1863
+ v52[0] = (wptype(2) * wp.mod(v5[0], s5[0]))[2]
1864
+ v53[0] = (wptype(2) * wp.mod(v5[0], s5[0]))[3]
1865
+ v54[0] = (wptype(2) * wp.mod(v5[0], s5[0]))[4]
1866
+
1867
+ kernel = getkernel(check_mod, suffix=dtype.__name__)
1868
+
1869
+ if register_kernels:
1870
+ return
1871
+
1872
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1873
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1874
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1875
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1876
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
1877
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
1878
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
1879
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
1880
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1881
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1882
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1883
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1884
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1885
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1886
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1887
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1888
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1889
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1890
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1891
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1892
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1893
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1894
+ tape = wp.Tape()
1895
+ with tape:
1896
+ wp.launch(
1897
+ kernel,
1898
+ dim=1,
1899
+ inputs=[
1900
+ s2,
1901
+ s3,
1902
+ s4,
1903
+ s5,
1904
+ v2,
1905
+ v3,
1906
+ v4,
1907
+ v5,
1908
+ ],
1909
+ outputs=[
1910
+ v20,
1911
+ v21,
1912
+ v30,
1913
+ v31,
1914
+ v32,
1915
+ v40,
1916
+ v41,
1917
+ v42,
1918
+ v43,
1919
+ v50,
1920
+ v51,
1921
+ v52,
1922
+ v53,
1923
+ v54,
1924
+ ],
1925
+ device=device,
1926
+ )
1927
+
1928
+ assert_np_equal(v20.numpy()[0], 2.0 * np.fmod(v2.numpy(), s2.numpy())[0, 0], tol=10 * tol)
1929
+ assert_np_equal(v21.numpy()[0], 2.0 * np.fmod(v2.numpy(), s2.numpy())[0, 1], tol=10 * tol)
1930
+ assert_np_equal(v30.numpy()[0], 2.0 * np.fmod(v3.numpy(), s3.numpy())[0, 0], tol=10 * tol)
1931
+ assert_np_equal(v31.numpy()[0], 2.0 * np.fmod(v3.numpy(), s3.numpy())[0, 1], tol=10 * tol)
1932
+ assert_np_equal(v32.numpy()[0], 2.0 * np.fmod(v3.numpy(), s3.numpy())[0, 2], tol=10 * tol)
1933
+ assert_np_equal(v40.numpy()[0], 2.0 * np.fmod(v4.numpy(), s4.numpy())[0, 0], tol=10 * tol)
1934
+ assert_np_equal(v41.numpy()[0], 2.0 * np.fmod(v4.numpy(), s4.numpy())[0, 1], tol=10 * tol)
1935
+ assert_np_equal(v42.numpy()[0], 2.0 * np.fmod(v4.numpy(), s4.numpy())[0, 2], tol=10 * tol)
1936
+ assert_np_equal(v43.numpy()[0], 2.0 * np.fmod(v4.numpy(), s4.numpy())[0, 3], tol=10 * tol)
1937
+ assert_np_equal(v50.numpy()[0], 2.0 * np.fmod(v5.numpy(), s5.numpy())[0, 0], tol=10 * tol)
1938
+ assert_np_equal(v51.numpy()[0], 2.0 * np.fmod(v5.numpy(), s5.numpy())[0, 1], tol=10 * tol)
1939
+ assert_np_equal(v52.numpy()[0], 2.0 * np.fmod(v5.numpy(), s5.numpy())[0, 2], tol=10 * tol)
1940
+ assert_np_equal(v53.numpy()[0], 2.0 * np.fmod(v5.numpy(), s5.numpy())[0, 3], tol=10 * tol)
1941
+ assert_np_equal(v54.numpy()[0], 2.0 * np.fmod(v5.numpy(), s5.numpy())[0, 4], tol=10 * tol)
1942
+
1943
+
1944
+ def test_equivalent_types(test, device, dtype, register_kernels=False):
1945
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1946
+
1947
+ # vector types
1948
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1949
+ vec3 = wp.types.vector(length=3, dtype=wptype)
1950
+ vec4 = wp.types.vector(length=4, dtype=wptype)
1951
+ vec5 = wp.types.vector(length=5, dtype=wptype)
1952
+
1953
+ # vector types equivalent to the above
1954
+ vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1955
+ vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1956
+ vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1957
+ vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1958
+
1959
+ # declare kernel with original types
1960
+ def check_equivalence(
1961
+ v2: vec2,
1962
+ v3: vec3,
1963
+ v4: vec4,
1964
+ v5: vec5,
1965
+ ):
1966
+ wp.expect_eq(v2, vec2(wptype(1), wptype(2)))
1967
+ wp.expect_eq(v3, vec3(wptype(1), wptype(2), wptype(3)))
1968
+ wp.expect_eq(v4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
1969
+ wp.expect_eq(v5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1970
+
1971
+ wp.expect_eq(v2, vec2_equiv(wptype(1), wptype(2)))
1972
+ wp.expect_eq(v3, vec3_equiv(wptype(1), wptype(2), wptype(3)))
1973
+ wp.expect_eq(v4, vec4_equiv(wptype(1), wptype(2), wptype(3), wptype(4)))
1974
+ wp.expect_eq(v5, vec5_equiv(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
1975
+
1976
+ kernel = getkernel(check_equivalence, suffix=dtype.__name__)
1977
+
1978
+ if register_kernels:
1979
+ return
1980
+
1981
+ # call kernel with equivalent types
1982
+ v2 = vec2_equiv(1, 2)
1983
+ v3 = vec3_equiv(1, 2, 3)
1984
+ v4 = vec4_equiv(1, 2, 3, 4)
1985
+ v5 = vec5_equiv(1, 2, 3, 4, 5)
1986
+
1987
+ wp.launch(kernel, dim=1, inputs=[v2, v3, v4, v5], device=device)
1988
+
1989
+
1990
+ def test_conversions(test, device, dtype, register_kernels=False):
1991
+ def check_vectors_equal(
1992
+ v0: wp.vec3,
1993
+ v1: wp.vec3,
1994
+ v2: wp.vec3,
1995
+ v3: wp.vec3,
1996
+ ):
1997
+ wp.expect_eq(v1, v0)
1998
+ wp.expect_eq(v2, v0)
1999
+ wp.expect_eq(v3, v0)
2000
+
2001
+ kernel = getkernel(check_vectors_equal, suffix=dtype.__name__)
2002
+
2003
+ if register_kernels:
2004
+ return
2005
+
2006
+ v0 = wp.vec3(1, 2, 3)
2007
+
2008
+ # test explicit conversions - constructing vectors from different containers
2009
+ v1 = wp.vec3((1, 2, 3))
2010
+ v2 = wp.vec3([1, 2, 3])
2011
+ v3 = wp.vec3(np.array([1, 2, 3], dtype=dtype))
2012
+
2013
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
2014
+
2015
+ # test implicit conversions - passing different containers as vectors to wp.launch()
2016
+ v1 = (1, 2, 3)
2017
+ v2 = [1, 2, 3]
2018
+ v3 = np.array([1, 2, 3], dtype=dtype)
2019
+
2020
+ wp.launch(kernel, dim=1, inputs=[v0, v1, v2, v3], device=device)
2021
+
2022
+
2023
+ def test_constants(test, device, dtype, register_kernels=False):
2024
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2025
+ vec2 = wp.types.vector(length=2, dtype=wptype)
2026
+ vec3 = wp.types.vector(length=3, dtype=wptype)
2027
+ vec4 = wp.types.vector(length=4, dtype=wptype)
2028
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2029
+
2030
+ cv2 = wp.constant(vec2(1, 2))
2031
+ cv3 = wp.constant(vec3(1, 2, 3))
2032
+ cv4 = wp.constant(vec4(1, 2, 3, 4))
2033
+ cv5 = wp.constant(vec5(1, 2, 3, 4, 5))
2034
+
2035
+ def check_vector_constants():
2036
+ wp.expect_eq(cv2, vec2(wptype(1), wptype(2)))
2037
+ wp.expect_eq(cv3, vec3(wptype(1), wptype(2), wptype(3)))
2038
+ wp.expect_eq(cv4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
2039
+ wp.expect_eq(cv5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
2040
+
2041
+ kernel = getkernel(check_vector_constants, suffix=dtype.__name__)
2042
+
2043
+ if register_kernels:
2044
+ return
2045
+
2046
+ wp.launch(kernel, dim=1, inputs=[], device=device)
2047
+
2048
+
2049
+ def test_abs(test, device, dtype, register_kernels=False):
2050
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2051
+ vec2 = wp.types.vector(length=2, dtype=wptype)
2052
+ vec3 = wp.types.vector(length=3, dtype=wptype)
2053
+ vec4 = wp.types.vector(length=4, dtype=wptype)
2054
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2055
+
2056
+ def check_vector_abs():
2057
+ res2 = wp.abs(vec2(wptype(-1), wptype(2)))
2058
+ wp.expect_eq(res2, vec2(wptype(1), wptype(2)))
2059
+
2060
+ res3 = wp.abs(vec3(wptype(1), wptype(-2), wptype(3)))
2061
+ wp.expect_eq(res3, vec3(wptype(1), wptype(2), wptype(3)))
2062
+
2063
+ res4 = wp.abs(vec4(wptype(-1), wptype(2), wptype(3), wptype(-4)))
2064
+ wp.expect_eq(res4, vec4(wptype(1), wptype(2), wptype(3), wptype(4)))
2065
+
2066
+ res5 = wp.abs(vec5(wptype(-1), wptype(2), wptype(-3), wptype(4), wptype(-5)))
2067
+ wp.expect_eq(res5, vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(5)))
2068
+
2069
+ kernel = getkernel(check_vector_abs, suffix=dtype.__name__)
2070
+
2071
+ if register_kernels:
2072
+ return
2073
+
2074
+ wp.launch(kernel, dim=1, inputs=[], device=device)
2075
+
2076
+
2077
+ def test_sign(test, device, dtype, register_kernels=False):
2078
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2079
+ vec2 = wp.types.vector(length=2, dtype=wptype)
2080
+ vec3 = wp.types.vector(length=3, dtype=wptype)
2081
+ vec4 = wp.types.vector(length=4, dtype=wptype)
2082
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2083
+
2084
+ def check_vector_sign():
2085
+ res2 = wp.sign(vec2(wptype(-1), wptype(2)))
2086
+ wp.expect_eq(res2, vec2(wptype(-1), wptype(1)))
2087
+
2088
+ res3 = wp.sign(vec3(wptype(1), wptype(-2), wptype(3)))
2089
+ wp.expect_eq(res3, vec3(wptype(1), wptype(-1), wptype(1)))
2090
+
2091
+ res4 = wp.sign(vec4(wptype(-1), wptype(2), wptype(3), wptype(-4)))
2092
+ wp.expect_eq(res4, vec4(wptype(-1), wptype(1), wptype(1), wptype(-1)))
2093
+
2094
+ res5 = wp.sign(vec5(wptype(-1), wptype(2), wptype(-3), wptype(4), wptype(-5)))
2095
+ wp.expect_eq(res5, vec5(wptype(-1), wptype(1), wptype(-1), wptype(1), wptype(-1)))
2096
+
2097
+ kernel = getkernel(check_vector_sign, suffix=dtype.__name__)
2098
+
2099
+ if register_kernels:
2100
+ return
2101
+
2102
+ wp.launch(kernel, dim=1, inputs=[], device=device)
2103
+
2104
+
2105
+ def test_minmax(test, device, dtype, register_kernels=False):
2106
+ rng = np.random.default_rng(123)
2107
+
2108
+ # \TODO: not quite sure why, but the numbers are off for 16 bit float
2109
+ # on the cpu (but not cuda). This is probably just the sketchy float16
2110
+ # arithmetic I implemented to get all this stuff working, so
2111
+ # hopefully that can be fixed when we do that correctly.
2112
+ tol = {
2113
+ np.float16: 1.0e-2,
2114
+ }.get(dtype, 0)
2115
+
2116
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2117
+ vec2 = wp.types.vector(length=2, dtype=wptype)
2118
+ vec3 = wp.types.vector(length=3, dtype=wptype)
2119
+ vec4 = wp.types.vector(length=4, dtype=wptype)
2120
+ vec5 = wp.types.vector(length=5, dtype=wptype)
2121
+
2122
+ # \TODO: Also not quite sure why: this kernel compiles incredibly
2123
+ # slowly though...
2124
+ def check_vec_min_max(
2125
+ a: wp.array(dtype=wptype, ndim=2),
2126
+ b: wp.array(dtype=wptype, ndim=2),
2127
+ mins: wp.array(dtype=wptype, ndim=2),
2128
+ maxs: wp.array(dtype=wptype, ndim=2),
2129
+ ):
2130
+ for i in range(10):
2131
+ # multiplying by 2 so we've got something to backpropagate:
2132
+ a2read = vec2(a[i, 0], a[i, 1])
2133
+ b2read = vec2(b[i, 0], b[i, 1])
2134
+ c2 = wptype(2) * wp.min(a2read, b2read)
2135
+ d2 = wptype(2) * wp.max(a2read, b2read)
2136
+
2137
+ a3read = vec3(a[i, 2], a[i, 3], a[i, 4])
2138
+ b3read = vec3(b[i, 2], b[i, 3], b[i, 4])
2139
+ c3 = wptype(2) * wp.min(a3read, b3read)
2140
+ d3 = wptype(2) * wp.max(a3read, b3read)
2141
+
2142
+ a4read = vec4(a[i, 5], a[i, 6], a[i, 7], a[i, 8])
2143
+ b4read = vec4(b[i, 5], b[i, 6], b[i, 7], b[i, 8])
2144
+ c4 = wptype(2) * wp.min(a4read, b4read)
2145
+ d4 = wptype(2) * wp.max(a4read, b4read)
2146
+
2147
+ a5read = vec5(a[i, 9], a[i, 10], a[i, 11], a[i, 12], a[i, 13])
2148
+ b5read = vec5(b[i, 9], b[i, 10], b[i, 11], b[i, 12], b[i, 13])
2149
+ c5 = wptype(2) * wp.min(a5read, b5read)
2150
+ d5 = wptype(2) * wp.max(a5read, b5read)
2151
+
2152
+ mins[i, 0] = c2[0]
2153
+ mins[i, 1] = c2[1]
2154
+
2155
+ mins[i, 2] = c3[0]
2156
+ mins[i, 3] = c3[1]
2157
+ mins[i, 4] = c3[2]
2158
+
2159
+ mins[i, 5] = c4[0]
2160
+ mins[i, 6] = c4[1]
2161
+ mins[i, 7] = c4[2]
2162
+ mins[i, 8] = c4[3]
2163
+
2164
+ mins[i, 9] = c5[0]
2165
+ mins[i, 10] = c5[1]
2166
+ mins[i, 11] = c5[2]
2167
+ mins[i, 12] = c5[3]
2168
+ mins[i, 13] = c5[4]
2169
+
2170
+ maxs[i, 0] = d2[0]
2171
+ maxs[i, 1] = d2[1]
2172
+
2173
+ maxs[i, 2] = d3[0]
2174
+ maxs[i, 3] = d3[1]
2175
+ maxs[i, 4] = d3[2]
2176
+
2177
+ maxs[i, 5] = d4[0]
2178
+ maxs[i, 6] = d4[1]
2179
+ maxs[i, 7] = d4[2]
2180
+ maxs[i, 8] = d4[3]
2181
+
2182
+ maxs[i, 9] = d5[0]
2183
+ maxs[i, 10] = d5[1]
2184
+ maxs[i, 11] = d5[2]
2185
+ maxs[i, 12] = d5[3]
2186
+ maxs[i, 13] = d5[4]
2187
+
2188
+ kernel = getkernel(check_vec_min_max, suffix=dtype.__name__)
2189
+ output_select_kernel = get_select_kernel2(wptype)
2190
+
2191
+ if register_kernels:
2192
+ return
2193
+
2194
+ a = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
2195
+ b = wp.array(randvals(rng, (10, 14), dtype), dtype=wptype, requires_grad=True, device=device)
2196
+
2197
+ mins = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
2198
+ maxs = wp.zeros((10, 14), dtype=wptype, requires_grad=True, device=device)
2199
+
2200
+ tape = wp.Tape()
2201
+ with tape:
2202
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2203
+
2204
+ assert_np_equal(mins.numpy(), 2 * np.minimum(a.numpy(), b.numpy()), tol=tol)
2205
+ assert_np_equal(maxs.numpy(), 2 * np.maximum(a.numpy(), b.numpy()), tol=tol)
2206
+
2207
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
2208
+ if dtype in np_float_types:
2209
+ for i in range(10):
2210
+ for j in range(14):
2211
+ tape = wp.Tape()
2212
+ with tape:
2213
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2214
+ wp.launch(output_select_kernel, dim=1, inputs=[mins, i, j], outputs=[out], device=device)
2215
+
2216
+ tape.backward(loss=out)
2217
+ expected = np.zeros_like(a.numpy())
2218
+ expected[i, j] = 2 if (a.numpy()[i, j] < b.numpy()[i, j]) else 0
2219
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2220
+ expected[i, j] = 2 if (b.numpy()[i, j] < a.numpy()[i, j]) else 0
2221
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2222
+ tape.zero()
2223
+
2224
+ tape = wp.Tape()
2225
+ with tape:
2226
+ wp.launch(kernel, dim=1, inputs=[a, b], outputs=[mins, maxs], device=device)
2227
+ wp.launch(output_select_kernel, dim=1, inputs=[maxs, i, j], outputs=[out], device=device)
2228
+
2229
+ tape.backward(loss=out)
2230
+ expected = np.zeros_like(a.numpy())
2231
+ expected[i, j] = 2 if (a.numpy()[i, j] > b.numpy()[i, j]) else 0
2232
+ assert_np_equal(tape.gradients[a].numpy(), expected, tol=tol)
2233
+ expected[i, j] = 2 if (b.numpy()[i, j] > a.numpy()[i, j]) else 0
2234
+ assert_np_equal(tape.gradients[b].numpy(), expected, tol=tol)
2235
+ tape.zero()
2236
+
2237
+
2238
+ devices = get_test_devices()
2239
+
2240
+
2241
+ class TestVecScalarOps(unittest.TestCase):
2242
+ pass
2243
+
2244
+
2245
+ for dtype in np_scalar_types:
2246
+ add_function_test(TestVecScalarOps, f"test_arrays_{dtype.__name__}", test_arrays, devices=devices, dtype=dtype)
2247
+ add_function_test(TestVecScalarOps, f"test_components_{dtype.__name__}", test_components, devices=None, dtype=dtype)
2248
+ add_function_test(
2249
+ TestVecScalarOps, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2250
+ )
2251
+ add_function_test_register_kernel(
2252
+ TestVecScalarOps, f"test_constructors_{dtype.__name__}", test_constructors, devices=devices, dtype=dtype
2253
+ )
2254
+ add_function_test_register_kernel(
2255
+ TestVecScalarOps,
2256
+ f"test_anon_type_instance_{dtype.__name__}",
2257
+ test_anon_type_instance,
2258
+ devices=devices,
2259
+ dtype=dtype,
2260
+ )
2261
+ add_function_test_register_kernel(
2262
+ TestVecScalarOps, f"test_indexing_{dtype.__name__}", test_indexing, devices=devices, dtype=dtype
2263
+ )
2264
+ add_function_test_register_kernel(
2265
+ TestVecScalarOps, f"test_equality_{dtype.__name__}", test_equality, devices=devices, dtype=dtype
2266
+ )
2267
+ add_function_test_register_kernel(
2268
+ TestVecScalarOps,
2269
+ f"test_scalar_multiplication_{dtype.__name__}",
2270
+ test_scalar_multiplication,
2271
+ devices=devices,
2272
+ dtype=dtype,
2273
+ )
2274
+ add_function_test_register_kernel(
2275
+ TestVecScalarOps,
2276
+ f"test_scalar_multiplication_rightmul_{dtype.__name__}",
2277
+ test_scalar_multiplication_rightmul,
2278
+ devices=devices,
2279
+ dtype=dtype,
2280
+ )
2281
+ add_function_test_register_kernel(
2282
+ TestVecScalarOps,
2283
+ f"test_cw_multiplication_{dtype.__name__}",
2284
+ test_cw_multiplication,
2285
+ devices=devices,
2286
+ dtype=dtype,
2287
+ )
2288
+ add_function_test_register_kernel(
2289
+ TestVecScalarOps, f"test_scalar_division_{dtype.__name__}", test_scalar_division, devices=devices, dtype=dtype
2290
+ )
2291
+ add_function_test_register_kernel(
2292
+ TestVecScalarOps, f"test_cw_division_{dtype.__name__}", test_cw_division, devices=devices, dtype=dtype
2293
+ )
2294
+ add_function_test_register_kernel(
2295
+ TestVecScalarOps, f"test_addition_{dtype.__name__}", test_addition, devices=devices, dtype=dtype
2296
+ )
2297
+ add_function_test_register_kernel(
2298
+ TestVecScalarOps, f"test_modulo_{dtype.__name__}", test_modulo, devices=devices, dtype=dtype
2299
+ )
2300
+ add_function_test_register_kernel(
2301
+ TestVecScalarOps, f"test_dotproduct_{dtype.__name__}", test_dotproduct, devices=devices, dtype=dtype
2302
+ )
2303
+ add_function_test_register_kernel(
2304
+ TestVecScalarOps, f"test_equivalent_types_{dtype.__name__}", test_equivalent_types, devices=devices, dtype=dtype
2305
+ )
2306
+ add_function_test_register_kernel(
2307
+ TestVecScalarOps, f"test_conversions_{dtype.__name__}", test_conversions, devices=devices, dtype=dtype
2308
+ )
2309
+ add_function_test_register_kernel(
2310
+ TestVecScalarOps, f"test_constants_{dtype.__name__}", test_constants, devices=devices, dtype=dtype
2311
+ )
2312
+
2313
+ if dtype not in np_unsigned_int_types:
2314
+ add_function_test_register_kernel(
2315
+ TestVecScalarOps, f"test_abs_{dtype.__name__}", test_abs, devices=devices, dtype=dtype
2316
+ )
2317
+ add_function_test_register_kernel(
2318
+ TestVecScalarOps, f"test_sign_{dtype.__name__}", test_sign, devices=devices, dtype=dtype
2319
+ )
2320
+
2321
+ # the kernels in this test compile incredibly slowly...
2322
+ # add_function_test_register_kernel(TestVecScalarOps, f"test_minmax_{dtype.__name__}", test_minmax, devices=devices, dtype=dtype)
2323
+
2324
+
2325
+ if __name__ == "__main__":
2326
+ wp.clear_kernel_cache()
2327
+ unittest.main(verbosity=2, failfast=True)