warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,124 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing as mp
17
+ import unittest
18
+
19
+ import warp as wp
20
+ from warp.tests.unittest_utils import *
21
+
22
+
23
+ def test_ipc_get_memory_handle(test, device):
24
+ if device.is_ipc_supported is False:
25
+ test.skipTest(f"IPC is not supported on {device}")
26
+
27
+ with wp.ScopedMempool(device, False):
28
+ test_array = wp.full(10, value=42.0, dtype=wp.float32, device=device)
29
+ ipc_handle = test_array.ipc_handle()
30
+
31
+ test.assertNotEqual(ipc_handle, bytes(64), "IPC memory handle appears to be invalid")
32
+
33
+
34
+ def test_ipc_get_event_handle(test, device):
35
+ if device.is_ipc_supported is False:
36
+ test.skipTest(f"IPC is not supported on {device}")
37
+
38
+ e1 = wp.Event(device, interprocess=True)
39
+
40
+ ipc_handle = e1.ipc_handle()
41
+
42
+ test.assertNotEqual(ipc_handle, bytes(64), "IPC event handle appears to be invalid")
43
+
44
+
45
+ def test_ipc_event_missing_interprocess_flag(test, device):
46
+ if device.is_ipc_supported is False:
47
+ test.skipTest(f"IPC is not supported on {device}")
48
+
49
+ e1 = wp.Event(device, interprocess=False)
50
+
51
+ try:
52
+ capture = StdOutCapture()
53
+ capture.begin()
54
+ ipc_handle = e1.ipc_handle()
55
+ finally:
56
+ output = capture.end()
57
+
58
+ # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
59
+ if sys.platform != "win32":
60
+ test.assertRegex(output, r"Warp UserWarning: IPC event handle appears to be invalid.")
61
+
62
+
63
+ @wp.kernel
64
+ def multiply_by_two(a: wp.array(dtype=wp.float32)):
65
+ i = wp.tid()
66
+ a[i] = 2.0 * a[i]
67
+
68
+
69
+ def child_task(array_handle, dtype, shape, device, event_handle):
70
+ with wp.ScopedDevice(device):
71
+ ipc_array = wp.from_ipc_handle(array_handle, dtype, shape, device=device)
72
+ ipc_event = wp.event_from_ipc_handle(event_handle, device=device)
73
+ stream = wp.get_stream()
74
+ wp.launch(multiply_by_two, ipc_array.shape, inputs=[ipc_array])
75
+ stream.record_event(ipc_event)
76
+ stream.wait_event(ipc_event)
77
+ wp.synchronize_device()
78
+
79
+
80
+ def test_ipc_multiprocess_write(test, device):
81
+ if device.is_ipc_supported is False:
82
+ test.skipTest(f"IPC is not supported on {device}")
83
+
84
+ stream = wp.get_stream(device)
85
+ e1 = wp.Event(device, interprocess=True)
86
+
87
+ with wp.ScopedMempool(device, False):
88
+ test_array = wp.full(1024, value=42.0, dtype=wp.float32, device=device)
89
+ ipc_handle = test_array.ipc_handle()
90
+
91
+ wp.launch(multiply_by_two, test_array.shape, inputs=[test_array], device=device)
92
+
93
+ ctx = mp.get_context("spawn")
94
+
95
+ process = ctx.Process(
96
+ target=child_task, args=(ipc_handle, test_array.dtype, test_array.shape, str(device), e1.ipc_handle())
97
+ )
98
+
99
+ process.start()
100
+ process.join()
101
+
102
+ assert_np_equal(test_array.numpy(), np.full(test_array.shape, 168.0, dtype=np.float32))
103
+
104
+
105
+ cuda_devices = get_cuda_test_devices()
106
+
107
+
108
+ class TestIpc(unittest.TestCase):
109
+ pass
110
+
111
+
112
+ add_function_test(TestIpc, "test_ipc_get_memory_handle", test_ipc_get_memory_handle, devices=cuda_devices)
113
+ add_function_test(TestIpc, "test_ipc_get_event_handle", test_ipc_get_event_handle, devices=cuda_devices)
114
+ add_function_test(
115
+ TestIpc, "test_ipc_event_missing_interprocess_flag", test_ipc_event_missing_interprocess_flag, devices=cuda_devices
116
+ )
117
+ add_function_test(
118
+ TestIpc, "test_ipc_multiprocess_write", test_ipc_multiprocess_write, devices=cuda_devices, check_output=False
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ wp.clear_kernel_cache()
124
+ unittest.main(verbosity=2)
@@ -0,0 +1,233 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+
22
+ def get_device_pair_with_mempool_access_support():
23
+ devices = wp.get_cuda_devices()
24
+ for target_device in devices:
25
+ for peer_device in devices:
26
+ if target_device != peer_device:
27
+ if wp.is_mempool_access_supported(target_device, peer_device):
28
+ return (target_device, peer_device)
29
+ return None
30
+
31
+
32
+ def get_device_pair_without_mempool_access_support():
33
+ devices = wp.get_cuda_devices()
34
+ for target_device in devices:
35
+ for peer_device in devices:
36
+ if target_device != peer_device:
37
+ if not wp.is_mempool_access_supported(target_device, peer_device):
38
+ return (target_device, peer_device)
39
+ return None
40
+
41
+
42
+ def test_mempool_release_threshold(test, device):
43
+ device = wp.get_device(device)
44
+
45
+ assert device.is_mempool_supported
46
+
47
+ test.assertEqual(wp.is_mempool_supported(device), device.is_mempool_supported)
48
+
49
+ was_enabled = wp.is_mempool_enabled(device)
50
+
51
+ # toggle
52
+ wp.set_mempool_enabled(device, not was_enabled)
53
+ test.assertEqual(wp.is_mempool_enabled(device), not was_enabled)
54
+
55
+ # restore
56
+ wp.set_mempool_enabled(device, was_enabled)
57
+ test.assertEqual(wp.is_mempool_enabled(device), was_enabled)
58
+
59
+ saved_threshold = wp.get_mempool_release_threshold(device)
60
+
61
+ # set new absolute threshold
62
+ wp.set_mempool_release_threshold(device, 42000)
63
+ test.assertEqual(wp.get_mempool_release_threshold(device), 42000)
64
+
65
+ # set new fractional threshold
66
+ wp.set_mempool_release_threshold(device, 0.5)
67
+ test.assertEqual(wp.get_mempool_release_threshold(device), int(0.5 * device.total_memory))
68
+
69
+ # restore threshold
70
+ wp.set_mempool_release_threshold(device, saved_threshold)
71
+ test.assertEqual(wp.get_mempool_release_threshold(device), saved_threshold)
72
+
73
+
74
+ def test_mempool_usage_queries(test, device):
75
+ """Check API to query mempool memory usage."""
76
+
77
+ device = wp.get_device(device)
78
+ pre_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
79
+ pre_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
80
+
81
+ # Allocate a 1 MiB array
82
+ test_data = wp.empty(262144, dtype=wp.float32, device=device)
83
+ wp.synchronize_device(device)
84
+
85
+ # Query memory usage again
86
+ post_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
87
+ post_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
88
+
89
+ test.assertEqual(
90
+ post_alloc_mempool_usage_curr, pre_alloc_mempool_usage_curr + 1048576, "Memory usage did not increase by 1 MiB"
91
+ )
92
+ test.assertGreaterEqual(post_alloc_mempool_usage_high, 1048576, "High-water mark is not at least 1 MiB")
93
+
94
+ # Free the allocation
95
+ del test_data
96
+ wp.synchronize_device(device)
97
+
98
+ # Query memory usage
99
+ post_free_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
100
+ post_free_mempool_usage_high = wp.get_mempool_used_mem_high(device)
101
+
102
+ test.assertEqual(
103
+ post_free_mempool_usage_curr,
104
+ pre_alloc_mempool_usage_curr,
105
+ "Test didn't end with the same amount of used memory as the test started with.",
106
+ )
107
+ test.assertEqual(
108
+ post_free_mempool_usage_high, post_alloc_mempool_usage_high, "High-water mark should not change after free"
109
+ )
110
+
111
+
112
+ def test_mempool_exceptions(test, device):
113
+ device = wp.get_device(device)
114
+
115
+ assert not device.is_mempool_supported
116
+
117
+ if device.is_cuda:
118
+ expected_error = RuntimeError
119
+ else:
120
+ expected_error = ValueError
121
+
122
+ with test.assertRaises(expected_error):
123
+ wp.get_mempool_release_threshold(device)
124
+
125
+ with test.assertRaises(expected_error):
126
+ wp.set_mempool_release_threshold(device, 42000)
127
+
128
+
129
+ def test_mempool_access_self(test, device):
130
+ device = wp.get_device(device)
131
+
132
+ assert device.is_mempool_supported
133
+
134
+ # setting mempool access to self is a no-op
135
+ wp.set_mempool_access_enabled(device, device, True)
136
+ wp.set_mempool_access_enabled(device, device, False)
137
+
138
+ # should always be enabled
139
+ enabled = wp.is_mempool_access_enabled(device, device)
140
+ test.assertTrue(enabled)
141
+
142
+
143
+ @unittest.skipUnless(get_device_pair_with_mempool_access_support(), "Requires devices with mempool access support")
144
+ def test_mempool_access(test, _):
145
+ target_device, peer_device = get_device_pair_with_mempool_access_support()
146
+
147
+ was_enabled = wp.is_mempool_access_enabled(target_device, peer_device)
148
+
149
+ if was_enabled:
150
+ # try disabling
151
+ wp.set_mempool_access_enabled(target_device, peer_device, False)
152
+ is_enabled = wp.is_mempool_access_enabled(target_device, peer_device)
153
+ test.assertFalse(is_enabled)
154
+
155
+ # try re-enabling
156
+ wp.set_mempool_access_enabled(target_device, peer_device, True)
157
+ is_enabled = wp.is_mempool_access_enabled(target_device, peer_device)
158
+ test.assertTrue(is_enabled)
159
+ else:
160
+ # try enabling
161
+ wp.set_mempool_access_enabled(target_device, peer_device, True)
162
+ is_enabled = wp.is_mempool_access_enabled(target_device, peer_device)
163
+ test.assertTrue(is_enabled)
164
+
165
+ # try re-disabling
166
+ wp.set_mempool_access_enabled(target_device, peer_device, False)
167
+ is_enabled = wp.is_mempool_access_enabled(target_device, peer_device)
168
+ test.assertFalse(is_enabled)
169
+
170
+
171
+ @unittest.skipUnless(
172
+ get_device_pair_without_mempool_access_support(), "Requires devices without mempool access support"
173
+ )
174
+ def test_mempool_access_exceptions_unsupported(test, _):
175
+ # get a CUDA device pair without mempool access support
176
+ target_device, peer_device = get_device_pair_without_mempool_access_support()
177
+
178
+ # querying is ok, but must return False
179
+ test.assertFalse(wp.is_mempool_access_enabled(target_device, peer_device))
180
+
181
+ # enabling should raise RuntimeError
182
+ with test.assertRaises(RuntimeError):
183
+ wp.set_mempool_access_enabled(target_device, peer_device, True)
184
+
185
+ # disabling should not raise an error
186
+ wp.set_mempool_access_enabled(target_device, peer_device, False)
187
+
188
+
189
+ @unittest.skipUnless(wp.is_cpu_available() and wp.is_cuda_available(), "Requires both CUDA and CPU devices")
190
+ def test_mempool_access_exceptions_cpu(test, _):
191
+ # querying is ok, but must return False
192
+ test.assertFalse(wp.is_mempool_access_enabled("cuda:0", "cpu"))
193
+ test.assertFalse(wp.is_mempool_access_enabled("cpu", "cuda:0"))
194
+
195
+ # enabling should raise ValueError
196
+ with test.assertRaises(ValueError):
197
+ wp.set_mempool_access_enabled("cpu", "cuda:0", True)
198
+ with test.assertRaises(ValueError):
199
+ wp.set_mempool_access_enabled("cuda:0", "cpu", True)
200
+
201
+ # disabling should not raise an error
202
+ wp.set_mempool_access_enabled("cpu", "cuda:0", False)
203
+ wp.set_mempool_access_enabled("cuda:0", "cpu", False)
204
+
205
+
206
+ class TestMempool(unittest.TestCase):
207
+ pass
208
+
209
+
210
+ devices_with_mempools = [d for d in get_test_devices() if d.is_mempool_supported]
211
+ devices_without_mempools = [d for d in get_test_devices() if not d.is_mempool_supported]
212
+
213
+ # test devices with mempool support
214
+ add_function_test(
215
+ TestMempool, "test_mempool_release_threshold", test_mempool_release_threshold, devices=devices_with_mempools
216
+ )
217
+ add_function_test(TestMempool, "test_mempool_usage_queries", test_mempool_usage_queries, devices=devices_with_mempools)
218
+ add_function_test(TestMempool, "test_mempool_access_self", test_mempool_access_self, devices=devices_with_mempools)
219
+
220
+ # test devices without mempool support
221
+ add_function_test(TestMempool, "test_mempool_exceptions", test_mempool_exceptions, devices=devices_without_mempools)
222
+
223
+ # mempool access tests
224
+ add_function_test(TestMempool, "test_mempool_access", test_mempool_access)
225
+
226
+ # mempool access exceptions
227
+ add_function_test(TestMempool, "test_mempool_access_exceptions_unsupported", test_mempool_access_exceptions_unsupported)
228
+ add_function_test(TestMempool, "test_mempool_access_exceptions_cpu", test_mempool_access_exceptions_cpu)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ wp.clear_kernel_cache()
233
+ unittest.main(verbosity=2)
@@ -0,0 +1,169 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+ from warp.utils import check_p2p
23
+
24
+
25
+ @wp.kernel
26
+ def inc(a: wp.array(dtype=float)):
27
+ tid = wp.tid()
28
+ a[tid] = a[tid] + 1.0
29
+
30
+
31
+ @wp.kernel
32
+ def arange(start: int, step: int, a: wp.array(dtype=int)):
33
+ tid = wp.tid()
34
+ a[tid] = start + step * tid
35
+
36
+
37
+ class TestMultiGPU(unittest.TestCase):
38
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
39
+ def test_multigpu_set_device(self):
40
+ # save default device
41
+ saved_device = wp.get_device()
42
+
43
+ n = 32
44
+
45
+ wp.set_device("cuda:0")
46
+ a0 = wp.empty(n, dtype=int)
47
+ wp.launch(arange, dim=a0.size, inputs=[0, 1, a0])
48
+
49
+ wp.set_device("cuda:1")
50
+ a1 = wp.empty(n, dtype=int)
51
+ wp.launch(arange, dim=a1.size, inputs=[0, 1, a1])
52
+
53
+ # restore default device
54
+ wp.set_device(saved_device)
55
+
56
+ assert a0.device == "cuda:0"
57
+ assert a1.device == "cuda:1"
58
+
59
+ expected = np.arange(n, dtype=int)
60
+
61
+ assert_np_equal(a0.numpy(), expected)
62
+ assert_np_equal(a1.numpy(), expected)
63
+
64
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
65
+ def test_multigpu_scoped_device(self):
66
+ n = 32
67
+
68
+ with wp.ScopedDevice("cuda:0"):
69
+ a0 = wp.empty(n, dtype=int)
70
+ wp.launch(arange, dim=a0.size, inputs=[0, 1, a0])
71
+
72
+ with wp.ScopedDevice("cuda:1"):
73
+ a1 = wp.empty(n, dtype=int)
74
+ wp.launch(arange, dim=a1.size, inputs=[0, 1, a1])
75
+
76
+ assert a0.device == "cuda:0"
77
+ assert a1.device == "cuda:1"
78
+
79
+ expected = np.arange(n, dtype=int)
80
+
81
+ assert_np_equal(a0.numpy(), expected)
82
+ assert_np_equal(a1.numpy(), expected)
83
+
84
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
85
+ def test_multigpu_nesting(self):
86
+ initial_device = wp.get_device()
87
+ initial_cuda_device = wp.get_cuda_device()
88
+
89
+ with wp.ScopedDevice("cuda:1"):
90
+ assert wp.get_device() == "cuda:1"
91
+ assert wp.get_cuda_device() == "cuda:1"
92
+
93
+ with wp.ScopedDevice("cuda:0"):
94
+ assert wp.get_device() == "cuda:0"
95
+ assert wp.get_cuda_device() == "cuda:0"
96
+
97
+ with wp.ScopedDevice("cpu"):
98
+ assert wp.get_device() == "cpu"
99
+ assert wp.get_cuda_device() == "cuda:0"
100
+
101
+ wp.set_device("cuda:1")
102
+
103
+ assert wp.get_device() == "cuda:1"
104
+ assert wp.get_cuda_device() == "cuda:1"
105
+
106
+ assert wp.get_device() == "cuda:0"
107
+ assert wp.get_cuda_device() == "cuda:0"
108
+
109
+ assert wp.get_device() == "cuda:1"
110
+ assert wp.get_cuda_device() == "cuda:1"
111
+
112
+ assert wp.get_device() == initial_device
113
+ assert wp.get_cuda_device() == initial_cuda_device
114
+
115
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
116
+ @unittest.skipUnless(check_p2p(), "Peer-to-Peer transfers not supported")
117
+ def test_multigpu_pingpong(self):
118
+ n = 1024 * 1024
119
+
120
+ a0 = wp.zeros(n, dtype=float, device="cuda:0")
121
+ a1 = wp.zeros(n, dtype=float, device="cuda:1")
122
+
123
+ iters = 10
124
+
125
+ for _ in range(iters):
126
+ wp.launch(inc, dim=a0.size, inputs=[a0], device=a0.device)
127
+ wp.synchronize_device(a0.device)
128
+ wp.copy(a1, a0)
129
+
130
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=a1.device)
131
+ wp.synchronize_device(a1.device)
132
+ wp.copy(a0, a1)
133
+
134
+ expected = np.full(n, iters * 2, dtype=np.float32)
135
+
136
+ assert_np_equal(a0.numpy(), expected)
137
+ assert_np_equal(a1.numpy(), expected)
138
+
139
+ @unittest.skipUnless(len(wp.get_cuda_devices()) > 1, "Requires at least two CUDA devices")
140
+ @unittest.skipUnless(check_p2p(), "Peer-to-Peer transfers not supported")
141
+ def test_multigpu_pingpong_streams(self):
142
+ n = 1024 * 1024
143
+
144
+ a0 = wp.zeros(n, dtype=float, device="cuda:0")
145
+ a1 = wp.zeros(n, dtype=float, device="cuda:1")
146
+
147
+ stream0 = wp.get_stream("cuda:0")
148
+ stream1 = wp.get_stream("cuda:1")
149
+
150
+ iters = 10
151
+
152
+ for _ in range(iters):
153
+ wp.launch(inc, dim=a0.size, inputs=[a0], stream=stream0)
154
+ stream1.wait_stream(stream0)
155
+ wp.copy(a1, a0, stream=stream1)
156
+
157
+ wp.launch(inc, dim=a1.size, inputs=[a1], stream=stream1)
158
+ stream0.wait_stream(stream1)
159
+ wp.copy(a0, a1, stream=stream0)
160
+
161
+ expected = np.full(n, iters * 2, dtype=np.float32)
162
+
163
+ assert_np_equal(a0.numpy(), expected)
164
+ assert_np_equal(a1.numpy(), expected)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ wp.clear_kernel_cache()
169
+ unittest.main(verbosity=2, failfast=False)
@@ -0,0 +1,139 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+
22
+ def get_device_pair_with_peer_access_support():
23
+ devices = wp.get_cuda_devices()
24
+ for target_device in devices:
25
+ for peer_device in devices:
26
+ if target_device != peer_device:
27
+ if wp.is_peer_access_supported(target_device, peer_device):
28
+ return (target_device, peer_device)
29
+ return None
30
+
31
+
32
+ def get_device_pair_without_peer_access_support():
33
+ devices = wp.get_cuda_devices()
34
+ for target_device in devices:
35
+ for peer_device in devices:
36
+ if target_device != peer_device:
37
+ if not wp.is_peer_access_supported(target_device, peer_device):
38
+ return (target_device, peer_device)
39
+ return None
40
+
41
+
42
+ def test_peer_access_self(test, device):
43
+ device = wp.get_device(device)
44
+
45
+ assert device.is_cuda
46
+
47
+ # device can access self
48
+ can_access = wp.is_peer_access_supported(device, device)
49
+ test.assertTrue(can_access)
50
+
51
+ # setting peer access to self is a no-op
52
+ wp.set_peer_access_enabled(device, device, True)
53
+ wp.set_peer_access_enabled(device, device, False)
54
+
55
+ # should always be enabled
56
+ enabled = wp.is_peer_access_enabled(device, device)
57
+ test.assertTrue(enabled)
58
+
59
+
60
+ @unittest.skipUnless(get_device_pair_with_peer_access_support(), "Requires devices with peer access support")
61
+ def test_peer_access(test, _):
62
+ target_device, peer_device = get_device_pair_with_peer_access_support()
63
+
64
+ was_enabled = wp.is_peer_access_enabled(target_device, peer_device)
65
+
66
+ if was_enabled:
67
+ # try disabling
68
+ wp.set_peer_access_enabled(target_device, peer_device, False)
69
+ is_enabled = wp.is_peer_access_enabled(target_device, peer_device)
70
+ test.assertFalse(is_enabled)
71
+
72
+ # try re-enabling
73
+ wp.set_peer_access_enabled(target_device, peer_device, True)
74
+ is_enabled = wp.is_peer_access_enabled(target_device, peer_device)
75
+ test.assertTrue(is_enabled)
76
+ else:
77
+ # try enabling
78
+ wp.set_peer_access_enabled(target_device, peer_device, True)
79
+ is_enabled = wp.is_peer_access_enabled(target_device, peer_device)
80
+ test.assertTrue(is_enabled)
81
+
82
+ # try re-disabling
83
+ wp.set_peer_access_enabled(target_device, peer_device, False)
84
+ is_enabled = wp.is_peer_access_enabled(target_device, peer_device)
85
+ test.assertFalse(is_enabled)
86
+
87
+
88
+ @unittest.skipUnless(get_device_pair_without_peer_access_support(), "Requires devices without peer access support")
89
+ def test_peer_access_exceptions_unsupported(test, _):
90
+ # get a CUDA device pair without peer access support
91
+ target_device, peer_device = get_device_pair_without_peer_access_support()
92
+
93
+ # querying is ok, but must return False
94
+ test.assertFalse(wp.is_peer_access_enabled(target_device, peer_device))
95
+
96
+ # enabling should raise RuntimeError
97
+ with test.assertRaises(RuntimeError):
98
+ wp.set_peer_access_enabled(target_device, peer_device, True)
99
+
100
+ # disabling should not raise an error
101
+ wp.set_peer_access_enabled(target_device, peer_device, False)
102
+
103
+
104
+ @unittest.skipUnless(wp.is_cpu_available() and wp.is_cuda_available(), "Requires both CUDA and CPU devices")
105
+ def test_peer_access_exceptions_cpu(test, _):
106
+ # querying is ok, but must return False
107
+ test.assertFalse(wp.is_peer_access_enabled("cuda:0", "cpu"))
108
+ test.assertFalse(wp.is_peer_access_enabled("cpu", "cuda:0"))
109
+
110
+ # enabling should raise ValueError
111
+ with test.assertRaises(ValueError):
112
+ wp.set_peer_access_enabled("cpu", "cuda:0", True)
113
+ with test.assertRaises(ValueError):
114
+ wp.set_peer_access_enabled("cuda:0", "cpu", True)
115
+
116
+ # disabling should not raise an error
117
+ wp.set_peer_access_enabled("cpu", "cuda:0", False)
118
+ wp.set_peer_access_enabled("cuda:0", "cpu", False)
119
+
120
+
121
+ class TestPeer(unittest.TestCase):
122
+ pass
123
+
124
+
125
+ cuda_test_devices = get_cuda_test_devices()
126
+
127
+ add_function_test(TestPeer, "test_peer_access_self", test_peer_access_self, devices=cuda_test_devices)
128
+
129
+ # peer access tests
130
+ add_function_test(TestPeer, "test_peer_access", test_peer_access)
131
+
132
+ # peer access exceptions
133
+ add_function_test(TestPeer, "test_peer_access_exceptions_unsupported", test_peer_access_exceptions_unsupported)
134
+ add_function_test(TestPeer, "test_peer_access_exceptions_cpu", test_peer_access_exceptions_cpu)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ wp.clear_kernel_cache()
139
+ unittest.main(verbosity=2)