warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,729 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import os
18
+ import unittest
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.tests.unittest_utils import *
24
+
25
+ N = 1024 * 1024
26
+
27
+
28
+ def _jax_version():
29
+ try:
30
+ import jax
31
+
32
+ return jax.__version_info__
33
+ except (ImportError, AttributeError):
34
+ return (0, 0, 0)
35
+
36
+
37
+ @wp.kernel
38
+ def inc(a: wp.array(dtype=float)):
39
+ tid = wp.tid()
40
+ a[tid] = a[tid] + 1.0
41
+
42
+
43
+ def test_dlpack_warp_to_warp(test, device):
44
+ a1 = wp.array(data=np.arange(N, dtype=np.float32), device=device)
45
+
46
+ a2 = wp.from_dlpack(wp.to_dlpack(a1))
47
+
48
+ test.assertEqual(a1.ptr, a2.ptr)
49
+ test.assertEqual(a1.device, a2.device)
50
+ test.assertEqual(a1.dtype, a2.dtype)
51
+ test.assertEqual(a1.shape, a2.shape)
52
+ test.assertEqual(a1.strides, a2.strides)
53
+
54
+ assert_np_equal(a1.numpy(), a2.numpy())
55
+
56
+ wp.launch(inc, dim=a2.size, inputs=[a2], device=device)
57
+
58
+ assert_np_equal(a1.numpy(), a2.numpy())
59
+
60
+
61
+ def test_dlpack_dtypes_and_shapes(test, device):
62
+ # automatically determine scalar dtype
63
+ def wrap_scalar_tensor_implicit(dtype):
64
+ a1 = wp.zeros(N, dtype=dtype, device=device)
65
+ a2 = wp.from_dlpack(wp.to_dlpack(a1))
66
+
67
+ test.assertEqual(a1.ptr, a2.ptr)
68
+ test.assertEqual(a1.device, a2.device)
69
+ test.assertEqual(a1.dtype, a2.dtype)
70
+ test.assertEqual(a1.shape, a2.shape)
71
+ test.assertEqual(a1.strides, a2.strides)
72
+
73
+ # explicitly specify scalar dtype
74
+ def wrap_scalar_tensor_explicit(dtype, target_dtype):
75
+ a1 = wp.zeros(N, dtype=dtype, device=device)
76
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
77
+
78
+ test.assertEqual(a1.ptr, a2.ptr)
79
+ test.assertEqual(a1.device, a2.device)
80
+ test.assertEqual(a1.dtype, dtype)
81
+ test.assertEqual(a2.dtype, target_dtype)
82
+ test.assertEqual(a1.shape, a2.shape)
83
+ test.assertEqual(a1.strides, a2.strides)
84
+
85
+ # convert vector arrays to scalar arrays
86
+ def wrap_vector_to_scalar_tensor(vec_dtype):
87
+ scalar_type = vec_dtype._wp_scalar_type_
88
+ scalar_size = ctypes.sizeof(vec_dtype._type_)
89
+
90
+ a1 = wp.zeros(N, dtype=vec_dtype, device=device)
91
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
92
+
93
+ test.assertEqual(a1.ptr, a2.ptr)
94
+ test.assertEqual(a1.device, a2.device)
95
+ test.assertEqual(a2.ndim, a1.ndim + 1)
96
+ test.assertEqual(a1.dtype, vec_dtype)
97
+ test.assertEqual(a2.dtype, scalar_type)
98
+ test.assertEqual(a2.shape, (*a1.shape, vec_dtype._length_))
99
+ test.assertEqual(a2.strides, (*a1.strides, scalar_size))
100
+
101
+ # convert scalar arrays to vector arrays
102
+ def wrap_scalar_to_vector_tensor(vec_dtype):
103
+ scalar_type = vec_dtype._wp_scalar_type_
104
+ scalar_size = ctypes.sizeof(vec_dtype._type_)
105
+
106
+ a1 = wp.zeros((N, vec_dtype._length_), dtype=scalar_type, device=device)
107
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
108
+
109
+ test.assertEqual(a1.ptr, a2.ptr)
110
+ test.assertEqual(a1.device, a2.device)
111
+ test.assertEqual(a2.ndim, a1.ndim - 1)
112
+ test.assertEqual(a1.dtype, scalar_type)
113
+ test.assertEqual(a2.dtype, vec_dtype)
114
+ test.assertEqual(a1.shape, (*a2.shape, vec_dtype._length_))
115
+ test.assertEqual(a1.strides, (*a2.strides, scalar_size))
116
+
117
+ # convert matrix arrays to scalar arrays
118
+ def wrap_matrix_to_scalar_tensor(mat_dtype):
119
+ scalar_type = mat_dtype._wp_scalar_type_
120
+ scalar_size = ctypes.sizeof(mat_dtype._type_)
121
+
122
+ a1 = wp.zeros(N, dtype=mat_dtype, device=device)
123
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
124
+
125
+ test.assertEqual(a1.ptr, a2.ptr)
126
+ test.assertEqual(a1.device, a2.device)
127
+ test.assertEqual(a2.ndim, a1.ndim + 2)
128
+ test.assertEqual(a1.dtype, mat_dtype)
129
+ test.assertEqual(a2.dtype, scalar_type)
130
+ test.assertEqual(a2.shape, (*a1.shape, *mat_dtype._shape_))
131
+ test.assertEqual(a2.strides, (*a1.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
132
+
133
+ # convert scalar arrays to matrix arrays
134
+ def wrap_scalar_to_matrix_tensor(mat_dtype):
135
+ scalar_type = mat_dtype._wp_scalar_type_
136
+ scalar_size = ctypes.sizeof(mat_dtype._type_)
137
+
138
+ a1 = wp.zeros((N, *mat_dtype._shape_), dtype=scalar_type, device=device)
139
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
140
+
141
+ test.assertEqual(a1.ptr, a2.ptr)
142
+ test.assertEqual(a1.device, a2.device)
143
+ test.assertEqual(a2.ndim, a1.ndim - 2)
144
+ test.assertEqual(a1.dtype, scalar_type)
145
+ test.assertEqual(a2.dtype, mat_dtype)
146
+ test.assertEqual(a1.shape, (*a2.shape, *mat_dtype._shape_))
147
+ test.assertEqual(a1.strides, (*a2.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
148
+
149
+ for t in wp.types.scalar_types:
150
+ wrap_scalar_tensor_implicit(t)
151
+
152
+ for t in wp.types.scalar_types:
153
+ wrap_scalar_tensor_explicit(t, t)
154
+
155
+ # test signed/unsigned conversions
156
+ wrap_scalar_tensor_explicit(wp.int8, wp.uint8)
157
+ wrap_scalar_tensor_explicit(wp.uint8, wp.int8)
158
+ wrap_scalar_tensor_explicit(wp.int16, wp.uint16)
159
+ wrap_scalar_tensor_explicit(wp.uint16, wp.int16)
160
+ wrap_scalar_tensor_explicit(wp.int32, wp.uint32)
161
+ wrap_scalar_tensor_explicit(wp.uint32, wp.int32)
162
+ wrap_scalar_tensor_explicit(wp.int64, wp.uint64)
163
+ wrap_scalar_tensor_explicit(wp.uint64, wp.int64)
164
+
165
+ vec_types = []
166
+ for t in wp.types.scalar_types:
167
+ for vec_len in [2, 3, 4, 5]:
168
+ vec_types.append(wp.types.vector(vec_len, t))
169
+
170
+ vec_types.append(wp.quath)
171
+ vec_types.append(wp.quatf)
172
+ vec_types.append(wp.quatd)
173
+ vec_types.append(wp.transformh)
174
+ vec_types.append(wp.transformf)
175
+ vec_types.append(wp.transformd)
176
+ vec_types.append(wp.spatial_vectorh)
177
+ vec_types.append(wp.spatial_vectorf)
178
+ vec_types.append(wp.spatial_vectord)
179
+
180
+ for vec_type in vec_types:
181
+ wrap_vector_to_scalar_tensor(vec_type)
182
+ wrap_scalar_to_vector_tensor(vec_type)
183
+
184
+ mat_shapes = [(2, 2), (3, 3), (4, 4), (5, 5), (2, 3), (3, 2), (3, 4), (4, 3)]
185
+ mat_types = []
186
+ for t in wp.types.scalar_types:
187
+ for mat_shape in mat_shapes:
188
+ mat_types.append(wp.types.matrix(mat_shape, t))
189
+
190
+ mat_types.append(wp.spatial_matrixh)
191
+ mat_types.append(wp.spatial_matrixf)
192
+ mat_types.append(wp.spatial_matrixd)
193
+
194
+ for mat_type in mat_types:
195
+ wrap_matrix_to_scalar_tensor(mat_type)
196
+ wrap_scalar_to_matrix_tensor(mat_type)
197
+
198
+
199
+ def test_dlpack_stream_arg(test, device):
200
+ # test valid range for the stream argument to array.__dlpack__()
201
+
202
+ data = np.arange(10)
203
+
204
+ def check_result(capsule):
205
+ result = wp.dlpack._from_dlpack(capsule)
206
+ assert_np_equal(result.numpy(), data)
207
+
208
+ with wp.ScopedDevice(device):
209
+ a = wp.array(data=data)
210
+
211
+ # stream arguments supported for all devices
212
+ check_result(a.__dlpack__())
213
+ check_result(a.__dlpack__(stream=None))
214
+ check_result(a.__dlpack__(stream=-1))
215
+
216
+ # device-specific stream arguments
217
+ if device.is_cuda:
218
+ check_result(a.__dlpack__(stream=0)) # default stream
219
+ check_result(a.__dlpack__(stream=1)) # legacy default stream
220
+ check_result(a.__dlpack__(stream=2)) # per thread default stream
221
+
222
+ # custom stream
223
+ stream = wp.Stream(device)
224
+ check_result(a.__dlpack__(stream=stream.cuda_stream))
225
+
226
+ # unsupported stream arguments
227
+ expected_error = r"DLPack stream must None or an integer >= -1"
228
+ with test.assertRaisesRegex(TypeError, expected_error):
229
+ check_result(a.__dlpack__(stream=-2))
230
+ with test.assertRaisesRegex(TypeError, expected_error):
231
+ check_result(a.__dlpack__(stream="nope"))
232
+ else:
233
+ expected_error = r"DLPack stream must be None or -1 for CPU device"
234
+
235
+ with test.assertRaisesRegex(TypeError, expected_error):
236
+ check_result(a.__dlpack__(stream=0))
237
+ with test.assertRaisesRegex(TypeError, expected_error):
238
+ check_result(a.__dlpack__(stream=1))
239
+ with test.assertRaisesRegex(TypeError, expected_error):
240
+ check_result(a.__dlpack__(stream=2))
241
+ with test.assertRaisesRegex(TypeError, expected_error):
242
+ check_result(a.__dlpack__(stream=1742))
243
+
244
+ with test.assertRaisesRegex(TypeError, expected_error):
245
+ check_result(a.__dlpack__(stream=-2))
246
+ with test.assertRaisesRegex(TypeError, expected_error):
247
+ check_result(a.__dlpack__(stream="nope"))
248
+
249
+
250
+ def test_dlpack_warp_to_torch(test, device):
251
+ import torch.utils.dlpack
252
+
253
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
254
+
255
+ t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
256
+
257
+ item_size = wp.types.type_size_in_bytes(a.dtype)
258
+
259
+ test.assertEqual(a.ptr, t.data_ptr())
260
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
261
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
262
+ test.assertEqual(a.shape, tuple(t.shape))
263
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
264
+
265
+ assert_np_equal(a.numpy(), t.cpu().numpy())
266
+
267
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
268
+
269
+ assert_np_equal(a.numpy(), t.cpu().numpy())
270
+
271
+ t += 1
272
+
273
+ assert_np_equal(a.numpy(), t.cpu().numpy())
274
+
275
+
276
+ def test_dlpack_warp_to_torch_v2(test, device):
277
+ # same as original test, but uses newer __dlpack__() method
278
+
279
+ import torch.utils.dlpack
280
+
281
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
282
+
283
+ # pass the array directly
284
+ t = torch.utils.dlpack.from_dlpack(a)
285
+
286
+ item_size = wp.types.type_size_in_bytes(a.dtype)
287
+
288
+ test.assertEqual(a.ptr, t.data_ptr())
289
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
290
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
291
+ test.assertEqual(a.shape, tuple(t.shape))
292
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
293
+
294
+ assert_np_equal(a.numpy(), t.cpu().numpy())
295
+
296
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
297
+
298
+ assert_np_equal(a.numpy(), t.cpu().numpy())
299
+
300
+ t += 1
301
+
302
+ assert_np_equal(a.numpy(), t.cpu().numpy())
303
+
304
+
305
+ def test_dlpack_torch_to_warp(test, device):
306
+ import torch
307
+ import torch.utils.dlpack
308
+
309
+ t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
310
+
311
+ a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
312
+
313
+ item_size = wp.types.type_size_in_bytes(a.dtype)
314
+
315
+ test.assertEqual(a.ptr, t.data_ptr())
316
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
317
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
318
+ test.assertEqual(a.shape, tuple(t.shape))
319
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
320
+
321
+ assert_np_equal(a.numpy(), t.cpu().numpy())
322
+
323
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
324
+
325
+ assert_np_equal(a.numpy(), t.cpu().numpy())
326
+
327
+ t += 1
328
+
329
+ assert_np_equal(a.numpy(), t.cpu().numpy())
330
+
331
+
332
+ def test_dlpack_torch_to_warp_v2(test, device):
333
+ # same as original test, but uses newer __dlpack__() method
334
+
335
+ import torch
336
+
337
+ t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
338
+
339
+ # pass tensor directly
340
+ a = wp.from_dlpack(t)
341
+
342
+ item_size = wp.types.type_size_in_bytes(a.dtype)
343
+
344
+ test.assertEqual(a.ptr, t.data_ptr())
345
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
346
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
347
+ test.assertEqual(a.shape, tuple(t.shape))
348
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
349
+
350
+ assert_np_equal(a.numpy(), t.cpu().numpy())
351
+
352
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
353
+
354
+ assert_np_equal(a.numpy(), t.cpu().numpy())
355
+
356
+ t += 1
357
+
358
+ assert_np_equal(a.numpy(), t.cpu().numpy())
359
+
360
+
361
+ def test_dlpack_paddle_to_warp(test, device):
362
+ import paddle
363
+ import paddle.utils.dlpack
364
+
365
+ t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
366
+
367
+ # paddle do not implement __dlpack__ yet, so only test to_dlpack here
368
+ a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t))
369
+
370
+ item_size = wp.types.type_size_in_bytes(a.dtype)
371
+
372
+ test.assertEqual(a.ptr, t.data_ptr())
373
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
374
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
375
+ test.assertEqual(a.shape, tuple(t.shape))
376
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
377
+
378
+ assert_np_equal(a.numpy(), t.numpy())
379
+
380
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
381
+
382
+ assert_np_equal(a.numpy(), t.numpy())
383
+
384
+ paddle.assign(t + 1, t)
385
+
386
+ assert_np_equal(a.numpy(), t.numpy())
387
+
388
+
389
+ def test_dlpack_warp_to_jax(test, device):
390
+ import jax
391
+ import jax.dlpack
392
+ import jax.numpy as jnp
393
+
394
+ cpu_device = jax.devices("cpu")[0]
395
+
396
+ # Create a numpy array from a JAX array to respect XLA alignment needs
397
+ with jax.default_device(cpu_device):
398
+ x_jax = jnp.arange(N, dtype=jnp.float32)
399
+ x_numpy = np.asarray(x_jax)
400
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
401
+
402
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
403
+
404
+ if device.is_cpu:
405
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
406
+
407
+ # use generic dlpack conversion
408
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
409
+
410
+ # use jax wrapper
411
+ j2 = wp.to_jax(a)
412
+
413
+ test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
414
+ test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
415
+ test.assertEqual(a.device, wp.device_from_jax(list(j1.devices())[0]))
416
+ test.assertEqual(a.device, wp.device_from_jax(list(j2.devices())[0]))
417
+ test.assertEqual(a.shape, j1.shape)
418
+ test.assertEqual(a.shape, j2.shape)
419
+
420
+ assert_np_equal(a.numpy(), np.asarray(j1))
421
+ assert_np_equal(a.numpy(), np.asarray(j2))
422
+
423
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
424
+ wp.synchronize_device(device)
425
+
426
+ # HACK? Run a no-op operation so that Jax flags the arrays as dirty
427
+ # and gets the latest values, which were modified by Warp.
428
+ j1 += 0
429
+ j2 += 0
430
+
431
+ assert_np_equal(a.numpy(), np.asarray(j1))
432
+ assert_np_equal(a.numpy(), np.asarray(j2))
433
+
434
+
435
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
436
+ def test_dlpack_warp_to_jax_v2(test, device):
437
+ # same as original test, but uses newer __dlpack__() method
438
+ import jax
439
+ import jax.dlpack
440
+ import jax.numpy as jnp
441
+
442
+ cpu_device = jax.devices("cpu")[0]
443
+
444
+ # Create a numpy array from a JAX array to respect XLA alignment needs
445
+ with jax.default_device(cpu_device):
446
+ x_jax = jnp.arange(N, dtype=jnp.float32)
447
+ x_numpy = np.asarray(x_jax)
448
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
449
+
450
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
451
+
452
+ if device.is_cpu:
453
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
454
+
455
+ # pass warp array directly
456
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
457
+
458
+ # use jax wrapper
459
+ j2 = wp.to_jax(a)
460
+
461
+ test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
462
+ test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
463
+ test.assertEqual(a.device, wp.device_from_jax(list(j1.devices())[0]))
464
+ test.assertEqual(a.device, wp.device_from_jax(list(j2.devices())[0]))
465
+ test.assertEqual(a.shape, j1.shape)
466
+ test.assertEqual(a.shape, j2.shape)
467
+
468
+ assert_np_equal(a.numpy(), np.asarray(j1))
469
+ assert_np_equal(a.numpy(), np.asarray(j2))
470
+
471
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
472
+ wp.synchronize_device(device)
473
+
474
+ # HACK? Run a no-op operation so that Jax flags the arrays as dirty
475
+ # and gets the latest values, which were modified by Warp.
476
+ j1 += 0
477
+ j2 += 0
478
+
479
+ assert_np_equal(a.numpy(), np.asarray(j1))
480
+ assert_np_equal(a.numpy(), np.asarray(j2))
481
+
482
+
483
+ def test_dlpack_warp_to_paddle(test, device):
484
+ import paddle.utils.dlpack
485
+
486
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
487
+
488
+ t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a))
489
+
490
+ item_size = wp.types.type_size_in_bytes(a.dtype)
491
+
492
+ test.assertEqual(a.ptr, t.data_ptr())
493
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
494
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
495
+ test.assertEqual(a.shape, tuple(t.shape))
496
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
497
+
498
+ assert_np_equal(a.numpy(), t.cpu().numpy())
499
+
500
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
501
+
502
+ assert_np_equal(a.numpy(), t.cpu().numpy())
503
+
504
+ paddle.assign(t + 1, t)
505
+
506
+ assert_np_equal(a.numpy(), t.cpu().numpy())
507
+
508
+
509
+ def test_dlpack_warp_to_paddle_v2(test, device):
510
+ # same as original test, but uses newer __dlpack__() method
511
+
512
+ import paddle.utils.dlpack
513
+
514
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
515
+
516
+ # pass the array directly
517
+ t = paddle.utils.dlpack.from_dlpack(a)
518
+
519
+ item_size = wp.types.type_size_in_bytes(a.dtype)
520
+
521
+ test.assertEqual(a.ptr, t.data_ptr())
522
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
523
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
524
+ test.assertEqual(a.shape, tuple(t.shape))
525
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
526
+
527
+ assert_np_equal(a.numpy(), t.numpy())
528
+
529
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
530
+
531
+ assert_np_equal(a.numpy(), t.numpy())
532
+
533
+ paddle.assign(t + 1, t)
534
+
535
+ assert_np_equal(a.numpy(), t.numpy())
536
+
537
+
538
+ def test_dlpack_jax_to_warp(test, device):
539
+ import jax
540
+ import jax.dlpack
541
+
542
+ with jax.default_device(wp.device_to_jax(device)):
543
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
544
+
545
+ # use generic dlpack conversion
546
+ a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j))
547
+
548
+ # use jax wrapper
549
+ a2 = wp.from_jax(j)
550
+
551
+ test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
552
+ test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
553
+ test.assertEqual(a1.device, wp.device_from_jax(list(j.devices())[0]))
554
+ test.assertEqual(a2.device, wp.device_from_jax(list(j.devices())[0]))
555
+ test.assertEqual(a1.shape, j.shape)
556
+ test.assertEqual(a2.shape, j.shape)
557
+
558
+ assert_np_equal(a1.numpy(), np.asarray(j))
559
+ assert_np_equal(a2.numpy(), np.asarray(j))
560
+
561
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
562
+ wp.synchronize_device(device)
563
+
564
+ # HACK? Run a no-op operation so that Jax flags the array as dirty
565
+ # and gets the latest values, which were modified by Warp.
566
+ j += 0
567
+
568
+ assert_np_equal(a1.numpy(), np.asarray(j))
569
+ assert_np_equal(a2.numpy(), np.asarray(j))
570
+
571
+
572
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
573
+ def test_dlpack_jax_to_warp_v2(test, device):
574
+ # same as original test, but uses newer __dlpack__() method
575
+
576
+ import jax
577
+
578
+ with jax.default_device(wp.device_to_jax(device)):
579
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
580
+
581
+ # pass jax array directly
582
+ a1 = wp.from_dlpack(j)
583
+
584
+ # use jax wrapper
585
+ a2 = wp.from_jax(j)
586
+
587
+ test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
588
+ test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
589
+ test.assertEqual(a1.device, wp.device_from_jax(list(j.devices())[0]))
590
+ test.assertEqual(a2.device, wp.device_from_jax(list(j.devices())[0]))
591
+ test.assertEqual(a1.shape, j.shape)
592
+ test.assertEqual(a2.shape, j.shape)
593
+
594
+ assert_np_equal(a1.numpy(), np.asarray(j))
595
+ assert_np_equal(a2.numpy(), np.asarray(j))
596
+
597
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
598
+ wp.synchronize_device(device)
599
+
600
+ # HACK? Run a no-op operation so that Jax flags the array as dirty
601
+ # and gets the latest values, which were modified by Warp.
602
+ j += 0
603
+
604
+ assert_np_equal(a1.numpy(), np.asarray(j))
605
+ assert_np_equal(a2.numpy(), np.asarray(j))
606
+
607
+
608
+ class TestDLPack(unittest.TestCase):
609
+ pass
610
+
611
+
612
+ devices = get_test_devices()
613
+
614
+ add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
615
+ add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
616
+ add_function_test(TestDLPack, "test_dlpack_stream_arg", test_dlpack_stream_arg, devices=devices)
617
+
618
+ # torch interop via dlpack
619
+ try:
620
+ import torch
621
+ import torch.utils.dlpack
622
+
623
+ # check which Warp devices work with Torch
624
+ # CUDA devices may fail if Torch was not compiled with CUDA support
625
+ test_devices = get_test_devices()
626
+ torch_compatible_devices = []
627
+ for d in test_devices:
628
+ try:
629
+ t = torch.arange(10, device=wp.device_to_torch(d))
630
+ t += 1
631
+ torch_compatible_devices.append(d)
632
+ except Exception as e:
633
+ print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
634
+
635
+ if torch_compatible_devices:
636
+ add_function_test(
637
+ TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
638
+ )
639
+ add_function_test(
640
+ TestDLPack, "test_dlpack_warp_to_torch_v2", test_dlpack_warp_to_torch_v2, devices=torch_compatible_devices
641
+ )
642
+ add_function_test(
643
+ TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
644
+ )
645
+ add_function_test(
646
+ TestDLPack, "test_dlpack_torch_to_warp_v2", test_dlpack_torch_to_warp_v2, devices=torch_compatible_devices
647
+ )
648
+
649
+ except Exception as e:
650
+ print(f"Skipping Torch DLPack tests due to exception: {e}")
651
+
652
+ # jax interop via dlpack
653
+ try:
654
+ # prevent Jax from gobbling up GPU memory
655
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
656
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
657
+
658
+ import jax
659
+ import jax.dlpack
660
+
661
+ # check which Warp devices work with Jax
662
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
663
+ test_devices = get_test_devices()
664
+ jax_compatible_devices = []
665
+ for d in test_devices:
666
+ try:
667
+ with jax.default_device(wp.device_to_jax(d)):
668
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
669
+ j += 1
670
+ jax_compatible_devices.append(d)
671
+ except Exception as e:
672
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
673
+
674
+ if jax_compatible_devices:
675
+ add_function_test(
676
+ TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
677
+ )
678
+ add_function_test(
679
+ TestDLPack, "test_dlpack_warp_to_jax_v2", test_dlpack_warp_to_jax_v2, devices=jax_compatible_devices
680
+ )
681
+ add_function_test(
682
+ TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
683
+ )
684
+ add_function_test(
685
+ TestDLPack, "test_dlpack_jax_to_warp_v2", test_dlpack_jax_to_warp_v2, devices=jax_compatible_devices
686
+ )
687
+
688
+ except Exception as e:
689
+ print(f"Skipping Jax DLPack tests due to exception: {e}")
690
+
691
+
692
+ # paddle interop via dlpack
693
+ try:
694
+ import paddle
695
+ import paddle.utils.dlpack
696
+
697
+ # check which Warp devices work with paddle
698
+ # CUDA devices may fail if paddle was not compiled with CUDA support
699
+ test_devices = get_test_devices()
700
+ paddle_compatible_devices = []
701
+ for d in test_devices:
702
+ try:
703
+ t = paddle.arange(10).to(device=wp.device_to_paddle(d))
704
+ paddle.assign(t + 1, t)
705
+ paddle_compatible_devices.append(d)
706
+ except Exception as e:
707
+ print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}")
708
+
709
+ if paddle_compatible_devices:
710
+ add_function_test(
711
+ TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices
712
+ )
713
+ add_function_test(
714
+ TestDLPack,
715
+ "test_dlpack_warp_to_paddle_v2",
716
+ test_dlpack_warp_to_paddle_v2,
717
+ devices=paddle_compatible_devices,
718
+ )
719
+ add_function_test(
720
+ TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices
721
+ )
722
+
723
+ except Exception as e:
724
+ print(f"Skipping Paddle DLPack tests due to exception: {e}")
725
+
726
+
727
+ if __name__ == "__main__":
728
+ wp.clear_kernel_cache()
729
+ unittest.main(verbosity=2)