warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,371 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import unittest
18
+ from typing import Any
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.tests.unittest_utils import *
24
+
25
+
26
+ # basic kernel with one input and output
27
+ @wp.kernel
28
+ def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
29
+ tid = wp.tid()
30
+ output[tid] = 3.0 * input[tid]
31
+
32
+
33
+ # generic kernel with one scalar input and output
34
+ @wp.kernel
35
+ def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
36
+ tid = wp.tid()
37
+ output[tid] = input.dtype(3) * input[tid]
38
+
39
+
40
+ # generic kernel with one vector/matrix input and output
41
+ @wp.kernel
42
+ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
43
+ tid = wp.tid()
44
+ output[tid] = input.dtype.dtype(3) * input[tid]
45
+
46
+
47
+ # kernel with multiple inputs and outputs
48
+ @wp.kernel
49
+ def multiarg_kernel(
50
+ # inputs
51
+ a: wp.array(dtype=float),
52
+ b: wp.array(dtype=float),
53
+ c: wp.array(dtype=float),
54
+ # outputs
55
+ ab: wp.array(dtype=float),
56
+ bc: wp.array(dtype=float),
57
+ ):
58
+ tid = wp.tid()
59
+ ab[tid] = a[tid] + b[tid]
60
+ bc[tid] = b[tid] + c[tid]
61
+
62
+
63
+ # various types for testing
64
+ scalar_types = wp.types.scalar_types
65
+ vector_types = []
66
+ matrix_types = []
67
+ for dim in [2, 3, 4]:
68
+ for T in scalar_types:
69
+ vector_types.append(wp.vec(dim, T))
70
+ matrix_types.append(wp.mat((dim, dim), T))
71
+
72
+ # explicitly overload generic kernels to avoid module reloading during tests
73
+ for T in scalar_types:
74
+ wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
75
+ for T in [*vector_types, *matrix_types]:
76
+ wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
77
+
78
+
79
+ def _jax_version():
80
+ try:
81
+ import jax
82
+
83
+ return jax.__version_info__
84
+ except ImportError:
85
+ return (0, 0, 0)
86
+
87
+
88
+ def test_dtype_from_jax(test, device):
89
+ import jax.numpy as jp
90
+
91
+ def test_conversions(jax_type, warp_type):
92
+ test.assertEqual(wp.dtype_from_jax(jax_type), warp_type)
93
+ test.assertEqual(wp.dtype_from_jax(jp.dtype(jax_type)), warp_type)
94
+
95
+ test_conversions(jp.float16, wp.float16)
96
+ test_conversions(jp.float32, wp.float32)
97
+ test_conversions(jp.float64, wp.float64)
98
+ test_conversions(jp.int8, wp.int8)
99
+ test_conversions(jp.int16, wp.int16)
100
+ test_conversions(jp.int32, wp.int32)
101
+ test_conversions(jp.int64, wp.int64)
102
+ test_conversions(jp.uint8, wp.uint8)
103
+ test_conversions(jp.uint16, wp.uint16)
104
+ test_conversions(jp.uint32, wp.uint32)
105
+ test_conversions(jp.uint64, wp.uint64)
106
+ test_conversions(jp.bool_, wp.bool)
107
+
108
+
109
+ def test_dtype_to_jax(test, device):
110
+ import jax.numpy as jp
111
+
112
+ def test_conversions(warp_type, jax_type):
113
+ test.assertEqual(wp.dtype_to_jax(warp_type), jax_type)
114
+
115
+ test_conversions(wp.float16, jp.float16)
116
+ test_conversions(wp.float32, jp.float32)
117
+ test_conversions(wp.float64, jp.float64)
118
+ test_conversions(wp.int8, jp.int8)
119
+ test_conversions(wp.int16, jp.int16)
120
+ test_conversions(wp.int32, jp.int32)
121
+ test_conversions(wp.int64, jp.int64)
122
+ test_conversions(wp.uint8, jp.uint8)
123
+ test_conversions(wp.uint16, jp.uint16)
124
+ test_conversions(wp.uint32, jp.uint32)
125
+ test_conversions(wp.uint64, jp.uint64)
126
+ test_conversions(wp.bool, jp.bool_)
127
+
128
+
129
+ def test_device_conversion(test, device):
130
+ jax_device = wp.device_to_jax(device)
131
+ warp_device = wp.device_from_jax(jax_device)
132
+ test.assertEqual(warp_device, device)
133
+
134
+
135
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
136
+ def test_jax_kernel_basic(test, device):
137
+ import jax.numpy as jp
138
+
139
+ from warp.jax_experimental import jax_kernel
140
+
141
+ n = 64
142
+
143
+ jax_triple = jax_kernel(triple_kernel)
144
+
145
+ @jax.jit
146
+ def f():
147
+ x = jp.arange(n, dtype=jp.float32)
148
+ return jax_triple(x)
149
+
150
+ # run on the given device
151
+ with jax.default_device(wp.device_to_jax(device)):
152
+ y = f()
153
+
154
+ result = np.asarray(y).reshape((n,))
155
+ expected = 3 * np.arange(n, dtype=np.float32)
156
+
157
+ assert_np_equal(result, expected)
158
+
159
+
160
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
161
+ def test_jax_kernel_scalar(test, device):
162
+ import jax.numpy as jp
163
+
164
+ from warp.jax_experimental import jax_kernel
165
+
166
+ n = 64
167
+
168
+ for T in scalar_types:
169
+ jp_dtype = wp.dtype_to_jax(T)
170
+ np_dtype = wp.dtype_to_numpy(T)
171
+
172
+ with test.subTest(msg=T.__name__):
173
+ # get the concrete overload
174
+ kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
175
+
176
+ jax_triple = jax_kernel(kernel_instance)
177
+
178
+ @jax.jit
179
+ def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
180
+ x = jp.arange(n, dtype=jp_dtype)
181
+ return jax_triple(x)
182
+
183
+ # run on the given device
184
+ with jax.default_device(wp.device_to_jax(device)):
185
+ y = f()
186
+
187
+ result = np.asarray(y).reshape((n,))
188
+ expected = 3 * np.arange(n, dtype=np_dtype)
189
+
190
+ assert_np_equal(result, expected)
191
+
192
+
193
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
194
+ def test_jax_kernel_vecmat(test, device):
195
+ import jax.numpy as jp
196
+
197
+ from warp.jax_experimental import jax_kernel
198
+
199
+ for T in [*vector_types, *matrix_types]:
200
+ jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
201
+ np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
202
+
203
+ n = 64 // T._length_
204
+ scalar_shape = (n, *T._shape_)
205
+ scalar_len = n * T._length_
206
+
207
+ with test.subTest(msg=T.__name__):
208
+ # get the concrete overload
209
+ kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
210
+
211
+ jax_triple = jax_kernel(kernel_instance)
212
+
213
+ @jax.jit
214
+ def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
215
+ x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
216
+ return jax_triple(x)
217
+
218
+ # run on the given device
219
+ with jax.default_device(wp.device_to_jax(device)):
220
+ y = f()
221
+
222
+ result = np.asarray(y).reshape(scalar_shape)
223
+ expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
224
+
225
+ assert_np_equal(result, expected)
226
+
227
+
228
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
229
+ def test_jax_kernel_multiarg(test, device):
230
+ import jax.numpy as jp
231
+
232
+ from warp.jax_experimental import jax_kernel
233
+
234
+ n = 64
235
+
236
+ jax_multiarg = jax_kernel(multiarg_kernel)
237
+
238
+ @jax.jit
239
+ def f():
240
+ a = jp.full(n, 1, dtype=jp.float32)
241
+ b = jp.full(n, 2, dtype=jp.float32)
242
+ c = jp.full(n, 3, dtype=jp.float32)
243
+ return jax_multiarg(a, b, c)
244
+
245
+ # run on the given device
246
+ with jax.default_device(wp.device_to_jax(device)):
247
+ x, y = f()
248
+
249
+ result_x, result_y = np.asarray(x), np.asarray(y)
250
+ expected_x = np.full(n, 3, dtype=np.float32)
251
+ expected_y = np.full(n, 5, dtype=np.float32)
252
+
253
+ assert_np_equal(result_x, expected_x)
254
+ assert_np_equal(result_y, expected_y)
255
+
256
+
257
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
258
+ def test_jax_kernel_launch_dims(test, device):
259
+ import jax.numpy as jp
260
+
261
+ from warp.jax_experimental import jax_kernel
262
+
263
+ n = 64
264
+ m = 32
265
+
266
+ # Test with 1D launch dims
267
+ @wp.kernel
268
+ def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
269
+ tid = wp.tid()
270
+ y[tid] = x[tid] + 1.0
271
+
272
+ jax_add_one = jax_kernel(
273
+ add_one_kernel, launch_dims=(n - 2,)
274
+ ) # Intentionally not the same as the first dimension of the input
275
+
276
+ @jax.jit
277
+ def f_1d():
278
+ x = jp.arange(n, dtype=jp.float32)
279
+ return jax_add_one(x)
280
+
281
+ # Test with 2D launch dims
282
+ @wp.kernel
283
+ def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
284
+ i, j = wp.tid()
285
+ y[i, j] = x[i, j] + 1.0
286
+
287
+ jax_add_one_2d = jax_kernel(
288
+ add_one_2d_kernel, launch_dims=(n - 2, m - 2)
289
+ ) # Intentionally not the same as the first dimension of the input
290
+
291
+ @jax.jit
292
+ def f_2d():
293
+ x = jp.zeros((n, m), dtype=jp.float32) + 3.0
294
+ return jax_add_one_2d(x)
295
+
296
+ # run on the given device
297
+ with jax.default_device(wp.device_to_jax(device)):
298
+ y_1d = f_1d()
299
+ y_2d = f_2d()
300
+
301
+ result_1d = np.asarray(y_1d).reshape((n - 2,))
302
+ expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
303
+
304
+ result_2d = np.asarray(y_2d).reshape((n - 2, m - 2))
305
+ expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32)
306
+
307
+ assert_np_equal(result_1d, expected_1d)
308
+ assert_np_equal(result_2d, expected_2d)
309
+
310
+
311
+ class TestJax(unittest.TestCase):
312
+ pass
313
+
314
+
315
+ # try adding Jax tests if Jax is installed correctly
316
+ try:
317
+ # prevent Jax from gobbling up GPU memory
318
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
319
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
320
+
321
+ import jax
322
+ import jax.dlpack
323
+
324
+ # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
325
+ jax.config.update("jax_enable_x64", True)
326
+
327
+ # check which Warp devices work with Jax
328
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
329
+ test_devices = get_test_devices()
330
+ jax_compatible_devices = []
331
+ jax_compatible_cuda_devices = []
332
+ for d in test_devices:
333
+ try:
334
+ with jax.default_device(wp.device_to_jax(d)):
335
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
336
+ j += 1
337
+ jax_compatible_devices.append(d)
338
+ if d.is_cuda:
339
+ jax_compatible_cuda_devices.append(d)
340
+ except Exception as e:
341
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
342
+
343
+ add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
344
+ add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
345
+
346
+ if jax_compatible_devices:
347
+ add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
348
+
349
+ if jax_compatible_cuda_devices:
350
+ add_function_test(TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices)
351
+ add_function_test(
352
+ TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
353
+ )
354
+ add_function_test(
355
+ TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
356
+ )
357
+ add_function_test(
358
+ TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
359
+ )
360
+
361
+ add_function_test(
362
+ TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices
363
+ )
364
+
365
+ except Exception as e:
366
+ print(f"Skipping Jax tests due to exception: {e}")
367
+
368
+
369
+ if __name__ == "__main__":
370
+ wp.clear_kernel_cache()
371
+ unittest.main(verbosity=2)