warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,44 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+ devices = get_test_devices()
22
+
23
+
24
+ class TestModuleLite(unittest.TestCase):
25
+ def test_module_lite_load(self):
26
+ # Load current module
27
+ wp.load_module()
28
+
29
+ # Load named module
30
+ wp.load_module(wp.config)
31
+
32
+ # Load named module (string)
33
+ wp.load_module(wp.config, recursive=True)
34
+
35
+ def test_module_lite_options(self):
36
+ wp.set_module_options({"max_unroll": 8})
37
+ module_options = wp.get_module_options()
38
+ self.assertIsInstance(module_options, dict)
39
+ self.assertEqual(module_options["max_unroll"], 8)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ wp.clear_kernel_cache()
44
+ unittest.main(verbosity=2)
@@ -0,0 +1,252 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def pnoise(
26
+ kernel_seed: int, W: int, px: int, py: int, noise_values: wp.array(dtype=float), pixel_values: wp.array(dtype=float)
27
+ ):
28
+ tid = wp.tid()
29
+
30
+ state = wp.rand_init(kernel_seed)
31
+
32
+ x = (float(tid % W) + 0.5) * 0.2
33
+ y = (float(tid / W) + 0.5) * 0.2
34
+ p = wp.vec2(x, y)
35
+
36
+ n = wp.pnoise(state, p, px, py)
37
+ noise_values[tid] = n
38
+
39
+ g = ((n + 1.0) / 2.0) * 255.0
40
+ pixel_values[tid] = g
41
+
42
+
43
+ @wp.kernel
44
+ def curlnoise(kernel_seed: int, W: int, noise_coords: wp.array(dtype=wp.vec2), noise_vectors: wp.array(dtype=wp.vec2)):
45
+ tid = wp.tid()
46
+
47
+ state = wp.rand_init(kernel_seed)
48
+
49
+ x = (float(tid % W) + 0.5) * 0.2
50
+ y = (float(tid / W) + 0.5) * 0.2
51
+
52
+ p = wp.vec2(x, y)
53
+ v = wp.curlnoise(state, p)
54
+
55
+ noise_coords[tid] = p
56
+ noise_vectors[tid] = v
57
+
58
+
59
+ def test_pnoise(test, device):
60
+ # image dim
61
+ W = 256
62
+ H = 256
63
+ N = W * H
64
+ seed = 42
65
+
66
+ # periodic perlin noise test
67
+ px = 16
68
+ py = 16
69
+
70
+ noise_values = wp.zeros(N, dtype=float, device=device)
71
+ pixel_values = wp.zeros(N, dtype=float, device=device)
72
+
73
+ wp.launch(kernel=pnoise, dim=N, inputs=[seed, W, px, py, noise_values, pixel_values], outputs=[], device=device)
74
+
75
+ # Perlin theoretical range is [-0.5*sqrt(n), 0.5*sqrt(n)] for n dimensions
76
+ n = noise_values.numpy()
77
+ # max = np.max(n)
78
+ # min = np.min(n)
79
+
80
+ img = pixel_values.numpy()
81
+ img = np.reshape(img, (W, H))
82
+
83
+ ### Figure viewing ###
84
+ # img = img.astype(np.uint8)
85
+ # imgplot = plt.imshow(img, 'gray')
86
+ # plt.savefig("pnoise_test.png")
87
+
88
+ ### Generating pnoise_test_result_true.npy ###
89
+ # np.save(os.path.join(os.path.dirname(__file__), "assets/pnoise_golden.npy"), img)
90
+
91
+ ### Golden image comparison ###
92
+ img_true = np.load(os.path.join(os.path.dirname(__file__), "assets/pnoise_golden.npy"))
93
+ test.assertTrue(img.shape == img_true.shape)
94
+ err = np.max(np.abs(img - img_true))
95
+ tolerance = 1.5e-3
96
+ test.assertTrue(err < tolerance, f"err is {err} which is >= {tolerance}")
97
+
98
+
99
+ def test_curlnoise(test, device):
100
+ # image dim
101
+ W = 128
102
+ H = 128
103
+ N = W * H
104
+ seed = 42
105
+
106
+ # curl noise test
107
+ quiver_coords_host = wp.zeros(N, dtype=wp.vec2, device="cpu")
108
+ quiver_coords = wp.zeros(N, dtype=wp.vec2, device=device)
109
+
110
+ quiver_arrows_host = wp.zeros(N, dtype=wp.vec2, device="cpu")
111
+ quiver_arrows = wp.zeros(N, dtype=wp.vec2, device=device)
112
+
113
+ wp.launch(kernel=curlnoise, dim=N, inputs=[seed, W, quiver_coords, quiver_arrows], outputs=[], device=device)
114
+
115
+ wp.copy(quiver_coords_host, quiver_coords)
116
+ wp.copy(quiver_arrows_host, quiver_arrows)
117
+
118
+ wp.synchronize()
119
+
120
+ xy_coords = quiver_coords_host.numpy()
121
+ uv_coords = quiver_arrows_host.numpy()
122
+
123
+ # normalize
124
+ norms = uv_coords[:, 0] * uv_coords[:, 0] + uv_coords[:, 1] * uv_coords[:, 1]
125
+ uv_coords = uv_coords / np.sqrt(np.max(norms))
126
+
127
+ X = xy_coords[:, 0]
128
+ Y = xy_coords[:, 1]
129
+ U = uv_coords[:, 0]
130
+ V = uv_coords[:, 1]
131
+
132
+ ### Figure viewing ###
133
+ # fig, ax = plt.subplots(figsize=(25,25))
134
+ # ax.quiver(X, Y, U, V)
135
+ # ax.axis([0.0, 25.0, 0.0, 25.0])
136
+ # ax.set_aspect('equal')
137
+ # plt.savefig("curlnoise_test.png")
138
+
139
+ ### Generating curlnoise_test_result_true.npy ###
140
+ result = np.stack((xy_coords, uv_coords))
141
+ # np.save(os.path.join(os.path.dirname(__file__), "assets/curlnoise_golden.npy"), result)
142
+
143
+ ### Golden image comparison ###
144
+ result_true = np.load(os.path.join(os.path.dirname(__file__), "assets/curlnoise_golden.npy"))
145
+ test.assertTrue(result.shape, result_true.shape)
146
+ err = np.max(np.abs(result - result_true))
147
+ test.assertTrue(err < 1e-04)
148
+
149
+
150
+ @wp.kernel
151
+ def noise_loss_kernel(
152
+ kernel_seed: int,
153
+ query_positions: wp.array(dtype=wp.vec2),
154
+ noise_values: wp.array(dtype=float),
155
+ noise_loss: wp.array(dtype=float),
156
+ ):
157
+ tid = wp.tid()
158
+ state = wp.rand_init(kernel_seed)
159
+
160
+ p = query_positions[tid]
161
+
162
+ n = wp.noise(state, p)
163
+ noise_values[tid] = n
164
+
165
+ wp.atomic_add(noise_loss, 0, n)
166
+
167
+
168
+ @wp.kernel
169
+ def noise_cd(kernel_seed: int, query_positions: wp.array(dtype=wp.vec2), gradients: wp.array(dtype=wp.vec2)):
170
+ tid = wp.tid()
171
+ state = wp.rand_init(kernel_seed)
172
+ p = query_positions[tid]
173
+
174
+ eps = 1.0e-3
175
+
176
+ pl = wp.vec2(p[0] - eps, p[1])
177
+ pr = wp.vec2(p[0] + eps, p[1])
178
+ pd = wp.vec2(p[0], p[1] - eps)
179
+ pu = wp.vec2(p[0], p[1] + eps)
180
+
181
+ nl = wp.noise(state, pl)
182
+ nr = wp.noise(state, pr)
183
+ nd = wp.noise(state, pd)
184
+ nu = wp.noise(state, pu)
185
+
186
+ gx = (nr - nl) / (2.0 * eps)
187
+ gy = (nu - nd) / (2.0 * eps)
188
+
189
+ gradients[tid] = wp.vec2(gx, gy)
190
+
191
+
192
+ def test_adj_noise(test, device):
193
+ # grid dim
194
+ N = 9
195
+ seed = 42
196
+
197
+ tape = wp.Tape()
198
+
199
+ positions = np.array(
200
+ [
201
+ [-0.1, -0.1],
202
+ [0.0, -0.1],
203
+ [0.1, -0.1],
204
+ [-0.1, 0.0],
205
+ [0.0, 0.0],
206
+ [0.1, 0.0],
207
+ [-0.1, 0.1],
208
+ [0.0, 0.1],
209
+ [0.1, 0.1],
210
+ ]
211
+ )
212
+
213
+ with tape:
214
+ query_positions = wp.array(positions, dtype=wp.vec2, device=device, requires_grad=True)
215
+ noise_values = wp.zeros(N, dtype=float, device=device)
216
+ noise_loss = wp.zeros(n=1, dtype=float, device=device, requires_grad=True)
217
+
218
+ wp.launch(
219
+ kernel=noise_loss_kernel, dim=N, inputs=[seed, query_positions, noise_values, noise_loss], device=device
220
+ )
221
+
222
+ # analytic
223
+ tape.backward(loss=noise_loss)
224
+ analytic = tape.gradients[query_positions].numpy().reshape((3, 3, 2))
225
+
226
+ # central difference
227
+ gradients = wp.zeros(N, dtype=wp.vec2, device=device)
228
+ wp.launch(kernel=noise_cd, dim=N, inputs=[seed, query_positions, gradients], device=device)
229
+
230
+ gradients_host = gradients.numpy().reshape((3, 3, 2))
231
+ diff = analytic - gradients_host
232
+ result = np.sum(diff * diff, axis=2)
233
+
234
+ err = np.where(result > 1.0e-3, result, 0).sum()
235
+ test.assertTrue(err < 1.0e-8)
236
+
237
+
238
+ devices = get_test_devices()
239
+
240
+
241
+ class TestNoise(unittest.TestCase):
242
+ pass
243
+
244
+
245
+ add_function_test(TestNoise, "test_pnoise", test_pnoise, devices=devices)
246
+ add_function_test(TestNoise, "test_curlnoise", test_curlnoise, devices=devices)
247
+ add_function_test(TestNoise, "test_adj_noise", test_adj_noise, devices=devices)
248
+
249
+
250
+ if __name__ == "__main__":
251
+ wp.clear_kernel_cache()
252
+ unittest.main(verbosity=2)
@@ -0,0 +1,299 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+
22
+ @wp.kernel
23
+ def test_operators_scalar_float():
24
+ a = 1.0
25
+ b = 2.0
26
+
27
+ c = a * b
28
+ d = a + b
29
+ e = a / b
30
+ f = a - b
31
+ g = b**8.0
32
+ h = 10.0 // 3.0
33
+
34
+ expect_eq(c, 2.0)
35
+ expect_eq(d, 3.0)
36
+ expect_eq(e, 0.5)
37
+ expect_eq(f, -1.0)
38
+ expect_eq(g, 256.0)
39
+ expect_eq(h, 3.0)
40
+
41
+
42
+ @wp.kernel
43
+ def test_operators_scalar_int():
44
+ a = 1
45
+ b = 2
46
+
47
+ c = a * b
48
+ d = a + b
49
+ e = a / b
50
+ f = a - b
51
+ # g = b**8 # integer pow not implemented
52
+ h = 10 // 3
53
+ i = 10 % 3
54
+ j = 2 << 3
55
+ k = 16 >> 1
56
+
57
+ expect_eq(c, 2)
58
+ expect_eq(d, 3)
59
+ expect_eq(e, 0)
60
+ expect_eq(f, -1)
61
+ # expect_eq(g, 256)
62
+ expect_eq(h, 3)
63
+ expect_eq(i, 1)
64
+ expect_eq(j, 16)
65
+ expect_eq(k, 8)
66
+
67
+ f0 = wp.uint32(1 << 0)
68
+ f1 = wp.uint32(1 << 3)
69
+ expect_eq(f0 | f1, f0 + f1)
70
+ expect_eq(f0 & f1, wp.uint32(0))
71
+
72
+ l = wp.uint8(0)
73
+ for n in range(8):
74
+ l |= wp.uint8(1 << n)
75
+ expect_eq(l, ~wp.uint8(0))
76
+
77
+
78
+ @wp.kernel
79
+ def test_operators_vector_index():
80
+ v = wp.vec4(1.0, 2.0, 3.0, 4.0)
81
+
82
+ expect_eq(v[0], 1.0)
83
+ expect_eq(v[1], 2.0)
84
+ expect_eq(v[2], 3.0)
85
+ expect_eq(v[3], 4.0)
86
+
87
+
88
+ @wp.kernel
89
+ def test_operators_matrix_index():
90
+ m22 = wp.mat22(1.0, 2.0, 3.0, 4.0)
91
+
92
+ expect_eq(m22[0, 0], 1.0)
93
+ expect_eq(m22[0, 1], 2.0)
94
+ expect_eq(m22[1, 0], 3.0)
95
+ expect_eq(m22[1, 1], 4.0)
96
+
97
+
98
+ @wp.kernel
99
+ def test_operators_vec3():
100
+ v = vec3(1.0, 2.0, 3.0)
101
+
102
+ r0 = v * 3.0
103
+ r1 = 3.0 * v
104
+
105
+ expect_eq(r0, vec3(3.0, 6.0, 9.0))
106
+ expect_eq(r1, vec3(3.0, 6.0, 9.0))
107
+
108
+ col0 = vec3(1.0, 0.0, 0.0)
109
+ col1 = vec3(0.0, 2.0, 0.0)
110
+ col2 = vec3(0.0, 0.0, 3.0)
111
+
112
+ m = mat33(col0, col1, col2)
113
+
114
+ expect_eq(m * vec3(1.0, 0.0, 0.0), col0)
115
+ expect_eq(m * vec3(0.0, 1.0, 0.0), col1)
116
+ expect_eq(m * vec3(0.0, 0.0, 1.0), col2)
117
+
118
+ two = vec3(1.0) * 2.0
119
+ expect_eq(two, vec3(2.0, 2.0, 2.0))
120
+
121
+
122
+ @wp.kernel
123
+ def test_operators_vec4():
124
+ v = vec4(1.0, 2.0, 3.0, 4.0)
125
+
126
+ r0 = v * 3.0
127
+ r1 = 3.0 * v
128
+
129
+ expect_eq(r0, vec4(3.0, 6.0, 9.0, 12.0))
130
+ expect_eq(r1, vec4(3.0, 6.0, 9.0, 12.0))
131
+
132
+ col0 = vec4(1.0, 0.0, 0.0, 0.0)
133
+ col1 = vec4(0.0, 2.0, 0.0, 0.0)
134
+ col2 = vec4(0.0, 0.0, 3.0, 0.0)
135
+ col3 = vec4(0.0, 0.0, 0.0, 4.0)
136
+
137
+ m = mat44(col0, col1, col2, col3)
138
+
139
+ expect_eq(m * vec4(1.0, 0.0, 0.0, 0.0), col0)
140
+ expect_eq(m * vec4(0.0, 1.0, 0.0, 0.0), col1)
141
+ expect_eq(m * vec4(0.0, 0.0, 1.0, 0.0), col2)
142
+ expect_eq(m * vec4(0.0, 0.0, 0.0, 1.0), col3)
143
+
144
+ two = vec4(1.0) * 2.0
145
+ expect_eq(two, vec4(2.0, 2.0, 2.0, 2.0))
146
+
147
+
148
+ @wp.kernel
149
+ def test_operators_mat22():
150
+ m = mat22(1.0, 2.0, 3.0, 4.0)
151
+ r = mat22(3.0, 6.0, 9.0, 12.0)
152
+
153
+ r0 = m * 3.0
154
+ r1 = 3.0 * m
155
+
156
+ expect_eq(r0, r)
157
+ expect_eq(r1, r)
158
+
159
+ expect_eq(r0[0, 0], 3.0)
160
+ expect_eq(r0[0, 1], 6.0)
161
+ expect_eq(r0[1, 0], 9.0)
162
+ expect_eq(r0[1, 1], 12.0)
163
+
164
+ expect_eq(r0[0], wp.vec2(3.0, 6.0))
165
+ expect_eq(r0[1], wp.vec2(9.0, 12.0))
166
+
167
+
168
+ @wp.kernel
169
+ def test_operators_mat33():
170
+ m = mat33(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
171
+
172
+ r = mat33(3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0)
173
+
174
+ r0 = m * 3.0
175
+ r1 = 3.0 * m
176
+
177
+ expect_eq(r0, r)
178
+ expect_eq(r1, r)
179
+
180
+ expect_eq(r0[0, 0], 3.0)
181
+ expect_eq(r0[0, 1], 6.0)
182
+ expect_eq(r0[0, 2], 9.0)
183
+
184
+ expect_eq(r0[1, 0], 12.0)
185
+ expect_eq(r0[1, 1], 15.0)
186
+ expect_eq(r0[1, 2], 18.0)
187
+
188
+ expect_eq(r0[2, 0], 21.0)
189
+ expect_eq(r0[2, 1], 24.0)
190
+ expect_eq(r0[2, 2], 27.0)
191
+
192
+ expect_eq(r0[0], wp.vec3(3.0, 6.0, 9.0))
193
+ expect_eq(r0[1], wp.vec3(12.0, 15.0, 18.0))
194
+ expect_eq(r0[2], wp.vec3(21.0, 24.0, 27.0))
195
+
196
+
197
+ @wp.kernel
198
+ def test_operators_mat44():
199
+ m = mat44(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0)
200
+
201
+ r = mat44(3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0, 39.0, 42.0, 45.0, 48.0)
202
+
203
+ r0 = m * 3.0
204
+ r1 = 3.0 * m
205
+
206
+ expect_eq(r0, r)
207
+ expect_eq(r1, r)
208
+
209
+ expect_eq(r0[0, 0], 3.0)
210
+ expect_eq(r0[0, 1], 6.0)
211
+ expect_eq(r0[0, 2], 9.0)
212
+ expect_eq(r0[0, 3], 12.0)
213
+
214
+ expect_eq(r0[1, 0], 15.0)
215
+ expect_eq(r0[1, 1], 18.0)
216
+ expect_eq(r0[1, 2], 21.0)
217
+ expect_eq(r0[1, 3], 24.0)
218
+
219
+ expect_eq(r0[2, 0], 27.0)
220
+ expect_eq(r0[2, 1], 30.0)
221
+ expect_eq(r0[2, 2], 33.0)
222
+ expect_eq(r0[2, 3], 36.0)
223
+
224
+ expect_eq(r0[3, 0], 39.0)
225
+ expect_eq(r0[3, 1], 42.0)
226
+ expect_eq(r0[3, 2], 45.0)
227
+ expect_eq(r0[3, 3], 48.0)
228
+
229
+ expect_eq(r0[0], wp.vec4(3.0, 6.0, 9.0, 12.0))
230
+ expect_eq(r0[1], wp.vec4(15.0, 18.0, 21.0, 24.0))
231
+ expect_eq(r0[2], wp.vec4(27.0, 30.0, 33.0, 36.0))
232
+ expect_eq(r0[3], wp.vec4(39.0, 42.0, 45.0, 48.0))
233
+
234
+
235
+ @wp.struct
236
+ class Complex:
237
+ real: float
238
+ imag: float
239
+
240
+
241
+ @wp.func
242
+ def add(
243
+ a: Complex,
244
+ b: Complex,
245
+ ) -> Complex:
246
+ return Complex(
247
+ a.real + b.real,
248
+ a.imag + b.imag,
249
+ )
250
+
251
+
252
+ @wp.func
253
+ def mul(
254
+ a: Complex,
255
+ b: Complex,
256
+ ) -> Complex:
257
+ return Complex(
258
+ a.real * b.real - a.imag * b.imag,
259
+ a.real * b.imag + a.imag * b.real,
260
+ )
261
+
262
+
263
+ @wp.kernel
264
+ def test_operators_overload():
265
+ a = Complex(1.0, 2.0)
266
+ b = Complex(3.0, 4.0)
267
+
268
+ c = a + b
269
+ expect_eq(c.real, 4.0)
270
+ expect_eq(c.imag, 6.0)
271
+
272
+ d = a * b
273
+ expect_eq(d.real, -5.0)
274
+ expect_eq(d.imag, 10.0)
275
+
276
+
277
+ devices = get_test_devices()
278
+
279
+
280
+ class TestOperators(unittest.TestCase):
281
+ pass
282
+
283
+
284
+ add_kernel_test(TestOperators, test_operators_scalar_float, dim=1, devices=devices)
285
+ add_kernel_test(TestOperators, test_operators_scalar_int, dim=1, devices=devices)
286
+ add_kernel_test(TestOperators, test_operators_matrix_index, dim=1, devices=devices)
287
+ add_kernel_test(TestOperators, test_operators_vector_index, dim=1, devices=devices)
288
+ add_kernel_test(TestOperators, test_operators_vec3, dim=1, devices=devices)
289
+ add_kernel_test(TestOperators, test_operators_vec4, dim=1, devices=devices)
290
+
291
+ add_kernel_test(TestOperators, test_operators_mat22, dim=1, devices=devices)
292
+ add_kernel_test(TestOperators, test_operators_mat33, dim=1, devices=devices)
293
+ add_kernel_test(TestOperators, test_operators_mat44, dim=1, devices=devices)
294
+ add_kernel_test(TestOperators, test_operators_overload, dim=1, devices=devices)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ wp.clear_kernel_cache()
299
+ unittest.main(verbosity=2)
@@ -0,0 +1,129 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import contextlib
17
+ import io
18
+ import unittest
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def scale(
26
+ x: wp.array(dtype=float),
27
+ y: wp.array(dtype=float),
28
+ ):
29
+ y[0] = x[0] ** 2.0
30
+
31
+
32
+ @wp.kernel(enable_backward=True)
33
+ def scale_1(
34
+ x: wp.array(dtype=float),
35
+ y: wp.array(dtype=float),
36
+ ):
37
+ y[0] = x[0] ** 2.0
38
+
39
+
40
+ @wp.kernel(enable_backward=False)
41
+ def scale_2(
42
+ x: wp.array(dtype=float),
43
+ y: wp.array(dtype=float),
44
+ ):
45
+ y[0] = x[0] ** 2.0
46
+
47
+
48
+ def test_options_1(test, device):
49
+ x = wp.array([3.0], dtype=float, requires_grad=True, device=device)
50
+ y = wp.zeros_like(x)
51
+
52
+ wp.set_module_options({"enable_backward": False})
53
+
54
+ tape = wp.Tape()
55
+ with tape:
56
+ wp.launch(scale, dim=1, inputs=[x, y], device=device)
57
+
58
+ with contextlib.redirect_stdout(io.StringIO()) as f:
59
+ tape.backward(y)
60
+
61
+ expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale.key} is defined in a module with the option 'enable_backward=False' set.\n"
62
+
63
+ assert f.getvalue() == expected
64
+ assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
65
+
66
+
67
+ def test_options_2(test, device):
68
+ x = wp.array([3.0], dtype=float, requires_grad=True, device=device)
69
+ y = wp.zeros_like(x)
70
+
71
+ wp.set_module_options({"enable_backward": True})
72
+
73
+ tape = wp.Tape()
74
+ with tape:
75
+ wp.launch(scale, dim=1, inputs=[x, y], device=device)
76
+
77
+ tape.backward(y)
78
+ assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
79
+
80
+
81
+ def test_options_3(test, device):
82
+ x = wp.array([3.0], dtype=float, requires_grad=True, device=device)
83
+ y = wp.zeros_like(x)
84
+
85
+ wp.set_module_options({"enable_backward": False})
86
+
87
+ tape = wp.Tape()
88
+ with tape:
89
+ wp.launch(scale_1, dim=1, inputs=[x, y], device=device)
90
+
91
+ tape.backward(y)
92
+ assert_np_equal(tape.gradients[x].numpy(), np.array(6.0))
93
+
94
+
95
+ def test_options_4(test, device):
96
+ x = wp.array([3.0], dtype=float, requires_grad=True, device=device)
97
+ y = wp.zeros_like(x)
98
+
99
+ wp.set_module_options({"enable_backward": True})
100
+
101
+ tape = wp.Tape()
102
+ with tape:
103
+ wp.launch(scale_2, dim=1, inputs=[x, y], device=device)
104
+
105
+ with contextlib.redirect_stdout(io.StringIO()) as f:
106
+ tape.backward(y)
107
+
108
+ expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale_2.key} is configured with the option 'enable_backward=False'.\n"
109
+
110
+ assert f.getvalue() == expected
111
+ assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
112
+
113
+
114
+ devices = get_test_devices()
115
+
116
+
117
+ class TestOptions(unittest.TestCase):
118
+ pass
119
+
120
+
121
+ add_function_test(TestOptions, "test_options_1", test_options_1, devices=devices)
122
+ add_function_test(TestOptions, "test_options_2", test_options_2, devices=devices)
123
+ add_function_test(TestOptions, "test_options_3", test_options_3, devices=devices)
124
+ add_function_test(TestOptions, "test_options_4", test_options_4, devices=devices)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ wp.clear_kernel_cache()
129
+ unittest.main(verbosity=2)