warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,893 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+
25
+ @wp.kernel
26
+ def scalar_grad(x: wp.array(dtype=float), y: wp.array(dtype=float)):
27
+ y[0] = x[0] ** 2.0
28
+
29
+
30
+ def test_scalar_grad(test, device):
31
+ x = wp.array([3.0], dtype=float, device=device, requires_grad=True)
32
+ y = wp.zeros_like(x)
33
+
34
+ tape = wp.Tape()
35
+ with tape:
36
+ wp.launch(scalar_grad, dim=1, inputs=[x, y], device=device)
37
+
38
+ tape.backward(y)
39
+
40
+ assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
41
+
42
+
43
+ @wp.kernel
44
+ def for_loop_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
45
+ sum = float(0.0)
46
+
47
+ for i in range(n):
48
+ sum = sum + x[i] * 2.0
49
+
50
+ s[0] = sum
51
+
52
+
53
+ def test_for_loop_grad(test, device):
54
+ n = 32
55
+ val = np.ones(n, dtype=np.float32)
56
+
57
+ x = wp.array(val, device=device, requires_grad=True)
58
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
59
+
60
+ tape = wp.Tape()
61
+ with tape:
62
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
63
+
64
+ # ensure forward pass outputs correct
65
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
66
+
67
+ tape.backward(loss=sum)
68
+
69
+ # ensure forward pass outputs persist
70
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
71
+ # ensure gradients correct
72
+ assert_np_equal(tape.gradients[x].numpy(), 2.0 * val)
73
+
74
+
75
+ def test_for_loop_graph_grad(test, device):
76
+ wp.load_module(device=device)
77
+
78
+ n = 32
79
+ val = np.ones(n, dtype=np.float32)
80
+
81
+ x = wp.array(val, device=device, requires_grad=True)
82
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
83
+
84
+ wp.capture_begin(device, force_module_load=False)
85
+ try:
86
+ tape = wp.Tape()
87
+ with tape:
88
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
89
+
90
+ tape.backward(loss=sum)
91
+ finally:
92
+ graph = wp.capture_end(device)
93
+
94
+ wp.capture_launch(graph)
95
+ wp.synchronize_device(device)
96
+
97
+ # ensure forward pass outputs persist
98
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
99
+ # ensure gradients correct
100
+ assert_np_equal(x.grad.numpy(), 2.0 * val)
101
+
102
+ wp.capture_launch(graph)
103
+ wp.synchronize_device(device)
104
+
105
+
106
+ @wp.kernel
107
+ def for_loop_nested_if_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
108
+ sum = float(0.0)
109
+
110
+ for i in range(n):
111
+ if i < 16:
112
+ if i < 8:
113
+ sum = sum + x[i] * 2.0
114
+ else:
115
+ sum = sum + x[i] * 4.0
116
+ else:
117
+ if i < 24:
118
+ sum = sum + x[i] * 6.0
119
+ else:
120
+ sum = sum + x[i] * 8.0
121
+
122
+ s[0] = sum
123
+
124
+
125
+ def test_for_loop_nested_if_grad(test, device):
126
+ n = 32
127
+ val = np.ones(n, dtype=np.float32)
128
+ # fmt: off
129
+ expected_val = [
130
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
131
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
132
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
133
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
134
+ ]
135
+ expected_grad = [
136
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
137
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
138
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
139
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
140
+ ]
141
+ # fmt: on
142
+
143
+ x = wp.array(val, device=device, requires_grad=True)
144
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
145
+
146
+ tape = wp.Tape()
147
+ with tape:
148
+ wp.launch(for_loop_nested_if_grad, dim=1, inputs=[n, x, sum], device=device)
149
+
150
+ assert_np_equal(sum.numpy(), np.sum(expected_val))
151
+
152
+ tape.backward(loss=sum)
153
+
154
+ assert_np_equal(sum.numpy(), np.sum(expected_val))
155
+ assert_np_equal(tape.gradients[x].numpy(), np.array(expected_grad))
156
+
157
+
158
+ @wp.kernel
159
+ def for_loop_grad_nested(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
160
+ sum = float(0.0)
161
+
162
+ for i in range(n):
163
+ for j in range(n):
164
+ sum = sum + x[i * n + j] * float(i * n + j) + 1.0
165
+
166
+ s[0] = sum
167
+
168
+
169
+ def test_for_loop_nested_for_grad(test, device):
170
+ x = wp.zeros(9, dtype=float, device=device, requires_grad=True)
171
+ s = wp.zeros(1, dtype=float, device=device, requires_grad=True)
172
+
173
+ tape = wp.Tape()
174
+ with tape:
175
+ wp.launch(for_loop_grad_nested, dim=1, inputs=[3, x, s], device=device)
176
+
177
+ tape.backward(s)
178
+
179
+ assert_np_equal(s.numpy(), np.array([9.0]))
180
+ assert_np_equal(tape.gradients[x].numpy(), np.arange(0.0, 9.0, 1.0))
181
+
182
+
183
+ # differentiating thought most while loops is not supported
184
+ # since doing things like i = i + 1 breaks adjointing
185
+
186
+ # @wp.kernel
187
+ # def while_loop_grad(n: int,
188
+ # x: wp.array(dtype=float),
189
+ # c: wp.array(dtype=int),
190
+ # s: wp.array(dtype=float)):
191
+
192
+ # tid = wp.tid()
193
+
194
+ # while i < n:
195
+ # s[0] = s[0] + x[i]*2.0
196
+ # i = i + 1
197
+
198
+
199
+ # def test_while_loop_grad(test, device):
200
+
201
+ # n = 32
202
+ # x = wp.array(np.ones(n, dtype=np.float32), device=device, requires_grad=True)
203
+ # c = wp.zeros(1, dtype=int, device=device)
204
+ # sum = wp.zeros(1, dtype=wp.float32, device=device)
205
+
206
+ # tape = wp.Tape()
207
+ # with tape:
208
+ # wp.launch(while_loop_grad, dim=1, inputs=[n, x, c, sum], device=device)
209
+
210
+ # tape.backward(loss=sum)
211
+
212
+ # assert_np_equal(sum.numpy(), 2.0*np.sum(x.numpy()))
213
+ # assert_np_equal(tape.gradients[x].numpy(), 2.0*np.ones_like(x.numpy()))
214
+
215
+
216
+ @wp.kernel
217
+ def preserve_outputs(
218
+ n: int, x: wp.array(dtype=float), c: wp.array(dtype=float), s1: wp.array(dtype=float), s2: wp.array(dtype=float)
219
+ ):
220
+ tid = wp.tid()
221
+
222
+ # plain store
223
+ c[tid] = x[tid] * 2.0
224
+
225
+ # atomic stores
226
+ wp.atomic_add(s1, 0, x[tid] * 3.0)
227
+ wp.atomic_sub(s2, 0, x[tid] * 2.0)
228
+
229
+
230
+ # tests that outputs from the forward pass are
231
+ # preserved by the backward pass, i.e.: stores
232
+ # are omitted during the forward reply
233
+ def test_preserve_outputs_grad(test, device):
234
+ n = 32
235
+
236
+ val = np.ones(n, dtype=np.float32)
237
+
238
+ x = wp.array(val, device=device, requires_grad=True)
239
+ c = wp.zeros_like(x)
240
+
241
+ s1 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
242
+ s2 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
243
+
244
+ tape = wp.Tape()
245
+ with tape:
246
+ wp.launch(preserve_outputs, dim=n, inputs=[n, x, c, s1, s2], device=device)
247
+
248
+ # ensure forward pass results are correct
249
+ assert_np_equal(x.numpy(), val)
250
+ assert_np_equal(c.numpy(), val * 2.0)
251
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
252
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
253
+
254
+ # run backward on first loss
255
+ tape.backward(loss=s1)
256
+
257
+ # ensure inputs, copy and sum are unchanged by backwards pass
258
+ assert_np_equal(x.numpy(), val)
259
+ assert_np_equal(c.numpy(), val * 2.0)
260
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
261
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
262
+
263
+ # ensure gradients are correct
264
+ assert_np_equal(tape.gradients[x].numpy(), 3.0 * val)
265
+
266
+ # run backward on second loss
267
+ tape.zero()
268
+ tape.backward(loss=s2)
269
+
270
+ assert_np_equal(x.numpy(), val)
271
+ assert_np_equal(c.numpy(), val * 2.0)
272
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
273
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
274
+
275
+ # ensure gradients are correct
276
+ assert_np_equal(tape.gradients[x].numpy(), -2.0 * val)
277
+
278
+
279
+ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
280
+ """
281
+ Checks that the gradient of the Warp kernel is correct by comparing it to the
282
+ numerical gradient computed using finite differences.
283
+ """
284
+
285
+ kernel = wp.Kernel(func=func, key=func_name)
286
+
287
+ def f(xs):
288
+ # call the kernel without taping for finite differences
289
+ wp_xs = [wp.array(xs[i], ndim=1, dtype=inputs[i].dtype, device=device) for i in range(len(inputs))]
290
+ output = wp.zeros(1, dtype=wp.float32, device=device)
291
+ wp.launch(kernel, dim=1, inputs=wp_xs, outputs=[output], device=device)
292
+ return output.numpy()[0]
293
+
294
+ # compute numerical gradient
295
+ numerical_grad = []
296
+ np_xs = []
297
+ for i in range(len(inputs)):
298
+ np_xs.append(inputs[i].numpy().flatten().copy())
299
+ numerical_grad.append(np.zeros_like(np_xs[-1]))
300
+ inputs[i].requires_grad = True
301
+
302
+ for i in range(len(np_xs)):
303
+ for j in range(len(np_xs[i])):
304
+ np_xs[i][j] += eps
305
+ y1 = f(np_xs)
306
+ np_xs[i][j] -= 2 * eps
307
+ y2 = f(np_xs)
308
+ np_xs[i][j] += eps
309
+ numerical_grad[i][j] = (y1 - y2) / (2 * eps)
310
+
311
+ # compute analytical gradient
312
+ tape = wp.Tape()
313
+ output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
314
+ with tape:
315
+ wp.launch(kernel, dim=1, inputs=inputs, outputs=[output], device=device)
316
+
317
+ tape.backward(loss=output)
318
+
319
+ # compare gradients
320
+ for i in range(len(inputs)):
321
+ grad = tape.gradients[inputs[i]]
322
+ assert_np_equal(grad.numpy(), numerical_grad[i], tol=tol)
323
+
324
+ tape.zero()
325
+
326
+
327
+ def test_vector_math_grad(test, device):
328
+ rng = np.random.default_rng(123)
329
+
330
+ # test unary operations
331
+ for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
332
+
333
+ def check_length(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
334
+ out[0] = wp.length(vs[0])
335
+
336
+ def check_length_sq(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
337
+ out[0] = wp.length_sq(vs[0])
338
+
339
+ def check_normalize(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
340
+ out[0] = wp.length_sq(wp.normalize(vs[0])) # compress to scalar output
341
+
342
+ # run the tests with 5 different random inputs
343
+ for _ in range(5):
344
+ x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
345
+ gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
346
+ gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
347
+ gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
348
+
349
+
350
+ def test_matrix_math_grad(test, device):
351
+ rng = np.random.default_rng(123)
352
+
353
+ # test unary operations
354
+ for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
355
+
356
+ def check_determinant(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
357
+ out[0] = wp.determinant(vs[0])
358
+
359
+ def check_trace(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
360
+ out[0] = wp.trace(vs[0])
361
+
362
+ # run the tests with 5 different random inputs
363
+ for _ in range(5):
364
+ x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
365
+ gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
366
+ gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
367
+
368
+
369
+ def test_3d_math_grad(test, device):
370
+ rng = np.random.default_rng(123)
371
+
372
+ # test binary operations
373
+ def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
374
+ out[0] = wp.length(wp.cross(vs[0], vs[1]))
375
+
376
+ def check_dot(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
377
+ out[0] = wp.dot(vs[0], vs[1])
378
+
379
+ def check_mat33(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
380
+ a = vs[0]
381
+ b = vs[1]
382
+ c = wp.cross(a, b)
383
+ m = wp.mat33(a[0], b[0], c[0], a[1], b[1], c[1], a[2], b[2], c[2])
384
+ out[0] = wp.determinant(m)
385
+
386
+ def check_trace_diagonal(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
387
+ a = vs[0]
388
+ b = vs[1]
389
+ c = wp.cross(a, b)
390
+ m = wp.mat33(
391
+ 1.0 / (a[0] + 10.0),
392
+ 0.0,
393
+ 0.0,
394
+ 0.0,
395
+ 1.0 / (b[1] + 10.0),
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 1.0 / (c[2] + 10.0),
400
+ )
401
+ out[0] = wp.trace(m)
402
+
403
+ def check_rot_rpy(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
404
+ v = vs[0]
405
+ q = wp.quat_rpy(v[0], v[1], v[2])
406
+ out[0] = wp.length(wp.quat_rotate(q, vs[1]))
407
+
408
+ def check_rot_axis_angle(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
409
+ v = wp.normalize(vs[0])
410
+ q = wp.quat_from_axis_angle(v, 0.5)
411
+ out[0] = wp.length(wp.quat_rotate(q, vs[1]))
412
+
413
+ def check_rot_quat_inv(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
414
+ v = vs[0]
415
+ q = wp.normalize(wp.quat(v[0], v[1], v[2], 1.0))
416
+ out[0] = wp.length(wp.quat_rotate_inv(q, vs[1]))
417
+
418
+ # run the tests with 5 different random inputs
419
+ for _ in range(5):
420
+ x = wp.array(
421
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
422
+ )
423
+ gradcheck(check_cross, "check_cross_3d", [x], device)
424
+ gradcheck(check_dot, "check_dot_3d", [x], device)
425
+ gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
426
+ gradcheck(check_trace_diagonal, "check_trace_diagonal_3d", [x], device)
427
+ gradcheck(check_rot_rpy, "check_rot_rpy_3d", [x], device)
428
+ gradcheck(check_rot_axis_angle, "check_rot_axis_angle_3d", [x], device)
429
+ gradcheck(check_rot_quat_inv, "check_rot_quat_inv_3d", [x], device)
430
+
431
+
432
+ def test_multi_valued_function_grad(test, device):
433
+ rng = np.random.default_rng(123)
434
+
435
+ @wp.func
436
+ def multi_valued(x: float, y: float, z: float):
437
+ return wp.sin(x), wp.cos(y) * z, wp.sqrt(wp.abs(z)) / wp.abs(x)
438
+
439
+ # test multi-valued functions
440
+ def check_multi_valued(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
441
+ tid = wp.tid()
442
+ v = vs[tid]
443
+ a, b, c = multi_valued(v[0], v[1], v[2])
444
+ out[tid] = a + b + c
445
+
446
+ # run the tests with 5 different random inputs
447
+ for _ in range(5):
448
+ x = wp.array(
449
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
450
+ )
451
+ gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
452
+
453
+
454
+ def test_mesh_grad(test, device):
455
+ pos = wp.array(
456
+ [
457
+ [0.0, 0.0, 0.0],
458
+ [1.0, 0.0, 0.0],
459
+ [0.0, 1.0, 0.0],
460
+ [0.0, 0.0, 1.0],
461
+ ],
462
+ dtype=wp.vec3,
463
+ device=device,
464
+ requires_grad=True,
465
+ )
466
+ indices = wp.array(
467
+ [0, 1, 2, 0, 2, 3, 0, 3, 1, 1, 3, 2],
468
+ dtype=wp.int32,
469
+ device=device,
470
+ )
471
+
472
+ mesh = wp.Mesh(points=pos, indices=indices)
473
+
474
+ @wp.func
475
+ def compute_triangle_area(mesh_id: wp.uint64, tri_id: int):
476
+ mesh = wp.mesh_get(mesh_id)
477
+ i, j, k = mesh.indices[tri_id * 3 + 0], mesh.indices[tri_id * 3 + 1], mesh.indices[tri_id * 3 + 2]
478
+ a = mesh.points[i]
479
+ b = mesh.points[j]
480
+ c = mesh.points[k]
481
+ return wp.length(wp.cross(b - a, c - a)) * 0.5
482
+
483
+ @wp.kernel
484
+ def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
485
+ wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
486
+
487
+ num_tris = int(len(indices) / 3)
488
+
489
+ # compute analytical gradient
490
+ tape = wp.Tape()
491
+ output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
492
+ with tape:
493
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
494
+
495
+ tape.backward(loss=output)
496
+
497
+ ad_grad = mesh.points.grad.numpy()
498
+
499
+ # compute finite differences
500
+ eps = 1e-3
501
+ pos_np = pos.numpy()
502
+ fd_grad = np.zeros_like(ad_grad)
503
+
504
+ for i in range(len(pos)):
505
+ for j in range(3):
506
+ pos_np[i, j] += eps
507
+ pos = wp.array(pos_np, dtype=wp.vec3, device=device)
508
+ mesh = wp.Mesh(points=pos, indices=indices)
509
+ output.zero_()
510
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
511
+ f1 = output.numpy()[0]
512
+ pos_np[i, j] -= 2 * eps
513
+ pos = wp.array(pos_np, dtype=wp.vec3, device=device)
514
+ mesh = wp.Mesh(points=pos, indices=indices)
515
+ output.zero_()
516
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
517
+ f2 = output.numpy()[0]
518
+ pos_np[i, j] += eps
519
+ fd_grad[i, j] = (f1 - f2) / (2 * eps)
520
+
521
+ assert np.allclose(ad_grad, fd_grad, atol=1e-3)
522
+
523
+
524
+ @wp.func
525
+ def name_clash(a: float, b: float) -> float:
526
+ return a + b
527
+
528
+
529
+ @wp.func_grad(name_clash)
530
+ def adj_name_clash(a: float, b: float, adj_ret: float):
531
+ # names `adj_a` and `adj_b` must not clash with function args of generated function
532
+ adj_a = 0.0
533
+ adj_b = 0.0
534
+ if a < 0.0:
535
+ adj_a = adj_ret
536
+ if b > 0.0:
537
+ adj_b = adj_ret
538
+
539
+ wp.adjoint[a] += adj_a
540
+ wp.adjoint[b] += adj_b
541
+
542
+
543
+ @wp.kernel
544
+ def name_clash_kernel(
545
+ input_a: wp.array(dtype=float),
546
+ input_b: wp.array(dtype=float),
547
+ output: wp.array(dtype=float),
548
+ ):
549
+ tid = wp.tid()
550
+ output[tid] = name_clash(input_a[tid], input_b[tid])
551
+
552
+
553
+ def test_name_clash(test, device):
554
+ # tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
555
+ with wp.ScopedDevice(device):
556
+ input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
557
+ input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
558
+ output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
559
+
560
+ tape = wp.Tape()
561
+ with tape:
562
+ wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
563
+
564
+ tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
565
+
566
+ assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
567
+ assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
568
+
569
+
570
+ @wp.struct
571
+ class NestedStruct:
572
+ v: wp.vec2
573
+
574
+
575
+ @wp.struct
576
+ class ParentStruct:
577
+ a: float
578
+ n: NestedStruct
579
+
580
+
581
+ @wp.func
582
+ def noop(a: Any):
583
+ pass
584
+
585
+
586
+ @wp.func
587
+ def sum2(v: wp.vec2):
588
+ return v[0] + v[1]
589
+
590
+
591
+ @wp.kernel
592
+ def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
593
+ tid = wp.tid()
594
+
595
+ p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
596
+
597
+ # test that we are not losing gradients when accessing attributes
598
+ noop(p.a)
599
+ noop(p.n)
600
+ noop(p.n.v)
601
+
602
+ res[tid] = p.a + sum2(p.n.v)
603
+
604
+
605
+ def test_struct_attribute_gradient(test, device):
606
+ with wp.ScopedDevice(device):
607
+ src = wp.array([1], dtype=float, requires_grad=True)
608
+ res = wp.empty_like(src)
609
+
610
+ tape = wp.Tape()
611
+ with tape:
612
+ wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
613
+
614
+ res.grad.fill_(1.0)
615
+ tape.backward()
616
+
617
+ test.assertEqual(src.grad.numpy()[0], 5.0)
618
+
619
+
620
+ @wp.kernel
621
+ def copy_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
622
+ tid = wp.tid()
623
+ ai = a[tid]
624
+ bi = ai
625
+ b[tid] = bi
626
+
627
+
628
+ def test_copy(test, device):
629
+ with wp.ScopedDevice(device):
630
+ a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
631
+ b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
632
+
633
+ wp.launch(copy_kernel, 1, inputs=[a, b])
634
+
635
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
636
+ wp.launch(copy_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
637
+
638
+ assert_np_equal(a.grad.numpy(), np.array([1.0, 1.0, 1.0]))
639
+
640
+
641
+ @wp.kernel
642
+ def aliasing_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
643
+ tid = wp.tid()
644
+ x = a[tid]
645
+
646
+ y = x
647
+ if y > 0.0:
648
+ y = x * x
649
+ else:
650
+ y = x * x * x
651
+
652
+ b[tid] = y
653
+
654
+
655
+ def test_aliasing(test, device):
656
+ with wp.ScopedDevice(device):
657
+ a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
658
+ b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
659
+
660
+ wp.launch(aliasing_kernel, 1, inputs=[a, b])
661
+
662
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
663
+ wp.launch(aliasing_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
664
+
665
+ assert_np_equal(a.grad.numpy(), np.array([3.0, 4.0, 6.0]))
666
+
667
+
668
+ @wp.kernel
669
+ def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
670
+ tid = wp.tid()
671
+ y[tid] = x[tid] ** 2.0
672
+
673
+
674
+ @wp.kernel
675
+ def square_slice_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), row_idx: int):
676
+ tid = wp.tid()
677
+ x_slice = x[row_idx]
678
+ y_slice = y[row_idx]
679
+ y_slice[tid] = x_slice[tid] ** 2.0
680
+
681
+
682
+ @wp.kernel
683
+ def square_slice_3d_1d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_idx: int):
684
+ i, j = wp.tid()
685
+ x_slice = x[slice_idx]
686
+ y_slice = y[slice_idx]
687
+ y_slice[i, j] = x_slice[i, j] ** 2.0
688
+
689
+
690
+ @wp.kernel
691
+ def square_slice_3d_2d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_i: int, slice_j: int):
692
+ tid = wp.tid()
693
+ x_slice = x[slice_i, slice_j]
694
+ y_slice = y[slice_i, slice_j]
695
+ y_slice[tid] = x_slice[tid] ** 2.0
696
+
697
+
698
+ def test_gradient_internal(test, device):
699
+ with wp.ScopedDevice(device):
700
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
701
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
702
+
703
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
704
+
705
+ # use internal gradients (.grad), adj_inputs are None
706
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=float)
707
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[None, None])
708
+
709
+ assert_np_equal(a.grad.numpy(), np.array([2.0, 4.0, 6.0]))
710
+
711
+
712
+ def test_gradient_external(test, device):
713
+ with wp.ScopedDevice(device):
714
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=False)
715
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=False)
716
+
717
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
718
+
719
+ # use external gradients passed in adj_inputs
720
+ a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
721
+ b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
722
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
723
+
724
+ assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0]))
725
+
726
+
727
+ def test_gradient_precedence(test, device):
728
+ with wp.ScopedDevice(device):
729
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
730
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
731
+
732
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
733
+
734
+ # if both internal and external gradients are present, the external one takes precedence,
735
+ # because it's explicitly passed by the user in adj_inputs
736
+ a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
737
+ b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
738
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
739
+
740
+ assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0])) # used
741
+ assert_np_equal(a.grad.numpy(), np.array([0.0, 0.0, 0.0])) # unused
742
+
743
+
744
+ def test_gradient_slice_2d(test, device):
745
+ with wp.ScopedDevice(device):
746
+ a = wp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=float, requires_grad=True)
747
+ b = wp.zeros_like(a, requires_grad=False)
748
+ b.grad = wp.ones_like(a, requires_grad=False)
749
+
750
+ wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1])
751
+
752
+ # use internal gradients (.grad), adj_inputs are None
753
+ wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1])
754
+
755
+ assert_np_equal(a.grad.numpy(), np.array([[0.0, 0.0], [6.0, 8.0], [0.0, 0.0]]))
756
+
757
+
758
+ def test_gradient_slice_3d_1d(test, device):
759
+ with wp.ScopedDevice(device):
760
+ data = [
761
+ [
762
+ [1, 2, 3],
763
+ [4, 5, 6],
764
+ [7, 8, 9],
765
+ ],
766
+ [
767
+ [11, 12, 13],
768
+ [14, 15, 16],
769
+ [17, 18, 19],
770
+ ],
771
+ [
772
+ [21, 22, 23],
773
+ [24, 25, 26],
774
+ [27, 28, 29],
775
+ ],
776
+ ]
777
+ a = wp.array(data, dtype=float, requires_grad=True)
778
+ b = wp.zeros_like(a, requires_grad=False)
779
+ b.grad = wp.ones_like(a, requires_grad=False)
780
+
781
+ wp.launch(square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1])
782
+
783
+ # use internal gradients (.grad), adj_inputs are None
784
+ wp.launch(
785
+ square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1]
786
+ )
787
+
788
+ expected_grad = [
789
+ [
790
+ [0, 0, 0],
791
+ [0, 0, 0],
792
+ [0, 0, 0],
793
+ ],
794
+ [
795
+ [11 * 2, 12 * 2, 13 * 2],
796
+ [14 * 2, 15 * 2, 16 * 2],
797
+ [17 * 2, 18 * 2, 19 * 2],
798
+ ],
799
+ [
800
+ [0, 0, 0],
801
+ [0, 0, 0],
802
+ [0, 0, 0],
803
+ ],
804
+ ]
805
+ assert_np_equal(a.grad.numpy(), np.array(expected_grad))
806
+
807
+
808
+ def test_gradient_slice_3d_2d(test, device):
809
+ with wp.ScopedDevice(device):
810
+ data = [
811
+ [
812
+ [1, 2, 3],
813
+ [4, 5, 6],
814
+ [7, 8, 9],
815
+ ],
816
+ [
817
+ [11, 12, 13],
818
+ [14, 15, 16],
819
+ [17, 18, 19],
820
+ ],
821
+ [
822
+ [21, 22, 23],
823
+ [24, 25, 26],
824
+ [27, 28, 29],
825
+ ],
826
+ ]
827
+ a = wp.array(data, dtype=float, requires_grad=True)
828
+ b = wp.zeros_like(a, requires_grad=False)
829
+ b.grad = wp.ones_like(a, requires_grad=False)
830
+
831
+ wp.launch(square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1])
832
+
833
+ # use internal gradients (.grad), adj_inputs are None
834
+ wp.launch(
835
+ square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1], adjoint=True, adj_inputs=[None, None, 1, 1]
836
+ )
837
+
838
+ expected_grad = [
839
+ [
840
+ [0, 0, 0],
841
+ [0, 0, 0],
842
+ [0, 0, 0],
843
+ ],
844
+ [
845
+ [0, 0, 0],
846
+ [14 * 2, 15 * 2, 16 * 2],
847
+ [0, 0, 0],
848
+ ],
849
+ [
850
+ [0, 0, 0],
851
+ [0, 0, 0],
852
+ [0, 0, 0],
853
+ ],
854
+ ]
855
+ assert_np_equal(a.grad.numpy(), np.array(expected_grad))
856
+
857
+
858
+ devices = get_test_devices()
859
+
860
+
861
+ class TestGrad(unittest.TestCase):
862
+ pass
863
+
864
+
865
+ # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
866
+ add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
867
+ add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
868
+ add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
869
+ add_function_test(
870
+ TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_selected_cuda_test_devices()
871
+ )
872
+ add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
873
+ add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
874
+ add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
875
+ add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
876
+ add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
877
+ add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
878
+ add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
879
+ add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
880
+ add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
881
+ add_function_test(TestGrad, "test_copy", test_copy, devices=devices)
882
+ add_function_test(TestGrad, "test_aliasing", test_aliasing, devices=devices)
883
+ add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices)
884
+ add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices)
885
+ add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices)
886
+ add_function_test(TestGrad, "test_gradient_slice_2d", test_gradient_slice_2d, devices=devices)
887
+ add_function_test(TestGrad, "test_gradient_slice_3d_1d", test_gradient_slice_3d_1d, devices=devices)
888
+ add_function_test(TestGrad, "test_gradient_slice_3d_2d", test_gradient_slice_3d_2d, devices=devices)
889
+
890
+
891
+ if __name__ == "__main__":
892
+ wp.clear_kernel_cache()
893
+ unittest.main(verbosity=2, failfast=False)