warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,893 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+
25
+ @wp.kernel
26
+ def scalar_grad(x: wp.array(dtype=float), y: wp.array(dtype=float)):
27
+ y[0] = x[0] ** 2.0
28
+
29
+
30
+ def test_scalar_grad(test, device):
31
+ x = wp.array([3.0], dtype=float, device=device, requires_grad=True)
32
+ y = wp.zeros_like(x)
33
+
34
+ tape = wp.Tape()
35
+ with tape:
36
+ wp.launch(scalar_grad, dim=1, inputs=[x, y], device=device)
37
+
38
+ tape.backward(y)
39
+
40
+ assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
41
+
42
+
43
+ @wp.kernel
44
+ def for_loop_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
45
+ sum = float(0.0)
46
+
47
+ for i in range(n):
48
+ sum = sum + x[i] * 2.0
49
+
50
+ s[0] = sum
51
+
52
+
53
+ def test_for_loop_grad(test, device):
54
+ n = 32
55
+ val = np.ones(n, dtype=np.float32)
56
+
57
+ x = wp.array(val, device=device, requires_grad=True)
58
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
59
+
60
+ tape = wp.Tape()
61
+ with tape:
62
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
63
+
64
+ # ensure forward pass outputs correct
65
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
66
+
67
+ tape.backward(loss=sum)
68
+
69
+ # ensure forward pass outputs persist
70
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
71
+ # ensure gradients correct
72
+ assert_np_equal(tape.gradients[x].numpy(), 2.0 * val)
73
+
74
+
75
+ def test_for_loop_graph_grad(test, device):
76
+ wp.load_module(device=device)
77
+
78
+ n = 32
79
+ val = np.ones(n, dtype=np.float32)
80
+
81
+ x = wp.array(val, device=device, requires_grad=True)
82
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
83
+
84
+ wp.capture_begin(device, force_module_load=False)
85
+ try:
86
+ tape = wp.Tape()
87
+ with tape:
88
+ wp.launch(for_loop_grad, dim=1, inputs=[n, x, sum], device=device)
89
+
90
+ tape.backward(loss=sum)
91
+ finally:
92
+ graph = wp.capture_end(device)
93
+
94
+ wp.capture_launch(graph)
95
+ wp.synchronize_device(device)
96
+
97
+ # ensure forward pass outputs persist
98
+ assert_np_equal(sum.numpy(), 2.0 * np.sum(x.numpy()))
99
+ # ensure gradients correct
100
+ assert_np_equal(x.grad.numpy(), 2.0 * val)
101
+
102
+ wp.capture_launch(graph)
103
+ wp.synchronize_device(device)
104
+
105
+
106
+ @wp.kernel
107
+ def for_loop_nested_if_grad(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
108
+ sum = float(0.0)
109
+
110
+ for i in range(n):
111
+ if i < 16:
112
+ if i < 8:
113
+ sum = sum + x[i] * 2.0
114
+ else:
115
+ sum = sum + x[i] * 4.0
116
+ else:
117
+ if i < 24:
118
+ sum = sum + x[i] * 6.0
119
+ else:
120
+ sum = sum + x[i] * 8.0
121
+
122
+ s[0] = sum
123
+
124
+
125
+ def test_for_loop_nested_if_grad(test, device):
126
+ n = 32
127
+ val = np.ones(n, dtype=np.float32)
128
+ # fmt: off
129
+ expected_val = [
130
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
131
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
132
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
133
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
134
+ ]
135
+ expected_grad = [
136
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
137
+ 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0,
138
+ 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0,
139
+ 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0,
140
+ ]
141
+ # fmt: on
142
+
143
+ x = wp.array(val, device=device, requires_grad=True)
144
+ sum = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
145
+
146
+ tape = wp.Tape()
147
+ with tape:
148
+ wp.launch(for_loop_nested_if_grad, dim=1, inputs=[n, x, sum], device=device)
149
+
150
+ assert_np_equal(sum.numpy(), np.sum(expected_val))
151
+
152
+ tape.backward(loss=sum)
153
+
154
+ assert_np_equal(sum.numpy(), np.sum(expected_val))
155
+ assert_np_equal(tape.gradients[x].numpy(), np.array(expected_grad))
156
+
157
+
158
+ @wp.kernel
159
+ def for_loop_grad_nested(n: int, x: wp.array(dtype=float), s: wp.array(dtype=float)):
160
+ sum = float(0.0)
161
+
162
+ for i in range(n):
163
+ for j in range(n):
164
+ sum = sum + x[i * n + j] * float(i * n + j) + 1.0
165
+
166
+ s[0] = sum
167
+
168
+
169
+ def test_for_loop_nested_for_grad(test, device):
170
+ x = wp.zeros(9, dtype=float, device=device, requires_grad=True)
171
+ s = wp.zeros(1, dtype=float, device=device, requires_grad=True)
172
+
173
+ tape = wp.Tape()
174
+ with tape:
175
+ wp.launch(for_loop_grad_nested, dim=1, inputs=[3, x, s], device=device)
176
+
177
+ tape.backward(s)
178
+
179
+ assert_np_equal(s.numpy(), np.array([9.0]))
180
+ assert_np_equal(tape.gradients[x].numpy(), np.arange(0.0, 9.0, 1.0))
181
+
182
+
183
+ # differentiating thought most while loops is not supported
184
+ # since doing things like i = i + 1 breaks adjointing
185
+
186
+ # @wp.kernel
187
+ # def while_loop_grad(n: int,
188
+ # x: wp.array(dtype=float),
189
+ # c: wp.array(dtype=int),
190
+ # s: wp.array(dtype=float)):
191
+
192
+ # tid = wp.tid()
193
+
194
+ # while i < n:
195
+ # s[0] = s[0] + x[i]*2.0
196
+ # i = i + 1
197
+
198
+
199
+ # def test_while_loop_grad(test, device):
200
+
201
+ # n = 32
202
+ # x = wp.array(np.ones(n, dtype=np.float32), device=device, requires_grad=True)
203
+ # c = wp.zeros(1, dtype=int, device=device)
204
+ # sum = wp.zeros(1, dtype=wp.float32, device=device)
205
+
206
+ # tape = wp.Tape()
207
+ # with tape:
208
+ # wp.launch(while_loop_grad, dim=1, inputs=[n, x, c, sum], device=device)
209
+
210
+ # tape.backward(loss=sum)
211
+
212
+ # assert_np_equal(sum.numpy(), 2.0*np.sum(x.numpy()))
213
+ # assert_np_equal(tape.gradients[x].numpy(), 2.0*np.ones_like(x.numpy()))
214
+
215
+
216
+ @wp.kernel
217
+ def preserve_outputs(
218
+ n: int, x: wp.array(dtype=float), c: wp.array(dtype=float), s1: wp.array(dtype=float), s2: wp.array(dtype=float)
219
+ ):
220
+ tid = wp.tid()
221
+
222
+ # plain store
223
+ c[tid] = x[tid] * 2.0
224
+
225
+ # atomic stores
226
+ wp.atomic_add(s1, 0, x[tid] * 3.0)
227
+ wp.atomic_sub(s2, 0, x[tid] * 2.0)
228
+
229
+
230
+ # tests that outputs from the forward pass are
231
+ # preserved by the backward pass, i.e.: stores
232
+ # are omitted during the forward reply
233
+ def test_preserve_outputs_grad(test, device):
234
+ n = 32
235
+
236
+ val = np.ones(n, dtype=np.float32)
237
+
238
+ x = wp.array(val, device=device, requires_grad=True)
239
+ c = wp.zeros_like(x)
240
+
241
+ s1 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
242
+ s2 = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
243
+
244
+ tape = wp.Tape()
245
+ with tape:
246
+ wp.launch(preserve_outputs, dim=n, inputs=[n, x, c, s1, s2], device=device)
247
+
248
+ # ensure forward pass results are correct
249
+ assert_np_equal(x.numpy(), val)
250
+ assert_np_equal(c.numpy(), val * 2.0)
251
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
252
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
253
+
254
+ # run backward on first loss
255
+ tape.backward(loss=s1)
256
+
257
+ # ensure inputs, copy and sum are unchanged by backwards pass
258
+ assert_np_equal(x.numpy(), val)
259
+ assert_np_equal(c.numpy(), val * 2.0)
260
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
261
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
262
+
263
+ # ensure gradients are correct
264
+ assert_np_equal(tape.gradients[x].numpy(), 3.0 * val)
265
+
266
+ # run backward on second loss
267
+ tape.zero()
268
+ tape.backward(loss=s2)
269
+
270
+ assert_np_equal(x.numpy(), val)
271
+ assert_np_equal(c.numpy(), val * 2.0)
272
+ assert_np_equal(s1.numpy(), np.array(3.0 * n))
273
+ assert_np_equal(s2.numpy(), np.array(-2.0 * n))
274
+
275
+ # ensure gradients are correct
276
+ assert_np_equal(tape.gradients[x].numpy(), -2.0 * val)
277
+
278
+
279
+ def gradcheck(func, func_name, inputs, device, eps=1e-4, tol=1e-2):
280
+ """
281
+ Checks that the gradient of the Warp kernel is correct by comparing it to the
282
+ numerical gradient computed using finite differences.
283
+ """
284
+
285
+ kernel = wp.Kernel(func=func, key=func_name)
286
+
287
+ def f(xs):
288
+ # call the kernel without taping for finite differences
289
+ wp_xs = [wp.array(xs[i], ndim=1, dtype=inputs[i].dtype, device=device) for i in range(len(inputs))]
290
+ output = wp.zeros(1, dtype=wp.float32, device=device)
291
+ wp.launch(kernel, dim=1, inputs=wp_xs, outputs=[output], device=device)
292
+ return output.numpy()[0]
293
+
294
+ # compute numerical gradient
295
+ numerical_grad = []
296
+ np_xs = []
297
+ for i in range(len(inputs)):
298
+ np_xs.append(inputs[i].numpy().flatten().copy())
299
+ numerical_grad.append(np.zeros_like(np_xs[-1]))
300
+ inputs[i].requires_grad = True
301
+
302
+ for i in range(len(np_xs)):
303
+ for j in range(len(np_xs[i])):
304
+ np_xs[i][j] += eps
305
+ y1 = f(np_xs)
306
+ np_xs[i][j] -= 2 * eps
307
+ y2 = f(np_xs)
308
+ np_xs[i][j] += eps
309
+ numerical_grad[i][j] = (y1 - y2) / (2 * eps)
310
+
311
+ # compute analytical gradient
312
+ tape = wp.Tape()
313
+ output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
314
+ with tape:
315
+ wp.launch(kernel, dim=1, inputs=inputs, outputs=[output], device=device)
316
+
317
+ tape.backward(loss=output)
318
+
319
+ # compare gradients
320
+ for i in range(len(inputs)):
321
+ grad = tape.gradients[inputs[i]]
322
+ assert_np_equal(grad.numpy(), numerical_grad[i], tol=tol)
323
+
324
+ tape.zero()
325
+
326
+
327
+ def test_vector_math_grad(test, device):
328
+ rng = np.random.default_rng(123)
329
+
330
+ # test unary operations
331
+ for dim, vec_type in [(2, wp.vec2), (3, wp.vec3), (4, wp.vec4), (4, wp.quat)]:
332
+
333
+ def check_length(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
334
+ out[0] = wp.length(vs[0])
335
+
336
+ def check_length_sq(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
337
+ out[0] = wp.length_sq(vs[0])
338
+
339
+ def check_normalize(vs: wp.array(dtype=vec_type), out: wp.array(dtype=float)):
340
+ out[0] = wp.length_sq(wp.normalize(vs[0])) # compress to scalar output
341
+
342
+ # run the tests with 5 different random inputs
343
+ for _ in range(5):
344
+ x = wp.array(rng.random(size=(1, dim), dtype=np.float32), dtype=vec_type, device=device)
345
+ gradcheck(check_length, f"check_length_{vec_type.__name__}", [x], device)
346
+ gradcheck(check_length_sq, f"check_length_sq_{vec_type.__name__}", [x], device)
347
+ gradcheck(check_normalize, f"check_normalize_{vec_type.__name__}", [x], device)
348
+
349
+
350
+ def test_matrix_math_grad(test, device):
351
+ rng = np.random.default_rng(123)
352
+
353
+ # test unary operations
354
+ for dim, mat_type in [(2, wp.mat22), (3, wp.mat33), (4, wp.mat44)]:
355
+
356
+ def check_determinant(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
357
+ out[0] = wp.determinant(vs[0])
358
+
359
+ def check_trace(vs: wp.array(dtype=mat_type), out: wp.array(dtype=float)):
360
+ out[0] = wp.trace(vs[0])
361
+
362
+ # run the tests with 5 different random inputs
363
+ for _ in range(5):
364
+ x = wp.array(rng.random(size=(1, dim, dim), dtype=np.float32), ndim=1, dtype=mat_type, device=device)
365
+ gradcheck(check_determinant, f"check_length_{mat_type.__name__}", [x], device)
366
+ gradcheck(check_trace, f"check_length_sq_{mat_type.__name__}", [x], device)
367
+
368
+
369
+ def test_3d_math_grad(test, device):
370
+ rng = np.random.default_rng(123)
371
+
372
+ # test binary operations
373
+ def check_cross(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
374
+ out[0] = wp.length(wp.cross(vs[0], vs[1]))
375
+
376
+ def check_dot(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
377
+ out[0] = wp.dot(vs[0], vs[1])
378
+
379
+ def check_mat33(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
380
+ a = vs[0]
381
+ b = vs[1]
382
+ c = wp.cross(a, b)
383
+ m = wp.mat33(a[0], b[0], c[0], a[1], b[1], c[1], a[2], b[2], c[2])
384
+ out[0] = wp.determinant(m)
385
+
386
+ def check_trace_diagonal(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
387
+ a = vs[0]
388
+ b = vs[1]
389
+ c = wp.cross(a, b)
390
+ m = wp.mat33(
391
+ 1.0 / (a[0] + 10.0),
392
+ 0.0,
393
+ 0.0,
394
+ 0.0,
395
+ 1.0 / (b[1] + 10.0),
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 1.0 / (c[2] + 10.0),
400
+ )
401
+ out[0] = wp.trace(m)
402
+
403
+ def check_rot_rpy(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
404
+ v = vs[0]
405
+ q = wp.quat_rpy(v[0], v[1], v[2])
406
+ out[0] = wp.length(wp.quat_rotate(q, vs[1]))
407
+
408
+ def check_rot_axis_angle(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
409
+ v = wp.normalize(vs[0])
410
+ q = wp.quat_from_axis_angle(v, 0.5)
411
+ out[0] = wp.length(wp.quat_rotate(q, vs[1]))
412
+
413
+ def check_rot_quat_inv(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
414
+ v = vs[0]
415
+ q = wp.normalize(wp.quat(v[0], v[1], v[2], 1.0))
416
+ out[0] = wp.length(wp.quat_rotate_inv(q, vs[1]))
417
+
418
+ # run the tests with 5 different random inputs
419
+ for _ in range(5):
420
+ x = wp.array(
421
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
422
+ )
423
+ gradcheck(check_cross, "check_cross_3d", [x], device)
424
+ gradcheck(check_dot, "check_dot_3d", [x], device)
425
+ gradcheck(check_mat33, "check_mat33_3d", [x], device, eps=2e-2)
426
+ gradcheck(check_trace_diagonal, "check_trace_diagonal_3d", [x], device)
427
+ gradcheck(check_rot_rpy, "check_rot_rpy_3d", [x], device)
428
+ gradcheck(check_rot_axis_angle, "check_rot_axis_angle_3d", [x], device)
429
+ gradcheck(check_rot_quat_inv, "check_rot_quat_inv_3d", [x], device)
430
+
431
+
432
+ def test_multi_valued_function_grad(test, device):
433
+ rng = np.random.default_rng(123)
434
+
435
+ @wp.func
436
+ def multi_valued(x: float, y: float, z: float):
437
+ return wp.sin(x), wp.cos(y) * z, wp.sqrt(wp.abs(z)) / wp.abs(x)
438
+
439
+ # test multi-valued functions
440
+ def check_multi_valued(vs: wp.array(dtype=wp.vec3), out: wp.array(dtype=float)):
441
+ tid = wp.tid()
442
+ v = vs[tid]
443
+ a, b, c = multi_valued(v[0], v[1], v[2])
444
+ out[tid] = a + b + c
445
+
446
+ # run the tests with 5 different random inputs
447
+ for _ in range(5):
448
+ x = wp.array(
449
+ rng.standard_normal(size=(2, 3), dtype=np.float32), dtype=wp.vec3, device=device, requires_grad=True
450
+ )
451
+ gradcheck(check_multi_valued, "check_multi_valued_3d", [x], device)
452
+
453
+
454
+ def test_mesh_grad(test, device):
455
+ pos = wp.array(
456
+ [
457
+ [0.0, 0.0, 0.0],
458
+ [1.0, 0.0, 0.0],
459
+ [0.0, 1.0, 0.0],
460
+ [0.0, 0.0, 1.0],
461
+ ],
462
+ dtype=wp.vec3,
463
+ device=device,
464
+ requires_grad=True,
465
+ )
466
+ indices = wp.array(
467
+ [0, 1, 2, 0, 2, 3, 0, 3, 1, 1, 3, 2],
468
+ dtype=wp.int32,
469
+ device=device,
470
+ )
471
+
472
+ mesh = wp.Mesh(points=pos, indices=indices)
473
+
474
+ @wp.func
475
+ def compute_triangle_area(mesh_id: wp.uint64, tri_id: int):
476
+ mesh = wp.mesh_get(mesh_id)
477
+ i, j, k = mesh.indices[tri_id * 3 + 0], mesh.indices[tri_id * 3 + 1], mesh.indices[tri_id * 3 + 2]
478
+ a = mesh.points[i]
479
+ b = mesh.points[j]
480
+ c = mesh.points[k]
481
+ return wp.length(wp.cross(b - a, c - a)) * 0.5
482
+
483
+ @wp.kernel
484
+ def compute_area(mesh_id: wp.uint64, out: wp.array(dtype=wp.float32)):
485
+ wp.atomic_add(out, 0, compute_triangle_area(mesh_id, wp.tid()))
486
+
487
+ num_tris = int(len(indices) / 3)
488
+
489
+ # compute analytical gradient
490
+ tape = wp.Tape()
491
+ output = wp.zeros(1, dtype=wp.float32, device=device, requires_grad=True)
492
+ with tape:
493
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
494
+
495
+ tape.backward(loss=output)
496
+
497
+ ad_grad = mesh.points.grad.numpy()
498
+
499
+ # compute finite differences
500
+ eps = 1e-3
501
+ pos_np = pos.numpy()
502
+ fd_grad = np.zeros_like(ad_grad)
503
+
504
+ for i in range(len(pos)):
505
+ for j in range(3):
506
+ pos_np[i, j] += eps
507
+ pos = wp.array(pos_np, dtype=wp.vec3, device=device)
508
+ mesh = wp.Mesh(points=pos, indices=indices)
509
+ output.zero_()
510
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
511
+ f1 = output.numpy()[0]
512
+ pos_np[i, j] -= 2 * eps
513
+ pos = wp.array(pos_np, dtype=wp.vec3, device=device)
514
+ mesh = wp.Mesh(points=pos, indices=indices)
515
+ output.zero_()
516
+ wp.launch(compute_area, dim=num_tris, inputs=[mesh.id], outputs=[output], device=device)
517
+ f2 = output.numpy()[0]
518
+ pos_np[i, j] += eps
519
+ fd_grad[i, j] = (f1 - f2) / (2 * eps)
520
+
521
+ np.testing.assert_allclose(ad_grad, fd_grad, atol=1e-3)
522
+
523
+
524
+ @wp.func
525
+ def name_clash(a: float, b: float) -> float:
526
+ return a + b
527
+
528
+
529
+ @wp.func_grad(name_clash)
530
+ def adj_name_clash(a: float, b: float, adj_ret: float):
531
+ # names `adj_a` and `adj_b` must not clash with function args of generated function
532
+ adj_a = 0.0
533
+ adj_b = 0.0
534
+ if a < 0.0:
535
+ adj_a = adj_ret
536
+ if b > 0.0:
537
+ adj_b = adj_ret
538
+
539
+ wp.adjoint[a] += adj_a
540
+ wp.adjoint[b] += adj_b
541
+
542
+
543
+ @wp.kernel
544
+ def name_clash_kernel(
545
+ input_a: wp.array(dtype=float),
546
+ input_b: wp.array(dtype=float),
547
+ output: wp.array(dtype=float),
548
+ ):
549
+ tid = wp.tid()
550
+ output[tid] = name_clash(input_a[tid], input_b[tid])
551
+
552
+
553
+ def test_name_clash(test, device):
554
+ # tests that no name clashes occur when variable names such as `adj_a` are used in custom gradient code
555
+ with wp.ScopedDevice(device):
556
+ input_a = wp.array([1.0, -2.0, 3.0], dtype=wp.float32, requires_grad=True)
557
+ input_b = wp.array([4.0, 5.0, -6.0], dtype=wp.float32, requires_grad=True)
558
+ output = wp.zeros(3, dtype=wp.float32, requires_grad=True)
559
+
560
+ tape = wp.Tape()
561
+ with tape:
562
+ wp.launch(name_clash_kernel, dim=len(input_a), inputs=[input_a, input_b], outputs=[output])
563
+
564
+ tape.backward(grads={output: wp.array(np.ones(len(input_a), dtype=np.float32))})
565
+
566
+ assert_np_equal(input_a.grad.numpy(), np.array([0.0, 1.0, 0.0]))
567
+ assert_np_equal(input_b.grad.numpy(), np.array([1.0, 1.0, 0.0]))
568
+
569
+
570
+ @wp.struct
571
+ class NestedStruct:
572
+ v: wp.vec2
573
+
574
+
575
+ @wp.struct
576
+ class ParentStruct:
577
+ a: float
578
+ n: NestedStruct
579
+
580
+
581
+ @wp.func
582
+ def noop(a: Any):
583
+ pass
584
+
585
+
586
+ @wp.func
587
+ def sum2(v: wp.vec2):
588
+ return v[0] + v[1]
589
+
590
+
591
+ @wp.kernel
592
+ def test_struct_attribute_gradient_kernel(src: wp.array(dtype=float), res: wp.array(dtype=float)):
593
+ tid = wp.tid()
594
+
595
+ p = ParentStruct(src[tid], NestedStruct(wp.vec2(2.0 * src[tid])))
596
+
597
+ # test that we are not losing gradients when accessing attributes
598
+ noop(p.a)
599
+ noop(p.n)
600
+ noop(p.n.v)
601
+
602
+ res[tid] = p.a + sum2(p.n.v)
603
+
604
+
605
+ def test_struct_attribute_gradient(test, device):
606
+ with wp.ScopedDevice(device):
607
+ src = wp.array([1], dtype=float, requires_grad=True)
608
+ res = wp.empty_like(src)
609
+
610
+ tape = wp.Tape()
611
+ with tape:
612
+ wp.launch(test_struct_attribute_gradient_kernel, dim=1, inputs=[src, res])
613
+
614
+ res.grad.fill_(1.0)
615
+ tape.backward()
616
+
617
+ test.assertEqual(src.grad.numpy()[0], 5.0)
618
+
619
+
620
+ @wp.kernel
621
+ def copy_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
622
+ tid = wp.tid()
623
+ ai = a[tid]
624
+ bi = ai
625
+ b[tid] = bi
626
+
627
+
628
+ def test_copy(test, device):
629
+ with wp.ScopedDevice(device):
630
+ a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
631
+ b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
632
+
633
+ wp.launch(copy_kernel, 1, inputs=[a, b])
634
+
635
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
636
+ wp.launch(copy_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
637
+
638
+ assert_np_equal(a.grad.numpy(), np.array([1.0, 1.0, 1.0]))
639
+
640
+
641
+ @wp.kernel
642
+ def aliasing_kernel(a: wp.array(dtype=wp.float32), b: wp.array(dtype=wp.float32)):
643
+ tid = wp.tid()
644
+ x = a[tid]
645
+
646
+ y = x
647
+ if y > 0.0:
648
+ y = x * x
649
+ else:
650
+ y = x * x * x
651
+
652
+ b[tid] = y
653
+
654
+
655
+ def test_aliasing(test, device):
656
+ with wp.ScopedDevice(device):
657
+ a = wp.array([-1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True)
658
+ b = wp.array([0.0, 0.0, 0.0], dtype=wp.float32, requires_grad=True)
659
+
660
+ wp.launch(aliasing_kernel, 1, inputs=[a, b])
661
+
662
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=wp.float32)
663
+ wp.launch(aliasing_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
664
+
665
+ assert_np_equal(a.grad.numpy(), np.array([3.0, 4.0, 6.0]))
666
+
667
+
668
+ @wp.kernel
669
+ def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
670
+ tid = wp.tid()
671
+ y[tid] = x[tid] ** 2.0
672
+
673
+
674
+ @wp.kernel
675
+ def square_slice_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), row_idx: int):
676
+ tid = wp.tid()
677
+ x_slice = x[row_idx]
678
+ y_slice = y[row_idx]
679
+ y_slice[tid] = x_slice[tid] ** 2.0
680
+
681
+
682
+ @wp.kernel
683
+ def square_slice_3d_1d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_idx: int):
684
+ i, j = wp.tid()
685
+ x_slice = x[slice_idx]
686
+ y_slice = y[slice_idx]
687
+ y_slice[i, j] = x_slice[i, j] ** 2.0
688
+
689
+
690
+ @wp.kernel
691
+ def square_slice_3d_2d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_i: int, slice_j: int):
692
+ tid = wp.tid()
693
+ x_slice = x[slice_i, slice_j]
694
+ y_slice = y[slice_i, slice_j]
695
+ y_slice[tid] = x_slice[tid] ** 2.0
696
+
697
+
698
+ def test_gradient_internal(test, device):
699
+ with wp.ScopedDevice(device):
700
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
701
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
702
+
703
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
704
+
705
+ # use internal gradients (.grad), adj_inputs are None
706
+ b.grad = wp.array([1.0, 1.0, 1.0], dtype=float)
707
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[None, None])
708
+
709
+ assert_np_equal(a.grad.numpy(), np.array([2.0, 4.0, 6.0]))
710
+
711
+
712
+ def test_gradient_external(test, device):
713
+ with wp.ScopedDevice(device):
714
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=False)
715
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=False)
716
+
717
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
718
+
719
+ # use external gradients passed in adj_inputs
720
+ a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
721
+ b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
722
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
723
+
724
+ assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0]))
725
+
726
+
727
+ def test_gradient_precedence(test, device):
728
+ with wp.ScopedDevice(device):
729
+ a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
730
+ b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)
731
+
732
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b])
733
+
734
+ # if both internal and external gradients are present, the external one takes precedence,
735
+ # because it's explicitly passed by the user in adj_inputs
736
+ a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
737
+ b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
738
+ wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
739
+
740
+ assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0])) # used
741
+ assert_np_equal(a.grad.numpy(), np.array([0.0, 0.0, 0.0])) # unused
742
+
743
+
744
+ def test_gradient_slice_2d(test, device):
745
+ with wp.ScopedDevice(device):
746
+ a = wp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=float, requires_grad=True)
747
+ b = wp.zeros_like(a, requires_grad=False)
748
+ b.grad = wp.ones_like(a, requires_grad=False)
749
+
750
+ wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1])
751
+
752
+ # use internal gradients (.grad), adj_inputs are None
753
+ wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1])
754
+
755
+ assert_np_equal(a.grad.numpy(), np.array([[0.0, 0.0], [6.0, 8.0], [0.0, 0.0]]))
756
+
757
+
758
+ def test_gradient_slice_3d_1d(test, device):
759
+ with wp.ScopedDevice(device):
760
+ data = [
761
+ [
762
+ [1, 2, 3],
763
+ [4, 5, 6],
764
+ [7, 8, 9],
765
+ ],
766
+ [
767
+ [11, 12, 13],
768
+ [14, 15, 16],
769
+ [17, 18, 19],
770
+ ],
771
+ [
772
+ [21, 22, 23],
773
+ [24, 25, 26],
774
+ [27, 28, 29],
775
+ ],
776
+ ]
777
+ a = wp.array(data, dtype=float, requires_grad=True)
778
+ b = wp.zeros_like(a, requires_grad=False)
779
+ b.grad = wp.ones_like(a, requires_grad=False)
780
+
781
+ wp.launch(square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1])
782
+
783
+ # use internal gradients (.grad), adj_inputs are None
784
+ wp.launch(
785
+ square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1]
786
+ )
787
+
788
+ expected_grad = [
789
+ [
790
+ [0, 0, 0],
791
+ [0, 0, 0],
792
+ [0, 0, 0],
793
+ ],
794
+ [
795
+ [11 * 2, 12 * 2, 13 * 2],
796
+ [14 * 2, 15 * 2, 16 * 2],
797
+ [17 * 2, 18 * 2, 19 * 2],
798
+ ],
799
+ [
800
+ [0, 0, 0],
801
+ [0, 0, 0],
802
+ [0, 0, 0],
803
+ ],
804
+ ]
805
+ assert_np_equal(a.grad.numpy(), np.array(expected_grad))
806
+
807
+
808
+ def test_gradient_slice_3d_2d(test, device):
809
+ with wp.ScopedDevice(device):
810
+ data = [
811
+ [
812
+ [1, 2, 3],
813
+ [4, 5, 6],
814
+ [7, 8, 9],
815
+ ],
816
+ [
817
+ [11, 12, 13],
818
+ [14, 15, 16],
819
+ [17, 18, 19],
820
+ ],
821
+ [
822
+ [21, 22, 23],
823
+ [24, 25, 26],
824
+ [27, 28, 29],
825
+ ],
826
+ ]
827
+ a = wp.array(data, dtype=float, requires_grad=True)
828
+ b = wp.zeros_like(a, requires_grad=False)
829
+ b.grad = wp.ones_like(a, requires_grad=False)
830
+
831
+ wp.launch(square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1])
832
+
833
+ # use internal gradients (.grad), adj_inputs are None
834
+ wp.launch(
835
+ square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1], adjoint=True, adj_inputs=[None, None, 1, 1]
836
+ )
837
+
838
+ expected_grad = [
839
+ [
840
+ [0, 0, 0],
841
+ [0, 0, 0],
842
+ [0, 0, 0],
843
+ ],
844
+ [
845
+ [0, 0, 0],
846
+ [14 * 2, 15 * 2, 16 * 2],
847
+ [0, 0, 0],
848
+ ],
849
+ [
850
+ [0, 0, 0],
851
+ [0, 0, 0],
852
+ [0, 0, 0],
853
+ ],
854
+ ]
855
+ assert_np_equal(a.grad.numpy(), np.array(expected_grad))
856
+
857
+
858
+ devices = get_test_devices()
859
+
860
+
861
+ class TestGrad(unittest.TestCase):
862
+ pass
863
+
864
+
865
+ # add_function_test(TestGrad, "test_while_loop_grad", test_while_loop_grad, devices=devices)
866
+ add_function_test(TestGrad, "test_for_loop_nested_for_grad", test_for_loop_nested_for_grad, devices=devices)
867
+ add_function_test(TestGrad, "test_scalar_grad", test_scalar_grad, devices=devices)
868
+ add_function_test(TestGrad, "test_for_loop_grad", test_for_loop_grad, devices=devices)
869
+ add_function_test(
870
+ TestGrad, "test_for_loop_graph_grad", test_for_loop_graph_grad, devices=get_selected_cuda_test_devices()
871
+ )
872
+ add_function_test(TestGrad, "test_for_loop_nested_if_grad", test_for_loop_nested_if_grad, devices=devices)
873
+ add_function_test(TestGrad, "test_preserve_outputs_grad", test_preserve_outputs_grad, devices=devices)
874
+ add_function_test(TestGrad, "test_vector_math_grad", test_vector_math_grad, devices=devices)
875
+ add_function_test(TestGrad, "test_matrix_math_grad", test_matrix_math_grad, devices=devices)
876
+ add_function_test(TestGrad, "test_3d_math_grad", test_3d_math_grad, devices=devices)
877
+ add_function_test(TestGrad, "test_multi_valued_function_grad", test_multi_valued_function_grad, devices=devices)
878
+ add_function_test(TestGrad, "test_mesh_grad", test_mesh_grad, devices=devices)
879
+ add_function_test(TestGrad, "test_name_clash", test_name_clash, devices=devices)
880
+ add_function_test(TestGrad, "test_struct_attribute_gradient", test_struct_attribute_gradient, devices=devices)
881
+ add_function_test(TestGrad, "test_copy", test_copy, devices=devices)
882
+ add_function_test(TestGrad, "test_aliasing", test_aliasing, devices=devices)
883
+ add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices)
884
+ add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices)
885
+ add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices)
886
+ add_function_test(TestGrad, "test_gradient_slice_2d", test_gradient_slice_2d, devices=devices)
887
+ add_function_test(TestGrad, "test_gradient_slice_3d_1d", test_gradient_slice_3d_1d, devices=devices)
888
+ add_function_test(TestGrad, "test_gradient_slice_3d_2d", test_gradient_slice_3d_2d, devices=devices)
889
+
890
+
891
+ if __name__ == "__main__":
892
+ wp.clear_kernel_cache()
893
+ unittest.main(verbosity=2, failfast=False)