warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/utils.py ADDED
@@ -0,0 +1,1137 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import cProfile
19
+ import ctypes
20
+ import os
21
+ import sys
22
+ import time
23
+ import warnings
24
+ from typing import Any, Callable, Dict, List, Optional, Union
25
+
26
+ import numpy as np
27
+
28
+ import warp as wp
29
+ import warp.context
30
+ import warp.types
31
+ from warp.context import Devicelike
32
+
33
+ warnings_seen = set()
34
+
35
+
36
+ def warp_showwarning(message, category, filename, lineno, file=None, line=None):
37
+ """Version of warnings.showwarning that always prints to sys.stdout."""
38
+
39
+ if warp.config.verbose_warnings:
40
+ s = f"Warp {category.__name__}: {message} ({filename}:{lineno})\n"
41
+
42
+ if line is None:
43
+ try:
44
+ import linecache
45
+
46
+ line = linecache.getline(filename, lineno)
47
+ except Exception:
48
+ # When a warning is logged during Python shutdown, linecache
49
+ # and the import machinery don't work anymore
50
+ line = None
51
+ linecache = None
52
+
53
+ if line:
54
+ line = line.strip()
55
+ s += " %s\n" % line
56
+ else:
57
+ # simple warning
58
+ s = f"Warp {category.__name__}: {message}\n"
59
+
60
+ sys.stdout.write(s)
61
+
62
+
63
+ def warn(message, category=None, stacklevel=1):
64
+ if (category, message) in warnings_seen:
65
+ return
66
+
67
+ with warnings.catch_warnings():
68
+ warnings.simplefilter("default") # Change the filter in this process
69
+ warnings.showwarning = warp_showwarning
70
+ warnings.warn(
71
+ message,
72
+ category,
73
+ stacklevel=stacklevel + 1, # Increment stacklevel by 1 since we are in a wrapper
74
+ )
75
+
76
+ if category is DeprecationWarning:
77
+ warnings_seen.add((category, message))
78
+
79
+
80
+ # expand a 7-vec to a tuple of arrays
81
+ def transform_expand(t):
82
+ return wp.transform(np.array(t[0:3]), np.array(t[3:7]))
83
+
84
+
85
+ @wp.func
86
+ def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
87
+ """
88
+ Compute the quaternion that rotates vector a to vector b
89
+ """
90
+ a = wp.normalize(a)
91
+ b = wp.normalize(b)
92
+ c = wp.cross(a, b)
93
+ d = wp.dot(a, b)
94
+ q = wp.quat(c[0], c[1], c[2], 1.0 + d)
95
+ return wp.normalize(q)
96
+
97
+
98
+ def array_scan(in_array, out_array, inclusive=True):
99
+ if in_array.device != out_array.device:
100
+ raise RuntimeError("Array storage devices do not match")
101
+
102
+ if in_array.size != out_array.size:
103
+ raise RuntimeError("Array storage sizes do not match")
104
+
105
+ if in_array.dtype != out_array.dtype:
106
+ raise RuntimeError("Array data types do not match")
107
+
108
+ if in_array.size == 0:
109
+ return
110
+
111
+ from warp.context import runtime
112
+
113
+ if in_array.device.is_cpu:
114
+ if in_array.dtype == wp.int32:
115
+ runtime.core.array_scan_int_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
116
+ elif in_array.dtype == wp.float32:
117
+ runtime.core.array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
118
+ else:
119
+ raise RuntimeError("Unsupported data type")
120
+ elif in_array.device.is_cuda:
121
+ if in_array.dtype == wp.int32:
122
+ runtime.core.array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
123
+ elif in_array.dtype == wp.float32:
124
+ runtime.core.array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
125
+ else:
126
+ raise RuntimeError("Unsupported data type")
127
+
128
+
129
+ def radix_sort_pairs(keys, values, count: int):
130
+ if keys.device != values.device:
131
+ raise RuntimeError("Array storage devices do not match")
132
+
133
+ if count == 0:
134
+ return
135
+
136
+ if keys.size < 2 * count or values.size < 2 * count:
137
+ raise RuntimeError("Array storage must be large enough to contain 2*count elements")
138
+
139
+ from warp.context import runtime
140
+
141
+ if keys.device.is_cpu:
142
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
143
+ runtime.core.radix_sort_pairs_int_host(keys.ptr, values.ptr, count)
144
+ elif keys.dtype == wp.float32 and values.dtype == wp.int32:
145
+ runtime.core.radix_sort_pairs_float_host(keys.ptr, values.ptr, count)
146
+ elif keys.dtype == wp.int64 and values.dtype == wp.int32:
147
+ runtime.core.radix_sort_pairs_int64_host(keys.ptr, values.ptr, count)
148
+ else:
149
+ raise RuntimeError("Unsupported data type")
150
+ elif keys.device.is_cuda:
151
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
152
+ runtime.core.radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
153
+ elif keys.dtype == wp.float32 and values.dtype == wp.int32:
154
+ runtime.core.radix_sort_pairs_float_device(keys.ptr, values.ptr, count)
155
+ elif keys.dtype == wp.int64 and values.dtype == wp.int32:
156
+ runtime.core.radix_sort_pairs_int64_device(keys.ptr, values.ptr, count)
157
+ else:
158
+ raise RuntimeError("Unsupported data type")
159
+
160
+
161
+ def segmented_sort_pairs(
162
+ keys,
163
+ values,
164
+ count: int,
165
+ segment_start_indices: wp.array(dtype=wp.int32),
166
+ segment_end_indices: wp.array(dtype=wp.int32) = None,
167
+ ):
168
+ """Sort key-value pairs within segments.
169
+
170
+ This function performs a segmented sort of key-value pairs, where the sorting is done independently within each segment.
171
+ The segments are defined by their start and optionally end indices.
172
+
173
+ Args:
174
+ keys: Array of keys to sort. Must be of type int32 or float32.
175
+ values: Array of values to sort along with keys. Must be of type int32.
176
+ count: Number of elements to sort.
177
+ segment_start_indices: Array containing start index of each segment. Must be of type int32.
178
+ If segment_end_indices is None, this array must have length at least num_segments + 1,
179
+ and segment_end_indices will be inferred as segment_start_indices[1:].
180
+ If segment_end_indices is provided, this array must have length at least num_segments.
181
+ segment_end_indices: Optional array containing end index of each segment. Must be of type int32 if provided.
182
+ If None, segment_end_indices will be inferred from segment_start_indices[1:].
183
+ If provided, must have length at least num_segments.
184
+
185
+ Raises:
186
+ RuntimeError: If array storage devices don't match, if storage size is insufficient,
187
+ if segment_start_indices is not of type int32, or if data types are unsupported.
188
+ """
189
+ if keys.device != values.device:
190
+ raise RuntimeError("Array storage devices do not match")
191
+
192
+ if count == 0:
193
+ return
194
+
195
+ if keys.size < 2 * count or values.size < 2 * count:
196
+ raise RuntimeError("Array storage must be large enough to contain 2*count elements")
197
+
198
+ from warp.context import runtime
199
+
200
+ if segment_start_indices.dtype != wp.int32:
201
+ raise RuntimeError("segment_start_indices array must be of type int32")
202
+
203
+ # Handle case where segment_end_indices is not provided
204
+ if segment_end_indices is None:
205
+ num_segments = max(0, segment_start_indices.size - 1)
206
+
207
+ segment_end_indices = segment_start_indices[1:]
208
+ segment_end_indices_ptr = segment_end_indices.ptr
209
+ segment_start_indices_ptr = segment_start_indices.ptr
210
+ else:
211
+ if segment_end_indices.dtype != wp.int32:
212
+ raise RuntimeError("segment_end_indices array must be of type int32")
213
+
214
+ num_segments = segment_start_indices.size
215
+
216
+ segment_end_indices_ptr = segment_end_indices.ptr
217
+ segment_start_indices_ptr = segment_start_indices.ptr
218
+
219
+ if keys.device.is_cpu:
220
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
221
+ runtime.core.segmented_sort_pairs_int_host(
222
+ keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
223
+ )
224
+ elif keys.dtype == wp.float32 and values.dtype == wp.int32:
225
+ runtime.core.segmented_sort_pairs_float_host(
226
+ keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
227
+ )
228
+ else:
229
+ raise RuntimeError("Unsupported data type")
230
+ elif keys.device.is_cuda:
231
+ if keys.dtype == wp.int32 and values.dtype == wp.int32:
232
+ runtime.core.segmented_sort_pairs_int_device(
233
+ keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
234
+ )
235
+ elif keys.dtype == wp.float32 and values.dtype == wp.int32:
236
+ runtime.core.segmented_sort_pairs_float_device(
237
+ keys.ptr, values.ptr, count, segment_start_indices_ptr, segment_end_indices_ptr, num_segments
238
+ )
239
+ else:
240
+ raise RuntimeError("Unsupported data type")
241
+
242
+
243
+ def runlength_encode(values, run_values, run_lengths, run_count=None, value_count=None):
244
+ if run_values.device != values.device or run_lengths.device != values.device:
245
+ raise RuntimeError("Array storage devices do not match")
246
+
247
+ if value_count is None:
248
+ value_count = values.size
249
+
250
+ if run_values.size < value_count or run_lengths.size < value_count:
251
+ raise RuntimeError("Output array storage sizes must be at least equal to value_count")
252
+
253
+ if values.dtype != run_values.dtype:
254
+ raise RuntimeError("values and run_values data types do not match")
255
+
256
+ if run_lengths.dtype != wp.int32:
257
+ raise RuntimeError("run_lengths array must be of type int32")
258
+
259
+ # User can provide a device output array for storing the number of runs
260
+ # For convenience, if no such array is provided, number of runs is returned on host
261
+ if run_count is None:
262
+ if value_count == 0:
263
+ return 0
264
+ run_count = wp.empty(shape=(1,), dtype=int, device=values.device)
265
+ host_return = True
266
+ else:
267
+ if run_count.device != values.device:
268
+ raise RuntimeError("run_count storage device does not match other arrays")
269
+ if run_count.dtype != wp.int32:
270
+ raise RuntimeError("run_count array must be of type int32")
271
+ if value_count == 0:
272
+ run_count.zero_()
273
+ return 0
274
+ host_return = False
275
+
276
+ from warp.context import runtime
277
+
278
+ if values.device.is_cpu:
279
+ if values.dtype == wp.int32:
280
+ runtime.core.runlength_encode_int_host(
281
+ values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
282
+ )
283
+ else:
284
+ raise RuntimeError("Unsupported data type")
285
+ elif values.device.is_cuda:
286
+ if values.dtype == wp.int32:
287
+ runtime.core.runlength_encode_int_device(
288
+ values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
289
+ )
290
+ else:
291
+ raise RuntimeError("Unsupported data type")
292
+
293
+ if host_return:
294
+ return int(run_count.numpy()[0])
295
+
296
+
297
+ def array_sum(values, out=None, value_count=None, axis=None):
298
+ if value_count is None:
299
+ if axis is None:
300
+ value_count = values.size
301
+ else:
302
+ value_count = values.shape[axis]
303
+
304
+ if axis is None:
305
+ output_shape = (1,)
306
+ else:
307
+
308
+ def output_dim(ax, dim):
309
+ return 1 if ax == axis else dim
310
+
311
+ output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(values.shape))
312
+
313
+ type_length = wp.types.type_length(values.dtype)
314
+ scalar_type = wp.types.type_scalar_type(values.dtype)
315
+
316
+ # User can provide a device output array for storing the number of runs
317
+ # For convenience, if no such array is provided, number of runs is returned on host
318
+ if out is None:
319
+ host_return = True
320
+ out = wp.empty(shape=output_shape, dtype=values.dtype, device=values.device)
321
+ else:
322
+ host_return = False
323
+ if out.device != values.device:
324
+ raise RuntimeError("out storage device should match values array")
325
+ if out.dtype != values.dtype:
326
+ raise RuntimeError(f"out array should have type {values.dtype.__name__}")
327
+ if out.shape != output_shape:
328
+ raise RuntimeError(f"out array should have shape {output_shape}")
329
+
330
+ if value_count == 0:
331
+ out.zero_()
332
+ if axis is None and host_return:
333
+ return out.numpy()[0]
334
+ return out
335
+
336
+ from warp.context import runtime
337
+
338
+ if values.device.is_cpu:
339
+ if scalar_type == wp.float32:
340
+ native_func = runtime.core.array_sum_float_host
341
+ elif scalar_type == wp.float64:
342
+ native_func = runtime.core.array_sum_double_host
343
+ else:
344
+ raise RuntimeError("Unsupported data type")
345
+ elif values.device.is_cuda:
346
+ if scalar_type == wp.float32:
347
+ native_func = runtime.core.array_sum_float_device
348
+ elif scalar_type == wp.float64:
349
+ native_func = runtime.core.array_sum_double_device
350
+ else:
351
+ raise RuntimeError("Unsupported data type")
352
+
353
+ if axis is None:
354
+ stride = wp.types.type_size_in_bytes(values.dtype)
355
+ native_func(values.ptr, out.ptr, value_count, stride, type_length)
356
+
357
+ if host_return:
358
+ return out.numpy()[0]
359
+ else:
360
+ stride = values.strides[axis]
361
+ for idx in np.ndindex(output_shape):
362
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
363
+ val_offset = sum(i * s for i, s in zip(idx, values.strides))
364
+
365
+ native_func(
366
+ values.ptr + val_offset,
367
+ out.ptr + out_offset,
368
+ value_count,
369
+ stride,
370
+ type_length,
371
+ )
372
+
373
+ if host_return:
374
+ return out
375
+
376
+
377
+ def array_inner(a, b, out=None, count=None, axis=None):
378
+ if a.size != b.size:
379
+ raise RuntimeError("Array storage sizes do not match")
380
+
381
+ if a.device != b.device:
382
+ raise RuntimeError("Array storage devices do not match")
383
+
384
+ if a.dtype != b.dtype:
385
+ raise RuntimeError("Array data types do not match")
386
+
387
+ if count is None:
388
+ if axis is None:
389
+ count = a.size
390
+ else:
391
+ count = a.shape[axis]
392
+
393
+ if axis is None:
394
+ output_shape = (1,)
395
+ else:
396
+
397
+ def output_dim(ax, dim):
398
+ return 1 if ax == axis else dim
399
+
400
+ output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(a.shape))
401
+
402
+ type_length = wp.types.type_length(a.dtype)
403
+ scalar_type = wp.types.type_scalar_type(a.dtype)
404
+
405
+ # User can provide a device output array for storing the number of runs
406
+ # For convenience, if no such array is provided, number of runs is returned on host
407
+ if out is None:
408
+ host_return = True
409
+ out = wp.empty(shape=output_shape, dtype=scalar_type, device=a.device)
410
+ else:
411
+ host_return = False
412
+ if out.device != a.device:
413
+ raise RuntimeError("out storage device should match values array")
414
+ if out.dtype != scalar_type:
415
+ raise RuntimeError(f"out array should have type {scalar_type.__name__}")
416
+ if out.shape != output_shape:
417
+ raise RuntimeError(f"out array should have shape {output_shape}")
418
+
419
+ if count == 0:
420
+ if axis is None and host_return:
421
+ return 0.0
422
+ out.zero_()
423
+ return out
424
+
425
+ from warp.context import runtime
426
+
427
+ if a.device.is_cpu:
428
+ if scalar_type == wp.float32:
429
+ native_func = runtime.core.array_inner_float_host
430
+ elif scalar_type == wp.float64:
431
+ native_func = runtime.core.array_inner_double_host
432
+ else:
433
+ raise RuntimeError("Unsupported data type")
434
+ elif a.device.is_cuda:
435
+ if scalar_type == wp.float32:
436
+ native_func = runtime.core.array_inner_float_device
437
+ elif scalar_type == wp.float64:
438
+ native_func = runtime.core.array_inner_double_device
439
+ else:
440
+ raise RuntimeError("Unsupported data type")
441
+
442
+ if axis is None:
443
+ stride_a = wp.types.type_size_in_bytes(a.dtype)
444
+ stride_b = wp.types.type_size_in_bytes(b.dtype)
445
+ native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_length)
446
+
447
+ if host_return:
448
+ return out.numpy()[0]
449
+ else:
450
+ stride_a = a.strides[axis]
451
+ stride_b = b.strides[axis]
452
+
453
+ for idx in np.ndindex(output_shape):
454
+ out_offset = sum(i * s for i, s in zip(idx, out.strides))
455
+ a_offset = sum(i * s for i, s in zip(idx, a.strides))
456
+ b_offset = sum(i * s for i, s in zip(idx, b.strides))
457
+
458
+ native_func(
459
+ a.ptr + a_offset,
460
+ b.ptr + b_offset,
461
+ out.ptr + out_offset,
462
+ count,
463
+ stride_a,
464
+ stride_b,
465
+ type_length,
466
+ )
467
+
468
+ if host_return:
469
+ return out
470
+
471
+
472
+ @wp.kernel
473
+ def _array_cast_kernel(
474
+ dest: Any,
475
+ src: Any,
476
+ ):
477
+ i = wp.tid()
478
+ dest[i] = dest.dtype(src[i])
479
+
480
+
481
+ def array_cast(in_array, out_array, count=None):
482
+ if in_array.device != out_array.device:
483
+ raise RuntimeError("Array storage devices do not match")
484
+
485
+ in_array_data_shape = getattr(in_array.dtype, "_shape_", ())
486
+ out_array_data_shape = getattr(out_array.dtype, "_shape_", ())
487
+
488
+ if in_array.ndim != out_array.ndim or in_array_data_shape != out_array_data_shape:
489
+ # Number of dimensions or data type shape do not match.
490
+ # Flatten arrays and do cast at the scalar level
491
+ in_array = in_array.flatten()
492
+ out_array = out_array.flatten()
493
+
494
+ in_array_data_length = warp.types.type_length(in_array.dtype)
495
+ out_array_data_length = warp.types.type_length(out_array.dtype)
496
+ in_array_scalar_type = wp.types.type_scalar_type(in_array.dtype)
497
+ out_array_scalar_type = wp.types.type_scalar_type(out_array.dtype)
498
+
499
+ in_array = wp.array(
500
+ data=None,
501
+ ptr=in_array.ptr,
502
+ capacity=in_array.capacity,
503
+ device=in_array.device,
504
+ dtype=in_array_scalar_type,
505
+ shape=in_array.shape[0] * in_array_data_length,
506
+ )
507
+
508
+ out_array = wp.array(
509
+ data=None,
510
+ ptr=out_array.ptr,
511
+ capacity=out_array.capacity,
512
+ device=out_array.device,
513
+ dtype=out_array_scalar_type,
514
+ shape=out_array.shape[0] * out_array_data_length,
515
+ )
516
+
517
+ if count is not None:
518
+ count *= in_array_data_length
519
+
520
+ if count is None:
521
+ count = in_array.size
522
+
523
+ if in_array.ndim == 1:
524
+ dim = count
525
+ elif count < in_array.size:
526
+ raise RuntimeError("Partial cast is not supported for arrays with more than one dimension")
527
+ else:
528
+ dim = in_array.shape
529
+
530
+ if in_array.dtype == out_array.dtype:
531
+ # Same data type, can simply copy
532
+ wp.copy(dest=out_array, src=in_array, count=count)
533
+ else:
534
+ wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
535
+
536
+
537
+ # code snippet for invoking cProfile
538
+ # cp = cProfile.Profile()
539
+ # cp.enable()
540
+ # for i in range(1000):
541
+ # self.state = self.integrator.forward(self.model, self.state, self.sim_dt)
542
+
543
+ # cp.disable()
544
+ # cp.print_stats(sort='tottime')
545
+ # exit(0)
546
+
547
+
548
+ # helper kernels for initializing NVDB volumes from a dense array
549
+ @wp.kernel
550
+ def copy_dense_volume_to_nano_vdb_v(volume: wp.uint64, values: wp.array(dtype=wp.vec3, ndim=3)):
551
+ i, j, k = wp.tid()
552
+ wp.volume_store_v(volume, i, j, k, values[i, j, k])
553
+
554
+
555
+ @wp.kernel
556
+ def copy_dense_volume_to_nano_vdb_f(volume: wp.uint64, values: wp.array(dtype=wp.float32, ndim=3)):
557
+ i, j, k = wp.tid()
558
+ wp.volume_store_f(volume, i, j, k, values[i, j, k])
559
+
560
+
561
+ @wp.kernel
562
+ def copy_dense_volume_to_nano_vdb_i(volume: wp.uint64, values: wp.array(dtype=wp.int32, ndim=3)):
563
+ i, j, k = wp.tid()
564
+ wp.volume_store_i(volume, i, j, k, values[i, j, k])
565
+
566
+
567
+ # represent an edge between v0, v1 with connected faces f0, f1, and opposite vertex o0, and o1
568
+ # winding is such that first tri can be reconstructed as {v0, v1, o0}, and second tri as { v1, v0, o1 }
569
+ class MeshEdge:
570
+ def __init__(self, v0, v1, o0, o1, f0, f1):
571
+ self.v0 = v0 # vertex 0
572
+ self.v1 = v1 # vertex 1
573
+ self.o0 = o0 # opposite vertex 1
574
+ self.o1 = o1 # opposite vertex 2
575
+ self.f0 = f0 # index of tri1
576
+ self.f1 = f1 # index of tri2
577
+
578
+
579
+ class MeshAdjacency:
580
+ def __init__(self, indices, num_tris):
581
+ # map edges (v0, v1) to faces (f0, f1)
582
+ self.edges = {}
583
+ self.indices = indices
584
+
585
+ for index, tri in enumerate(indices):
586
+ self.add_edge(tri[0], tri[1], tri[2], index)
587
+ self.add_edge(tri[1], tri[2], tri[0], index)
588
+ self.add_edge(tri[2], tri[0], tri[1], index)
589
+
590
+ def add_edge(self, i0, i1, o, f): # index1, index2, index3, index of triangle
591
+ key = (min(i0, i1), max(i0, i1))
592
+ edge = None
593
+
594
+ if key in self.edges:
595
+ edge = self.edges[key]
596
+
597
+ if edge.f1 != -1:
598
+ print("Detected non-manifold edge")
599
+ return
600
+ else:
601
+ # update other side of the edge
602
+ edge.o1 = o
603
+ edge.f1 = f
604
+ else:
605
+ # create new edge with opposite yet to be filled
606
+ edge = MeshEdge(i0, i1, o, -1, f, -1)
607
+
608
+ self.edges[key] = edge
609
+
610
+
611
+ def mem_report(): # pragma: no cover
612
+ def _mem_report(tensors, mem_type):
613
+ """Print the selected tensors of type
614
+ There are two major storage types in our major concern:
615
+ - GPU: tensors transferred to CUDA devices
616
+ - CPU: tensors remaining on the system memory (usually unimportant)
617
+ Args:
618
+ - tensors: the tensors of specified type
619
+ - mem_type: 'CPU' or 'GPU' in current implementation"""
620
+ total_numel = 0
621
+ total_mem = 0
622
+ visited_data = []
623
+ for tensor in tensors:
624
+ if tensor.is_sparse:
625
+ continue
626
+ # a data_ptr indicates a memory block allocated
627
+ data_ptr = tensor.storage().data_ptr()
628
+ if data_ptr in visited_data:
629
+ continue
630
+ visited_data.append(data_ptr)
631
+
632
+ numel = tensor.storage().size()
633
+ total_numel += numel
634
+ element_size = tensor.storage().element_size()
635
+ mem = numel * element_size / 1024 / 1024 # 32bit=4Byte, MByte
636
+ total_mem += mem
637
+ print("Type: %s Total Tensors: %d \tUsed Memory Space: %.2f MBytes" % (mem_type, total_numel, total_mem))
638
+
639
+ import gc
640
+
641
+ import torch
642
+
643
+ gc.collect()
644
+
645
+ LEN = 65
646
+ objects = gc.get_objects()
647
+ # print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
648
+ tensors = [obj for obj in objects if torch.is_tensor(obj)]
649
+ cuda_tensors = [t for t in tensors if t.is_cuda]
650
+ host_tensors = [t for t in tensors if not t.is_cuda]
651
+ _mem_report(cuda_tensors, "GPU")
652
+ _mem_report(host_tensors, "CPU")
653
+ print("=" * LEN)
654
+
655
+
656
+ class ScopedDevice:
657
+ """A context manager to temporarily change the current default device.
658
+
659
+ For CUDA devices, this context manager makes the device's CUDA context
660
+ current and restores the previous CUDA context on exit. This is handy when
661
+ running Warp scripts as part of a bigger pipeline because it avoids any side
662
+ effects of changing the CUDA context in the enclosed code.
663
+
664
+ Attributes:
665
+ device (Device): The device that will temporarily become the default
666
+ device within the context.
667
+ saved_device (Device): The previous default device. This is restored as
668
+ the default device on exiting the context.
669
+ """
670
+
671
+ def __init__(self, device: Devicelike):
672
+ """Initializes the context manager with a device.
673
+
674
+ Args:
675
+ device: The device that will temporarily become the default device
676
+ within the context.
677
+ """
678
+ self.device = wp.get_device(device)
679
+
680
+ def __enter__(self):
681
+ # save the previous default device
682
+ self.saved_device = self.device.runtime.default_device
683
+
684
+ # make this the default device
685
+ self.device.runtime.default_device = self.device
686
+
687
+ # make it the current CUDA device so that device alias "cuda" will evaluate to this device
688
+ self.device.context_guard.__enter__()
689
+
690
+ return self.device
691
+
692
+ def __exit__(self, exc_type, exc_value, traceback):
693
+ # restore original CUDA context
694
+ self.device.context_guard.__exit__(exc_type, exc_value, traceback)
695
+
696
+ # restore original target device
697
+ self.device.runtime.default_device = self.saved_device
698
+
699
+
700
+ class ScopedStream:
701
+ """A context manager to temporarily change the current stream on a device.
702
+
703
+ Attributes:
704
+ stream (Stream or None): The stream that will temporarily become the device's
705
+ default stream within the context.
706
+ saved_stream (Stream): The device's previous current stream. This is
707
+ restored as the device's current stream on exiting the context.
708
+ sync_enter (bool): Whether to synchronize this context's stream with
709
+ the device's previous current stream on entering the context.
710
+ sync_exit (bool): Whether to synchronize the device's previous current
711
+ with this context's stream on exiting the context.
712
+ device (Device): The device associated with the stream.
713
+ """
714
+
715
+ def __init__(self, stream: Optional[wp.Stream], sync_enter: bool = True, sync_exit: bool = False):
716
+ """Initializes the context manager with a stream and synchronization options.
717
+
718
+ Args:
719
+ stream: The stream that will temporarily become the device's
720
+ default stream within the context.
721
+ sync_enter (bool): Whether to synchronize this context's stream with
722
+ the device's previous current stream on entering the context.
723
+ sync_exit (bool): Whether to synchronize the device's previous current
724
+ with this context's stream on exiting the context.
725
+ """
726
+
727
+ self.stream = stream
728
+ self.sync_enter = sync_enter
729
+ self.sync_exit = sync_exit
730
+ if stream is not None:
731
+ self.device = stream.device
732
+ self.device_scope = ScopedDevice(self.device)
733
+
734
+ def __enter__(self):
735
+ if self.stream is not None:
736
+ self.device_scope.__enter__()
737
+ self.saved_stream = self.device.stream
738
+ self.device.set_stream(self.stream, self.sync_enter)
739
+
740
+ return self.stream
741
+
742
+ def __exit__(self, exc_type, exc_value, traceback):
743
+ if self.stream is not None:
744
+ self.device.set_stream(self.saved_stream, self.sync_exit)
745
+ self.device_scope.__exit__(exc_type, exc_value, traceback)
746
+
747
+
748
+ TIMING_KERNEL = 1
749
+ TIMING_KERNEL_BUILTIN = 2
750
+ TIMING_MEMCPY = 4
751
+ TIMING_MEMSET = 8
752
+ TIMING_GRAPH = 16
753
+ TIMING_ALL = 0xFFFFFFFF
754
+
755
+
756
+ # timer utils
757
+ class ScopedTimer:
758
+ indent = -1
759
+
760
+ enabled = True
761
+
762
+ def __init__(
763
+ self,
764
+ name: str,
765
+ active: bool = True,
766
+ print: bool = True,
767
+ detailed: bool = False,
768
+ dict: Optional[Dict[str, List[float]]] = None,
769
+ use_nvtx: bool = False,
770
+ color: Union[int, str] = "rapids",
771
+ synchronize: bool = False,
772
+ cuda_filter: int = 0,
773
+ report_func: Optional[Callable[[List[TimingResult], str], None]] = None,
774
+ skip_tape: bool = False,
775
+ ):
776
+ """Context manager object for a timer
777
+
778
+ Parameters:
779
+ name: Name of timer
780
+ active: Enables this timer
781
+ print: At context manager exit, print elapsed time to ``sys.stdout``
782
+ detailed: Collects additional profiling data using cProfile and calls ``print_stats()`` at context exit
783
+ dict: A dictionary of lists to which the elapsed time will be appended using ``name`` as a key
784
+ use_nvtx: If true, timing functionality is replaced by an NVTX range
785
+ color: ARGB value (e.g. 0x00FFFF) or color name (e.g. 'cyan') associated with the NVTX range
786
+ synchronize: Synchronize the CPU thread with any outstanding CUDA work to return accurate GPU timings
787
+ cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL``
788
+ report_func: A callback function to print the activity report.
789
+ If ``None``, :func:`wp.timing_print() <timing_print>` will be used.
790
+ skip_tape: If true, the timer will not be recorded in the tape
791
+
792
+ Attributes:
793
+ extra_msg (str): Can be set to a string that will be added to the printout at context exit.
794
+ elapsed (float): The duration of the ``with`` block used with this object
795
+ timing_results (List[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter``
796
+ """
797
+ self.name = name
798
+ self.active = active and self.enabled
799
+ self.print = print
800
+ self.detailed = detailed
801
+ self.dict = dict
802
+ self.use_nvtx = use_nvtx
803
+ self.color = color
804
+ self.synchronize = synchronize
805
+ self.skip_tape = skip_tape
806
+ self.elapsed = 0.0
807
+ self.cuda_filter = cuda_filter
808
+ self.report_func = report_func or wp.timing_print
809
+ self.extra_msg = "" # Can be used to add to the message printed at manager exit
810
+
811
+ if self.dict is not None:
812
+ if name not in self.dict:
813
+ self.dict[name] = []
814
+
815
+ def __enter__(self):
816
+ if not self.skip_tape and warp.context.runtime is not None and warp.context.runtime.tape is not None:
817
+ warp.context.runtime.tape.record_scope_begin(self.name)
818
+ if self.active:
819
+ if self.synchronize:
820
+ wp.synchronize()
821
+
822
+ if self.cuda_filter:
823
+ # begin CUDA activity collection, synchronizing if needed
824
+ timing_begin(self.cuda_filter, synchronize=not self.synchronize)
825
+
826
+ if self.detailed:
827
+ self.cp = cProfile.Profile()
828
+ self.cp.clear()
829
+ self.cp.enable()
830
+
831
+ if self.use_nvtx:
832
+ import nvtx
833
+
834
+ self.nvtx_range_id = nvtx.start_range(self.name, color=self.color)
835
+
836
+ if self.print:
837
+ ScopedTimer.indent += 1
838
+
839
+ if warp.config.verbose:
840
+ indent = " " * ScopedTimer.indent
841
+ print(f"{indent}{self.name} ...", flush=True)
842
+
843
+ self.start = time.perf_counter_ns()
844
+
845
+ return self
846
+
847
+ def __exit__(self, exc_type, exc_value, traceback):
848
+ if not self.skip_tape and warp.context.runtime is not None and warp.context.runtime.tape is not None:
849
+ warp.context.runtime.tape.record_scope_end()
850
+ if self.active:
851
+ if self.synchronize:
852
+ wp.synchronize()
853
+
854
+ self.elapsed = (time.perf_counter_ns() - self.start) / 1000000.0
855
+
856
+ if self.use_nvtx:
857
+ import nvtx
858
+
859
+ nvtx.end_range(self.nvtx_range_id)
860
+
861
+ if self.detailed:
862
+ self.cp.disable()
863
+ self.cp.print_stats(sort="tottime")
864
+
865
+ if self.cuda_filter:
866
+ # end CUDA activity collection, synchronizing if needed
867
+ self.timing_results = timing_end(synchronize=not self.synchronize)
868
+ else:
869
+ self.timing_results = []
870
+
871
+ if self.dict is not None:
872
+ self.dict[self.name].append(self.elapsed)
873
+
874
+ if self.print:
875
+ indent = " " * ScopedTimer.indent
876
+
877
+ if self.timing_results:
878
+ self.report_func(self.timing_results, indent=indent)
879
+ print()
880
+
881
+ if self.extra_msg:
882
+ print(f"{indent}{self.name} took {self.elapsed:.2f} ms {self.extra_msg}")
883
+ else:
884
+ print(f"{indent}{self.name} took {self.elapsed:.2f} ms")
885
+
886
+ ScopedTimer.indent -= 1
887
+
888
+
889
+ # Allow temporarily enabling/disabling mempool allocators
890
+ class ScopedMempool:
891
+ def __init__(self, device: Devicelike, enable: bool):
892
+ self.device = wp.get_device(device)
893
+ self.enable = enable
894
+
895
+ def __enter__(self):
896
+ self.saved_setting = wp.is_mempool_enabled(self.device)
897
+ wp.set_mempool_enabled(self.device, self.enable)
898
+
899
+ def __exit__(self, exc_type, exc_value, traceback):
900
+ wp.set_mempool_enabled(self.device, self.saved_setting)
901
+
902
+
903
+ # Allow temporarily enabling/disabling mempool access
904
+ class ScopedMempoolAccess:
905
+ def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool):
906
+ self.target_device = target_device
907
+ self.peer_device = peer_device
908
+ self.enable = enable
909
+
910
+ def __enter__(self):
911
+ self.saved_setting = wp.is_mempool_access_enabled(self.target_device, self.peer_device)
912
+ wp.set_mempool_access_enabled(self.target_device, self.peer_device, self.enable)
913
+
914
+ def __exit__(self, exc_type, exc_value, traceback):
915
+ wp.set_mempool_access_enabled(self.target_device, self.peer_device, self.saved_setting)
916
+
917
+
918
+ # Allow temporarily enabling/disabling peer access
919
+ class ScopedPeerAccess:
920
+ def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool):
921
+ self.target_device = target_device
922
+ self.peer_device = peer_device
923
+ self.enable = enable
924
+
925
+ def __enter__(self):
926
+ self.saved_setting = wp.is_peer_access_enabled(self.target_device, self.peer_device)
927
+ wp.set_peer_access_enabled(self.target_device, self.peer_device, self.enable)
928
+
929
+ def __exit__(self, exc_type, exc_value, traceback):
930
+ wp.set_peer_access_enabled(self.target_device, self.peer_device, self.saved_setting)
931
+
932
+
933
+ class ScopedCapture:
934
+ def __init__(self, device: Devicelike = None, stream=None, force_module_load=None, external=False):
935
+ self.device = device
936
+ self.stream = stream
937
+ self.force_module_load = force_module_load
938
+ self.external = external
939
+ self.active = False
940
+ self.graph = None
941
+
942
+ def __enter__(self):
943
+ try:
944
+ wp.capture_begin(
945
+ device=self.device, stream=self.stream, force_module_load=self.force_module_load, external=self.external
946
+ )
947
+ self.active = True
948
+ return self
949
+ except:
950
+ raise
951
+
952
+ def __exit__(self, exc_type, exc_value, traceback):
953
+ if self.active:
954
+ try:
955
+ self.graph = wp.capture_end(device=self.device, stream=self.stream)
956
+ finally:
957
+ self.active = False
958
+
959
+
960
+ def check_p2p():
961
+ """Check if the machine is configured properly for peer-to-peer transfers.
962
+
963
+ Returns:
964
+ A Boolean indicating whether the machine is configured properly for peer-to-peer transfers.
965
+ On Linux, this function attempts to determine if IOMMU is enabled and will return `False` if IOMMU is detected.
966
+ On other operating systems, it always return `True`.
967
+ """
968
+
969
+ # HACK: allow disabling P2P tests using an environment variable
970
+ disable_p2p_tests = os.getenv("WARP_DISABLE_P2P_TESTS", default="0")
971
+ if int(disable_p2p_tests):
972
+ return False
973
+
974
+ if sys.platform == "linux":
975
+ # IOMMU enablement can affect peer-to-peer transfers.
976
+ # On modern Linux, there should be IOMMU-related entries in the /sys file system.
977
+ # This should be more reliable than checking kernel logs like dmesg.
978
+ if os.path.isdir("/sys/class/iommu") and os.listdir("/sys/class/iommu"):
979
+ return False
980
+ if os.path.isdir("/sys/kernel/iommu_groups") and os.listdir("/sys/kernel/iommu_groups"):
981
+ return False
982
+
983
+ return True
984
+
985
+
986
+ class timing_result_t(ctypes.Structure):
987
+ """CUDA timing struct for fetching values from C++"""
988
+
989
+ _fields_ = [
990
+ ("context", ctypes.c_void_p),
991
+ ("name", ctypes.c_char_p),
992
+ ("filter", ctypes.c_int),
993
+ ("elapsed", ctypes.c_float),
994
+ ]
995
+
996
+
997
+ class TimingResult:
998
+ """Timing result for a single activity."""
999
+
1000
+ def __init__(self, device, name, filter, elapsed):
1001
+ self.device: warp.context.Device = device
1002
+ """The device where the activity was recorded."""
1003
+
1004
+ self.name: str = name
1005
+ """The activity name."""
1006
+
1007
+ self.filter: int = filter
1008
+ """The type of activity (e.g., ``warp.TIMING_KERNEL``)."""
1009
+
1010
+ self.elapsed: float = elapsed
1011
+ """The elapsed time in milliseconds."""
1012
+
1013
+
1014
+ def timing_begin(cuda_filter: int = TIMING_ALL, synchronize: bool = True) -> None:
1015
+ """Begin detailed activity timing.
1016
+
1017
+ Parameters:
1018
+ cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL``
1019
+ synchronize: Whether to synchronize all CUDA devices before timing starts
1020
+ """
1021
+
1022
+ if synchronize:
1023
+ warp.synchronize()
1024
+
1025
+ warp.context.runtime.core.cuda_timing_begin(cuda_filter)
1026
+
1027
+
1028
+ def timing_end(synchronize: bool = True) -> List[TimingResult]:
1029
+ """End detailed activity timing.
1030
+
1031
+ Parameters:
1032
+ synchronize: Whether to synchronize all CUDA devices before timing ends
1033
+
1034
+ Returns:
1035
+ A list of :class:`TimingResult` objects for all recorded activities.
1036
+ """
1037
+
1038
+ if synchronize:
1039
+ warp.synchronize()
1040
+
1041
+ # get result count
1042
+ count = warp.context.runtime.core.cuda_timing_get_result_count()
1043
+
1044
+ # get result array from C++
1045
+ result_buffer = (timing_result_t * count)()
1046
+ warp.context.runtime.core.cuda_timing_end(ctypes.byref(result_buffer), count)
1047
+
1048
+ # prepare Python result list
1049
+ results = []
1050
+ for r in result_buffer:
1051
+ device = warp.context.runtime.context_map.get(r.context)
1052
+ filter = r.filter
1053
+ elapsed = r.elapsed
1054
+
1055
+ name = r.name.decode()
1056
+ if filter == TIMING_KERNEL:
1057
+ if name.endswith("forward"):
1058
+ # strip trailing "_cuda_kernel_forward"
1059
+ name = f"forward kernel {name[:-20]}"
1060
+ else:
1061
+ # strip trailing "_cuda_kernel_backward"
1062
+ name = f"backward kernel {name[:-21]}"
1063
+ elif filter == TIMING_KERNEL_BUILTIN:
1064
+ if name.startswith("wp::"):
1065
+ name = f"builtin kernel {name[4:]}"
1066
+ else:
1067
+ name = f"builtin kernel {name}"
1068
+
1069
+ results.append(TimingResult(device, name, filter, elapsed))
1070
+
1071
+ return results
1072
+
1073
+
1074
+ def timing_print(results: List[TimingResult], indent: str = "") -> None:
1075
+ """Print timing results.
1076
+
1077
+ Parameters:
1078
+ results: List of :class:`TimingResult` objects to print.
1079
+ indent: Optional indentation to prepend to all output lines.
1080
+ """
1081
+
1082
+ if not results:
1083
+ print("No activity")
1084
+ return
1085
+
1086
+ class Aggregate:
1087
+ def __init__(self, count=0, elapsed=0):
1088
+ self.count = count
1089
+ self.elapsed = elapsed
1090
+
1091
+ device_totals = {}
1092
+ activity_totals = {}
1093
+
1094
+ max_name_len = len("Activity")
1095
+ for r in results:
1096
+ name_len = len(r.name)
1097
+ max_name_len = max(max_name_len, name_len)
1098
+
1099
+ activity_width = max_name_len + 1
1100
+ activity_dashes = "-" * activity_width
1101
+
1102
+ print(f"{indent}CUDA timeline:")
1103
+ print(f"{indent}----------------+---------+{activity_dashes}")
1104
+ print(f"{indent}Time | Device | Activity")
1105
+ print(f"{indent}----------------+---------+{activity_dashes}")
1106
+ for r in results:
1107
+ device_agg = device_totals.get(r.device.alias)
1108
+ if device_agg is None:
1109
+ device_totals[r.device.alias] = Aggregate(count=1, elapsed=r.elapsed)
1110
+ else:
1111
+ device_agg.count += 1
1112
+ device_agg.elapsed += r.elapsed
1113
+
1114
+ activity_agg = activity_totals.get(r.name)
1115
+ if activity_agg is None:
1116
+ activity_totals[r.name] = Aggregate(count=1, elapsed=r.elapsed)
1117
+ else:
1118
+ activity_agg.count += 1
1119
+ activity_agg.elapsed += r.elapsed
1120
+
1121
+ print(f"{indent}{r.elapsed:12.6f} ms | {r.device.alias:7s} | {r.name}")
1122
+
1123
+ print()
1124
+ print(f"{indent}CUDA activity summary:")
1125
+ print(f"{indent}----------------+---------+{activity_dashes}")
1126
+ print(f"{indent}Total time | Count | Activity")
1127
+ print(f"{indent}----------------+---------+{activity_dashes}")
1128
+ for name, agg in activity_totals.items():
1129
+ print(f"{indent}{agg.elapsed:12.6f} ms | {agg.count:7d} | {name}")
1130
+
1131
+ print()
1132
+ print(f"{indent}CUDA device summary:")
1133
+ print(f"{indent}----------------+---------+{activity_dashes}")
1134
+ print(f"{indent}Total time | Count | Device")
1135
+ print(f"{indent}----------------+---------+{activity_dashes}")
1136
+ for device, agg in device_totals.items():
1137
+ print(f"{indent}{agg.elapsed:12.6f} ms | {agg.count:7d} | {device}")