warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/bin/warp-clang.so ADDED
Binary file
warp/bin/warp.so ADDED
Binary file
warp/build.py ADDED
@@ -0,0 +1,557 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import errno
18
+ import hashlib
19
+ import json
20
+ import os
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import warp.config
25
+ from warp.thirdparty import appdirs
26
+ from warp.types import *
27
+
28
+ # From nvJitLink.h
29
+ nvJitLink_input_type = {"cubin": 1, "ptx": 2, "ltoir": 3, "fatbin": 4, "object": 5, "library": 6}
30
+
31
+
32
+ # builds cuda source to PTX or CUBIN using NVRTC (output type determined by output_path extension)
33
+ def build_cuda(
34
+ cu_path,
35
+ arch,
36
+ output_path,
37
+ config="release",
38
+ verify_fp=False,
39
+ fast_math=False,
40
+ fuse_fp=True,
41
+ lineinfo=False,
42
+ ltoirs=None,
43
+ fatbins=None,
44
+ ) -> None:
45
+ with open(cu_path, "rb") as src_file:
46
+ src = src_file.read()
47
+ cu_path_bytes = cu_path.encode("utf-8")
48
+ program_name_bytes = os.path.basename(cu_path).encode("utf-8")
49
+ inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
50
+ output_path = output_path.encode("utf-8")
51
+
52
+ if warp.config.llvm_cuda:
53
+ warp.context.runtime.llvm.compile_cuda(src, cu_path_bytes, inc_path, output_path, False)
54
+
55
+ else:
56
+ if ltoirs is None:
57
+ ltoirs = []
58
+ if fatbins is None:
59
+ fatbins = []
60
+
61
+ link_data = list(ltoirs) + list(fatbins)
62
+ num_link = len(link_data)
63
+ arr_link = (ctypes.c_char_p * num_link)(*link_data)
64
+ arr_link_sizes = (ctypes.c_size_t * num_link)(*[len(l) for l in link_data])
65
+ link_input_types = [nvJitLink_input_type["ltoir"]] * len(ltoirs) + [nvJitLink_input_type["fatbin"]] * len(
66
+ fatbins
67
+ )
68
+ arr_link_input_types = (ctypes.c_int * num_link)(*link_input_types)
69
+ err = warp.context.runtime.core.cuda_compile_program(
70
+ src,
71
+ program_name_bytes,
72
+ arch,
73
+ inc_path,
74
+ 0,
75
+ None,
76
+ config == "debug",
77
+ warp.config.verbose,
78
+ verify_fp,
79
+ fast_math,
80
+ fuse_fp,
81
+ lineinfo,
82
+ output_path,
83
+ num_link,
84
+ arr_link,
85
+ arr_link_sizes,
86
+ arr_link_input_types,
87
+ )
88
+ if err != 0:
89
+ raise Exception(f"CUDA kernel build failed with error code {err}")
90
+
91
+
92
+ # load PTX or CUBIN as a CUDA runtime module (input type determined by input_path extension)
93
+ def load_cuda(input_path, device):
94
+ if not device.is_cuda:
95
+ raise RuntimeError("Not a CUDA device")
96
+
97
+ return warp.context.runtime.core.cuda_load_module(device.context, input_path.encode("utf-8"))
98
+
99
+
100
+ def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=False, fuse_fp=True):
101
+ with open(cpp_path, "rb") as cpp:
102
+ src = cpp.read()
103
+ cpp_path = cpp_path.encode("utf-8")
104
+ inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
105
+ obj_path = obj_path.encode("utf-8")
106
+
107
+ err = warp.context.runtime.llvm.compile_cpp(
108
+ src, cpp_path, inc_path, obj_path, mode == "debug", verify_fp, fuse_fp
109
+ )
110
+ if err != 0:
111
+ raise Exception(f"CPU kernel build failed with error code {err}")
112
+
113
+
114
+ def init_kernel_cache(path=None):
115
+ """Initialize kernel cache directory.
116
+
117
+ This function is used during Warp initialization, but it can also be called directly to change the cache location.
118
+ If the path is not explicitly specified, a default location will be chosen based on OS-specific conventions.
119
+
120
+ To change the default cache location, set warp.config.kernel_cache_dir before calling warp.init().
121
+ """
122
+
123
+ if path is not None:
124
+ cache_root_dir = os.path.realpath(path)
125
+ elif "WARP_CACHE_PATH" in os.environ:
126
+ cache_root_dir = os.path.realpath(os.environ.get("WARP_CACHE_PATH"))
127
+ else:
128
+ cache_root_dir = appdirs.user_cache_dir(appname="warp", appauthor="NVIDIA", version=warp.config.version)
129
+
130
+ warp.config.kernel_cache_dir = cache_root_dir
131
+
132
+ os.makedirs(warp.config.kernel_cache_dir, exist_ok=True)
133
+
134
+
135
+ def clear_kernel_cache() -> None:
136
+ """Clear the kernel cache directory of previously generated source code and compiler artifacts.
137
+
138
+ Only directories beginning with ``wp_`` will be deleted.
139
+ This function only clears the cache for the current Warp version.
140
+ LTO artifacts are not affected.
141
+ """
142
+
143
+ warp.context.init()
144
+
145
+ import shutil
146
+
147
+ is_intialized = warp.context.runtime is not None
148
+ assert is_intialized, "The kernel cache directory is not configured; wp.init() has not been called yet or failed."
149
+
150
+ for item in os.listdir(warp.config.kernel_cache_dir):
151
+ item_path = os.path.join(warp.config.kernel_cache_dir, item)
152
+ if os.path.isdir(item_path) and item.startswith("wp_"):
153
+ # Remove the directory and its contents
154
+ shutil.rmtree(item_path, ignore_errors=True)
155
+
156
+
157
+ def clear_lto_cache() -> None:
158
+ """Clear the LTO cache directory of previously generated LTO code.
159
+
160
+ The LTO cache is stored within a subdirectory of the kernel cache directory.
161
+ This function only clears the cache for the current Warp version.
162
+ """
163
+
164
+ warp.context.init()
165
+
166
+ import shutil
167
+
168
+ is_intialized = warp.context.runtime is not None
169
+ assert is_intialized, "The kernel cache directory is not configured; wp.init() has not been called yet or failed."
170
+
171
+ lto_path = os.path.join(warp.config.kernel_cache_dir, "lto")
172
+ if os.path.isdir(lto_path):
173
+ # Remove the lto directory and its contents
174
+ shutil.rmtree(lto_path, ignore_errors=True)
175
+
176
+
177
+ def safe_rename(src, dst, attempts=5, delay=0.1):
178
+ for i in range(attempts):
179
+ try:
180
+ os.rename(src, dst)
181
+ return
182
+ except FileExistsError:
183
+ return
184
+ except OSError as e:
185
+ if e.errno == errno.ENOTEMPTY:
186
+ # if directory exists we assume another process
187
+ # got there first, in which case we will copy
188
+ # our output to the directory manually in second step
189
+ return
190
+ else:
191
+ # otherwise assume directory creation failed e.g.: access denied
192
+ # on Windows we see occasional failures to rename directories due to
193
+ # some process holding a lock on a file to be moved to workaround
194
+ # this we make multiple attempts to rename with some delay
195
+ if i < attempts - 1:
196
+ time.sleep(delay)
197
+ else:
198
+ print(
199
+ f"Could not update Warp cache with compiled binaries, trying to rename {src} to {dst}, error {e}"
200
+ )
201
+ raise e
202
+
203
+
204
+ def hash_symbol(symbol):
205
+ ch = hashlib.sha256()
206
+ ch.update(symbol.encode("utf-8"))
207
+ return ch.hexdigest()
208
+
209
+
210
+ def get_lto_cache_dir():
211
+ lto_dir = os.path.join(warp.config.kernel_cache_dir, "lto")
212
+ return lto_dir
213
+
214
+
215
+ def get_cached_lto(path):
216
+ if os.path.exists(path):
217
+ with open(path, "rb") as f:
218
+ lto_code_data = f.read()
219
+ return lto_code_data
220
+ else:
221
+ return None
222
+
223
+
224
+ def get_cached_lto_meta(path, symbol):
225
+ if os.path.exists(path):
226
+ with open(path, "r") as f:
227
+ keys = json.load(f)
228
+ value = keys[symbol]
229
+ return value
230
+ else:
231
+ return None
232
+
233
+
234
+ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
235
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
236
+ arch = min(arch, 90)
237
+
238
+ # Maps Python/Warp types to C++ types and enums
239
+ def cublasdx_type_map(dtype):
240
+ if dtype == float16:
241
+ return ("wp::float16", 3, 0)
242
+ if dtype == float32:
243
+ return ("wp::float32", 5, 0)
244
+ if dtype == float64:
245
+ return ("wp::float64", 6, 0)
246
+ if dtype == vec2h:
247
+ return ("wp::vec2h", 3, 1)
248
+ if dtype == vec2f:
249
+ return ("wp::vec2f", 5, 1)
250
+ if dtype == vec2d:
251
+ return ("wp::vec2d", 6, 1)
252
+ raise TypeError("Unsupported input type in tile_matmul")
253
+
254
+ def cublasdx_arrangement_map(layout):
255
+ if layout == "colmajor":
256
+ return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
257
+ if layout == "rowmajor":
258
+ return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
259
+ raise ValueError("Unsupported layout in tile_matmul")
260
+
261
+ (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
262
+ (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
263
+ (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
264
+ a_arrangement = cublasdx_arrangement_map(alayout)
265
+ b_arrangement = cublasdx_arrangement_map(blayout)
266
+ c_arrangement = cublasdx_arrangement_map(clayout)
267
+
268
+ if a_type != b_type or a_type != c_type:
269
+ raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
270
+
271
+ element_type = a_type
272
+
273
+ lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
274
+
275
+ # early out if LTO for this symbol is already cached in current module
276
+ if lto_symbol in builder.ltoirs:
277
+ return lto_symbol, builder.ltoirs[lto_symbol]
278
+
279
+ # hash symbol and determine output path
280
+ h = hash_symbol(lto_symbol)
281
+
282
+ lto_dir = get_lto_cache_dir()
283
+ lto_name = f"{h[:7]}.lto"
284
+ lto_path = os.path.join(lto_dir, lto_name)
285
+
286
+ # early out if LTO for this symbol is already built but not cached in current module
287
+ lto_code_data = get_cached_lto(lto_path)
288
+
289
+ if lto_code_data is not None:
290
+ builder.ltoirs[lto_symbol] = lto_code_data
291
+ builder.ltoirs_decl[lto_symbol] = (
292
+ f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
293
+ )
294
+
295
+ return lto_symbol, lto_code_data
296
+
297
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
298
+ build_dir = f"{lto_dir}_p{os.getpid()}"
299
+
300
+ # dir may exist from previous attempts / runs / archs
301
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
302
+
303
+ # temporary path to compile to in build_dir
304
+ temp_lto_path = os.path.join(build_dir, lto_name)
305
+
306
+ # compile LTO
307
+ result = warp.context.runtime.core.cuda_compile_dot(
308
+ temp_lto_path.encode("utf-8"),
309
+ lto_symbol.encode("utf-8"),
310
+ 0,
311
+ None,
312
+ None,
313
+ arch,
314
+ M,
315
+ N,
316
+ K,
317
+ a_prec,
318
+ b_prec,
319
+ c_prec,
320
+ element_type,
321
+ a_arrangement,
322
+ b_arrangement,
323
+ c_arrangement,
324
+ num_threads,
325
+ )
326
+
327
+ if not result:
328
+ if Path(temp_lto_path).exists():
329
+ Path(temp_lto_path).unlink()
330
+ raise RuntimeError("Failed to compile tile_matmul")
331
+ else:
332
+ with open(temp_lto_path, "rb") as f:
333
+ lto_code_data = f.read()
334
+
335
+ builder.ltoirs[lto_symbol] = lto_code_data
336
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
337
+
338
+ # try to move process outputs to cache
339
+ safe_rename(build_dir, lto_dir)
340
+
341
+ if os.path.exists(lto_dir):
342
+ if not os.path.exists(lto_path):
343
+ # copy output file to the destination lto dir
344
+ try:
345
+ os.rename(temp_lto_path, lto_path)
346
+ except (OSError, FileExistsError):
347
+ # another process likely updated the lto dir first
348
+ pass
349
+
350
+ if build_dir:
351
+ import shutil
352
+
353
+ # clean up build_dir used for this process
354
+ shutil.rmtree(build_dir, ignore_errors=True)
355
+
356
+ return lto_symbol, lto_code_data
357
+
358
+
359
+ def build_lto_solver(M, N, solver, solver_enum, fill_mode, arch, precision_enum, num_threads, parameter_list, builder):
360
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
361
+ arch = min(arch, 90)
362
+
363
+ lto_symbol = f"{solver}_{M}_{N}_{arch}_{precision_enum}"
364
+ ltoir_decl = f"void {lto_symbol}{parameter_list};"
365
+
366
+ # early out if LTO for this symbol is already cached in current module
367
+ if lto_symbol in builder.ltoirs:
368
+ return lto_symbol, builder.ltoirs[lto_symbol]
369
+
370
+ # hash symbol and determine output path
371
+ h = hash_symbol(lto_symbol)
372
+
373
+ lto_dir = get_lto_cache_dir()
374
+ lto_name = f"{h[:7]}.lto"
375
+ lto_path = os.path.join(lto_dir, lto_name)
376
+
377
+ # we also cache a universal fatbin binary for this symbol
378
+ universal_fatbin_name = f"{h[:7]}_fatbin.lto"
379
+ universal_fatbin_path = os.path.join(lto_dir, universal_fatbin_name)
380
+
381
+ lto_code_data = get_cached_lto(lto_path)
382
+ universal_fatbin_code_data = get_cached_lto(universal_fatbin_path)
383
+
384
+ # early out if LTO for this symbol is already built but not cached in current module
385
+ if lto_code_data is not None and universal_fatbin_code_data is not None:
386
+ builder.ltoirs[lto_symbol] = lto_code_data
387
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
388
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
389
+
390
+ return lto_symbol, lto_code_data
391
+
392
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
393
+ build_dir = f"{lto_dir}_p{os.getpid()}"
394
+
395
+ # dir may exist from previous attempts / runs / archs
396
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
397
+
398
+ # temporary paths to compile to in build_dir
399
+ temp_lto_path = os.path.join(build_dir, lto_name)
400
+ temp_universal_fatbin_path = os.path.join(build_dir, universal_fatbin_name)
401
+
402
+ # compile LTO
403
+ result = warp.context.runtime.core.cuda_compile_solver(
404
+ temp_universal_fatbin_path.encode("utf-8"),
405
+ temp_lto_path.encode("utf-8"),
406
+ lto_symbol.encode("utf-8"),
407
+ 0,
408
+ None,
409
+ None,
410
+ arch,
411
+ M,
412
+ N,
413
+ solver_enum,
414
+ precision_enum,
415
+ fill_mode,
416
+ num_threads,
417
+ )
418
+
419
+ if not result:
420
+ for path in [temp_universal_fatbin_path, temp_lto_path]:
421
+ if Path(path).exists():
422
+ Path(path).unlink()
423
+ raise RuntimeError("Failed to compile tile_cholesky")
424
+
425
+ else:
426
+ with open(temp_lto_path, "rb") as f:
427
+ lto_code_data = f.read()
428
+ with open(temp_universal_fatbin_path, "rb") as f:
429
+ universal_fatbin_code_data = f.read()
430
+
431
+ builder.ltoirs[lto_symbol] = lto_code_data
432
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
433
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
434
+
435
+ # try to move process outputs to lto cache
436
+ safe_rename(build_dir, lto_dir)
437
+
438
+ if os.path.exists(lto_dir):
439
+ for p in [(lto_path, temp_lto_path), (universal_fatbin_path, temp_universal_fatbin_path)]:
440
+ path, temp_path = p
441
+ if not os.path.exists(path):
442
+ # copy output file to the destination lto dir
443
+ try:
444
+ os.rename(temp_path, path)
445
+ except (OSError, FileExistsError):
446
+ # another process likely updated the lto dir first
447
+ pass
448
+
449
+ if build_dir:
450
+ import shutil
451
+
452
+ # clean up build_dir used for this process
453
+ shutil.rmtree(build_dir, ignore_errors=True)
454
+
455
+ return lto_symbol, lto_code_data
456
+
457
+
458
+ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
459
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
460
+ arch = min(arch, 90)
461
+
462
+ lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
463
+
464
+ # early out if LTO for this symbol is already cached in current module
465
+ if lto_symbol in builder.ltoirs:
466
+ return lto_symbol, builder.ltoirs[lto_symbol], builder.shared_memory_bytes[lto_symbol]
467
+
468
+ # hash symbol and determine output path
469
+ h = hash_symbol(lto_symbol)
470
+
471
+ lto_dir = get_lto_cache_dir()
472
+ lto_name = f"{h[:7]}.lto"
473
+ lto_path = os.path.join(lto_dir, lto_name)
474
+
475
+ # we also cache shared memory requirements for this kernel in a .meta file
476
+ meta_name = f"{h[:7]}.meta"
477
+ meta_path = os.path.join(lto_dir, meta_name)
478
+
479
+ # early out if LTO for this symbol is already built but not cached in current module
480
+ lto_code_data = get_cached_lto(lto_path)
481
+ shared_memory_bytes = get_cached_lto_meta(meta_path, lto_symbol)
482
+
483
+ if lto_code_data is not None and shared_memory_bytes is not None:
484
+ builder.ltoirs[lto_symbol] = lto_code_data
485
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
486
+
487
+ return lto_symbol, lto_code_data, shared_memory_bytes
488
+
489
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
490
+ build_dir = f"{lto_dir}_p{os.getpid()}"
491
+
492
+ # dir may exist from previous attempts / runs / archs
493
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
494
+
495
+ # temporary paths to compile to in build_dir
496
+ temp_lto_path = os.path.join(build_dir, lto_name)
497
+ temp_meta_path = os.path.join(build_dir, meta_name)
498
+
499
+ # compile LTO
500
+ shared_memory_size = ctypes.c_int(0)
501
+
502
+ result = warp.context.runtime.core.cuda_compile_fft(
503
+ temp_lto_path.encode("utf-8"),
504
+ lto_symbol.encode("utf-8"),
505
+ 0,
506
+ None,
507
+ None,
508
+ arch,
509
+ size,
510
+ ept,
511
+ dir,
512
+ precision,
513
+ ctypes.byref(shared_memory_size),
514
+ )
515
+
516
+ shared_memory_bytes = Tile.round_up(shared_memory_size.value)
517
+
518
+ if not result:
519
+ if Path(temp_lto_path).exists():
520
+ Path(temp_lto_path).unlink()
521
+ raise RuntimeError("Failed to compile tile_fft")
522
+
523
+ else:
524
+ with open(temp_lto_path, "rb") as f:
525
+ lto_code_data = f.read()
526
+
527
+ # output meta file with shared memory requirements for this lto_symbol
528
+ meta = {}
529
+ meta[lto_symbol] = shared_memory_bytes
530
+
531
+ with open(temp_meta_path, "w") as meta_file:
532
+ json.dump(meta, meta_file)
533
+
534
+ builder.ltoirs[lto_symbol] = lto_code_data
535
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
536
+
537
+ # try to move process outputs to cache
538
+ safe_rename(build_dir, lto_dir)
539
+
540
+ if os.path.exists(lto_dir):
541
+ for p in [(lto_path, temp_lto_path), (meta_path, temp_meta_path)]:
542
+ path, temp_path = p
543
+ if not os.path.exists(path):
544
+ # copy output file to the destination lto dir
545
+ try:
546
+ os.rename(temp_path, path)
547
+ except (OSError, FileExistsError):
548
+ # another process likely updated the lto dir first
549
+ pass
550
+
551
+ if build_dir:
552
+ import shutil
553
+
554
+ # clean up build_dir used for this process
555
+ shutil.rmtree(build_dir, ignore_errors=True)
556
+
557
+ return lto_symbol, lto_code_data, shared_memory_bytes