warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,166 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import time
17
+
18
+ import torch
19
+
20
+ import warp as wp
21
+
22
+
23
+ def create_simple_kernel(dtype):
24
+ def simple_kernel(
25
+ a: wp.array(dtype=dtype),
26
+ b: wp.array(dtype=dtype),
27
+ c: wp.array(dtype=dtype),
28
+ d: wp.array(dtype=dtype),
29
+ e: wp.array(dtype=dtype),
30
+ ):
31
+ pass
32
+
33
+ return wp.Kernel(simple_kernel)
34
+
35
+
36
+ def test_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
37
+ warp_device = wp.get_device(device)
38
+ torch_device = wp.device_to_torch(warp_device)
39
+
40
+ if hasattr(warp_dtype, "_shape_"):
41
+ torch_shape = (array_size, *warp_dtype._shape_)
42
+ torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
43
+ else:
44
+ torch_shape = (array_size,)
45
+ torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)
46
+
47
+ _a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
48
+ _b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
49
+ _c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
50
+ _d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
51
+ _e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
52
+
53
+ wp.synchronize()
54
+
55
+ # profiler = Profiler(interval=0.000001)
56
+ # profiler.start()
57
+
58
+ t1 = time.time_ns()
59
+
60
+ for _ in range(num_iters):
61
+ a = wp.from_torch(_a, dtype=warp_dtype)
62
+ b = wp.from_torch(_b, dtype=warp_dtype)
63
+ c = wp.from_torch(_c, dtype=warp_dtype)
64
+ d = wp.from_torch(_d, dtype=warp_dtype)
65
+ e = wp.from_torch(_e, dtype=warp_dtype)
66
+ wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
67
+
68
+ t2 = time.time_ns()
69
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_torch(...)")
70
+
71
+ # profiler.stop()
72
+ # profiler.print()
73
+
74
+
75
+ def test_array_ctype_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
76
+ warp_device = wp.get_device(device)
77
+ torch_device = wp.device_to_torch(warp_device)
78
+
79
+ if hasattr(warp_dtype, "_shape_"):
80
+ torch_shape = (array_size, *warp_dtype._shape_)
81
+ torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
82
+ else:
83
+ torch_shape = (array_size,)
84
+ torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)
85
+
86
+ _a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
87
+ _b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
88
+ _c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
89
+ _d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
90
+ _e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
91
+
92
+ wp.synchronize()
93
+
94
+ # profiler = Profiler(interval=0.000001)
95
+ # profiler.start()
96
+
97
+ t1 = time.time_ns()
98
+
99
+ for _ in range(num_iters):
100
+ a = wp.from_torch(_a, dtype=warp_dtype, return_ctype=True)
101
+ b = wp.from_torch(_b, dtype=warp_dtype, return_ctype=True)
102
+ c = wp.from_torch(_c, dtype=warp_dtype, return_ctype=True)
103
+ d = wp.from_torch(_d, dtype=warp_dtype, return_ctype=True)
104
+ e = wp.from_torch(_e, dtype=warp_dtype, return_ctype=True)
105
+ wp.launch(kernel, dim=array_size, inputs=[a, b, c, d, e])
106
+
107
+ t2 = time.time_ns()
108
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms from_torch(..., return_ctype=True)")
109
+
110
+ # profiler.stop()
111
+ # profiler.print()
112
+
113
+
114
+ def test_direct_from_torch(kernel, num_iters, array_size, device, warp_dtype=None):
115
+ warp_device = wp.get_device(device)
116
+ torch_device = wp.device_to_torch(warp_device)
117
+
118
+ if hasattr(warp_dtype, "_shape_"):
119
+ torch_shape = (array_size, *warp_dtype._shape_)
120
+ torch_dtype = wp.dtype_to_torch(warp_dtype._wp_scalar_type_)
121
+ else:
122
+ torch_shape = (array_size,)
123
+ torch_dtype = torch.float32 if warp_dtype is None else wp.dtype_to_torch(warp_dtype)
124
+
125
+ _a = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
126
+ _b = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
127
+ _c = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
128
+ _d = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
129
+ _e = torch.zeros(torch_shape, dtype=torch_dtype, device=torch_device)
130
+
131
+ wp.synchronize()
132
+
133
+ # profiler = Profiler(interval=0.000001)
134
+ # profiler.start()
135
+
136
+ t1 = time.time_ns()
137
+
138
+ for _ in range(num_iters):
139
+ wp.launch(kernel, dim=array_size, inputs=[_a, _b, _c, _d, _e])
140
+
141
+ t2 = time.time_ns()
142
+ print(f"{(t2 - t1) / 1_000_000:8.0f} ms direct from torch")
143
+
144
+ # profiler.stop()
145
+ # profiler.print()
146
+
147
+
148
+ wp.init()
149
+
150
+ params = [
151
+ # (warp_dtype arg, kernel)
152
+ (None, create_simple_kernel(wp.float32)),
153
+ (wp.float32, create_simple_kernel(wp.float32)),
154
+ (wp.vec3f, create_simple_kernel(wp.vec3f)),
155
+ (wp.mat22f, create_simple_kernel(wp.mat22f)),
156
+ ]
157
+
158
+ wp.load_module()
159
+
160
+ num_iters = 100000
161
+
162
+ for warp_dtype, kernel in params:
163
+ print(f"\ndtype={wp.context.type_str(warp_dtype)}")
164
+ test_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
165
+ test_array_ctype_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
166
+ test_direct_from_torch(kernel, num_iters, 10, "cuda:0", warp_dtype=warp_dtype)
@@ -0,0 +1,301 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Benchmarks for kernel launches with different types of args
18
+ ###########################################################################
19
+
20
+ import warp as wp
21
+
22
+
23
+ @wp.struct
24
+ class S0:
25
+ pass
26
+
27
+
28
+ @wp.struct
29
+ class Sf:
30
+ x: float
31
+ y: float
32
+ z: float
33
+
34
+
35
+ @wp.struct
36
+ class Sv:
37
+ u: wp.vec3
38
+ v: wp.vec3
39
+ w: wp.vec3
40
+
41
+
42
+ @wp.struct
43
+ class Sm:
44
+ M: wp.mat33
45
+ N: wp.mat33
46
+ O: wp.mat33
47
+
48
+
49
+ @wp.struct
50
+ class Sa:
51
+ a: wp.array(dtype=float)
52
+ b: wp.array(dtype=float)
53
+ c: wp.array(dtype=float)
54
+
55
+
56
+ @wp.struct
57
+ class Sz:
58
+ a: wp.array(dtype=float)
59
+ b: wp.array(dtype=float)
60
+ c: wp.array(dtype=float)
61
+ x: float
62
+ y: float
63
+ z: float
64
+ u: wp.vec3
65
+ v: wp.vec3
66
+ w: wp.vec3
67
+
68
+
69
+ @wp.kernel
70
+ def k0():
71
+ tid = wp.tid() # noqa: F841
72
+
73
+
74
+ @wp.kernel
75
+ def kf(x: float, y: float, z: float):
76
+ tid = wp.tid() # noqa: F841
77
+
78
+
79
+ @wp.kernel
80
+ def kv(u: wp.vec3, v: wp.vec3, w: wp.vec3):
81
+ tid = wp.tid() # noqa: F841
82
+
83
+
84
+ @wp.kernel
85
+ def km(M: wp.mat33, N: wp.mat33, O: wp.mat33):
86
+ tid = wp.tid() # noqa: F841
87
+
88
+
89
+ @wp.kernel
90
+ def ka(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float)):
91
+ tid = wp.tid() # noqa: F841
92
+
93
+
94
+ @wp.kernel
95
+ def kz(
96
+ a: wp.array(dtype=float),
97
+ b: wp.array(dtype=float),
98
+ c: wp.array(dtype=float),
99
+ x: float,
100
+ y: float,
101
+ z: float,
102
+ u: wp.vec3,
103
+ v: wp.vec3,
104
+ w: wp.vec3,
105
+ ):
106
+ tid = wp.tid() # noqa: F841
107
+
108
+
109
+ @wp.kernel
110
+ def ks0(s: S0):
111
+ tid = wp.tid() # noqa: F841
112
+
113
+
114
+ @wp.kernel
115
+ def ksf(s: Sf):
116
+ tid = wp.tid() # noqa: F841
117
+
118
+
119
+ @wp.kernel
120
+ def ksv(s: Sv):
121
+ tid = wp.tid() # noqa: F841
122
+
123
+
124
+ @wp.kernel
125
+ def ksm(s: Sm):
126
+ tid = wp.tid() # noqa: F841
127
+
128
+
129
+ @wp.kernel
130
+ def ksa(s: Sa):
131
+ tid = wp.tid() # noqa: F841
132
+
133
+
134
+ @wp.kernel
135
+ def ksz(s: Sz):
136
+ tid = wp.tid() # noqa: F841
137
+
138
+
139
+ wp.clear_kernel_cache()
140
+
141
+ devices = wp.get_devices()
142
+ num_launches = 100000
143
+
144
+ for device in devices:
145
+ with wp.ScopedDevice(device):
146
+ print(f"\n=================== Device '{device}' ===================")
147
+
148
+ wp.force_load(device)
149
+
150
+ n = 1
151
+ a = wp.zeros(n, dtype=float)
152
+ b = wp.zeros(n, dtype=float)
153
+ c = wp.zeros(n, dtype=float)
154
+ x = 17.0
155
+ y = 42.0
156
+ z = 99.0
157
+ u = wp.vec3(1, 2, 3)
158
+ v = wp.vec3(10, 20, 30)
159
+ w = wp.vec3(100, 200, 300)
160
+ M = wp.mat33()
161
+ N = wp.mat33()
162
+ O = wp.mat33()
163
+
164
+ s0 = S0()
165
+
166
+ sf = Sf()
167
+ sf.x = x
168
+ sf.y = y
169
+ sf.z = z
170
+
171
+ sv = Sv()
172
+ sv.u = u
173
+ sv.v = v
174
+ sv.w = w
175
+
176
+ sm = Sm()
177
+ sm.M = M
178
+ sm.N = N
179
+ sm.O = O
180
+
181
+ sa = Sa()
182
+ sa.a = a
183
+ sa.b = b
184
+ sa.c = c
185
+
186
+ sz = Sz()
187
+ sz.a = a
188
+ sz.b = b
189
+ sz.c = c
190
+ sz.x = x
191
+ sz.y = y
192
+ sz.z = z
193
+ sz.u = u
194
+ sz.v = v
195
+ sz.w = w
196
+
197
+ tk0 = wp.ScopedTimer("k0")
198
+ tkf = wp.ScopedTimer("kf")
199
+ tkv = wp.ScopedTimer("kv")
200
+ tkm = wp.ScopedTimer("km")
201
+ tka = wp.ScopedTimer("ka")
202
+ tkz = wp.ScopedTimer("kz")
203
+
204
+ ts0 = wp.ScopedTimer("s0")
205
+ tsf = wp.ScopedTimer("sf")
206
+ tsv = wp.ScopedTimer("sv")
207
+ tsm = wp.ScopedTimer("sm")
208
+ tsa = wp.ScopedTimer("sa")
209
+ tsz = wp.ScopedTimer("sz")
210
+
211
+ wp.synchronize_device()
212
+
213
+ with tk0:
214
+ for _ in range(num_launches):
215
+ wp.launch(k0, dim=1, inputs=[])
216
+
217
+ wp.synchronize_device()
218
+
219
+ with tkf:
220
+ for _ in range(num_launches):
221
+ wp.launch(kf, dim=1, inputs=[x, y, z])
222
+
223
+ wp.synchronize_device()
224
+
225
+ with tkv:
226
+ for _ in range(num_launches):
227
+ wp.launch(kv, dim=1, inputs=[u, v, w])
228
+
229
+ wp.synchronize_device()
230
+
231
+ with tkm:
232
+ for _ in range(num_launches):
233
+ wp.launch(km, dim=1, inputs=[M, N, O])
234
+
235
+ wp.synchronize_device()
236
+
237
+ with tka:
238
+ for _ in range(num_launches):
239
+ wp.launch(ka, dim=1, inputs=[a, b, c])
240
+
241
+ wp.synchronize_device()
242
+
243
+ with tkz:
244
+ for _ in range(num_launches):
245
+ wp.launch(kz, dim=1, inputs=[a, b, c, x, y, z, u, v, w])
246
+
247
+ # structs
248
+
249
+ wp.synchronize_device()
250
+
251
+ with ts0:
252
+ for _ in range(num_launches):
253
+ wp.launch(ks0, dim=1, inputs=[s0])
254
+
255
+ wp.synchronize_device()
256
+
257
+ with tsf:
258
+ for _ in range(num_launches):
259
+ wp.launch(ksf, dim=1, inputs=[sf])
260
+
261
+ wp.synchronize_device()
262
+
263
+ with tsv:
264
+ for _ in range(num_launches):
265
+ wp.launch(ksv, dim=1, inputs=[sv])
266
+
267
+ wp.synchronize_device()
268
+
269
+ with tsm:
270
+ for _ in range(num_launches):
271
+ wp.launch(ksm, dim=1, inputs=[sm])
272
+
273
+ wp.synchronize_device()
274
+
275
+ with tsa:
276
+ for _ in range(num_launches):
277
+ wp.launch(ksa, dim=1, inputs=[sa])
278
+
279
+ wp.synchronize_device()
280
+
281
+ with tsz:
282
+ for _ in range(num_launches):
283
+ wp.launch(ksz, dim=1, inputs=[sz])
284
+
285
+ wp.synchronize_device()
286
+
287
+ timers = [
288
+ [tk0, ts0],
289
+ [tkf, tsf],
290
+ [tkv, tsv],
291
+ [tkm, tsm],
292
+ [tka, tsa],
293
+ [tkz, tsz],
294
+ ]
295
+
296
+ print("--------------------------------")
297
+ print("| args | direct | struct |")
298
+ print("--------------------------------")
299
+ for tk, ts in timers:
300
+ print(f"| {tk.name} |{tk.elapsed:10.0f} |{ts.elapsed:10.0f} |")
301
+ print("--------------------------------")
@@ -0,0 +1,103 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import warp as wp
19
+
20
+ BLOCK_DIM = 128
21
+
22
+ TILE = 32
23
+
24
+
25
+ def create_test_kernel(storage_type: str):
26
+ @wp.kernel
27
+ def load_store(a: wp.array2d(dtype=wp.float32), b: wp.array2d(dtype=wp.float32)):
28
+ i, j = wp.tid()
29
+
30
+ if wp.static(storage_type == "shared"):
31
+ a_tile = wp.tile_load(a, shape=(TILE, TILE), offset=(i * TILE, j * TILE), storage="shared")
32
+ else:
33
+ a_tile = wp.tile_load(a, shape=(TILE, TILE), offset=(i * TILE, j * TILE), storage="register")
34
+
35
+ wp.tile_store(b, a_tile, offset=(i * TILE, j * TILE))
36
+
37
+ return load_store
38
+
39
+
40
+ if __name__ == "__main__":
41
+ wp.config.quiet = True
42
+ wp.init()
43
+ wp.clear_kernel_cache()
44
+ wp.set_module_options({"fast_math": True, "enable_backward": False})
45
+
46
+ iterations = 100
47
+ rng = np.random.default_rng(42)
48
+
49
+ shared_benchmark_data = {}
50
+ register_benchmark_data = {}
51
+ memcpy_benchmark_data = {}
52
+
53
+ sizes = list(range(128, 4097, 128))
54
+
55
+ print(f"{'Transfer Size (Bytes)':<23s} {'Shared (GiB/s)':<16s} {'Register (GiB/s)':<18s} {'memcpy (GiB/s)':<16s}")
56
+ print("-" * 79)
57
+
58
+ for size in sizes:
59
+ a = wp.array(rng.random((size, size), dtype=np.float32), dtype=wp.float32)
60
+ b = wp.empty_like(a)
61
+
62
+ for storage_type in ("shared", "register"):
63
+ load_store = create_test_kernel(storage_type)
64
+
65
+ cmd = wp.launch_tiled(
66
+ load_store,
67
+ dim=(a.shape[0] // TILE, a.shape[1] // TILE),
68
+ inputs=[a],
69
+ outputs=[b],
70
+ block_dim=BLOCK_DIM,
71
+ record_cmd=True,
72
+ )
73
+ # Warmup
74
+ for _ in range(5):
75
+ cmd.launch()
76
+
77
+ with wp.ScopedTimer("benchmark", cuda_filter=wp.TIMING_KERNEL, print=False, synchronize=True) as timer:
78
+ for _ in range(iterations):
79
+ cmd.launch()
80
+
81
+ np.testing.assert_equal(a.numpy(), b.numpy())
82
+
83
+ timing_results = [result.elapsed for result in timer.timing_results]
84
+ avg_bw = 2.0 * (a.capacity / (1024 * 1024 * 1024)) / (1e-3 * np.mean(timing_results))
85
+
86
+ if storage_type == "shared":
87
+ shared_benchmark_data[a.capacity] = avg_bw
88
+ else:
89
+ register_benchmark_data[a.capacity] = avg_bw
90
+
91
+ # Compare with memcpy
92
+ with wp.ScopedTimer("benchmark", cuda_filter=wp.TIMING_MEMCPY, print=False, synchronize=True) as timer:
93
+ for _ in range(iterations):
94
+ wp.copy(b, a)
95
+
96
+ timing_results = [result.elapsed for result in timer.timing_results]
97
+ avg_bw = 2.0 * (a.capacity / (1024 * 1024 * 1024)) / (1e-3 * np.mean(timing_results))
98
+ memcpy_benchmark_data[a.capacity] = avg_bw
99
+
100
+ # Print results
101
+ print(
102
+ f"{a.capacity:<23d} {shared_benchmark_data[a.capacity]:<#16.4g} {register_benchmark_data[a.capacity]:<#18.4g} {memcpy_benchmark_data[a.capacity]:<#16.4g}"
103
+ )
@@ -0,0 +1,37 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import subprocess
18
+ import sys
19
+
20
+
21
+ def open_file(filename):
22
+ if sys.platform == "win32":
23
+ os.startfile(filename)
24
+ else:
25
+ subprocess.call(["xdg-open", filename])
26
+
27
+
28
+ if __name__ == "__main__":
29
+ import warp.examples
30
+
31
+ source_dir = warp.examples.get_source_directory()
32
+ print(f"Example source directory: {source_dir}")
33
+
34
+ try:
35
+ open_file(source_dir)
36
+ except Exception:
37
+ pass
@@ -0,0 +1,86 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example CuPy
18
+ #
19
+ # The example demonstrates interoperability with CuPy on CUDA devices
20
+ # and NumPy on CPU devices.
21
+ ###########################################################################
22
+
23
+ import warp as wp
24
+
25
+
26
+ @wp.kernel
27
+ def saxpy(x: wp.array(dtype=float), y: wp.array(dtype=float), a: float):
28
+ i = wp.tid()
29
+ y[i] = a * x[i] + y[i]
30
+
31
+
32
+ class Example:
33
+ def __init__(self):
34
+ device = wp.get_device()
35
+
36
+ self.n = 10
37
+ self.a = 1.0
38
+
39
+ if device.is_cuda:
40
+ # use CuPy arrays on CUDA devices
41
+ import cupy as cp
42
+
43
+ print(f"Using CuPy on device {device}")
44
+
45
+ # tell CuPy to use the same device
46
+ with cp.cuda.Device(device.ordinal):
47
+ self.x = cp.arange(self.n, dtype=cp.float32)
48
+ self.y = cp.ones(self.n, dtype=cp.float32)
49
+ else:
50
+ # use NumPy arrays on CPU
51
+ import numpy as np
52
+
53
+ print("Using NumPy on CPU")
54
+
55
+ self.x = np.arange(self.n, dtype=np.float32)
56
+ self.y = np.ones(self.n, dtype=np.float32)
57
+
58
+ def step(self):
59
+ # Launch a Warp kernel on the pre-allocated arrays.
60
+ # When running on a CUDA device, these are CuPy arrays.
61
+ # When running on the CPU, these are NumPy arrays.
62
+ #
63
+ # Note that the arrays can be passed to Warp kernels directly. Under the hood,
64
+ # Warp uses the __cuda_array_interface__ and __array_interface__ protocols to
65
+ # access the data.
66
+ wp.launch(saxpy, dim=self.n, inputs=[self.x, self.y, self.a])
67
+
68
+ def render(self):
69
+ print(self.y)
70
+
71
+
72
+ if __name__ == "__main__":
73
+ import argparse
74
+
75
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
76
+ parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
77
+ parser.add_argument("--num_frames", type=int, default=10, help="Total number of frames.")
78
+
79
+ args = parser.parse_known_args()[0]
80
+
81
+ with wp.ScopedDevice(args.device):
82
+ example = Example()
83
+
84
+ for _ in range(args.num_frames):
85
+ example.step()
86
+ example.render()