warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/torch.py ADDED
@@ -0,0 +1,391 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+
18
+ import numpy
19
+
20
+ import warp
21
+ import warp.context
22
+
23
+
24
+ # return the warp device corresponding to a torch device
25
+ def device_from_torch(torch_device) -> warp.context.Device:
26
+ """Return the Warp device corresponding to a Torch device.
27
+
28
+ Args:
29
+ torch_device (`torch.device` or `str`): Torch device identifier
30
+
31
+ Raises:
32
+ RuntimeError: Torch device does not have a corresponding Warp device
33
+ """
34
+ if type(torch_device) is str:
35
+ warp_device = warp.context.runtime.device_map.get(torch_device)
36
+ if warp_device is not None:
37
+ return warp_device
38
+ elif torch_device == "cuda":
39
+ return warp.context.runtime.get_current_cuda_device()
40
+ else:
41
+ raise RuntimeError(f"Unsupported Torch device {torch_device}")
42
+ else:
43
+ try:
44
+ if torch_device.type == "cuda":
45
+ return warp.context.runtime.cuda_devices[torch_device.index]
46
+ elif torch_device.type == "cpu":
47
+ return warp.context.runtime.cpu_device
48
+ else:
49
+ raise RuntimeError(f"Unsupported Torch device type {torch_device.type}")
50
+ except Exception as e:
51
+ import torch
52
+
53
+ if not isinstance(torch_device, torch.device):
54
+ raise ValueError("Argument must be a torch.device object or a string") from e
55
+ raise
56
+
57
+
58
+ def device_to_torch(warp_device: warp.context.Devicelike) -> str:
59
+ """Return the Torch device string corresponding to a Warp device.
60
+
61
+ Args:
62
+ warp_device: An identifier that can be resolved to a :class:`warp.context.Device`.
63
+
64
+ Raises:
65
+ RuntimeError: The Warp device is not compatible with PyTorch.
66
+ """
67
+ device = warp.get_device(warp_device)
68
+ if device.is_cpu or device.is_primary:
69
+ return str(device)
70
+ elif device.is_cuda and device.is_uva:
71
+ # it's not a primary context, but torch can access the data ptr directly thanks to UVA
72
+ return f"cuda:{device.ordinal}"
73
+ raise RuntimeError(f"Warp device {device} is not compatible with torch")
74
+
75
+
76
+ def dtype_to_torch(warp_dtype):
77
+ """Return the Torch dtype corresponding to a Warp dtype.
78
+
79
+ Args:
80
+ warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
81
+ ``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
82
+ to the signed integer ``torch.dtype`` of the same width.
83
+ Raises:
84
+ TypeError: Unable to find a corresponding PyTorch data type.
85
+ """
86
+ # initialize lookup table on first call to defer torch import
87
+ if dtype_to_torch.type_map is None:
88
+ import torch
89
+
90
+ dtype_to_torch.type_map = {
91
+ warp.float16: torch.float16,
92
+ warp.float32: torch.float32,
93
+ warp.float64: torch.float64,
94
+ warp.int8: torch.int8,
95
+ warp.int16: torch.int16,
96
+ warp.int32: torch.int32,
97
+ warp.int64: torch.int64,
98
+ warp.uint8: torch.uint8,
99
+ # torch doesn't support unsigned ints bigger than 8 bits
100
+ warp.uint16: torch.int16,
101
+ warp.uint32: torch.int32,
102
+ warp.uint64: torch.int64,
103
+ warp.bool: torch.bool,
104
+ }
105
+
106
+ torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
107
+ if torch_dtype is not None:
108
+ return torch_dtype
109
+ else:
110
+ raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
111
+
112
+
113
+ def dtype_from_torch(torch_dtype):
114
+ """Return the Warp dtype corresponding to a Torch dtype.
115
+
116
+ Args:
117
+ torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
118
+ Currently ``torch.bfloat16``, ``torch.complex64``, and
119
+ ``torch.complex128`` are not supported.
120
+
121
+ Raises:
122
+ TypeError: Unable to find a corresponding Warp data type.
123
+ """
124
+ # initialize lookup table on first call to defer torch import
125
+ if dtype_from_torch.type_map is None:
126
+ import torch
127
+
128
+ dtype_from_torch.type_map = {
129
+ torch.float16: warp.float16,
130
+ torch.float32: warp.float32,
131
+ torch.float64: warp.float64,
132
+ torch.int8: warp.int8,
133
+ torch.int16: warp.int16,
134
+ torch.int32: warp.int32,
135
+ torch.int64: warp.int64,
136
+ torch.uint8: warp.uint8,
137
+ torch.bool: warp.bool,
138
+ # currently unsupported by Warp
139
+ # torch.bfloat16:
140
+ # torch.complex64:
141
+ # torch.complex128:
142
+ }
143
+
144
+ warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
145
+
146
+ if warp_dtype is not None:
147
+ return warp_dtype
148
+ else:
149
+ raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
150
+
151
+
152
+ def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
153
+ """Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
154
+ # initialize lookup table on first call to defer torch import
155
+ if dtype_is_compatible.compatible_sets is None:
156
+ import torch
157
+
158
+ dtype_is_compatible.compatible_sets = {
159
+ torch.float64: {warp.float64},
160
+ torch.float32: {warp.float32},
161
+ torch.float16: {warp.float16},
162
+ # allow aliasing integer tensors as signed or unsigned integer arrays
163
+ torch.int64: {warp.int64, warp.uint64},
164
+ torch.int32: {warp.int32, warp.uint32},
165
+ torch.int16: {warp.int16, warp.uint16},
166
+ torch.int8: {warp.int8, warp.uint8},
167
+ torch.uint8: {warp.uint8, warp.int8},
168
+ torch.bool: {warp.bool, warp.uint8, warp.int8},
169
+ # currently unsupported by Warp
170
+ # torch.bfloat16:
171
+ # torch.complex64:
172
+ # torch.complex128:
173
+ }
174
+
175
+ compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
176
+
177
+ if compatible_set is not None:
178
+ if warp_dtype in compatible_set:
179
+ return True
180
+ # check if it's a vector or matrix type
181
+ if hasattr(warp_dtype, "_wp_scalar_type_"):
182
+ return warp_dtype._wp_scalar_type_ in compatible_set
183
+
184
+ return False
185
+
186
+
187
+ # lookup tables initialized when needed
188
+ dtype_from_torch.type_map = None
189
+ dtype_to_torch.type_map = None
190
+ dtype_is_compatible.compatible_sets = None
191
+
192
+
193
+ # wrap a torch tensor to a wp array, data is not copied
194
+ def from_torch(t, dtype=None, requires_grad=None, grad=None, return_ctype=False):
195
+ """Convert a Torch tensor to a Warp array without copying the data.
196
+
197
+ Args:
198
+ t (torch.Tensor): The torch tensor to wrap.
199
+ dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
200
+ requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
201
+ return_ctype (bool, optional): Whether to return a low-level array descriptor instead of a ``wp.array`` object (faster). The descriptor can be passed to Warp kernels.
202
+
203
+ Returns:
204
+ warp.array: The wrapped array or array descriptor.
205
+ """
206
+ if dtype is None:
207
+ dtype = dtype_from_torch(t.dtype)
208
+ elif not dtype_is_compatible(t.dtype, dtype):
209
+ raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
210
+
211
+ # get size of underlying data type to compute strides
212
+ ctype_size = ctypes.sizeof(dtype._type_)
213
+
214
+ shape = tuple(t.shape)
215
+ strides = tuple(s * ctype_size for s in t.stride())
216
+
217
+ # if target is a vector or matrix type
218
+ # then check if trailing dimensions match
219
+ # the target type and update the shape
220
+ if hasattr(dtype, "_shape_"):
221
+ dtype_shape = dtype._shape_
222
+ dtype_dims = len(dtype._shape_)
223
+ # ensure inner shape matches
224
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
225
+ raise RuntimeError(
226
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
227
+ )
228
+ # ensure inner strides are contiguous
229
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
230
+ raise RuntimeError(
231
+ f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
232
+ )
233
+ # trim shape and strides
234
+ shape = tuple(shape[:-dtype_dims]) or (1,)
235
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
236
+
237
+ # gradient
238
+ # - if return_ctype is False, we set `grad` to a wp.array or None
239
+ # - if return_ctype is True, we set `grad_ptr` and set `grad` as the owner (wp.array or torch.Tensor)
240
+ requires_grad = t.requires_grad if requires_grad is None else requires_grad
241
+ grad_ptr = 0
242
+ if grad is not None:
243
+ if isinstance(grad, warp.array):
244
+ if return_ctype:
245
+ if grad.strides != strides:
246
+ raise RuntimeError(
247
+ f"Gradient strides must match array strides, expected {strides} but got {grad.strides}"
248
+ )
249
+ grad_ptr = grad.ptr
250
+ else:
251
+ # assume grad is a torch.Tensor
252
+ if return_ctype:
253
+ if t.stride() != grad.stride():
254
+ raise RuntimeError(
255
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
256
+ )
257
+ grad_ptr = grad.data_ptr()
258
+ else:
259
+ grad = from_torch(grad, dtype=dtype, requires_grad=False)
260
+ elif requires_grad:
261
+ # wrap the tensor gradient, allocate if necessary
262
+ if t.grad is not None:
263
+ if return_ctype:
264
+ grad = t.grad
265
+ if t.stride() != grad.stride():
266
+ raise RuntimeError(
267
+ f"Gradient strides must match array strides, expected {t.stride()} but got {grad.stride()}"
268
+ )
269
+ grad_ptr = grad.data_ptr()
270
+ else:
271
+ grad = from_torch(t.grad, dtype=dtype, requires_grad=False)
272
+ else:
273
+ # allocate a zero-filled gradient if it doesn't exist
274
+ # Note: we use Warp to allocate the shared gradient with compatible strides
275
+ grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device_from_torch(t.device))
276
+ t.grad = to_torch(grad, requires_grad=False)
277
+ grad_ptr = grad.ptr
278
+
279
+ if return_ctype:
280
+ ptr = t.data_ptr()
281
+
282
+ # create array descriptor
283
+ array_ctype = warp.types.array_t(ptr, grad_ptr, len(shape), shape, strides)
284
+
285
+ # keep data and gradient alive
286
+ array_ctype._ref = t
287
+ array_ctype._gradref = grad
288
+
289
+ return array_ctype
290
+
291
+ else:
292
+ a = warp.array(
293
+ ptr=t.data_ptr(),
294
+ dtype=dtype,
295
+ shape=shape,
296
+ strides=strides,
297
+ device=device_from_torch(t.device),
298
+ copy=False,
299
+ grad=grad,
300
+ requires_grad=requires_grad,
301
+ )
302
+
303
+ # save a reference to the source tensor, otherwise it may get deallocated
304
+ a._tensor = t
305
+
306
+ return a
307
+
308
+
309
+ def to_torch(a, requires_grad=None):
310
+ """
311
+ Convert a Warp array to a Torch tensor without copying the data.
312
+
313
+ Args:
314
+ a (warp.array): The Warp array to convert.
315
+ requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
316
+
317
+ Returns:
318
+ torch.Tensor: The converted tensor.
319
+ """
320
+ import torch
321
+
322
+ if requires_grad is None:
323
+ requires_grad = a.requires_grad
324
+
325
+ # Torch does not support structured arrays
326
+ if isinstance(a.dtype, warp.codegen.Struct):
327
+ raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
328
+
329
+ if a.device.is_cpu:
330
+ # Torch has an issue wrapping CPU objects
331
+ # that support the __array_interface__ protocol
332
+ # in this case we need to workaround by going
333
+ # to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
334
+ t = torch.as_tensor(numpy.asarray(a))
335
+ t.requires_grad = requires_grad
336
+ if requires_grad and a.requires_grad:
337
+ t.grad = torch.as_tensor(numpy.asarray(a.grad))
338
+ return t
339
+
340
+ elif a.device.is_cuda:
341
+ # Torch does support the __cuda_array_interface__
342
+ # correctly, but we must be sure to maintain a reference
343
+ # to the owning object to prevent memory allocs going out of scope
344
+ t = torch.as_tensor(a, device=device_to_torch(a.device))
345
+ t.requires_grad = requires_grad
346
+ if requires_grad and a.requires_grad:
347
+ t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
348
+ return t
349
+
350
+ else:
351
+ raise RuntimeError("Unsupported device")
352
+
353
+
354
+ def stream_from_torch(stream_or_device=None):
355
+ """Convert from a Torch CUDA stream to a Warp CUDA stream."""
356
+ import torch
357
+
358
+ if isinstance(stream_or_device, torch.cuda.Stream):
359
+ stream = stream_or_device
360
+ else:
361
+ # assume arg is a torch device
362
+ stream = torch.cuda.current_stream(stream_or_device)
363
+
364
+ device = device_from_torch(stream.device)
365
+
366
+ warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
367
+
368
+ # save a reference to the source stream, otherwise it may be destroyed
369
+ warp_stream._torch_stream = stream
370
+
371
+ return warp_stream
372
+
373
+
374
+ def stream_to_torch(stream_or_device=None):
375
+ """Convert from a Warp CUDA stream to a Torch CUDA stream."""
376
+ import torch
377
+
378
+ if isinstance(stream_or_device, warp.Stream):
379
+ stream = stream_or_device
380
+ else:
381
+ # assume arg is a warp device
382
+ stream = warp.get_device(stream_or_device).stream
383
+
384
+ device = device_to_torch(stream.device)
385
+
386
+ torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
387
+
388
+ # save a reference to the source stream, otherwise it may be destroyed
389
+ torch_stream._warp_stream = stream
390
+
391
+ return torch_stream