warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,116 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example jax_callable()
18
+ #
19
+ # Examples of calling annotated Python functions from JAX.
20
+ ###########################################################################
21
+
22
+ from functools import partial
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+
27
+ import warp as wp
28
+ from warp.jax_experimental.ffi import jax_callable
29
+
30
+
31
+ @wp.kernel
32
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
33
+ tid = wp.tid()
34
+ output[tid] = a[tid] * s
35
+
36
+
37
+ @wp.kernel
38
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
39
+ tid = wp.tid()
40
+ output[tid] = a[tid] * s
41
+
42
+
43
+ # The Python function to call.
44
+ # Note the argument annotations, just like Warp kernels.
45
+ def example_func(
46
+ # inputs
47
+ a: wp.array(dtype=float),
48
+ b: wp.array(dtype=wp.vec2),
49
+ s: float,
50
+ # outputs
51
+ c: wp.array(dtype=float),
52
+ d: wp.array(dtype=wp.vec2),
53
+ ):
54
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
55
+ wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
56
+
57
+
58
+ def example1():
59
+ jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
60
+
61
+ @jax.jit
62
+ def f():
63
+ # inputs
64
+ a = jnp.arange(10, dtype=jnp.float32)
65
+ b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
66
+ s = 2.0
67
+
68
+ # output shapes
69
+ output_dims = {"c": a.shape, "d": b.shape}
70
+
71
+ c, d = jax_func(a, b, s, output_dims=output_dims)
72
+
73
+ return c, d
74
+
75
+ r1, r2 = f()
76
+ print(r1)
77
+ print(r2)
78
+
79
+
80
+ def example2():
81
+ jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
82
+
83
+ # NOTE: scalar arguments must be static compile-time constants
84
+ @partial(jax.jit, static_argnames=["s"])
85
+ def f(a, b, s):
86
+ # output shapes
87
+ output_dims = {"c": a.shape, "d": b.shape}
88
+
89
+ c, d = jax_func(a, b, s, output_dims=output_dims)
90
+
91
+ return c, d
92
+
93
+ # inputs
94
+ a = jnp.arange(10, dtype=jnp.float32)
95
+ b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
96
+ s = 3.0
97
+
98
+ r1, r2 = f(a, b, s)
99
+ print(r1)
100
+ print(r2)
101
+
102
+
103
+ def main():
104
+ wp.init()
105
+ wp.load_module(device=wp.get_device())
106
+
107
+ examples = [example1, example2]
108
+
109
+ for example in examples:
110
+ print("\n===========================================================================")
111
+ print(f"{example.__name__}:")
112
+ example()
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
@@ -0,0 +1,132 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example register_ffi_callback()
18
+ #
19
+ # Examples of calling Python functions from JAX.
20
+ # Target functions must have the form func(inputs, outputs, attrs, ctx).
21
+ ###########################################################################
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ import warp as wp
28
+ from warp.jax import get_jax_device
29
+ from warp.jax_experimental.ffi import register_ffi_callback
30
+
31
+
32
+ @wp.kernel
33
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] * s
36
+
37
+
38
+ @wp.kernel
39
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
40
+ tid = wp.tid()
41
+ output[tid] = a[tid] * s
42
+
43
+
44
+ def example1():
45
+ # the Python function to call
46
+ def print_args(inputs, outputs, attrs, ctx):
47
+ def buffer_to_string(b):
48
+ return str(b.dtype) + str(list(b.shape)) + " @%x" % b.data
49
+
50
+ print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
51
+ print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
52
+ print("Attributes: ", "".join(["\n %s: %s" % (k, str(v)) for k, v in attrs.items()]))
53
+
54
+ # register callback
55
+ register_ffi_callback("print_args", print_args)
56
+
57
+ # set up call
58
+ call = jax.ffi.ffi_call("print_args", jax.ShapeDtypeStruct((1, 2, 3), jnp.int8))
59
+
60
+ # call it
61
+ call(
62
+ jnp.arange(16),
63
+ jnp.arange(32.0).reshape((4, 8)),
64
+ str_attr="hi",
65
+ f32_attr=np.float32(4.2),
66
+ dict_attr={"a": 1, "b": 6.4},
67
+ )
68
+
69
+
70
+ def example2():
71
+ # the Python function to call
72
+ def warp_func(inputs, outputs, attrs, ctx):
73
+ # input arrays
74
+ a = inputs[0]
75
+ b = inputs[1]
76
+
77
+ # scalar attributes
78
+ s = attrs["scale"]
79
+
80
+ # output arrays
81
+ c = outputs[0]
82
+ d = outputs[1]
83
+
84
+ device = wp.device_from_jax(get_jax_device())
85
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
86
+
87
+ with wp.ScopedStream(stream):
88
+ # launch with arrays of scalars
89
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
90
+
91
+ # launch with arrays of vec2
92
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
93
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
94
+
95
+ # register callback
96
+ register_ffi_callback("warp_func", warp_func)
97
+
98
+ n = 10
99
+
100
+ # inputs
101
+ a = jnp.arange(n, dtype=jnp.float32)
102
+ b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2)) # array of wp.vec2
103
+ s = 2.0
104
+
105
+ # set up call
106
+ out_types = [
107
+ jax.ShapeDtypeStruct(a.shape, jnp.float32),
108
+ jax.ShapeDtypeStruct(b.shape, jnp.float32), # array of wp.vec2
109
+ ]
110
+ call = jax.ffi.ffi_call("warp_func", out_types)
111
+
112
+ # call it
113
+ c, d = call(a, b, scale=s)
114
+
115
+ print(c)
116
+ print(d)
117
+
118
+
119
+ def main():
120
+ wp.init()
121
+ wp.load_module(device=wp.get_device())
122
+
123
+ examples = [example1, example2]
124
+
125
+ for example in examples:
126
+ print("\n===========================================================================")
127
+ print(f"{example.__name__}:")
128
+ example()
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
@@ -0,0 +1,205 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example jax_kernel()
18
+ #
19
+ # Examples of calling a Warp kernel from JAX.
20
+ ###########################################################################
21
+
22
+ import math
23
+ from functools import partial
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ import warp as wp
29
+ from warp.jax_experimental.ffi import jax_kernel
30
+
31
+
32
+ @wp.kernel
33
+ def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] + b[tid]
36
+
37
+
38
+ @wp.kernel
39
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
40
+ tid = wp.tid()
41
+ sin_out[tid] = wp.sin(angle[tid])
42
+ cos_out[tid] = wp.cos(angle[tid])
43
+
44
+
45
+ @wp.kernel
46
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
47
+ tid = wp.tid()
48
+ output[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0)
49
+
50
+
51
+ @wp.kernel
52
+ def matmul_kernel(
53
+ a: wp.array2d(dtype=float), # NxK
54
+ b: wp.array2d(dtype=float), # KxM
55
+ c: wp.array2d(dtype=float), # NxM
56
+ ):
57
+ # launch dims should be (N, M)
58
+ i, j = wp.tid()
59
+ N = a.shape[0]
60
+ K = a.shape[1]
61
+ M = b.shape[1]
62
+ if i < N and j < M:
63
+ s = wp.float32(0)
64
+ for k in range(K):
65
+ s += a[i, k] * b[k, j]
66
+ c[i, j] = s
67
+
68
+
69
+ @wp.kernel
70
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
71
+ tid = wp.tid()
72
+ output[tid] = a[tid] * s
73
+
74
+
75
+ def example1():
76
+ # two inputs and one output
77
+ jax_add = jax_kernel(add_kernel)
78
+
79
+ @jax.jit
80
+ def f():
81
+ n = 10
82
+ a = jnp.arange(n, dtype=jnp.int32)
83
+ b = jnp.ones(n, dtype=jnp.int32)
84
+ return jax_add(a, b)
85
+
86
+ print(f())
87
+
88
+
89
+ def example2():
90
+ # one input and two outputs
91
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
92
+
93
+ @jax.jit
94
+ def f():
95
+ n = 32
96
+ a = jnp.linspace(0, 2 * math.pi, n)
97
+ return jax_sincos(a)
98
+
99
+ s, c = f()
100
+ print(s)
101
+ print()
102
+ print(c)
103
+
104
+
105
+ def example3():
106
+ # multiply vectors by scalar
107
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
108
+
109
+ @jax.jit
110
+ def f():
111
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
112
+ s = 2.0
113
+ return jax_scale_vec(a, s)
114
+
115
+ b = f()
116
+ print(b)
117
+
118
+
119
+ def example4():
120
+ # multiply vectors by scalar (static arg)
121
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
122
+
123
+ # NOTE: scalar arguments must be static compile-time constants
124
+ @partial(jax.jit, static_argnames=["s"])
125
+ def f(a, s):
126
+ return jax_scale_vec(a, s)
127
+
128
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
129
+ s = 3.0
130
+
131
+ b = f(a, s)
132
+ print(b)
133
+
134
+
135
+ def example5():
136
+ N, M, K = 3, 4, 2
137
+
138
+ # specify default launch dims
139
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
140
+
141
+ @jax.jit
142
+ def f():
143
+ a = jnp.full((N, K), 2, dtype=jnp.float32)
144
+ b = jnp.full((K, M), 3, dtype=jnp.float32)
145
+
146
+ # use default launch dims
147
+ return jax_matmul(a, b)
148
+
149
+ print(f())
150
+
151
+
152
+ def example6():
153
+ # don't specify default launch dims
154
+ jax_matmul = jax_kernel(matmul_kernel)
155
+
156
+ @jax.jit
157
+ def f():
158
+ N1, M1, K1 = 3, 4, 2
159
+ a1 = jnp.full((N1, K1), 2, dtype=jnp.float32)
160
+ b1 = jnp.full((K1, M1), 3, dtype=jnp.float32)
161
+
162
+ # use custom launch dims
163
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
164
+
165
+ N2, M2, K2 = 4, 3, 2
166
+ a2 = jnp.full((N2, K2), 2, dtype=jnp.float32)
167
+ b2 = jnp.full((K2, M2), 3, dtype=jnp.float32)
168
+
169
+ # use custom launch dims
170
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
171
+
172
+ return result1, result2
173
+
174
+ r1, r2 = f()
175
+ print(r1)
176
+ print()
177
+ print(r2)
178
+
179
+
180
+ def example7():
181
+ # no inputs and one output
182
+ jax_diagonal = jax_kernel(diagonal_kernel)
183
+
184
+ @jax.jit
185
+ def f():
186
+ # launch dimensions determine output size
187
+ return jax_diagonal(launch_dims=4)
188
+
189
+ print(f())
190
+
191
+
192
+ def main():
193
+ wp.init()
194
+ wp.load_module(device=wp.get_device())
195
+
196
+ examples = [example1, example2, example3, example4, example5, example6, example7]
197
+
198
+ for example in examples:
199
+ print("\n===========================================================================")
200
+ print(f"{example.__name__}:")
201
+ example()
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
@@ -0,0 +1,266 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Sim Grad Bounce
18
+ #
19
+ # Shows how to use Warp to optimize the initial velocity of a particle
20
+ # such that it bounces off the wall and floor in order to hit a target.
21
+ #
22
+ # This example uses the built-in wp.Tape() object to compute gradients of
23
+ # the distance to target (loss) w.r.t the initial velocity, followed by
24
+ # a simple gradient-descent optimization step.
25
+ #
26
+ ###########################################################################
27
+
28
+ import numpy as np
29
+
30
+ import warp as wp
31
+ import warp.sim
32
+ import warp.sim.render
33
+
34
+
35
+ @wp.kernel
36
+ def loss_kernel(pos: wp.array(dtype=wp.vec3), target: wp.vec3, loss: wp.array(dtype=float)):
37
+ # distance to target
38
+ delta = pos[0] - target
39
+ loss[0] = wp.dot(delta, delta)
40
+
41
+
42
+ @wp.kernel
43
+ def step_kernel(x: wp.array(dtype=wp.vec3), grad: wp.array(dtype=wp.vec3), alpha: float):
44
+ tid = wp.tid()
45
+
46
+ # gradient descent step
47
+ x[tid] = x[tid] - grad[tid] * alpha
48
+
49
+
50
+ class Example:
51
+ def __init__(self, stage_path="example_bounce.usd", verbose=False):
52
+ self.verbose = verbose
53
+
54
+ # seconds
55
+ sim_duration = 0.6
56
+
57
+ # control frequency
58
+ fps = 60
59
+ self.frame_dt = 1.0 / fps
60
+ frame_steps = int(sim_duration / self.frame_dt)
61
+
62
+ # sim frequency
63
+ self.sim_substeps = 8
64
+ self.sim_steps = frame_steps * self.sim_substeps
65
+ self.sim_dt = self.frame_dt / self.sim_substeps
66
+
67
+ self.iter = 0
68
+ self.render_time = 0.0
69
+
70
+ self.train_rate = 0.02
71
+
72
+ ke = 1.0e4
73
+ kf = 0.0
74
+ kd = 1.0e1
75
+ mu = 0.2
76
+
77
+ builder = wp.sim.ModelBuilder()
78
+ builder.add_particle(pos=wp.vec3(-0.5, 1.0, 0.0), vel=wp.vec3(5.0, -5.0, 0.0), mass=1.0)
79
+ builder.add_shape_box(body=-1, pos=wp.vec3(2.0, 1.0, 0.0), hx=0.25, hy=1.0, hz=1.0, ke=ke, kf=kf, kd=kd, mu=mu)
80
+
81
+ # use `requires_grad=True` to create a model for differentiable simulation
82
+ self.model = builder.finalize(requires_grad=True)
83
+ self.model.ground = True
84
+
85
+ self.model.soft_contact_ke = ke
86
+ self.model.soft_contact_kf = kf
87
+ self.model.soft_contact_kd = kd
88
+ self.model.soft_contact_mu = mu
89
+ self.model.soft_contact_margin = 10.0
90
+ self.model.soft_contact_restitution = 1.0
91
+
92
+ self.integrator = wp.sim.SemiImplicitIntegrator()
93
+
94
+ self.target = (-2.0, 1.5, 0.0)
95
+ self.loss = wp.zeros(1, dtype=wp.float32, requires_grad=True)
96
+
97
+ # allocate sim states for trajectory
98
+ self.states = []
99
+ for _i in range(self.sim_steps + 1):
100
+ self.states.append(self.model.state())
101
+
102
+ # one-shot contact creation (valid if we're doing simple collision against a constant normal plane)
103
+ wp.sim.collide(self.model, self.states[0])
104
+
105
+ if stage_path:
106
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=1.0)
107
+ else:
108
+ self.renderer = None
109
+
110
+ # capture forward/backward passes
111
+ self.use_cuda_graph = wp.get_device().is_cuda
112
+ if self.use_cuda_graph:
113
+ with wp.ScopedCapture() as capture:
114
+ self.tape = wp.Tape()
115
+ with self.tape:
116
+ self.forward()
117
+ self.tape.backward(self.loss)
118
+ self.graph = capture.graph
119
+
120
+ def forward(self):
121
+ # run control loop
122
+ for i in range(self.sim_steps):
123
+ self.states[i].clear_forces()
124
+ self.integrator.simulate(self.model, self.states[i], self.states[i + 1], self.sim_dt)
125
+
126
+ # compute loss on final state
127
+ wp.launch(loss_kernel, dim=1, inputs=[self.states[-1].particle_q, self.target, self.loss])
128
+
129
+ return self.loss
130
+
131
+ def step(self):
132
+ with wp.ScopedTimer("step"):
133
+ if self.use_cuda_graph:
134
+ wp.capture_launch(self.graph)
135
+ else:
136
+ self.tape = wp.Tape()
137
+ with self.tape:
138
+ self.forward()
139
+ self.tape.backward(self.loss)
140
+
141
+ # gradient descent step
142
+ x = self.states[0].particle_qd
143
+ wp.launch(step_kernel, dim=len(x), inputs=[x, x.grad, self.train_rate])
144
+
145
+ x_grad = self.tape.gradients[self.states[0].particle_qd]
146
+
147
+ if self.verbose:
148
+ print(f"Iter: {self.iter} Loss: {self.loss}")
149
+ print(f" x: {x} g: {x_grad}")
150
+
151
+ # clear grads for next iteration
152
+ self.tape.zero()
153
+
154
+ self.iter = self.iter + 1
155
+
156
+ def render(self):
157
+ if self.renderer is None:
158
+ return
159
+
160
+ with wp.ScopedTimer("render"):
161
+ # draw trajectory
162
+ traj_verts = [self.states[0].particle_q.numpy()[0].tolist()]
163
+
164
+ for i in range(0, self.sim_steps, self.sim_substeps):
165
+ traj_verts.append(self.states[i].particle_q.numpy()[0].tolist())
166
+
167
+ self.renderer.begin_frame(self.render_time)
168
+ self.renderer.render(self.states[i])
169
+ self.renderer.render_box(
170
+ pos=self.target,
171
+ rot=wp.quat_identity(),
172
+ extents=(0.1, 0.1, 0.1),
173
+ name="target",
174
+ color=(0.0, 0.0, 0.0),
175
+ )
176
+ self.renderer.render_line_strip(
177
+ vertices=traj_verts,
178
+ color=wp.render.bourke_color_map(0.0, 7.0, self.loss.numpy()[0]),
179
+ radius=0.02,
180
+ name=f"traj_{self.iter - 1}",
181
+ )
182
+ self.renderer.end_frame()
183
+
184
+ from pxr import Gf, UsdGeom
185
+
186
+ particles_prim = self.renderer.stage.GetPrimAtPath("/root/particles")
187
+ particles = UsdGeom.Points.Get(self.renderer.stage, particles_prim.GetPath())
188
+ particles.CreateDisplayColorAttr().Set([Gf.Vec3f(1.0, 1.0, 1.0)], time=self.renderer.time)
189
+
190
+ self.render_time += self.frame_dt
191
+
192
+ def check_grad(self):
193
+ param = self.states[0].particle_qd
194
+
195
+ # initial value
196
+ x_c = param.numpy().flatten()
197
+
198
+ # compute numeric gradient
199
+ x_grad_numeric = np.zeros_like(x_c)
200
+
201
+ for i in range(len(x_c)):
202
+ eps = 1.0e-3
203
+
204
+ step = np.zeros_like(x_c)
205
+ step[i] = eps
206
+
207
+ x_1 = x_c + step
208
+ x_0 = x_c - step
209
+
210
+ param.assign(x_1)
211
+ l_1 = self.forward().numpy()[0]
212
+
213
+ param.assign(x_0)
214
+ l_0 = self.forward().numpy()[0]
215
+
216
+ dldx = (l_1 - l_0) / (eps * 2.0)
217
+
218
+ x_grad_numeric[i] = dldx
219
+
220
+ # reset initial state
221
+ param.assign(x_c)
222
+
223
+ # compute analytic gradient
224
+ tape = wp.Tape()
225
+ with tape:
226
+ l = self.forward()
227
+
228
+ tape.backward(l)
229
+
230
+ x_grad_analytic = tape.gradients[param]
231
+
232
+ print(f"numeric grad: {x_grad_numeric}")
233
+ print(f"analytic grad: {x_grad_analytic}")
234
+
235
+ tape.zero()
236
+
237
+
238
+ if __name__ == "__main__":
239
+ import argparse
240
+
241
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
242
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
243
+ parser.add_argument(
244
+ "--stage_path",
245
+ type=lambda x: None if x == "None" else str(x),
246
+ default="example_bounce.usd",
247
+ help="Path to the output USD file.",
248
+ )
249
+ parser.add_argument("--train_iters", type=int, default=250, help="Total number of training iterations.")
250
+ parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.")
251
+
252
+ args = parser.parse_known_args()[0]
253
+
254
+ with wp.ScopedDevice(args.device):
255
+ example = Example(stage_path=args.stage_path, verbose=args.verbose)
256
+
257
+ example.check_grad()
258
+
259
+ # replay and optimize
260
+ for i in range(args.train_iters):
261
+ example.step()
262
+ if i % 16 == 0:
263
+ example.render()
264
+
265
+ if example.renderer:
266
+ example.renderer.save()