warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/tests/test_vec.py ADDED
@@ -0,0 +1,1487 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ np_signed_int_types = [
25
+ np.int8,
26
+ np.int16,
27
+ np.int32,
28
+ np.int64,
29
+ np.byte,
30
+ ]
31
+
32
+ np_unsigned_int_types = [
33
+ np.uint8,
34
+ np.uint16,
35
+ np.uint32,
36
+ np.uint64,
37
+ np.ubyte,
38
+ ]
39
+
40
+ np_float_types = [np.float16, np.float32, np.float64]
41
+
42
+
43
+ def randvals(rng, shape, dtype):
44
+ if dtype in np_float_types:
45
+ return rng.standard_normal(size=shape).astype(dtype)
46
+ elif dtype in [np.int8, np.uint8, np.byte, np.ubyte]:
47
+ return rng.integers(1, high=3, size=shape, dtype=dtype)
48
+ return rng.integers(1, high=5, size=shape, dtype=dtype)
49
+
50
+
51
+ kernel_cache = {}
52
+
53
+
54
+ def getkernel(func, suffix=""):
55
+ key = func.__name__ + "_" + suffix
56
+ if key not in kernel_cache:
57
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
58
+ return kernel_cache[key]
59
+
60
+
61
+ def test_anon_constructor_error_length_mismatch(test, device):
62
+ @wp.kernel
63
+ def kernel():
64
+ wp.vector(
65
+ wp.vector(length=2, dtype=float),
66
+ length=3,
67
+ dtype=float,
68
+ )
69
+
70
+ with test.assertRaisesRegex(
71
+ RuntimeError,
72
+ r"incompatible vector of length 3 given when copy constructing a vector of length 2$",
73
+ ):
74
+ wp.launch(
75
+ kernel,
76
+ dim=1,
77
+ inputs=[],
78
+ device=device,
79
+ )
80
+
81
+
82
+ def test_anon_constructor_error_numeric_arg_missing(test, device):
83
+ @wp.kernel
84
+ def kernel():
85
+ wp.vector(1.0, 2.0, length=12345)
86
+
87
+ with test.assertRaisesRegex(
88
+ RuntimeError,
89
+ r"incompatible number of values given \(2\) when constructing a vector of length 12345$",
90
+ ):
91
+ wp.launch(
92
+ kernel,
93
+ dim=1,
94
+ inputs=[],
95
+ device=device,
96
+ )
97
+
98
+
99
+ def test_anon_constructor_error_length_arg_missing(test, device):
100
+ @wp.kernel
101
+ def kernel():
102
+ wp.vector()
103
+
104
+ with test.assertRaisesRegex(
105
+ RuntimeError,
106
+ r"the `length` argument must be specified when zero-initializing a vector$",
107
+ ):
108
+ wp.launch(
109
+ kernel,
110
+ dim=1,
111
+ inputs=[],
112
+ device=device,
113
+ )
114
+
115
+
116
+ def test_anon_constructor_error_numeric_args_mismatch(test, device):
117
+ @wp.kernel
118
+ def kernel():
119
+ wp.vector(1.0, 2)
120
+
121
+ with test.assertRaisesRegex(
122
+ RuntimeError,
123
+ r"all values given when constructing a vector must have the same type$",
124
+ ):
125
+ wp.launch(
126
+ kernel,
127
+ dim=1,
128
+ inputs=[],
129
+ device=device,
130
+ )
131
+
132
+
133
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
134
+ @wp.kernel
135
+ def kernel():
136
+ wp.vec3(wp.vec2(1.0, 2.0))
137
+
138
+ with test.assertRaisesRegex(
139
+ RuntimeError, "incompatible vector of length 3 given when copy constructing a vector of length 2"
140
+ ):
141
+ wp.launch(
142
+ kernel,
143
+ dim=1,
144
+ inputs=[],
145
+ device=device,
146
+ )
147
+
148
+
149
+ def test_tpl_constructor_error_numeric_args_mismatch(test, device):
150
+ @wp.kernel
151
+ def kernel():
152
+ wp.vec2(1.0, 2)
153
+
154
+ with test.assertRaisesRegex(
155
+ RuntimeError,
156
+ r"all values given when constructing a vector must have the same type$",
157
+ ):
158
+ wp.launch(
159
+ kernel,
160
+ dim=1,
161
+ inputs=[],
162
+ device=device,
163
+ )
164
+
165
+
166
+ def test_negation(test, device, dtype, register_kernels=False):
167
+ rng = np.random.default_rng(123)
168
+
169
+ tol = {
170
+ np.float16: 5.0e-3,
171
+ np.float32: 1.0e-6,
172
+ np.float64: 1.0e-8,
173
+ }.get(dtype, 0)
174
+
175
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
176
+ vec2 = wp.types.vector(length=2, dtype=wptype)
177
+ vec3 = wp.types.vector(length=3, dtype=wptype)
178
+ vec4 = wp.types.vector(length=4, dtype=wptype)
179
+ vec5 = wp.types.vector(length=5, dtype=wptype)
180
+
181
+ def check_negation(
182
+ v2: wp.array(dtype=vec2),
183
+ v3: wp.array(dtype=vec3),
184
+ v4: wp.array(dtype=vec4),
185
+ v5: wp.array(dtype=vec5),
186
+ v2out: wp.array(dtype=vec2),
187
+ v3out: wp.array(dtype=vec3),
188
+ v4out: wp.array(dtype=vec4),
189
+ v5out: wp.array(dtype=vec5),
190
+ v20: wp.array(dtype=wptype),
191
+ v21: wp.array(dtype=wptype),
192
+ v30: wp.array(dtype=wptype),
193
+ v31: wp.array(dtype=wptype),
194
+ v32: wp.array(dtype=wptype),
195
+ v40: wp.array(dtype=wptype),
196
+ v41: wp.array(dtype=wptype),
197
+ v42: wp.array(dtype=wptype),
198
+ v43: wp.array(dtype=wptype),
199
+ v50: wp.array(dtype=wptype),
200
+ v51: wp.array(dtype=wptype),
201
+ v52: wp.array(dtype=wptype),
202
+ v53: wp.array(dtype=wptype),
203
+ v54: wp.array(dtype=wptype),
204
+ ):
205
+ v2result = -v2[0]
206
+ v3result = -v3[0]
207
+ v4result = -v4[0]
208
+ v5result = -v5[0]
209
+
210
+ v2out[0] = v2result
211
+ v3out[0] = v3result
212
+ v4out[0] = v4result
213
+ v5out[0] = v5result
214
+
215
+ # multiply these outputs by 2 so we've got something to backpropagate:
216
+ v20[0] = wptype(2) * v2result[0]
217
+ v21[0] = wptype(2) * v2result[1]
218
+
219
+ v30[0] = wptype(2) * v3result[0]
220
+ v31[0] = wptype(2) * v3result[1]
221
+ v32[0] = wptype(2) * v3result[2]
222
+
223
+ v40[0] = wptype(2) * v4result[0]
224
+ v41[0] = wptype(2) * v4result[1]
225
+ v42[0] = wptype(2) * v4result[2]
226
+ v43[0] = wptype(2) * v4result[3]
227
+
228
+ v50[0] = wptype(2) * v5result[0]
229
+ v51[0] = wptype(2) * v5result[1]
230
+ v52[0] = wptype(2) * v5result[2]
231
+ v53[0] = wptype(2) * v5result[3]
232
+ v54[0] = wptype(2) * v5result[4]
233
+
234
+ kernel = getkernel(check_negation, suffix=dtype.__name__)
235
+
236
+ if register_kernels:
237
+ return
238
+
239
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
240
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
241
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
242
+ v5_np = randvals(rng, (1, 5), dtype)
243
+ v5 = wp.array(v5_np, dtype=vec5, requires_grad=True, device=device)
244
+
245
+ v2out = wp.zeros(1, dtype=vec2, device=device)
246
+ v3out = wp.zeros(1, dtype=vec3, device=device)
247
+ v4out = wp.zeros(1, dtype=vec4, device=device)
248
+ v5out = wp.zeros(1, dtype=vec5, device=device)
249
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
250
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
251
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
252
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
253
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
254
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
255
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
256
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
257
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
258
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
259
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
260
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
261
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
262
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
263
+
264
+ tape = wp.Tape()
265
+ with tape:
266
+ wp.launch(
267
+ kernel,
268
+ dim=1,
269
+ inputs=[v2, v3, v4, v5],
270
+ outputs=[v2out, v3out, v4out, v5out, v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
271
+ device=device,
272
+ )
273
+
274
+ if dtype in np_float_types:
275
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
276
+ tape.backward(loss=l)
277
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
278
+ expected_grads = np.zeros_like(allgrads)
279
+ expected_grads[i] = -2
280
+ assert_np_equal(allgrads, expected_grads, tol=tol)
281
+ tape.zero()
282
+
283
+ assert_np_equal(v2out.numpy()[0], -v2.numpy()[0], tol=tol)
284
+ assert_np_equal(v3out.numpy()[0], -v3.numpy()[0], tol=tol)
285
+ assert_np_equal(v4out.numpy()[0], -v4.numpy()[0], tol=tol)
286
+ assert_np_equal(v5out.numpy()[0], -v5.numpy()[0], tol=tol)
287
+
288
+
289
+ def test_subtraction_unsigned(test, device, dtype, register_kernels=False):
290
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
291
+ vec2 = wp.types.vector(length=2, dtype=wptype)
292
+ vec3 = wp.types.vector(length=3, dtype=wptype)
293
+ vec4 = wp.types.vector(length=4, dtype=wptype)
294
+ vec5 = wp.types.vector(length=5, dtype=wptype)
295
+
296
+ def check_subtraction_unsigned():
297
+ wp.expect_eq(vec2(wptype(3), wptype(4)) - vec2(wptype(1), wptype(2)), vec2(wptype(2), wptype(2)))
298
+ wp.expect_eq(
299
+ vec3(
300
+ wptype(3),
301
+ wptype(4),
302
+ wptype(4),
303
+ )
304
+ - vec3(wptype(1), wptype(2), wptype(3)),
305
+ vec3(wptype(2), wptype(2), wptype(1)),
306
+ )
307
+ wp.expect_eq(
308
+ vec4(
309
+ wptype(3),
310
+ wptype(4),
311
+ wptype(4),
312
+ wptype(5),
313
+ )
314
+ - vec4(wptype(1), wptype(2), wptype(3), wptype(4)),
315
+ vec4(wptype(2), wptype(2), wptype(1), wptype(1)),
316
+ )
317
+ wp.expect_eq(
318
+ vec5(
319
+ wptype(3),
320
+ wptype(4),
321
+ wptype(4),
322
+ wptype(5),
323
+ wptype(4),
324
+ )
325
+ - vec5(wptype(1), wptype(2), wptype(3), wptype(4), wptype(4)),
326
+ vec5(wptype(2), wptype(2), wptype(1), wptype(1), wptype(0)),
327
+ )
328
+
329
+ kernel = getkernel(check_subtraction_unsigned, suffix=dtype.__name__)
330
+
331
+ if register_kernels:
332
+ return
333
+
334
+ wp.launch(kernel, dim=1, inputs=[], outputs=[], device=device)
335
+
336
+
337
+ def test_subtraction(test, device, dtype, register_kernels=False):
338
+ rng = np.random.default_rng(123)
339
+
340
+ tol = {
341
+ np.float16: 5.0e-3,
342
+ np.float32: 1.0e-6,
343
+ np.float64: 1.0e-8,
344
+ }.get(dtype, 0)
345
+
346
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
347
+ vec2 = wp.types.vector(length=2, dtype=wptype)
348
+ vec3 = wp.types.vector(length=3, dtype=wptype)
349
+ vec4 = wp.types.vector(length=4, dtype=wptype)
350
+ vec5 = wp.types.vector(length=5, dtype=wptype)
351
+
352
+ def check_subtraction(
353
+ s2: wp.array(dtype=vec2),
354
+ s3: wp.array(dtype=vec3),
355
+ s4: wp.array(dtype=vec4),
356
+ s5: wp.array(dtype=vec5),
357
+ v2: wp.array(dtype=vec2),
358
+ v3: wp.array(dtype=vec3),
359
+ v4: wp.array(dtype=vec4),
360
+ v5: wp.array(dtype=vec5),
361
+ v20: wp.array(dtype=wptype),
362
+ v21: wp.array(dtype=wptype),
363
+ v30: wp.array(dtype=wptype),
364
+ v31: wp.array(dtype=wptype),
365
+ v32: wp.array(dtype=wptype),
366
+ v40: wp.array(dtype=wptype),
367
+ v41: wp.array(dtype=wptype),
368
+ v42: wp.array(dtype=wptype),
369
+ v43: wp.array(dtype=wptype),
370
+ v50: wp.array(dtype=wptype),
371
+ v51: wp.array(dtype=wptype),
372
+ v52: wp.array(dtype=wptype),
373
+ v53: wp.array(dtype=wptype),
374
+ v54: wp.array(dtype=wptype),
375
+ ):
376
+ v2result = v2[0] - s2[0]
377
+ v3result = v3[0] - s3[0]
378
+ v4result = v4[0] - s4[0]
379
+ v5result = v5[0] - s5[0]
380
+
381
+ # multiply outputs by 2 so there's something to backpropagate:
382
+ v20[0] = wptype(2) * v2result[0]
383
+ v21[0] = wptype(2) * v2result[1]
384
+
385
+ v30[0] = wptype(2) * v3result[0]
386
+ v31[0] = wptype(2) * v3result[1]
387
+ v32[0] = wptype(2) * v3result[2]
388
+
389
+ v40[0] = wptype(2) * v4result[0]
390
+ v41[0] = wptype(2) * v4result[1]
391
+ v42[0] = wptype(2) * v4result[2]
392
+ v43[0] = wptype(2) * v4result[3]
393
+
394
+ v50[0] = wptype(2) * v5result[0]
395
+ v51[0] = wptype(2) * v5result[1]
396
+ v52[0] = wptype(2) * v5result[2]
397
+ v53[0] = wptype(2) * v5result[3]
398
+ v54[0] = wptype(2) * v5result[4]
399
+
400
+ kernel = getkernel(check_subtraction, suffix=dtype.__name__)
401
+
402
+ if register_kernels:
403
+ return
404
+
405
+ s2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
406
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
407
+ s4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
408
+ s5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
409
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
410
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
411
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
412
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
413
+ v20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
414
+ v21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
415
+ v30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
416
+ v31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
417
+ v32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
418
+ v40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
419
+ v41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
420
+ v42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
421
+ v43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
422
+ v50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
423
+ v51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
424
+ v52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
425
+ v53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
426
+ v54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
427
+ tape = wp.Tape()
428
+ with tape:
429
+ wp.launch(
430
+ kernel,
431
+ dim=1,
432
+ inputs=[
433
+ s2,
434
+ s3,
435
+ s4,
436
+ s5,
437
+ v2,
438
+ v3,
439
+ v4,
440
+ v5,
441
+ ],
442
+ outputs=[v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54],
443
+ device=device,
444
+ )
445
+
446
+ assert_np_equal(v20.numpy()[0], 2 * (v2.numpy()[0, 0] - s2.numpy()[0, 0]), tol=tol)
447
+ assert_np_equal(v21.numpy()[0], 2 * (v2.numpy()[0, 1] - s2.numpy()[0, 1]), tol=tol)
448
+
449
+ assert_np_equal(v30.numpy()[0], 2 * (v3.numpy()[0, 0] - s3.numpy()[0, 0]), tol=tol)
450
+ assert_np_equal(v31.numpy()[0], 2 * (v3.numpy()[0, 1] - s3.numpy()[0, 1]), tol=tol)
451
+ assert_np_equal(v32.numpy()[0], 2 * (v3.numpy()[0, 2] - s3.numpy()[0, 2]), tol=tol)
452
+
453
+ assert_np_equal(v40.numpy()[0], 2 * (v4.numpy()[0, 0] - s4.numpy()[0, 0]), tol=2 * tol)
454
+ assert_np_equal(v41.numpy()[0], 2 * (v4.numpy()[0, 1] - s4.numpy()[0, 1]), tol=2 * tol)
455
+ assert_np_equal(v42.numpy()[0], 2 * (v4.numpy()[0, 2] - s4.numpy()[0, 2]), tol=2 * tol)
456
+ assert_np_equal(v43.numpy()[0], 2 * (v4.numpy()[0, 3] - s4.numpy()[0, 3]), tol=2 * tol)
457
+
458
+ assert_np_equal(v50.numpy()[0], 2 * (v5.numpy()[0, 0] - s5.numpy()[0, 0]), tol=tol)
459
+ assert_np_equal(v51.numpy()[0], 2 * (v5.numpy()[0, 1] - s5.numpy()[0, 1]), tol=tol)
460
+ assert_np_equal(v52.numpy()[0], 2 * (v5.numpy()[0, 2] - s5.numpy()[0, 2]), tol=tol)
461
+ assert_np_equal(v53.numpy()[0], 2 * (v5.numpy()[0, 3] - s5.numpy()[0, 3]), tol=tol)
462
+ assert_np_equal(v54.numpy()[0], 2 * (v5.numpy()[0, 4] - s5.numpy()[0, 4]), tol=tol)
463
+
464
+ if dtype in np_float_types:
465
+ for i, l in enumerate([v20, v21, v30, v31, v32, v40, v41, v42, v43, v50, v51, v52, v53, v54]):
466
+ tape.backward(loss=l)
467
+ sgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [s2, s3, s4, s5]])
468
+ expected_grads = np.zeros_like(sgrads)
469
+
470
+ expected_grads[i] = -2
471
+ assert_np_equal(sgrads, expected_grads, tol=10 * tol)
472
+
473
+ allgrads = np.concatenate([tape.gradients[v].numpy()[0] for v in [v2, v3, v4, v5]])
474
+ expected_grads = np.zeros_like(allgrads)
475
+
476
+ # d/dv v/s = 1/s
477
+ expected_grads[i] = 2
478
+ assert_np_equal(allgrads, expected_grads, tol=tol)
479
+
480
+ tape.zero()
481
+
482
+
483
+ def test_length(test, device, dtype, register_kernels=False):
484
+ rng = np.random.default_rng(123)
485
+
486
+ tol = {
487
+ np.float16: 5.0e-3,
488
+ np.float32: 1.0e-6,
489
+ np.float64: 1.0e-7,
490
+ }.get(dtype, 0)
491
+
492
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
493
+ vec2 = wp.types.vector(length=2, dtype=wptype)
494
+ vec3 = wp.types.vector(length=3, dtype=wptype)
495
+ vec4 = wp.types.vector(length=4, dtype=wptype)
496
+ vec5 = wp.types.vector(length=5, dtype=wptype)
497
+
498
+ def check_length(
499
+ v2: wp.array(dtype=vec2),
500
+ v3: wp.array(dtype=vec3),
501
+ v4: wp.array(dtype=vec4),
502
+ v5: wp.array(dtype=vec5),
503
+ l2: wp.array(dtype=wptype),
504
+ l3: wp.array(dtype=wptype),
505
+ l4: wp.array(dtype=wptype),
506
+ l5: wp.array(dtype=wptype),
507
+ l22: wp.array(dtype=wptype),
508
+ l23: wp.array(dtype=wptype),
509
+ l24: wp.array(dtype=wptype),
510
+ l25: wp.array(dtype=wptype),
511
+ ):
512
+ l2[0] = wptype(2) * wp.length(v2[0])
513
+ l3[0] = wptype(2) * wp.length(v3[0])
514
+ l4[0] = wptype(2) * wp.length(v4[0])
515
+ l5[0] = wptype(2) * wp.length(v5[0])
516
+
517
+ l22[0] = wptype(2) * wp.length_sq(v2[0])
518
+ l23[0] = wptype(2) * wp.length_sq(v3[0])
519
+ l24[0] = wptype(2) * wp.length_sq(v4[0])
520
+ l25[0] = wptype(2) * wp.length_sq(v5[0])
521
+
522
+ kernel = getkernel(check_length, suffix=dtype.__name__)
523
+
524
+ if register_kernels:
525
+ return
526
+
527
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
528
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
529
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
530
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
531
+
532
+ l2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
533
+ l3 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
534
+ l4 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
535
+ l5 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
536
+
537
+ l22 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
538
+ l23 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
539
+ l24 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
540
+ l25 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
541
+
542
+ tape = wp.Tape()
543
+ with tape:
544
+ wp.launch(
545
+ kernel,
546
+ dim=1,
547
+ inputs=[
548
+ v2,
549
+ v3,
550
+ v4,
551
+ v5,
552
+ ],
553
+ outputs=[l2, l3, l4, l5, l22, l23, l24, l25],
554
+ device=device,
555
+ )
556
+
557
+ assert_np_equal(l2.numpy()[0], 2 * np.linalg.norm(v2.numpy()), tol=10 * tol)
558
+ assert_np_equal(l3.numpy()[0], 2 * np.linalg.norm(v3.numpy()), tol=10 * tol)
559
+ assert_np_equal(l4.numpy()[0], 2 * np.linalg.norm(v4.numpy()), tol=10 * tol)
560
+ assert_np_equal(l5.numpy()[0], 2 * np.linalg.norm(v5.numpy()), tol=10 * tol)
561
+
562
+ assert_np_equal(l22.numpy()[0], 2 * np.linalg.norm(v2.numpy()) ** 2, tol=10 * tol)
563
+ assert_np_equal(l23.numpy()[0], 2 * np.linalg.norm(v3.numpy()) ** 2, tol=10 * tol)
564
+ assert_np_equal(l24.numpy()[0], 2 * np.linalg.norm(v4.numpy()) ** 2, tol=10 * tol)
565
+ assert_np_equal(l25.numpy()[0], 2 * np.linalg.norm(v5.numpy()) ** 2, tol=10 * tol)
566
+
567
+ tape.backward(loss=l2)
568
+ grad = tape.gradients[v2].numpy()[0]
569
+ expected_grad = 2 * v2.numpy()[0] / np.linalg.norm(v2.numpy())
570
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
571
+ tape.zero()
572
+
573
+ tape.backward(loss=l3)
574
+ grad = tape.gradients[v3].numpy()[0]
575
+ expected_grad = 2 * v3.numpy()[0] / np.linalg.norm(v3.numpy())
576
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
577
+ tape.zero()
578
+
579
+ tape.backward(loss=l4)
580
+ grad = tape.gradients[v4].numpy()[0]
581
+ expected_grad = 2 * v4.numpy()[0] / np.linalg.norm(v4.numpy())
582
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
583
+ tape.zero()
584
+
585
+ tape.backward(loss=l5)
586
+ grad = tape.gradients[v5].numpy()[0]
587
+ expected_grad = 2 * v5.numpy()[0] / np.linalg.norm(v5.numpy())
588
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
589
+ tape.zero()
590
+
591
+ tape.backward(loss=l22)
592
+ grad = tape.gradients[v2].numpy()[0]
593
+ expected_grad = 4 * v2.numpy()[0]
594
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
595
+ tape.zero()
596
+
597
+ tape.backward(loss=l23)
598
+ grad = tape.gradients[v3].numpy()[0]
599
+ expected_grad = 4 * v3.numpy()[0]
600
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
601
+ tape.zero()
602
+
603
+ tape.backward(loss=l24)
604
+ grad = tape.gradients[v4].numpy()[0]
605
+ expected_grad = 4 * v4.numpy()[0]
606
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
607
+ tape.zero()
608
+
609
+ tape.backward(loss=l25)
610
+ grad = tape.gradients[v5].numpy()[0]
611
+ expected_grad = 4 * v5.numpy()[0]
612
+ assert_np_equal(grad, expected_grad, tol=10 * tol)
613
+ tape.zero()
614
+
615
+
616
+ def test_normalize(test, device, dtype, register_kernels=False):
617
+ rng = np.random.default_rng(123)
618
+
619
+ tol = {
620
+ np.float16: 5.0e-3,
621
+ np.float32: 1.0e-6,
622
+ np.float64: 1.0e-8,
623
+ }.get(dtype, 0)
624
+
625
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
626
+ vec2 = wp.types.vector(length=2, dtype=wptype)
627
+ vec3 = wp.types.vector(length=3, dtype=wptype)
628
+ vec4 = wp.types.vector(length=4, dtype=wptype)
629
+ vec5 = wp.types.vector(length=5, dtype=wptype)
630
+
631
+ def check_normalize(
632
+ v2: wp.array(dtype=vec2),
633
+ v3: wp.array(dtype=vec3),
634
+ v4: wp.array(dtype=vec4),
635
+ v5: wp.array(dtype=vec5),
636
+ n20: wp.array(dtype=wptype),
637
+ n21: wp.array(dtype=wptype),
638
+ n30: wp.array(dtype=wptype),
639
+ n31: wp.array(dtype=wptype),
640
+ n32: wp.array(dtype=wptype),
641
+ n40: wp.array(dtype=wptype),
642
+ n41: wp.array(dtype=wptype),
643
+ n42: wp.array(dtype=wptype),
644
+ n43: wp.array(dtype=wptype),
645
+ n50: wp.array(dtype=wptype),
646
+ n51: wp.array(dtype=wptype),
647
+ n52: wp.array(dtype=wptype),
648
+ n53: wp.array(dtype=wptype),
649
+ n54: wp.array(dtype=wptype),
650
+ ):
651
+ n2 = wptype(2) * wp.normalize(v2[0])
652
+ n3 = wptype(2) * wp.normalize(v3[0])
653
+ n4 = wptype(2) * wp.normalize(v4[0])
654
+ n5 = wptype(2) * wp.normalize(v5[0])
655
+
656
+ n20[0] = n2[0]
657
+ n21[0] = n2[1]
658
+
659
+ n30[0] = n3[0]
660
+ n31[0] = n3[1]
661
+ n32[0] = n3[2]
662
+
663
+ n40[0] = n4[0]
664
+ n41[0] = n4[1]
665
+ n42[0] = n4[2]
666
+ n43[0] = n4[3]
667
+
668
+ n50[0] = n5[0]
669
+ n51[0] = n5[1]
670
+ n52[0] = n5[2]
671
+ n53[0] = n5[3]
672
+ n54[0] = n5[4]
673
+
674
+ def check_normalize_alt(
675
+ v2: wp.array(dtype=vec2),
676
+ v3: wp.array(dtype=vec3),
677
+ v4: wp.array(dtype=vec4),
678
+ v5: wp.array(dtype=vec5),
679
+ n20: wp.array(dtype=wptype),
680
+ n21: wp.array(dtype=wptype),
681
+ n30: wp.array(dtype=wptype),
682
+ n31: wp.array(dtype=wptype),
683
+ n32: wp.array(dtype=wptype),
684
+ n40: wp.array(dtype=wptype),
685
+ n41: wp.array(dtype=wptype),
686
+ n42: wp.array(dtype=wptype),
687
+ n43: wp.array(dtype=wptype),
688
+ n50: wp.array(dtype=wptype),
689
+ n51: wp.array(dtype=wptype),
690
+ n52: wp.array(dtype=wptype),
691
+ n53: wp.array(dtype=wptype),
692
+ n54: wp.array(dtype=wptype),
693
+ ):
694
+ n2 = wptype(2) * v2[0] / wp.length(v2[0])
695
+ n3 = wptype(2) * v3[0] / wp.length(v3[0])
696
+ n4 = wptype(2) * v4[0] / wp.length(v4[0])
697
+ n5 = wptype(2) * v5[0] / wp.length(v5[0])
698
+
699
+ n20[0] = n2[0]
700
+ n21[0] = n2[1]
701
+
702
+ n30[0] = n3[0]
703
+ n31[0] = n3[1]
704
+ n32[0] = n3[2]
705
+
706
+ n40[0] = n4[0]
707
+ n41[0] = n4[1]
708
+ n42[0] = n4[2]
709
+ n43[0] = n4[3]
710
+
711
+ n50[0] = n5[0]
712
+ n51[0] = n5[1]
713
+ n52[0] = n5[2]
714
+ n53[0] = n5[3]
715
+ n54[0] = n5[4]
716
+
717
+ normalize_kernel = getkernel(check_normalize, suffix=dtype.__name__)
718
+ normalize_alt_kernel = getkernel(check_normalize_alt, suffix=dtype.__name__)
719
+
720
+ if register_kernels:
721
+ return
722
+
723
+ # I've already tested the things I'm using in check_normalize_alt, so I'll just
724
+ # make sure the two are giving the same results/gradients
725
+ v2 = wp.array(randvals(rng, (1, 2), dtype), dtype=vec2, requires_grad=True, device=device)
726
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
727
+ v4 = wp.array(randvals(rng, (1, 4), dtype), dtype=vec4, requires_grad=True, device=device)
728
+ v5 = wp.array(randvals(rng, (1, 5), dtype), dtype=vec5, requires_grad=True, device=device)
729
+
730
+ n20 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
731
+ n21 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
732
+ n30 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
733
+ n31 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
734
+ n32 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
735
+ n40 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
736
+ n41 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
737
+ n42 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
738
+ n43 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
739
+ n50 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
740
+ n51 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
741
+ n52 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
742
+ n53 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
743
+ n54 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
744
+
745
+ n20_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
746
+ n21_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
747
+ n30_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
748
+ n31_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
749
+ n32_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
750
+ n40_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
751
+ n41_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
752
+ n42_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
753
+ n43_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
754
+ n50_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
755
+ n51_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
756
+ n52_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
757
+ n53_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
758
+ n54_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
759
+
760
+ outputs0 = [
761
+ n20,
762
+ n21,
763
+ n30,
764
+ n31,
765
+ n32,
766
+ n40,
767
+ n41,
768
+ n42,
769
+ n43,
770
+ n50,
771
+ n51,
772
+ n52,
773
+ n53,
774
+ n54,
775
+ ]
776
+ tape0 = wp.Tape()
777
+ with tape0:
778
+ wp.launch(
779
+ normalize_kernel,
780
+ dim=1,
781
+ inputs=[
782
+ v2,
783
+ v3,
784
+ v4,
785
+ v5,
786
+ ],
787
+ outputs=outputs0,
788
+ device=device,
789
+ )
790
+
791
+ outputs1 = [
792
+ n20_alt,
793
+ n21_alt,
794
+ n30_alt,
795
+ n31_alt,
796
+ n32_alt,
797
+ n40_alt,
798
+ n41_alt,
799
+ n42_alt,
800
+ n43_alt,
801
+ n50_alt,
802
+ n51_alt,
803
+ n52_alt,
804
+ n53_alt,
805
+ n54_alt,
806
+ ]
807
+ tape1 = wp.Tape()
808
+ with tape1:
809
+ wp.launch(
810
+ normalize_alt_kernel,
811
+ dim=1,
812
+ inputs=[
813
+ v2,
814
+ v3,
815
+ v4,
816
+ v5,
817
+ ],
818
+ outputs=outputs1,
819
+ device=device,
820
+ )
821
+
822
+ for ncmp, ncmpalt in zip(outputs0, outputs1):
823
+ assert_np_equal(ncmp.numpy()[0], ncmpalt.numpy()[0], tol=10 * tol)
824
+
825
+ invecs = [
826
+ v2,
827
+ v2,
828
+ v3,
829
+ v3,
830
+ v3,
831
+ v4,
832
+ v4,
833
+ v4,
834
+ v4,
835
+ v5,
836
+ v5,
837
+ v5,
838
+ v5,
839
+ v5,
840
+ ]
841
+ for ncmp, ncmpalt, v in zip(outputs0, outputs1, invecs):
842
+ tape0.backward(loss=ncmp)
843
+ tape1.backward(loss=ncmpalt)
844
+ assert_np_equal(tape0.gradients[v].numpy()[0], tape1.gradients[v].numpy()[0], tol=10 * tol)
845
+ tape0.zero()
846
+ tape1.zero()
847
+
848
+
849
+ def test_crossproduct(test, device, dtype, register_kernels=False):
850
+ rng = np.random.default_rng(123)
851
+
852
+ tol = {
853
+ np.float16: 5.0e-3,
854
+ np.float32: 1.0e-6,
855
+ np.float64: 1.0e-8,
856
+ }.get(dtype, 0)
857
+
858
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
859
+ vec3 = wp.types.vector(length=3, dtype=wptype)
860
+
861
+ def check_cross(
862
+ s3: wp.array(dtype=vec3),
863
+ v3: wp.array(dtype=vec3),
864
+ c0: wp.array(dtype=wptype),
865
+ c1: wp.array(dtype=wptype),
866
+ c2: wp.array(dtype=wptype),
867
+ ):
868
+ c = wp.cross(s3[0], v3[0])
869
+
870
+ # multiply outputs by 2 so we've got something to backpropagate:
871
+ c0[0] = wptype(2) * c[0]
872
+ c1[0] = wptype(2) * c[1]
873
+ c2[0] = wptype(2) * c[2]
874
+
875
+ kernel = getkernel(check_cross, suffix=dtype.__name__)
876
+
877
+ if register_kernels:
878
+ return
879
+
880
+ s3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
881
+ v3 = wp.array(randvals(rng, (1, 3), dtype), dtype=vec3, requires_grad=True, device=device)
882
+ c0 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
883
+ c1 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
884
+ c2 = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
885
+ tape = wp.Tape()
886
+ with tape:
887
+ wp.launch(
888
+ kernel,
889
+ dim=1,
890
+ inputs=[
891
+ s3,
892
+ v3,
893
+ ],
894
+ outputs=[c0, c1, c2],
895
+ device=device,
896
+ )
897
+
898
+ result = 2 * np.cross(s3.numpy(), v3.numpy())[0]
899
+ assert_np_equal(c0.numpy()[0], result[0], tol=10 * tol)
900
+ assert_np_equal(c1.numpy()[0], result[1], tol=10 * tol)
901
+ assert_np_equal(c2.numpy()[0], result[2], tol=10 * tol)
902
+
903
+ if dtype in np_float_types:
904
+ # c.x = sy vz - sz vy
905
+ # c.y = sz vx - sx vz
906
+ # c.z = sx vy - sy vx
907
+
908
+ # ( d/dsx d/dsy d/dsz )c.x = ( 0 vz -vy )
909
+ # ( d/dsx d/dsy d/dsz )c.y = ( -vz 0 vx )
910
+ # ( d/dsx d/dsy d/dsz )c.z = ( vy -vx 0 )
911
+
912
+ # ( d/dvx d/dvy d/dvz )c.x = (0 -sz sy)
913
+ # ( d/dvx d/dvy d/dvz )c.y = (sz 0 -sx)
914
+ # ( d/dvx d/dvy d/dvz )c.z = (-sy sx 0)
915
+
916
+ tape.backward(loss=c0)
917
+ assert_np_equal(
918
+ tape.gradients[s3].numpy(), 2.0 * np.array([0, v3.numpy()[0, 2], -v3.numpy()[0, 1]]), tol=10 * tol
919
+ )
920
+ assert_np_equal(
921
+ tape.gradients[v3].numpy(), 2.0 * np.array([0, -s3.numpy()[0, 2], s3.numpy()[0, 1]]), tol=10 * tol
922
+ )
923
+ tape.zero()
924
+
925
+ tape.backward(loss=c1)
926
+ assert_np_equal(
927
+ tape.gradients[s3].numpy(), 2.0 * np.array([-v3.numpy()[0, 2], 0, v3.numpy()[0, 0]]), tol=10 * tol
928
+ )
929
+ assert_np_equal(
930
+ tape.gradients[v3].numpy(), 2.0 * np.array([s3.numpy()[0, 2], 0, -s3.numpy()[0, 0]]), tol=10 * tol
931
+ )
932
+ tape.zero()
933
+
934
+ tape.backward(loss=c2)
935
+ assert_np_equal(
936
+ tape.gradients[s3].numpy(), 2.0 * np.array([v3.numpy()[0, 1], -v3.numpy()[0, 0], 0]), tol=10 * tol
937
+ )
938
+ assert_np_equal(
939
+ tape.gradients[v3].numpy(), 2.0 * np.array([-s3.numpy()[0, 1], s3.numpy()[0, 0], 0]), tol=10 * tol
940
+ )
941
+ tape.zero()
942
+
943
+
944
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
945
+ np_type = np.dtype(dtype)
946
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
947
+ vec3 = wp.types.vector(length=3, dtype=wp_type)
948
+
949
+ np16 = np.dtype(np.float16)
950
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
951
+
952
+ np32 = np.dtype(np.float32)
953
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
954
+
955
+ np64 = np.dtype(np.float64)
956
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
957
+
958
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
959
+ tid = wp.tid()
960
+
961
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
962
+ v2 = wp.vector(v1, dtype=wp16)
963
+
964
+ b[tid, 0] = v2[0]
965
+ b[tid, 1] = v2[1]
966
+ b[tid, 2] = v2[2]
967
+
968
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
969
+ tid = wp.tid()
970
+
971
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
972
+ v2 = wp.vector(v1, dtype=wp32)
973
+
974
+ b[tid, 0] = v2[0]
975
+ b[tid, 1] = v2[1]
976
+ b[tid, 2] = v2[2]
977
+
978
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
979
+ tid = wp.tid()
980
+
981
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
982
+ v2 = wp.vector(v1, dtype=wp64)
983
+
984
+ b[tid, 0] = v2[0]
985
+ b[tid, 1] = v2[1]
986
+ b[tid, 2] = v2[2]
987
+
988
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
989
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
990
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
991
+
992
+ if register_kernels:
993
+ return
994
+
995
+ # check casting to float 16
996
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
997
+ b = wp.array(np.zeros((1, 3), dtype=np16), dtype=wp16, requires_grad=True, device=device)
998
+ b_result = np.ones((1, 3), dtype=np16)
999
+ b_grad = wp.array(np.ones((1, 3), dtype=np16), dtype=wp16, device=device)
1000
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
1001
+
1002
+ tape = wp.Tape()
1003
+ with tape:
1004
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
1005
+
1006
+ tape.backward(grads={b: b_grad})
1007
+ out = tape.gradients[a].numpy()
1008
+
1009
+ assert_np_equal(b.numpy(), b_result)
1010
+ assert_np_equal(out, a_grad.numpy())
1011
+
1012
+ # check casting to float 32
1013
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
1014
+ b = wp.array(np.zeros((1, 3), dtype=np32), dtype=wp32, requires_grad=True, device=device)
1015
+ b_result = np.ones((1, 3), dtype=np32)
1016
+ b_grad = wp.array(np.ones((1, 3), dtype=np32), dtype=wp32, device=device)
1017
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
1018
+
1019
+ tape = wp.Tape()
1020
+ with tape:
1021
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
1022
+
1023
+ tape.backward(grads={b: b_grad})
1024
+ out = tape.gradients[a].numpy()
1025
+
1026
+ assert_np_equal(b.numpy(), b_result)
1027
+ assert_np_equal(out, a_grad.numpy())
1028
+
1029
+ # check casting to float 64
1030
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
1031
+ b = wp.array(np.zeros((1, 3), dtype=np64), dtype=wp64, requires_grad=True, device=device)
1032
+ b_result = np.ones((1, 3), dtype=np64)
1033
+ b_grad = wp.array(np.ones((1, 3), dtype=np64), dtype=wp64, device=device)
1034
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
1035
+
1036
+ tape = wp.Tape()
1037
+ with tape:
1038
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
1039
+
1040
+ tape.backward(grads={b: b_grad})
1041
+ out = tape.gradients[a].numpy()
1042
+
1043
+ assert_np_equal(b.numpy(), b_result)
1044
+ assert_np_equal(out, a_grad.numpy())
1045
+
1046
+
1047
+ def test_vector_assign_inplace(test, device, dtype, register_kernels=False):
1048
+ np_type = np.dtype(dtype)
1049
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
1050
+
1051
+ vec2 = wp.types.vector(length=2, dtype=wp_type)
1052
+ vec3 = wp.types.vector(length=3, dtype=wp_type)
1053
+ vec4 = wp.types.vector(length=4, dtype=wp_type)
1054
+
1055
+ def vectest_read_write_store(
1056
+ x: wp.array(dtype=wp_type), a: wp.array(dtype=vec2), b: wp.array(dtype=vec3), c: wp.array(dtype=vec4)
1057
+ ):
1058
+ tid = wp.tid()
1059
+
1060
+ t = a[tid]
1061
+ t[0] = x[tid]
1062
+ a[tid] = t
1063
+
1064
+ u = b[tid]
1065
+ u[1] = x[tid]
1066
+ b[tid] = u
1067
+
1068
+ v = c[tid]
1069
+ v[2] = x[tid]
1070
+ c[tid] = v
1071
+
1072
+ def vectest_in_register(
1073
+ x: wp.array(dtype=wp_type), y: wp.array(dtype=vec3), a: wp.array(dtype=vec2), b: wp.array(dtype=vec3)
1074
+ ):
1075
+ tid = wp.tid()
1076
+
1077
+ f = vec3(wp_type(0.0))
1078
+ b_vec = b[tid]
1079
+ f[0] = b_vec[1]
1080
+ f[2] = b_vec[0] * b_vec[1]
1081
+ y[tid] = f
1082
+
1083
+ g = wp_type(0.0)
1084
+ a_vec = a[tid]
1085
+ g = a_vec[0] + a_vec[1]
1086
+ x[tid] = g
1087
+
1088
+ def vectest_component(x: wp.array(dtype=vec3), y: wp.array(dtype=wp_type)):
1089
+ i = wp.tid()
1090
+
1091
+ a = vec3(wp_type(0.0))
1092
+ a.x = wp_type(1.0) * y[i]
1093
+ a.y = wp_type(2.0) * y[i]
1094
+ a.z = wp_type(3.0) * y[i]
1095
+ x[i] = a
1096
+
1097
+ kernel_read_write_store = getkernel(vectest_read_write_store, suffix=dtype.__name__)
1098
+ kernel_in_register = getkernel(vectest_in_register, suffix=dtype.__name__)
1099
+ kernel_component = getkernel(vectest_component, suffix=dtype.__name__)
1100
+
1101
+ if register_kernels:
1102
+ return
1103
+
1104
+ a = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1105
+ b = wp.ones(1, dtype=vec3, device=device, requires_grad=True)
1106
+ c = wp.ones(1, dtype=vec4, device=device, requires_grad=True)
1107
+ x = wp.full(1, value=2.0, dtype=wp_type, device=device, requires_grad=True)
1108
+
1109
+ tape = wp.Tape()
1110
+ with tape:
1111
+ wp.launch(kernel_read_write_store, dim=1, inputs=[x, a, b, c], device=device)
1112
+
1113
+ tape.backward(
1114
+ grads={
1115
+ a: wp.ones_like(a, requires_grad=False),
1116
+ b: wp.ones_like(b, requires_grad=False),
1117
+ c: wp.ones_like(c, requires_grad=False),
1118
+ }
1119
+ )
1120
+
1121
+ assert_np_equal(a.numpy(), np.array([[2.0, 1.0]], dtype=np_type))
1122
+ assert_np_equal(b.numpy(), np.array([[1.0, 2.0, 1.0]], dtype=np_type))
1123
+ assert_np_equal(c.numpy(), np.array([[1.0, 1.0, 2.0, 1.0]], dtype=np_type))
1124
+ assert_np_equal(x.grad.numpy(), np.array([3.0], dtype=np_type))
1125
+
1126
+ tape.reset()
1127
+
1128
+ a = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1129
+ b = wp.ones(1, dtype=vec3, device=device, requires_grad=True)
1130
+ x = wp.zeros(1, dtype=wp_type, device=device, requires_grad=True)
1131
+ y = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1132
+
1133
+ with tape:
1134
+ wp.launch(kernel_in_register, dim=1, inputs=[x, y, a, b], device=device)
1135
+
1136
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False), y: wp.ones_like(y, requires_grad=False)})
1137
+
1138
+ assert_np_equal(x.numpy(), np.array([2.0], dtype=np_type))
1139
+ assert_np_equal(y.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=np_type))
1140
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 1.0]], dtype=np_type))
1141
+ assert_np_equal(b.grad.numpy(), np.array([[1.0, 2.0, 0.0]], dtype=np_type))
1142
+
1143
+ tape.reset()
1144
+
1145
+ x = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1146
+ y = wp.ones(1, dtype=wp_type, device=device, requires_grad=True)
1147
+
1148
+ with tape:
1149
+ wp.launch(kernel_component, dim=1, inputs=[x, y], device=device)
1150
+
1151
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1152
+
1153
+ assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0]], dtype=np_type))
1154
+ assert_np_equal(y.grad.numpy(), np.array([6.0], dtype=np_type))
1155
+
1156
+
1157
+ @wp.kernel
1158
+ def test_vector_constructor_value_func():
1159
+ a = wp.vec2()
1160
+ b = wp.vector(a, dtype=wp.float16)
1161
+ c = wp.vector(a)
1162
+ d = wp.vector(a, length=2)
1163
+
1164
+
1165
+ # Test matrix constructors using explicit type (float16)
1166
+ # note that these tests are specifically not using generics / closure
1167
+ # args to create kernels dynamically (like the rest of this file)
1168
+ # as those use different code paths to resolve arg types which
1169
+ # has lead to regressions.
1170
+ @wp.kernel
1171
+ def test_constructors_explicit_precision():
1172
+ # construction for custom matrix types
1173
+ ones = wp.vector(wp.float16(1.0), length=2)
1174
+ zeros = wp.vector(length=2, dtype=wp.float16)
1175
+ custom = wp.vector(wp.float16(0.0), wp.float16(1.0))
1176
+
1177
+ for i in range(2):
1178
+ wp.expect_eq(ones[i], wp.float16(1.0))
1179
+ wp.expect_eq(zeros[i], wp.float16(0.0))
1180
+ wp.expect_eq(custom[i], wp.float16(i))
1181
+
1182
+
1183
+ # Same as above but with a default (float/int) type
1184
+ # which tests some different code paths that
1185
+ # need to ensure types are correctly canonicalized
1186
+ # during codegen
1187
+ @wp.kernel
1188
+ def test_constructors_default_precision():
1189
+ # construction for custom matrix types
1190
+ ones = wp.vector(1.0, length=2)
1191
+ zeros = wp.vector(length=2, dtype=float)
1192
+ custom = wp.vector(0.0, 1.0)
1193
+
1194
+ for i in range(2):
1195
+ wp.expect_eq(ones[i], 1.0)
1196
+ wp.expect_eq(zeros[i], 0.0)
1197
+ wp.expect_eq(custom[i], float(i))
1198
+
1199
+
1200
+ @wp.kernel
1201
+ def test_vector_mutation(expected: wp.types.vector(length=10, dtype=float)):
1202
+ v = wp.vector(length=10, dtype=float)
1203
+
1204
+ # test element indexing
1205
+ v[0] = 1.0
1206
+
1207
+ for i in range(1, 10):
1208
+ v[i] = float(i) + 1.0
1209
+
1210
+ wp.expect_eq(v, expected)
1211
+
1212
+
1213
+ CONSTANT_LENGTH = wp.constant(10)
1214
+
1215
+
1216
+ # tests that we can use global constants in length keyword argument
1217
+ # for vector constructor
1218
+ @wp.kernel
1219
+ def test_constructors_constant_length():
1220
+ v = wp.vector(length=(CONSTANT_LENGTH), dtype=float)
1221
+
1222
+ for i in range(CONSTANT_LENGTH):
1223
+ v[i] = float(i)
1224
+
1225
+
1226
+ Vec123 = wp.vec(123, dtype=wp.float16)
1227
+
1228
+
1229
+ @wp.kernel
1230
+ def vector_len_kernel(
1231
+ v1: wp.vec2,
1232
+ v2: wp.vec(3, float),
1233
+ v3: wp.vec(Any, float),
1234
+ v4: Vec123,
1235
+ out: wp.array(dtype=int),
1236
+ ):
1237
+ length = wp.static(len(v1))
1238
+ wp.expect_eq(len(v1), 2)
1239
+ out[0] = len(v1)
1240
+
1241
+ length = len(v2)
1242
+ wp.expect_eq(wp.static(len(v2)), 3)
1243
+ out[1] = len(v2)
1244
+
1245
+ length = len(v3)
1246
+ wp.expect_eq(len(v3), 4)
1247
+ out[2] = wp.static(len(v3))
1248
+
1249
+ length = wp.static(len(v4))
1250
+ wp.expect_eq(wp.static(len(v4)), 123)
1251
+ out[3] = wp.static(len(v4))
1252
+
1253
+ foo = wp.vec2()
1254
+ length = len(foo)
1255
+ wp.expect_eq(len(foo), 2)
1256
+ out[4] = len(foo)
1257
+
1258
+
1259
+ def test_vector_len(test, device):
1260
+ v1 = wp.vec2()
1261
+ v2 = wp.vec3()
1262
+ v3 = wp.vec4()
1263
+ v4 = Vec123()
1264
+ out = wp.empty(5, dtype=int, device=device)
1265
+ wp.launch(vector_len_kernel, dim=(1,), inputs=(v1, v2, v3, v4), outputs=(out,), device=device)
1266
+
1267
+ test.assertEqual(out.numpy()[0], 2)
1268
+ test.assertEqual(out.numpy()[1], 3)
1269
+ test.assertEqual(out.numpy()[2], 4)
1270
+ test.assertEqual(out.numpy()[3], 123)
1271
+ test.assertEqual(out.numpy()[4], 2)
1272
+
1273
+
1274
+ @wp.kernel
1275
+ def vector_augassign_kernel(
1276
+ a: wp.array(dtype=wp.vec3), b: wp.array(dtype=wp.vec3), c: wp.array(dtype=wp.vec3), d: wp.array(dtype=wp.vec3)
1277
+ ):
1278
+ i = wp.tid()
1279
+
1280
+ v1 = wp.vec3()
1281
+ v2 = b[i]
1282
+
1283
+ v1[0] += v2[0]
1284
+ v1[1] += v2[1]
1285
+ v1[2] += v2[2]
1286
+
1287
+ a[i] = v1
1288
+
1289
+ v3 = wp.vec3()
1290
+ v4 = d[i]
1291
+
1292
+ v3[0] -= v4[0]
1293
+ v3[1] -= v4[1]
1294
+ v3[2] -= v4[2]
1295
+
1296
+ c[i] = v3
1297
+
1298
+
1299
+ def test_vector_augassign(test, device):
1300
+ N = 3
1301
+
1302
+ a = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1303
+ b = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1304
+
1305
+ c = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1306
+ d = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1307
+
1308
+ tape = wp.Tape()
1309
+ with tape:
1310
+ wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d], device=device)
1311
+
1312
+ tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
1313
+
1314
+ assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
1315
+ assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
1316
+ assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
1317
+
1318
+ assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
1319
+ assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
1320
+ assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
1321
+
1322
+
1323
+ def test_vector_assign_copy(test, device):
1324
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1325
+ try:
1326
+ wp.config.enable_vector_component_overwrites = True
1327
+
1328
+ @wp.kernel
1329
+ def vec_in_register_overwrite(x: wp.array(dtype=wp.vec3), a: wp.array(dtype=wp.vec3)):
1330
+ tid = wp.tid()
1331
+
1332
+ f = wp.vec3(0.0)
1333
+ a_vec = a[tid]
1334
+ f = a_vec
1335
+ f[1] = 3.0
1336
+
1337
+ x[tid] = f
1338
+
1339
+ x = wp.zeros(1, dtype=wp.vec3, device=device, requires_grad=True)
1340
+ a = wp.ones(1, dtype=wp.vec3, device=device, requires_grad=True)
1341
+
1342
+ tape = wp.Tape()
1343
+ with tape:
1344
+ wp.launch(vec_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1345
+
1346
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1347
+
1348
+ assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0]], dtype=float))
1349
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=float))
1350
+
1351
+ finally:
1352
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1353
+
1354
+
1355
+ devices = get_test_devices()
1356
+
1357
+
1358
+ class TestVec(unittest.TestCase):
1359
+ def test_tpl_ops_with_anon(self):
1360
+ vec3i = wp.vec(3, dtype=int)
1361
+
1362
+ v = wp.vec3i(1, 2, 3)
1363
+ v += vec3i(2, 3, 4)
1364
+ v -= vec3i(3, 4, 5)
1365
+ self.assertSequenceEqual(v, (0, 1, 2))
1366
+
1367
+ v = vec3i(1, 2, 3)
1368
+ v += wp.vec3i(2, 3, 4)
1369
+ v -= wp.vec3i(3, 4, 5)
1370
+ self.assertSequenceEqual(v, (0, 1, 2))
1371
+
1372
+
1373
+ add_kernel_test(TestVec, test_vector_constructor_value_func, dim=1, devices=devices)
1374
+ add_kernel_test(TestVec, test_constructors_explicit_precision, dim=1, devices=devices)
1375
+ add_kernel_test(TestVec, test_constructors_default_precision, dim=1, devices=devices)
1376
+ add_kernel_test(TestVec, test_constructors_constant_length, dim=1, devices=devices)
1377
+
1378
+ vec10 = wp.types.vector(length=10, dtype=float)
1379
+ add_kernel_test(
1380
+ TestVec,
1381
+ test_vector_mutation,
1382
+ dim=1,
1383
+ inputs=[vec10(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)],
1384
+ devices=devices,
1385
+ )
1386
+
1387
+ for dtype in np_unsigned_int_types:
1388
+ add_function_test_register_kernel(
1389
+ TestVec,
1390
+ f"test_subtraction_unsigned_{dtype.__name__}",
1391
+ test_subtraction_unsigned,
1392
+ devices=devices,
1393
+ dtype=dtype,
1394
+ )
1395
+
1396
+ for dtype in np_signed_int_types + np_float_types:
1397
+ add_function_test_register_kernel(
1398
+ TestVec, f"test_negation_{dtype.__name__}", test_negation, devices=devices, dtype=dtype
1399
+ )
1400
+ add_function_test_register_kernel(
1401
+ TestVec, f"test_subtraction_{dtype.__name__}", test_subtraction, devices=devices, dtype=dtype
1402
+ )
1403
+
1404
+ for dtype in np_float_types:
1405
+ add_function_test_register_kernel(
1406
+ TestVec, f"test_crossproduct_{dtype.__name__}", test_crossproduct, devices=devices, dtype=dtype
1407
+ )
1408
+ add_function_test_register_kernel(
1409
+ TestVec, f"test_length_{dtype.__name__}", test_length, devices=devices, dtype=dtype
1410
+ )
1411
+ add_function_test_register_kernel(
1412
+ TestVec, f"test_normalize_{dtype.__name__}", test_normalize, devices=devices, dtype=dtype
1413
+ )
1414
+ add_function_test_register_kernel(
1415
+ TestVec,
1416
+ f"test_casting_constructors_{dtype.__name__}",
1417
+ test_casting_constructors,
1418
+ devices=devices,
1419
+ dtype=dtype,
1420
+ )
1421
+ add_function_test_register_kernel(
1422
+ TestVec,
1423
+ f"test_vector_assign_inplace_{dtype.__name__}",
1424
+ test_vector_assign_inplace,
1425
+ devices=devices,
1426
+ dtype=dtype,
1427
+ )
1428
+
1429
+ add_function_test(
1430
+ TestVec,
1431
+ "test_anon_constructor_error_length_mismatch",
1432
+ test_anon_constructor_error_length_mismatch,
1433
+ devices=devices,
1434
+ )
1435
+ add_function_test(
1436
+ TestVec,
1437
+ "test_anon_constructor_error_numeric_arg_missing",
1438
+ test_anon_constructor_error_numeric_arg_missing,
1439
+ devices=devices,
1440
+ )
1441
+ add_function_test(
1442
+ TestVec,
1443
+ "test_anon_constructor_error_length_arg_missing",
1444
+ test_anon_constructor_error_length_arg_missing,
1445
+ devices=devices,
1446
+ )
1447
+ add_function_test(
1448
+ TestVec,
1449
+ "test_anon_constructor_error_numeric_args_mismatch",
1450
+ test_anon_constructor_error_numeric_args_mismatch,
1451
+ devices=devices,
1452
+ )
1453
+ add_function_test(
1454
+ TestVec,
1455
+ "test_tpl_constructor_error_incompatible_sizes",
1456
+ test_tpl_constructor_error_incompatible_sizes,
1457
+ devices=devices,
1458
+ )
1459
+ add_function_test(
1460
+ TestVec,
1461
+ "test_tpl_constructor_error_numeric_args_mismatch",
1462
+ test_tpl_constructor_error_numeric_args_mismatch,
1463
+ devices=devices,
1464
+ )
1465
+ add_function_test(
1466
+ TestVec,
1467
+ "test_vector_len",
1468
+ test_vector_len,
1469
+ devices=devices,
1470
+ )
1471
+ add_function_test(
1472
+ TestVec,
1473
+ "test_vector_augassign",
1474
+ test_vector_augassign,
1475
+ devices=devices,
1476
+ )
1477
+ add_function_test(
1478
+ TestVec,
1479
+ "test_vector_assign_copy",
1480
+ test_vector_assign_copy,
1481
+ devices=devices,
1482
+ )
1483
+
1484
+
1485
+ if __name__ == "__main__":
1486
+ wp.clear_kernel_cache()
1487
+ unittest.main(verbosity=2, failfast=True)