warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/build_dll.py ADDED
@@ -0,0 +1,405 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import platform
18
+ import subprocess
19
+ import sys
20
+
21
+ from warp.utils import ScopedTimer
22
+
23
+ verbose_cmd = True # print command lines before executing them
24
+
25
+
26
+ # returns a canonical machine architecture string
27
+ # - "x86_64" for x86-64, aka. AMD64, aka. x64
28
+ # - "aarch64" for AArch64, aka. ARM64
29
+ def machine_architecture() -> str:
30
+ machine = platform.machine()
31
+ if machine == "x86_64" or machine == "AMD64":
32
+ return "x86_64"
33
+ if machine == "aarch64" or machine == "arm64":
34
+ return "aarch64"
35
+ raise RuntimeError(f"Unrecognized machine architecture {machine}")
36
+
37
+
38
+ def run_cmd(cmd):
39
+ if verbose_cmd:
40
+ print(cmd)
41
+
42
+ try:
43
+ return subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
44
+ except subprocess.CalledProcessError as e:
45
+ print("Command failed with exit code:", e.returncode)
46
+ print("Command output was:")
47
+ print(e.output.decode())
48
+ raise e
49
+
50
+
51
+ # cut-down version of vcvars64.bat that allows using
52
+ # custom toolchain locations, returns the compiler program path
53
+ def set_msvc_env(msvc_path, sdk_path):
54
+ if "INCLUDE" not in os.environ:
55
+ os.environ["INCLUDE"] = ""
56
+
57
+ if "LIB" not in os.environ:
58
+ os.environ["LIB"] = ""
59
+
60
+ msvc_path = os.path.abspath(msvc_path)
61
+ sdk_path = os.path.abspath(sdk_path)
62
+
63
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(msvc_path, "include")
64
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/winrt")
65
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/um")
66
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/ucrt")
67
+ os.environ["INCLUDE"] += os.pathsep + os.path.join(sdk_path, "include/shared")
68
+
69
+ os.environ["LIB"] += os.pathsep + os.path.join(msvc_path, "lib/x64")
70
+ os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/ucrt/x64")
71
+ os.environ["LIB"] += os.pathsep + os.path.join(sdk_path, "lib/um/x64")
72
+
73
+ os.environ["PATH"] += os.pathsep + os.path.join(msvc_path, "bin/HostX64/x64")
74
+ os.environ["PATH"] += os.pathsep + os.path.join(sdk_path, "bin/x64")
75
+
76
+ return os.path.join(msvc_path, "bin", "HostX64", "x64", "cl.exe")
77
+
78
+
79
+ def find_host_compiler():
80
+ if os.name == "nt":
81
+ # try and find an installed host compiler (msvc)
82
+ # runs vcvars and copies back the build environment
83
+
84
+ vswhere_path = r"%ProgramFiles(x86)%/Microsoft Visual Studio/Installer/vswhere.exe"
85
+ vswhere_path = os.path.expandvars(vswhere_path)
86
+ if not os.path.exists(vswhere_path):
87
+ return ""
88
+
89
+ vs_path = run_cmd(f'"{vswhere_path}" -latest -property installationPath').decode().rstrip()
90
+ vsvars_path = os.path.join(vs_path, "VC\\Auxiliary\\Build\\vcvars64.bat")
91
+
92
+ output = run_cmd(f'"{vsvars_path}" && set').decode()
93
+
94
+ for line in output.splitlines():
95
+ pair = line.split("=", 1)
96
+ if len(pair) >= 2:
97
+ os.environ[pair[0]] = pair[1]
98
+
99
+ cl_path = run_cmd("where cl.exe").decode("utf-8").rstrip()
100
+ cl_version = os.environ["VCToolsVersion"].split(".")
101
+
102
+ # ensure at least VS2019 version, see list of MSVC versions here https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B
103
+ cl_required_major = 14
104
+ cl_required_minor = 29
105
+
106
+ if (
107
+ (int(cl_version[0]) < cl_required_major)
108
+ or (int(cl_version[0]) == cl_required_major)
109
+ and int(cl_version[1]) < cl_required_minor
110
+ ):
111
+ print(
112
+ f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
113
+ )
114
+ return ""
115
+
116
+ return cl_path
117
+
118
+ else:
119
+ # try and find g++
120
+ return run_cmd("which g++").decode()
121
+
122
+
123
+ def get_cuda_toolkit_version(cuda_home):
124
+ try:
125
+ # the toolkit version can be obtained by running "nvcc --version"
126
+ nvcc_path = os.path.join(cuda_home, "bin", "nvcc")
127
+ nvcc_version_output = subprocess.check_output([nvcc_path, "--version"]).decode("utf-8")
128
+ # search for release substring (e.g., "release 11.5")
129
+ import re
130
+
131
+ m = re.search(r"(?<=release )\d+\.\d+", nvcc_version_output)
132
+ if m is not None:
133
+ return tuple(int(x) for x in m.group(0).split("."))
134
+ else:
135
+ raise Exception("Failed to parse NVCC output")
136
+
137
+ except Exception as e:
138
+ print(f"Failed to determine CUDA Toolkit version: {e}")
139
+
140
+
141
+ def quote(path):
142
+ return '"' + path + '"'
143
+
144
+
145
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None):
146
+ mode = args.mode if (mode is None) else mode
147
+ cuda_home = args.cuda_path
148
+ cuda_cmd = None
149
+
150
+ if args.quick or cu_path is None:
151
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
152
+ else:
153
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
154
+
155
+ import pathlib
156
+
157
+ warp_home_path = pathlib.Path(__file__).parent
158
+ warp_home = warp_home_path.resolve()
159
+
160
+ # output stale, rebuild
161
+ if args.verbose:
162
+ print(f"Building {dll_path}")
163
+
164
+ native_dir = os.path.join(warp_home, "native")
165
+
166
+ if cu_path:
167
+ # check CUDA Toolkit version
168
+ min_ctk_version = (11, 5)
169
+ ctk_version = get_cuda_toolkit_version(cuda_home) or min_ctk_version
170
+ if ctk_version < min_ctk_version:
171
+ raise Exception(
172
+ f"CUDA Toolkit version {min_ctk_version[0]}.{min_ctk_version[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
173
+ )
174
+
175
+ if ctk_version[0] < 12 and args.libmathdx_path:
176
+ print("MathDx support requires at least CUDA 12, skipping")
177
+ args.libmathdx_path = None
178
+
179
+ gencode_opts = []
180
+
181
+ if args.quick:
182
+ # minimum supported architectures (PTX)
183
+ gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
184
+ else:
185
+ # generate code for all supported architectures
186
+ gencode_opts += [
187
+ # SASS for supported desktop/datacenter architectures
188
+ "-gencode=arch=compute_52,code=sm_52", # Maxwell
189
+ "-gencode=arch=compute_60,code=sm_60", # Pascal
190
+ "-gencode=arch=compute_61,code=sm_61",
191
+ "-gencode=arch=compute_70,code=sm_70", # Volta
192
+ "-gencode=arch=compute_75,code=sm_75", # Turing
193
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
194
+ "-gencode=arch=compute_86,code=sm_86",
195
+ ]
196
+ if arch == "aarch64" and sys.platform == "linux":
197
+ gencode_opts += [
198
+ # SASS for supported mobile architectures (e.g. Tegra/Jetson)
199
+ "-gencode=arch=compute_53,code=sm_53", # X1
200
+ "-gencode=arch=compute_62,code=sm_62", # X2
201
+ "-gencode=arch=compute_72,code=sm_72", # Xavier
202
+ "-gencode=arch=compute_87,code=sm_87", # Orin
203
+ ]
204
+
205
+ if ctk_version >= (12, 8):
206
+ # Support for Blackwell is available with CUDA Toolkit 12.8+
207
+ gencode_opts += [
208
+ "-gencode=arch=compute_89,code=sm_89", # Ada
209
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
210
+ "-gencode=arch=compute_100,code=sm_100", # Blackwell
211
+ "-gencode=arch=compute_120,code=sm_120", # Blackwell
212
+ "-gencode=arch=compute_120,code=compute_120", # PTX for future hardware
213
+ ]
214
+ elif ctk_version >= (11, 8):
215
+ # Support for Ada and Hopper is available with CUDA Toolkit 11.8+
216
+ gencode_opts += [
217
+ "-gencode=arch=compute_89,code=sm_89", # Ada
218
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
219
+ "-gencode=arch=compute_90,code=compute_90", # PTX for future hardware
220
+ ]
221
+ else:
222
+ gencode_opts += [
223
+ "-gencode=arch=compute_86,code=compute_86", # PTX for future hardware
224
+ ]
225
+
226
+ nvcc_opts = gencode_opts + [
227
+ "-t0", # multithreaded compilation
228
+ "--extended-lambda",
229
+ ]
230
+
231
+ if args.fast_math:
232
+ nvcc_opts.append("--use_fast_math")
233
+
234
+ # is the library being built with CUDA enabled?
235
+ cuda_enabled = "WP_ENABLE_CUDA=1" if (cu_path is not None) else "WP_ENABLE_CUDA=0"
236
+
237
+ if args.libmathdx_path:
238
+ libmathdx_includes = f' -I"{args.libmathdx_path}/include"'
239
+ mathdx_enabled = "WP_ENABLE_MATHDX=1"
240
+ else:
241
+ libmathdx_includes = ""
242
+ mathdx_enabled = "WP_ENABLE_MATHDX=0"
243
+
244
+ if os.name == "nt":
245
+ if args.host_compiler:
246
+ host_linker = os.path.join(os.path.dirname(args.host_compiler), "link.exe")
247
+ else:
248
+ raise RuntimeError("Warp build error: No host compiler was found")
249
+
250
+ cpp_includes = f' /I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
251
+ cpp_includes += f' /I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
252
+ cuda_includes = f' /I"{cuda_home}/include"' if cu_path else ""
253
+ includes = cpp_includes + cuda_includes
254
+
255
+ # nvrtc_static.lib is built with /MT and _ITERATOR_DEBUG_LEVEL=0 so if we link it in we must match these options
256
+ if cu_path or mode != "debug":
257
+ runtime = "/MT"
258
+ iter_dbg = "_ITERATOR_DEBUG_LEVEL=0"
259
+ debug = "NDEBUG"
260
+ else:
261
+ runtime = "/MTd"
262
+ iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
263
+ debug = "_DEBUG"
264
+
265
+ cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
266
+
267
+ if args.mode == "debug":
268
+ cpp_flags += "/Zi /Od /D WP_ENABLE_DEBUG=1"
269
+ linkopts = ["/DLL", "/DEBUG"]
270
+ elif args.mode == "release":
271
+ cpp_flags += "/Ox /D WP_ENABLE_DEBUG=0"
272
+ linkopts = ["/DLL"]
273
+ else:
274
+ raise RuntimeError(f"Unrecognized build configuration (debug, release), got: {args.mode}")
275
+
276
+ if args.verify_fp:
277
+ cpp_flags += ' /D "WP_VERIFY_FP"'
278
+
279
+ if args.fast_math:
280
+ cpp_flags += " /fp:fast"
281
+
282
+ with ScopedTimer("build", active=args.verbose):
283
+ for cpp_path in cpp_paths:
284
+ cpp_out = cpp_path + ".obj"
285
+ linkopts.append(quote(cpp_out))
286
+
287
+ cpp_cmd = f'"{args.host_compiler}" {cpp_flags} -c "{cpp_path}" /Fo"{cpp_out}"'
288
+ run_cmd(cpp_cmd)
289
+
290
+ if cu_path:
291
+ cu_out = cu_path + ".o"
292
+
293
+ if mode == "debug":
294
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
295
+
296
+ elif mode == "release":
297
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
298
+
299
+ with ScopedTimer("build_cuda", active=args.verbose):
300
+ run_cmd(cuda_cmd)
301
+ linkopts.append(quote(cu_out))
302
+ linkopts.append(
303
+ f'cudart_static.lib nvrtc_static.lib nvrtc-builtins_static.lib nvptxcompiler_static.lib ws2_32.lib user32.lib /LIBPATH:"{cuda_home}/lib/x64"'
304
+ )
305
+
306
+ if args.libmathdx_path:
307
+ linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib" mathdx_static.lib')
308
+
309
+ with ScopedTimer("link", active=args.verbose):
310
+ link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
311
+ run_cmd(link_cmd)
312
+
313
+ else:
314
+ cpp_includes = f' -I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
315
+ cpp_includes += f' -I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
316
+ cuda_includes = f' -I"{cuda_home}/include"' if cu_path else ""
317
+ includes = cpp_includes + cuda_includes
318
+
319
+ if sys.platform == "darwin":
320
+ version = f"--target={arch}-apple-macos11"
321
+ else:
322
+ version = "-fabi-version=13" # GCC 8.2+
323
+
324
+ cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
325
+
326
+ if mode == "debug":
327
+ cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
328
+
329
+ if mode == "release":
330
+ cpp_flags += "-O3 -DNDEBUG -DWP_ENABLE_DEBUG=0"
331
+
332
+ if args.verify_fp:
333
+ cpp_flags += " -DWP_VERIFY_FP"
334
+
335
+ if args.fast_math:
336
+ cpp_flags += " -ffast-math"
337
+
338
+ ld_inputs = []
339
+
340
+ with ScopedTimer("build", active=args.verbose):
341
+ for cpp_path in cpp_paths:
342
+ cpp_out = cpp_path + ".o"
343
+ ld_inputs.append(quote(cpp_out))
344
+
345
+ build_cmd = f'g++ {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
346
+ run_cmd(build_cmd)
347
+
348
+ if cu_path:
349
+ cu_out = cu_path + ".o"
350
+
351
+ if mode == "debug":
352
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
353
+
354
+ elif mode == "release":
355
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
356
+
357
+ with ScopedTimer("build_cuda", active=args.verbose):
358
+ run_cmd(cuda_cmd)
359
+
360
+ ld_inputs.append(quote(cu_out))
361
+ ld_inputs.append(
362
+ f'-L"{cuda_home}/lib64" -lcudart_static -lnvrtc_static -lnvrtc-builtins_static -lnvptxcompiler_static -lpthread -ldl -lrt'
363
+ )
364
+
365
+ if args.libmathdx_path:
366
+ ld_inputs.append(f"-lnvJitLink_static -L{args.libmathdx_path}/lib -lmathdx_static")
367
+
368
+ if sys.platform == "darwin":
369
+ opt_no_undefined = "-Wl,-undefined,error"
370
+ opt_exclude_libs = ""
371
+ else:
372
+ opt_no_undefined = "-Wl,--no-undefined"
373
+ opt_exclude_libs = "-Wl,--exclude-libs,ALL"
374
+
375
+ with ScopedTimer("link", active=args.verbose):
376
+ origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
377
+ link_cmd = f"g++ {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
378
+ run_cmd(link_cmd)
379
+
380
+ # Strip symbols to reduce the binary size
381
+ if mode == "release":
382
+ if sys.platform == "darwin":
383
+ run_cmd(f"strip -x {dll_path}") # Strip all local symbols
384
+ else: # Linux
385
+ # Strip all symbols except for those needed to support debugging JIT-compiled code
386
+ run_cmd(
387
+ f"strip --strip-all --keep-symbol=__jit_debug_register_code --keep-symbol=__jit_debug_descriptor {dll_path}"
388
+ )
389
+
390
+
391
+ def build_dll(args, dll_path, cpp_paths, cu_path, libs=None):
392
+ if libs is None:
393
+ libs = []
394
+
395
+ if sys.platform == "darwin":
396
+ # create a universal binary by combining x86-64 and AArch64 builds
397
+ build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, libs, "x86_64")
398
+ build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, libs, "aarch64")
399
+
400
+ run_cmd(f"lipo -create -output {dll_path} {dll_path}-x86_64 {dll_path}-aarch64")
401
+ os.remove(f"{dll_path}-x86_64")
402
+ os.remove(f"{dll_path}-aarch64")
403
+
404
+ else:
405
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, machine_architecture())