warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1001 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def op_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
26
+ tid = wp.tid()
27
+ y[tid] = 0.5 - x[tid] * 2.0
28
+
29
+
30
+ @wp.kernel
31
+ def inc(a: wp.array(dtype=float)):
32
+ tid = wp.tid()
33
+ a[tid] = a[tid] + 1.0
34
+
35
+
36
+ @wp.kernel
37
+ def inc_vector(a: wp.array(dtype=wp.vec3f)):
38
+ tid = wp.tid()
39
+ a[tid] = a[tid] + wp.vec3f(1.0)
40
+
41
+
42
+ @wp.kernel
43
+ def inc_matrix(a: wp.array(dtype=wp.mat22f)):
44
+ tid = wp.tid()
45
+ a[tid] = a[tid] + wp.mat22f(1.0)
46
+
47
+
48
+ @wp.kernel
49
+ def arange(start: int, step: int, a: wp.array(dtype=int)):
50
+ tid = wp.tid()
51
+ a[tid] = start + step * tid
52
+
53
+
54
+ # copy elements between non-contiguous 1d arrays of float
55
+ @wp.kernel
56
+ def copy1d_float_kernel(dst: wp.array(dtype=float), src: wp.array(dtype=float)):
57
+ i = wp.tid()
58
+ dst[i] = src[i]
59
+
60
+
61
+ # copy elements between non-contiguous 2d arrays of float
62
+ @wp.kernel
63
+ def copy2d_float_kernel(dst: wp.array2d(dtype=float), src: wp.array2d(dtype=float)):
64
+ i, j = wp.tid()
65
+ dst[i, j] = src[i, j]
66
+
67
+
68
+ # copy elements between non-contiguous 3d arrays of float
69
+ @wp.kernel
70
+ def copy3d_float_kernel(dst: wp.array3d(dtype=float), src: wp.array3d(dtype=float)):
71
+ i, j, k = wp.tid()
72
+ dst[i, j, k] = src[i, j, k]
73
+
74
+
75
+ # copy elements between non-contiguous 2d arrays of vec3
76
+ @wp.kernel
77
+ def copy2d_vec3_kernel(dst: wp.array2d(dtype=wp.vec3), src: wp.array2d(dtype=wp.vec3)):
78
+ i, j = wp.tid()
79
+ dst[i, j] = src[i, j]
80
+
81
+
82
+ # copy elements between non-contiguous 2d arrays of mat22
83
+ @wp.kernel
84
+ def copy2d_mat22_kernel(dst: wp.array2d(dtype=wp.mat22), src: wp.array2d(dtype=wp.mat22)):
85
+ i, j = wp.tid()
86
+ dst[i, j] = src[i, j]
87
+
88
+
89
+ def test_dtype_from_torch(test, device):
90
+ import torch
91
+
92
+ def test_conversions(torch_type, warp_type):
93
+ test.assertEqual(wp.dtype_from_torch(torch_type), warp_type)
94
+
95
+ test_conversions(torch.float16, wp.float16)
96
+ test_conversions(torch.float32, wp.float32)
97
+ test_conversions(torch.float64, wp.float64)
98
+ test_conversions(torch.int8, wp.int8)
99
+ test_conversions(torch.int16, wp.int16)
100
+ test_conversions(torch.int32, wp.int32)
101
+ test_conversions(torch.int64, wp.int64)
102
+ test_conversions(torch.uint8, wp.uint8)
103
+ test_conversions(torch.bool, wp.bool)
104
+
105
+
106
+ def test_dtype_to_torch(test, device):
107
+ import torch
108
+
109
+ def test_conversions(warp_type, torch_type):
110
+ test.assertEqual(wp.dtype_to_torch(warp_type), torch_type)
111
+
112
+ test_conversions(wp.float16, torch.float16)
113
+ test_conversions(wp.float32, torch.float32)
114
+ test_conversions(wp.float64, torch.float64)
115
+ test_conversions(wp.int8, torch.int8)
116
+ test_conversions(wp.int16, torch.int16)
117
+ test_conversions(wp.int32, torch.int32)
118
+ test_conversions(wp.int64, torch.int64)
119
+ test_conversions(wp.uint8, torch.uint8)
120
+ test_conversions(wp.uint16, torch.int16)
121
+ test_conversions(wp.uint32, torch.int32)
122
+ test_conversions(wp.uint64, torch.int64)
123
+ test_conversions(wp.bool, torch.bool)
124
+
125
+
126
+ def test_device_conversion(test, device):
127
+ torch_device = wp.device_to_torch(device)
128
+ warp_device = wp.device_from_torch(torch_device)
129
+ test.assertEqual(warp_device, device)
130
+
131
+
132
+ def test_torch_zerocopy(test, device):
133
+ import torch
134
+
135
+ a = wp.zeros(10, dtype=wp.float32, device=device)
136
+ t = wp.to_torch(a)
137
+ assert a.ptr == t.data_ptr()
138
+
139
+ torch_device = wp.device_to_torch(device)
140
+
141
+ t = torch.zeros(10, dtype=torch.float32, device=torch_device)
142
+ a = wp.from_torch(t)
143
+ assert a.ptr == t.data_ptr()
144
+
145
+
146
+ def test_from_torch(test, device):
147
+ import torch
148
+
149
+ torch_device = wp.device_to_torch(device)
150
+
151
+ # automatically determine warp dtype
152
+ def wrap_scalar_tensor_implicit(torch_dtype, expected_warp_dtype):
153
+ t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
154
+ a = wp.from_torch(t)
155
+ assert a.dtype == expected_warp_dtype
156
+ assert a.shape == tuple(t.shape)
157
+
158
+ wrap_scalar_tensor_implicit(torch.float64, wp.float64)
159
+ wrap_scalar_tensor_implicit(torch.float32, wp.float32)
160
+ wrap_scalar_tensor_implicit(torch.float16, wp.float16)
161
+ wrap_scalar_tensor_implicit(torch.int64, wp.int64)
162
+ wrap_scalar_tensor_implicit(torch.int32, wp.int32)
163
+ wrap_scalar_tensor_implicit(torch.int16, wp.int16)
164
+ wrap_scalar_tensor_implicit(torch.int8, wp.int8)
165
+ wrap_scalar_tensor_implicit(torch.uint8, wp.uint8)
166
+ wrap_scalar_tensor_implicit(torch.bool, wp.bool)
167
+
168
+ # explicitly specify warp dtype
169
+ def wrap_scalar_tensor_explicit(torch_dtype, expected_warp_dtype):
170
+ t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
171
+ a = wp.from_torch(t, expected_warp_dtype)
172
+ assert a.dtype == expected_warp_dtype
173
+ assert a.shape == tuple(t.shape)
174
+
175
+ wrap_scalar_tensor_explicit(torch.float64, wp.float64)
176
+ wrap_scalar_tensor_explicit(torch.float32, wp.float32)
177
+ wrap_scalar_tensor_explicit(torch.float16, wp.float16)
178
+ wrap_scalar_tensor_explicit(torch.int64, wp.int64)
179
+ wrap_scalar_tensor_explicit(torch.int64, wp.uint64)
180
+ wrap_scalar_tensor_explicit(torch.int32, wp.int32)
181
+ wrap_scalar_tensor_explicit(torch.int32, wp.uint32)
182
+ wrap_scalar_tensor_explicit(torch.int16, wp.int16)
183
+ wrap_scalar_tensor_explicit(torch.int16, wp.uint16)
184
+ wrap_scalar_tensor_explicit(torch.int8, wp.int8)
185
+ wrap_scalar_tensor_explicit(torch.int8, wp.uint8)
186
+ wrap_scalar_tensor_explicit(torch.uint8, wp.uint8)
187
+ wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
188
+ wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
189
+ wrap_scalar_tensor_explicit(torch.bool, wp.int8)
190
+ wrap_scalar_tensor_explicit(torch.bool, wp.bool)
191
+
192
+ def wrap_vec_tensor(n, desired_warp_dtype):
193
+ t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
194
+ a = wp.from_torch(t, desired_warp_dtype)
195
+ assert a.dtype == desired_warp_dtype
196
+ assert a.shape == (10,)
197
+
198
+ wrap_vec_tensor(2, wp.vec2)
199
+ wrap_vec_tensor(3, wp.vec3)
200
+ wrap_vec_tensor(4, wp.vec4)
201
+ wrap_vec_tensor(6, wp.spatial_vector)
202
+ wrap_vec_tensor(7, wp.transform)
203
+
204
+ def wrap_mat_tensor(n, m, desired_warp_dtype):
205
+ t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
206
+ a = wp.from_torch(t, desired_warp_dtype)
207
+ assert a.dtype == desired_warp_dtype
208
+ assert a.shape == (10,)
209
+
210
+ wrap_mat_tensor(2, 2, wp.mat22)
211
+ wrap_mat_tensor(3, 3, wp.mat33)
212
+ wrap_mat_tensor(4, 4, wp.mat44)
213
+ wrap_mat_tensor(6, 6, wp.spatial_matrix)
214
+
215
+ def wrap_vec_tensor_with_grad(n, desired_warp_dtype):
216
+ t = torch.zeros((10, n), dtype=torch.float32, device=torch_device)
217
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
218
+ assert a.dtype == desired_warp_dtype
219
+ assert a.shape == (10,)
220
+
221
+ wrap_vec_tensor_with_grad(2, wp.vec2)
222
+ wrap_vec_tensor_with_grad(3, wp.vec3)
223
+ wrap_vec_tensor_with_grad(4, wp.vec4)
224
+ wrap_vec_tensor_with_grad(6, wp.spatial_vector)
225
+ wrap_vec_tensor_with_grad(7, wp.transform)
226
+
227
+ def wrap_mat_tensor_with_grad(n, m, desired_warp_dtype):
228
+ t = torch.zeros((10, n, m), dtype=torch.float32, device=torch_device)
229
+ a = wp.from_torch(t, desired_warp_dtype, requires_grad=True)
230
+ assert a.dtype == desired_warp_dtype
231
+ assert a.shape == (10,)
232
+
233
+ wrap_mat_tensor_with_grad(2, 2, wp.mat22)
234
+ wrap_mat_tensor_with_grad(3, 3, wp.mat33)
235
+ wrap_mat_tensor_with_grad(4, 4, wp.mat44)
236
+ wrap_mat_tensor_with_grad(6, 6, wp.spatial_matrix)
237
+
238
+
239
+ def test_array_ctype_from_torch(test, device):
240
+ import torch
241
+
242
+ torch_device = wp.device_to_torch(device)
243
+
244
+ # automatically determine warp dtype
245
+ def wrap_scalar_tensor_implicit(torch_dtype):
246
+ t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
247
+ a = wp.from_torch(t, return_ctype=True)
248
+ warp_dtype = wp.dtype_from_torch(torch_dtype)
249
+ ctype_size = ctypes.sizeof(warp_dtype._type_)
250
+ assert a.data == t.data_ptr()
251
+ assert a.grad == 0
252
+ assert a.ndim == 1
253
+ assert a.shape[0] == t.shape[0]
254
+ assert a.strides[0] == t.stride()[0] * ctype_size
255
+
256
+ wrap_scalar_tensor_implicit(torch.float64)
257
+ wrap_scalar_tensor_implicit(torch.float32)
258
+ wrap_scalar_tensor_implicit(torch.float16)
259
+ wrap_scalar_tensor_implicit(torch.int64)
260
+ wrap_scalar_tensor_implicit(torch.int32)
261
+ wrap_scalar_tensor_implicit(torch.int16)
262
+ wrap_scalar_tensor_implicit(torch.int8)
263
+ wrap_scalar_tensor_implicit(torch.uint8)
264
+ wrap_scalar_tensor_implicit(torch.bool)
265
+
266
+ # explicitly specify warp dtype
267
+ def wrap_scalar_tensor_explicit(torch_dtype, warp_dtype):
268
+ t = torch.zeros(10, dtype=torch_dtype, device=torch_device)
269
+ a = wp.from_torch(t, dtype=warp_dtype, return_ctype=True)
270
+ ctype_size = ctypes.sizeof(warp_dtype._type_)
271
+ assert a.data == t.data_ptr()
272
+ assert a.grad == 0
273
+ assert a.ndim == 1
274
+ assert a.shape[0] == t.shape[0]
275
+ assert a.strides[0] == t.stride()[0] * ctype_size
276
+
277
+ wrap_scalar_tensor_explicit(torch.float64, wp.float64)
278
+ wrap_scalar_tensor_explicit(torch.float32, wp.float32)
279
+ wrap_scalar_tensor_explicit(torch.float16, wp.float16)
280
+ wrap_scalar_tensor_explicit(torch.int64, wp.int64)
281
+ wrap_scalar_tensor_explicit(torch.int64, wp.uint64)
282
+ wrap_scalar_tensor_explicit(torch.int32, wp.int32)
283
+ wrap_scalar_tensor_explicit(torch.int32, wp.uint32)
284
+ wrap_scalar_tensor_explicit(torch.int16, wp.int16)
285
+ wrap_scalar_tensor_explicit(torch.int16, wp.uint16)
286
+ wrap_scalar_tensor_explicit(torch.int8, wp.int8)
287
+ wrap_scalar_tensor_explicit(torch.int8, wp.uint8)
288
+ wrap_scalar_tensor_explicit(torch.uint8, wp.uint8)
289
+ wrap_scalar_tensor_explicit(torch.uint8, wp.int8)
290
+ wrap_scalar_tensor_explicit(torch.bool, wp.uint8)
291
+ wrap_scalar_tensor_explicit(torch.bool, wp.int8)
292
+ wrap_scalar_tensor_explicit(torch.bool, wp.bool)
293
+
294
+ def wrap_vec_tensor(vec_dtype):
295
+ t = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
296
+ a = wp.from_torch(t, dtype=vec_dtype, return_ctype=True)
297
+ ctype_size = ctypes.sizeof(vec_dtype._type_)
298
+ assert a.data == t.data_ptr()
299
+ assert a.grad == 0
300
+ assert a.ndim == 1
301
+ assert a.shape[0] == t.shape[0]
302
+ assert a.strides[0] == t.stride()[0] * ctype_size
303
+
304
+ wrap_vec_tensor(wp.vec2)
305
+ wrap_vec_tensor(wp.vec3)
306
+ wrap_vec_tensor(wp.vec4)
307
+ wrap_vec_tensor(wp.spatial_vector)
308
+ wrap_vec_tensor(wp.transform)
309
+
310
+ def wrap_mat_tensor(mat_dtype):
311
+ t = torch.zeros((10, *mat_dtype._shape_), dtype=torch.float32, device=torch_device)
312
+ a = wp.from_torch(t, dtype=mat_dtype, return_ctype=True)
313
+ ctype_size = ctypes.sizeof(mat_dtype._type_)
314
+ assert a.data == t.data_ptr()
315
+ assert a.grad == 0
316
+ assert a.ndim == 1
317
+ assert a.shape[0] == t.shape[0]
318
+ assert a.strides[0] == t.stride()[0] * ctype_size
319
+
320
+ wrap_mat_tensor(wp.mat22)
321
+ wrap_mat_tensor(wp.mat33)
322
+ wrap_mat_tensor(wp.mat44)
323
+ wrap_mat_tensor(wp.spatial_matrix)
324
+
325
+ def wrap_vec_tensor_with_existing_grad(vec_dtype):
326
+ t = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device, requires_grad=True)
327
+ t.grad = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
328
+ a = wp.from_torch(t, dtype=vec_dtype, return_ctype=True)
329
+ ctype_size = ctypes.sizeof(vec_dtype._type_)
330
+ assert a.data == t.data_ptr()
331
+ assert a.grad == t.grad.data_ptr()
332
+ assert a.ndim == 1
333
+ assert a.shape[0] == t.shape[0]
334
+ assert a.strides[0] == t.stride()[0] * ctype_size
335
+
336
+ wrap_vec_tensor_with_existing_grad(wp.vec2)
337
+ wrap_vec_tensor_with_existing_grad(wp.vec3)
338
+ wrap_vec_tensor_with_existing_grad(wp.vec4)
339
+ wrap_vec_tensor_with_existing_grad(wp.spatial_vector)
340
+ wrap_vec_tensor_with_existing_grad(wp.transform)
341
+
342
+ def wrap_vec_tensor_with_new_grad(vec_dtype):
343
+ t = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
344
+ a = wp.from_torch(t, dtype=vec_dtype, requires_grad=True, return_ctype=True)
345
+ ctype_size = ctypes.sizeof(vec_dtype._type_)
346
+ assert a.data == t.data_ptr()
347
+ assert a.grad == t.grad.data_ptr()
348
+ assert a.ndim == 1
349
+ assert a.shape[0] == t.shape[0]
350
+ assert a.strides[0] == t.stride()[0] * ctype_size
351
+
352
+ wrap_vec_tensor_with_new_grad(wp.vec2)
353
+ wrap_vec_tensor_with_new_grad(wp.vec3)
354
+ wrap_vec_tensor_with_new_grad(wp.vec4)
355
+ wrap_vec_tensor_with_new_grad(wp.spatial_vector)
356
+ wrap_vec_tensor_with_new_grad(wp.transform)
357
+
358
+ def wrap_vec_tensor_with_torch_grad(vec_dtype):
359
+ t = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
360
+ grad = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
361
+ a = wp.from_torch(t, dtype=vec_dtype, grad=grad, return_ctype=True)
362
+ ctype_size = ctypes.sizeof(vec_dtype._type_)
363
+ assert a.data == t.data_ptr()
364
+ assert a.grad == grad.data_ptr()
365
+ assert a.ndim == 1
366
+ assert a.shape[0] == t.shape[0]
367
+ assert a.strides[0] == t.stride()[0] * ctype_size
368
+
369
+ wrap_vec_tensor_with_torch_grad(wp.vec2)
370
+ wrap_vec_tensor_with_torch_grad(wp.vec3)
371
+ wrap_vec_tensor_with_torch_grad(wp.vec4)
372
+ wrap_vec_tensor_with_torch_grad(wp.spatial_vector)
373
+ wrap_vec_tensor_with_torch_grad(wp.transform)
374
+
375
+ def wrap_vec_tensor_with_warp_grad(vec_dtype):
376
+ t = torch.zeros((10, vec_dtype._length_), dtype=torch.float32, device=torch_device)
377
+ grad = wp.zeros(10, dtype=vec_dtype, device=device)
378
+ a = wp.from_torch(t, dtype=vec_dtype, grad=grad, return_ctype=True)
379
+ ctype_size = ctypes.sizeof(vec_dtype._type_)
380
+ assert a.data == t.data_ptr()
381
+ assert a.grad == grad.ptr
382
+ assert a.ndim == 1
383
+ assert a.shape[0] == t.shape[0]
384
+ assert a.strides[0] == t.stride()[0] * ctype_size
385
+
386
+ wrap_vec_tensor_with_warp_grad(wp.vec2)
387
+ wrap_vec_tensor_with_warp_grad(wp.vec3)
388
+ wrap_vec_tensor_with_warp_grad(wp.vec4)
389
+ wrap_vec_tensor_with_warp_grad(wp.spatial_vector)
390
+ wrap_vec_tensor_with_warp_grad(wp.transform)
391
+
392
+
393
+ def test_cuda_array_interface(test, device):
394
+ # We should be able to construct Torch tensors from Warp arrays via __cuda_array_interface__ on GPU.
395
+ # Note that Torch does not support __array_interface__ on CPU.
396
+
397
+ torch_device = wp.device_to_torch(device)
398
+ n = 10
399
+
400
+ # test the types supported by both Warp and Torch
401
+ scalar_types = [wp.float16, wp.float32, wp.float64, wp.int8, wp.int16, wp.int32, wp.int64, wp.uint8]
402
+
403
+ for dtype in scalar_types:
404
+ # test round trip
405
+ a1 = wp.zeros(n, dtype=dtype, device=device)
406
+ t = torch.tensor(a1, device=torch_device)
407
+ a2 = wp.array(t, device=device)
408
+
409
+ assert a1.dtype == a2.dtype
410
+ assert a1.shape == a2.shape
411
+ assert a1.strides == a2.strides
412
+
413
+
414
+ @wp.kernel
415
+ def vec_sum_kernel(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=wp.vec3), z: wp.array(dtype=wp.vec3)):
416
+ tid = wp.tid()
417
+ z[tid] = x[tid] + y[tid]
418
+
419
+
420
+ # ensure torch arrays passed to Warp kernels are unchanged by Tape.backward()
421
+ def test_tensor_in_warp_kernel(test, device):
422
+ torch_device = wp.device_to_torch(device)
423
+
424
+ x = torch.ones((10, 3), dtype=torch.float32, device=torch_device)
425
+ y = torch.ones((10, 3), dtype=torch.float32, device=torch_device)
426
+ wp_y = wp.from_torch(y, dtype=wp.vec3, requires_grad=True)
427
+ z = torch.zeros((10, 3), dtype=torch.float32, device=torch_device)
428
+ wp_z = wp.from_torch(z, dtype=wp.vec3, requires_grad=True)
429
+
430
+ tape = wp.Tape()
431
+
432
+ with tape:
433
+ wp.launch(vec_sum_kernel, dim=10, inputs=[x, wp_y], outputs=[wp_z], device=device)
434
+
435
+ assert_np_equal(x.cpu().numpy(), np.ones((10, 3), dtype=float))
436
+
437
+ tape.backward(grads={wp_z: wp.ones_like(wp_z)})
438
+
439
+ # x is unchanged by Tape.backward()
440
+ assert_np_equal(x.cpu().numpy(), np.ones((10, 3), dtype=float))
441
+
442
+ # we can still compute the gradient of y because Warp created an array for it
443
+ assert_np_equal(y.grad.cpu().numpy(), np.ones((10, 3), dtype=float))
444
+
445
+
446
+ def test_to_torch(test, device):
447
+ import torch
448
+
449
+ def wrap_scalar_array(warp_dtype, expected_torch_dtype):
450
+ a = wp.zeros(10, dtype=warp_dtype, device=device)
451
+ t = wp.to_torch(a)
452
+ assert t.dtype == expected_torch_dtype
453
+ assert tuple(t.shape) == a.shape
454
+
455
+ wrap_scalar_array(wp.float64, torch.float64)
456
+ wrap_scalar_array(wp.float32, torch.float32)
457
+ wrap_scalar_array(wp.float16, torch.float16)
458
+ wrap_scalar_array(wp.int64, torch.int64)
459
+ wrap_scalar_array(wp.int32, torch.int32)
460
+ wrap_scalar_array(wp.int16, torch.int16)
461
+ wrap_scalar_array(wp.int8, torch.int8)
462
+ wrap_scalar_array(wp.uint8, torch.uint8)
463
+ wrap_scalar_array(wp.bool, torch.bool)
464
+
465
+ # not supported by torch
466
+ # wrap_scalar_array(wp.uint64, torch.int64)
467
+ # wrap_scalar_array(wp.uint32, torch.int32)
468
+ # wrap_scalar_array(wp.uint16, torch.int16)
469
+
470
+ def wrap_vec_array(n, warp_dtype):
471
+ a = wp.zeros(10, dtype=warp_dtype, device=device)
472
+ t = wp.to_torch(a)
473
+ assert t.dtype == torch.float32
474
+ assert tuple(t.shape) == (10, n)
475
+
476
+ wrap_vec_array(2, wp.vec2)
477
+ wrap_vec_array(3, wp.vec3)
478
+ wrap_vec_array(4, wp.vec4)
479
+ wrap_vec_array(6, wp.spatial_vector)
480
+ wrap_vec_array(7, wp.transform)
481
+
482
+ def wrap_mat_array(n, m, warp_dtype):
483
+ a = wp.zeros(10, dtype=warp_dtype, device=device)
484
+ t = wp.to_torch(a)
485
+ assert t.dtype == torch.float32
486
+ assert tuple(t.shape) == (10, n, m)
487
+
488
+ wrap_mat_array(2, 2, wp.mat22)
489
+ wrap_mat_array(3, 3, wp.mat33)
490
+ wrap_mat_array(4, 4, wp.mat44)
491
+ wrap_mat_array(6, 6, wp.spatial_matrix)
492
+
493
+
494
+ def test_from_torch_slices(test, device):
495
+ import torch
496
+
497
+ torch_device = wp.device_to_torch(device)
498
+
499
+ # 1D slice, contiguous
500
+ t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
501
+ t = t_base[2:9]
502
+ a = wp.from_torch(t)
503
+ assert a.ptr == t.data_ptr()
504
+ assert a.is_contiguous
505
+ assert a.shape == tuple(t.shape)
506
+ assert_np_equal(a.numpy(), t.cpu().numpy())
507
+
508
+ # 1D slice with non-contiguous stride
509
+ t_base = torch.arange(10, dtype=torch.float32, device=torch_device)
510
+ t = t_base[2:9:2]
511
+ a = wp.from_torch(t)
512
+ assert a.ptr == t.data_ptr()
513
+ assert not a.is_contiguous
514
+ assert a.shape == tuple(t.shape)
515
+ # copy contents to contiguous array
516
+ a_contiguous = wp.empty_like(a)
517
+ wp.launch(copy1d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
518
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
519
+
520
+ # 2D slices (non-contiguous)
521
+ t_base = torch.arange(24, dtype=torch.float32, device=torch_device).reshape((4, 6))
522
+ t = t_base[1:3, 2:5]
523
+ a = wp.from_torch(t)
524
+ assert a.ptr == t.data_ptr()
525
+ assert not a.is_contiguous
526
+ assert a.shape == tuple(t.shape)
527
+ # copy contents to contiguous array
528
+ a_contiguous = wp.empty_like(a)
529
+ wp.launch(copy2d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
530
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
531
+
532
+ # 3D slices (non-contiguous)
533
+ t_base = torch.arange(36, dtype=torch.float32, device=torch_device).reshape((4, 3, 3))
534
+ t = t_base[::2, 0:1, 1:2]
535
+ a = wp.from_torch(t)
536
+ assert a.ptr == t.data_ptr()
537
+ assert not a.is_contiguous
538
+ assert a.shape == tuple(t.shape)
539
+ # copy contents to contiguous array
540
+ a_contiguous = wp.empty_like(a)
541
+ wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
542
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
543
+
544
+ # 2D slices of vec3 (inner contiguous, outer non-contiguous)
545
+ t_base = torch.arange(150, dtype=torch.float32, device=torch_device).reshape((10, 5, 3))
546
+ t = t_base[1:7:2, 2:5]
547
+ a = wp.from_torch(t, dtype=wp.vec3)
548
+ assert a.ptr == t.data_ptr()
549
+ assert not a.is_contiguous
550
+ assert a.shape == tuple(t.shape[:-1])
551
+ # copy contents to contiguous array
552
+ a_contiguous = wp.empty_like(a)
553
+ wp.launch(copy2d_vec3_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
554
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
555
+
556
+ # 2D slices of mat22 (inner contiguous, outer non-contiguous)
557
+ t_base = torch.arange(200, dtype=torch.float32, device=torch_device).reshape((10, 5, 2, 2))
558
+ t = t_base[1:7:2, 2:5]
559
+ a = wp.from_torch(t, dtype=wp.mat22)
560
+ assert a.ptr == t.data_ptr()
561
+ assert not a.is_contiguous
562
+ assert a.shape == tuple(t.shape[:-2])
563
+ # copy contents to contiguous array
564
+ a_contiguous = wp.empty_like(a)
565
+ wp.launch(copy2d_mat22_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
566
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
567
+
568
+
569
+ def test_from_torch_zero_strides(test, device):
570
+ import torch
571
+
572
+ torch_device = wp.device_to_torch(device)
573
+
574
+ t_base = torch.arange(9, dtype=torch.float32, device=torch_device).reshape((3, 3))
575
+
576
+ # expand outermost dimension
577
+ t = t_base.unsqueeze(0).expand(3, -1, -1)
578
+ a = wp.from_torch(t)
579
+ assert a.ptr == t.data_ptr()
580
+ assert not a.is_contiguous
581
+ assert a.shape == tuple(t.shape)
582
+ a_contiguous = wp.empty_like(a)
583
+ wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
584
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
585
+
586
+ # expand middle dimension
587
+ t = t_base.unsqueeze(1).expand(-1, 3, -1)
588
+ a = wp.from_torch(t)
589
+ assert a.ptr == t.data_ptr()
590
+ assert not a.is_contiguous
591
+ assert a.shape == tuple(t.shape)
592
+ a_contiguous = wp.empty_like(a)
593
+ wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
594
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
595
+
596
+ # expand innermost dimension
597
+ t = t_base.unsqueeze(2).expand(-1, -1, 3)
598
+ a = wp.from_torch(t)
599
+ assert a.ptr == t.data_ptr()
600
+ assert not a.is_contiguous
601
+ assert a.shape == tuple(t.shape)
602
+ a_contiguous = wp.empty_like(a)
603
+ wp.launch(copy3d_float_kernel, dim=a.shape, inputs=[a_contiguous, a], device=device)
604
+ assert_np_equal(a_contiguous.numpy(), t.cpu().numpy())
605
+
606
+
607
+ def test_torch_mgpu_from_torch(test, device):
608
+ import torch
609
+
610
+ n = 32
611
+
612
+ t0 = torch.arange(0, n, 1, dtype=torch.int32, device="cuda:0")
613
+ t1 = torch.arange(0, n * 2, 2, dtype=torch.int32, device="cuda:1")
614
+
615
+ a0 = wp.from_torch(t0, dtype=wp.int32)
616
+ a1 = wp.from_torch(t1, dtype=wp.int32)
617
+
618
+ assert a0.device == "cuda:0"
619
+ assert a1.device == "cuda:1"
620
+
621
+ expected0 = np.arange(0, n, 1)
622
+ expected1 = np.arange(0, n * 2, 2)
623
+
624
+ assert_np_equal(a0.numpy(), expected0)
625
+ assert_np_equal(a1.numpy(), expected1)
626
+
627
+
628
+ def test_torch_mgpu_to_torch(test, device):
629
+ n = 32
630
+
631
+ with wp.ScopedDevice("cuda:0"):
632
+ a0 = wp.empty(n, dtype=wp.int32)
633
+ wp.launch(arange, dim=a0.size, inputs=[0, 1, a0])
634
+
635
+ with wp.ScopedDevice("cuda:1"):
636
+ a1 = wp.empty(n, dtype=wp.int32)
637
+ wp.launch(arange, dim=a1.size, inputs=[0, 2, a1])
638
+
639
+ t0 = wp.to_torch(a0)
640
+ t1 = wp.to_torch(a1)
641
+
642
+ assert str(t0.device) == "cuda:0"
643
+ assert str(t1.device) == "cuda:1"
644
+
645
+ expected0 = np.arange(0, n, 1, dtype=np.int32)
646
+ expected1 = np.arange(0, n * 2, 2, dtype=np.int32)
647
+
648
+ assert_np_equal(t0.cpu().numpy(), expected0)
649
+ assert_np_equal(t1.cpu().numpy(), expected1)
650
+
651
+
652
+ def test_torch_mgpu_interop(test, device):
653
+ import torch
654
+
655
+ n = 1024 * 1024
656
+
657
+ with torch.cuda.device(0):
658
+ t0 = torch.arange(n, dtype=torch.float32, device="cuda")
659
+ a0 = wp.from_torch(t0)
660
+ wp.launch(inc, dim=a0.size, inputs=[a0], stream=wp.stream_from_torch())
661
+
662
+ with torch.cuda.device(1):
663
+ t1 = torch.arange(n, dtype=torch.float32, device="cuda")
664
+ a1 = wp.from_torch(t1)
665
+ wp.launch(inc, dim=a1.size, inputs=[a1], stream=wp.stream_from_torch())
666
+
667
+ assert a0.device == "cuda:0"
668
+ assert a1.device == "cuda:1"
669
+
670
+ expected = np.arange(n, dtype=int) + 1
671
+
672
+ # ensure the torch tensors were modified by warp
673
+ assert_np_equal(t0.cpu().numpy(), expected)
674
+ assert_np_equal(t1.cpu().numpy(), expected)
675
+
676
+
677
+ def test_torch_autograd(test, device):
678
+ """Test torch autograd with a custom Warp op"""
679
+
680
+ import torch
681
+
682
+ # custom autograd op
683
+ class TestFunc(torch.autograd.Function):
684
+ @staticmethod
685
+ def forward(ctx, x):
686
+ # allocate output array
687
+ y = torch.empty_like(x)
688
+
689
+ ctx.x = x
690
+ ctx.y = y
691
+
692
+ wp.launch(kernel=op_kernel, dim=len(x), inputs=[wp.from_torch(x)], outputs=[wp.from_torch(y)])
693
+
694
+ return y
695
+
696
+ @staticmethod
697
+ def backward(ctx, adj_y):
698
+ # adjoints should be allocated as zero initialized
699
+ adj_x = torch.zeros_like(ctx.x).contiguous()
700
+ adj_y = adj_y.contiguous()
701
+
702
+ wp_x = wp.from_torch(ctx.x, grad=adj_x)
703
+ wp_y = wp.from_torch(ctx.y, grad=adj_y)
704
+
705
+ wp.launch(
706
+ kernel=op_kernel,
707
+ dim=len(ctx.x),
708
+ # fwd inputs
709
+ inputs=[wp_x],
710
+ outputs=[wp_y],
711
+ # adj inputs (already stored in input/output arrays, passing null pointers)
712
+ adj_inputs=[None],
713
+ adj_outputs=[None],
714
+ adjoint=True,
715
+ )
716
+
717
+ return adj_x
718
+
719
+ # run autograd on given device
720
+ with wp.ScopedDevice(device):
721
+ torch_device = wp.device_to_torch(device)
722
+
723
+ # input data
724
+ x = torch.ones(16, dtype=torch.float32, device=torch_device, requires_grad=True)
725
+
726
+ # execute op
727
+ y = TestFunc.apply(x)
728
+
729
+ # compute grads
730
+ l = y.sum()
731
+ l.backward()
732
+
733
+ passed = (x.grad == -2.0).all()
734
+ assert passed.item()
735
+
736
+
737
+ def test_torch_graph_torch_stream(test, device):
738
+ """Capture Torch graph on Torch stream"""
739
+
740
+ wp.load_module(device=device)
741
+
742
+ import torch
743
+
744
+ torch_device = wp.device_to_torch(device)
745
+
746
+ n = 1024 * 1024
747
+ t = torch.zeros(n, dtype=torch.float32, device=torch_device)
748
+ a = wp.from_torch(t)
749
+
750
+ g = torch.cuda.CUDAGraph()
751
+
752
+ # create a device-specific torch stream to use for capture
753
+ # (otherwise torch.cuda.graph reuses its capture stream, which can be problematic if it's from a different device)
754
+ torch_stream = torch.cuda.Stream(device=torch_device)
755
+
756
+ # make warp use the same stream
757
+ warp_stream = wp.stream_from_torch(torch_stream)
758
+
759
+ # capture graph
760
+ with wp.ScopedStream(warp_stream), torch.cuda.graph(g, stream=torch_stream):
761
+ wp.capture_begin(force_module_load=False, external=True)
762
+ try:
763
+ t += 1.0
764
+ wp.launch(inc, dim=n, inputs=[a])
765
+ t += 1.0
766
+ wp.launch(inc, dim=n, inputs=[a])
767
+ finally:
768
+ wp.capture_end()
769
+
770
+ # replay graph
771
+ num_iters = 10
772
+ for _i in range(num_iters):
773
+ g.replay()
774
+
775
+ passed = (t == num_iters * 4.0).all()
776
+ assert passed.item()
777
+
778
+
779
+ def test_torch_graph_warp_stream(test, device):
780
+ """Capture Torch graph on Warp stream"""
781
+
782
+ import torch
783
+
784
+ torch_device = wp.device_to_torch(device)
785
+
786
+ n = 1024 * 1024
787
+ t = torch.zeros(n, dtype=torch.float32, device=torch_device)
788
+ a = wp.from_torch(t)
789
+
790
+ g = torch.cuda.CUDAGraph()
791
+
792
+ # make torch use the warp stream from the given device
793
+ torch_stream = wp.stream_to_torch(device)
794
+
795
+ # capture graph
796
+ with wp.ScopedDevice(device), torch.cuda.graph(g, stream=torch_stream):
797
+ wp.capture_begin(force_module_load=False, external=True)
798
+ try:
799
+ t += 1.0
800
+ wp.launch(inc, dim=n, inputs=[a])
801
+ t += 1.0
802
+ wp.launch(inc, dim=n, inputs=[a])
803
+ finally:
804
+ wp.capture_end()
805
+
806
+ # replay graph
807
+ num_iters = 10
808
+ for _i in range(num_iters):
809
+ g.replay()
810
+
811
+ passed = (t == num_iters * 4.0).all()
812
+ assert passed.item()
813
+
814
+
815
+ def test_warp_graph_warp_stream(test, device):
816
+ """Capture Warp graph on Warp stream"""
817
+
818
+ import torch
819
+
820
+ torch_device = wp.device_to_torch(device)
821
+
822
+ n = 1024 * 1024
823
+ t = torch.zeros(n, dtype=torch.float32, device=torch_device)
824
+ a = wp.from_torch(t)
825
+
826
+ # make torch use the warp stream from the given device
827
+ torch_stream = wp.stream_to_torch(device)
828
+
829
+ # capture graph
830
+ with wp.ScopedDevice(device), torch.cuda.stream(torch_stream):
831
+ wp.capture_begin(force_module_load=False)
832
+ try:
833
+ t += 1.0
834
+ wp.launch(inc, dim=n, inputs=[a])
835
+ t += 1.0
836
+ wp.launch(inc, dim=n, inputs=[a])
837
+ finally:
838
+ g = wp.capture_end()
839
+
840
+ # replay graph
841
+ num_iters = 10
842
+ for _i in range(num_iters):
843
+ wp.capture_launch(g)
844
+
845
+ passed = (t == num_iters * 4.0).all()
846
+ assert passed.item()
847
+
848
+
849
+ def test_warp_graph_torch_stream(test, device):
850
+ """Capture Warp graph on Torch stream"""
851
+
852
+ wp.load_module(device=device)
853
+
854
+ import torch
855
+
856
+ torch_device = wp.device_to_torch(device)
857
+
858
+ n = 1024 * 1024
859
+ t = torch.zeros(n, dtype=torch.float32, device=torch_device)
860
+ a = wp.from_torch(t)
861
+
862
+ # create a device-specific torch stream to use for capture
863
+ # (the default torch stream is not suitable for graph capture)
864
+ torch_stream = torch.cuda.Stream(device=torch_device)
865
+
866
+ # make warp use the same stream
867
+ warp_stream = wp.stream_from_torch(torch_stream)
868
+
869
+ # capture graph
870
+ with wp.ScopedStream(warp_stream), torch.cuda.stream(torch_stream):
871
+ wp.capture_begin(force_module_load=False)
872
+ try:
873
+ t += 1.0
874
+ wp.launch(inc, dim=n, inputs=[a])
875
+ t += 1.0
876
+ wp.launch(inc, dim=n, inputs=[a])
877
+ finally:
878
+ g = wp.capture_end()
879
+
880
+ # replay graph
881
+ num_iters = 10
882
+ for _i in range(num_iters):
883
+ wp.capture_launch(g)
884
+
885
+ passed = (t == num_iters * 4.0).all()
886
+ assert passed.item()
887
+
888
+
889
+ def test_direct(test, device):
890
+ """Pass Torch tensors to Warp kernels directly"""
891
+
892
+ import torch
893
+
894
+ torch_device = wp.device_to_torch(device)
895
+ n = 12
896
+
897
+ s = torch.arange(n, dtype=torch.float32, device=torch_device)
898
+ v = torch.arange(n, dtype=torch.float32, device=torch_device).reshape((n // 3, 3))
899
+ m = torch.arange(n, dtype=torch.float32, device=torch_device).reshape((n // 4, 2, 2))
900
+
901
+ wp.launch(inc, dim=n, inputs=[s], device=device)
902
+ wp.launch(inc_vector, dim=n // 3, inputs=[v], device=device)
903
+ wp.launch(inc_matrix, dim=n // 4, inputs=[m], device=device)
904
+
905
+ expected = torch.arange(1, n + 1, dtype=torch.float32, device=torch_device)
906
+
907
+ assert torch.equal(s, expected)
908
+ assert torch.equal(v.reshape(n), expected)
909
+ assert torch.equal(m.reshape(n), expected)
910
+
911
+
912
+ class TestTorch(unittest.TestCase):
913
+ pass
914
+
915
+
916
+ test_devices = get_test_devices()
917
+
918
+ try:
919
+ import torch
920
+
921
+ # check which Warp devices work with Torch
922
+ # CUDA devices may fail if Torch was not compiled with CUDA support
923
+ torch_compatible_devices = []
924
+ torch_compatible_cuda_devices = []
925
+
926
+ for d in test_devices:
927
+ try:
928
+ t = torch.arange(10, device=wp.device_to_torch(d))
929
+ t += 1
930
+ torch_compatible_devices.append(d)
931
+ if d.is_cuda:
932
+ torch_compatible_cuda_devices.append(d)
933
+ except Exception as e:
934
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
935
+
936
+ add_function_test(TestTorch, "test_dtype_from_torch", test_dtype_from_torch, devices=None)
937
+ add_function_test(TestTorch, "test_dtype_to_torch", test_dtype_to_torch, devices=None)
938
+
939
+ if torch_compatible_devices:
940
+ add_function_test(TestTorch, "test_device_conversion", test_device_conversion, devices=torch_compatible_devices)
941
+ add_function_test(TestTorch, "test_from_torch", test_from_torch, devices=torch_compatible_devices)
942
+ add_function_test(TestTorch, "test_from_torch_slices", test_from_torch_slices, devices=torch_compatible_devices)
943
+ add_function_test(
944
+ TestTorch, "test_array_ctype_from_torch", test_array_ctype_from_torch, devices=torch_compatible_devices
945
+ )
946
+ add_function_test(
947
+ TestTorch,
948
+ "test_from_torch_zero_strides",
949
+ test_from_torch_zero_strides,
950
+ devices=torch_compatible_devices,
951
+ )
952
+ add_function_test(TestTorch, "test_to_torch", test_to_torch, devices=torch_compatible_devices)
953
+ add_function_test(TestTorch, "test_torch_zerocopy", test_torch_zerocopy, devices=torch_compatible_devices)
954
+ add_function_test(TestTorch, "test_torch_autograd", test_torch_autograd, devices=torch_compatible_devices)
955
+ add_function_test(TestTorch, "test_direct", test_direct, devices=torch_compatible_devices)
956
+ add_function_test(
957
+ TestTorch, "test_tensor_in_warp_kernel", test_tensor_in_warp_kernel, devices=torch_compatible_devices
958
+ )
959
+
960
+ if torch_compatible_cuda_devices:
961
+ add_function_test(
962
+ TestTorch,
963
+ "test_torch_graph_torch_stream",
964
+ test_torch_graph_torch_stream,
965
+ devices=torch_compatible_cuda_devices,
966
+ )
967
+ add_function_test(
968
+ TestTorch,
969
+ "test_torch_graph_warp_stream",
970
+ test_torch_graph_warp_stream,
971
+ devices=torch_compatible_cuda_devices,
972
+ )
973
+ add_function_test(
974
+ TestTorch,
975
+ "test_warp_graph_warp_stream",
976
+ test_warp_graph_warp_stream,
977
+ devices=torch_compatible_cuda_devices,
978
+ )
979
+ add_function_test(
980
+ TestTorch,
981
+ "test_warp_graph_torch_stream",
982
+ test_warp_graph_torch_stream,
983
+ devices=torch_compatible_cuda_devices,
984
+ )
985
+ add_function_test(
986
+ TestTorch, "test_cuda_array_interface", test_cuda_array_interface, devices=torch_compatible_cuda_devices
987
+ )
988
+
989
+ # multi-GPU tests
990
+ if len(torch_compatible_cuda_devices) > 1:
991
+ add_function_test(TestTorch, "test_torch_mgpu_from_torch", test_torch_mgpu_from_torch)
992
+ add_function_test(TestTorch, "test_torch_mgpu_to_torch", test_torch_mgpu_to_torch)
993
+ add_function_test(TestTorch, "test_torch_mgpu_interop", test_torch_mgpu_interop)
994
+
995
+ except Exception as e:
996
+ print(f"Skipping Torch tests due to exception: {e}")
997
+
998
+
999
+ if __name__ == "__main__":
1000
+ wp.clear_kernel_cache()
1001
+ unittest.main(verbosity=2)