warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,730 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import os
18
+ import unittest
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.tests.unittest_utils import *
24
+
25
+ N = 1024 * 1024
26
+
27
+
28
+ def _jax_version():
29
+ try:
30
+ import jax
31
+
32
+ return jax.__version_info__
33
+ except (ImportError, AttributeError):
34
+ return (0, 0, 0)
35
+
36
+
37
+ @wp.kernel
38
+ def inc(a: wp.array(dtype=float)):
39
+ tid = wp.tid()
40
+ a[tid] = a[tid] + 1.0
41
+
42
+
43
+ def test_dlpack_warp_to_warp(test, device):
44
+ a1 = wp.array(data=np.arange(N, dtype=np.float32), device=device)
45
+
46
+ a2 = wp.from_dlpack(wp.to_dlpack(a1))
47
+
48
+ test.assertEqual(a1.ptr, a2.ptr)
49
+ test.assertEqual(a1.device, a2.device)
50
+ test.assertEqual(a1.dtype, a2.dtype)
51
+ test.assertEqual(a1.shape, a2.shape)
52
+ test.assertEqual(a1.strides, a2.strides)
53
+
54
+ assert_np_equal(a1.numpy(), a2.numpy())
55
+
56
+ wp.launch(inc, dim=a2.size, inputs=[a2], device=device)
57
+
58
+ assert_np_equal(a1.numpy(), a2.numpy())
59
+
60
+
61
+ def test_dlpack_dtypes_and_shapes(test, device):
62
+ # automatically determine scalar dtype
63
+ def wrap_scalar_tensor_implicit(dtype):
64
+ a1 = wp.zeros(N, dtype=dtype, device=device)
65
+ a2 = wp.from_dlpack(wp.to_dlpack(a1))
66
+
67
+ test.assertEqual(a1.ptr, a2.ptr)
68
+ test.assertEqual(a1.device, a2.device)
69
+ test.assertEqual(a1.dtype, a2.dtype)
70
+ test.assertEqual(a1.shape, a2.shape)
71
+ test.assertEqual(a1.strides, a2.strides)
72
+
73
+ # explicitly specify scalar dtype
74
+ def wrap_scalar_tensor_explicit(dtype, target_dtype):
75
+ a1 = wp.zeros(N, dtype=dtype, device=device)
76
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
77
+
78
+ test.assertEqual(a1.ptr, a2.ptr)
79
+ test.assertEqual(a1.device, a2.device)
80
+ test.assertEqual(a1.dtype, dtype)
81
+ test.assertEqual(a2.dtype, target_dtype)
82
+ test.assertEqual(a1.shape, a2.shape)
83
+ test.assertEqual(a1.strides, a2.strides)
84
+
85
+ # convert vector arrays to scalar arrays
86
+ def wrap_vector_to_scalar_tensor(vec_dtype):
87
+ scalar_type = vec_dtype._wp_scalar_type_
88
+ scalar_size = ctypes.sizeof(vec_dtype._type_)
89
+
90
+ a1 = wp.zeros(N, dtype=vec_dtype, device=device)
91
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
92
+
93
+ test.assertEqual(a1.ptr, a2.ptr)
94
+ test.assertEqual(a1.device, a2.device)
95
+ test.assertEqual(a2.ndim, a1.ndim + 1)
96
+ test.assertEqual(a1.dtype, vec_dtype)
97
+ test.assertEqual(a2.dtype, scalar_type)
98
+ test.assertEqual(a2.shape, (*a1.shape, vec_dtype._length_))
99
+ test.assertEqual(a2.strides, (*a1.strides, scalar_size))
100
+
101
+ # convert scalar arrays to vector arrays
102
+ def wrap_scalar_to_vector_tensor(vec_dtype):
103
+ scalar_type = vec_dtype._wp_scalar_type_
104
+ scalar_size = ctypes.sizeof(vec_dtype._type_)
105
+
106
+ a1 = wp.zeros((N, vec_dtype._length_), dtype=scalar_type, device=device)
107
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
108
+
109
+ test.assertEqual(a1.ptr, a2.ptr)
110
+ test.assertEqual(a1.device, a2.device)
111
+ test.assertEqual(a2.ndim, a1.ndim - 1)
112
+ test.assertEqual(a1.dtype, scalar_type)
113
+ test.assertEqual(a2.dtype, vec_dtype)
114
+ test.assertEqual(a1.shape, (*a2.shape, vec_dtype._length_))
115
+ test.assertEqual(a1.strides, (*a2.strides, scalar_size))
116
+
117
+ # convert matrix arrays to scalar arrays
118
+ def wrap_matrix_to_scalar_tensor(mat_dtype):
119
+ scalar_type = mat_dtype._wp_scalar_type_
120
+ scalar_size = ctypes.sizeof(mat_dtype._type_)
121
+
122
+ a1 = wp.zeros(N, dtype=mat_dtype, device=device)
123
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
124
+
125
+ test.assertEqual(a1.ptr, a2.ptr)
126
+ test.assertEqual(a1.device, a2.device)
127
+ test.assertEqual(a2.ndim, a1.ndim + 2)
128
+ test.assertEqual(a1.dtype, mat_dtype)
129
+ test.assertEqual(a2.dtype, scalar_type)
130
+ test.assertEqual(a2.shape, (*a1.shape, *mat_dtype._shape_))
131
+ test.assertEqual(a2.strides, (*a1.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
132
+
133
+ # convert scalar arrays to matrix arrays
134
+ def wrap_scalar_to_matrix_tensor(mat_dtype):
135
+ scalar_type = mat_dtype._wp_scalar_type_
136
+ scalar_size = ctypes.sizeof(mat_dtype._type_)
137
+
138
+ a1 = wp.zeros((N, *mat_dtype._shape_), dtype=scalar_type, device=device)
139
+ a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
140
+
141
+ test.assertEqual(a1.ptr, a2.ptr)
142
+ test.assertEqual(a1.device, a2.device)
143
+ test.assertEqual(a2.ndim, a1.ndim - 2)
144
+ test.assertEqual(a1.dtype, scalar_type)
145
+ test.assertEqual(a2.dtype, mat_dtype)
146
+ test.assertEqual(a1.shape, (*a2.shape, *mat_dtype._shape_))
147
+ test.assertEqual(a1.strides, (*a2.strides, scalar_size * mat_dtype._shape_[1], scalar_size))
148
+
149
+ for t in wp._src.types.scalar_types:
150
+ wrap_scalar_tensor_implicit(t)
151
+
152
+ for t in wp._src.types.scalar_types:
153
+ wrap_scalar_tensor_explicit(t, t)
154
+
155
+ # test signed/unsigned conversions
156
+ wrap_scalar_tensor_explicit(wp.int8, wp.uint8)
157
+ wrap_scalar_tensor_explicit(wp.uint8, wp.int8)
158
+ wrap_scalar_tensor_explicit(wp.int16, wp.uint16)
159
+ wrap_scalar_tensor_explicit(wp.uint16, wp.int16)
160
+ wrap_scalar_tensor_explicit(wp.int32, wp.uint32)
161
+ wrap_scalar_tensor_explicit(wp.uint32, wp.int32)
162
+ wrap_scalar_tensor_explicit(wp.int64, wp.uint64)
163
+ wrap_scalar_tensor_explicit(wp.uint64, wp.int64)
164
+
165
+ vec_types = []
166
+ for t in wp._src.types.scalar_types:
167
+ for vec_len in [2, 3, 4, 5]:
168
+ vec_types.append(wp._src.types.vector(vec_len, t))
169
+
170
+ vec_types.append(wp.quath)
171
+ vec_types.append(wp.quatf)
172
+ vec_types.append(wp.quatd)
173
+ vec_types.append(wp.transformh)
174
+ vec_types.append(wp.transformf)
175
+ vec_types.append(wp.transformd)
176
+ vec_types.append(wp.spatial_vectorh)
177
+ vec_types.append(wp.spatial_vectorf)
178
+ vec_types.append(wp.spatial_vectord)
179
+
180
+ for vec_type in vec_types:
181
+ wrap_vector_to_scalar_tensor(vec_type)
182
+ wrap_scalar_to_vector_tensor(vec_type)
183
+
184
+ mat_shapes = [(2, 2), (3, 3), (4, 4), (5, 5), (2, 3), (3, 2), (3, 4), (4, 3)]
185
+ mat_types = []
186
+ for t in wp._src.types.scalar_types:
187
+ for mat_shape in mat_shapes:
188
+ mat_types.append(wp._src.types.matrix(mat_shape, t))
189
+
190
+ mat_types.append(wp.spatial_matrixh)
191
+ mat_types.append(wp.spatial_matrixf)
192
+ mat_types.append(wp.spatial_matrixd)
193
+
194
+ for mat_type in mat_types:
195
+ wrap_matrix_to_scalar_tensor(mat_type)
196
+ wrap_scalar_to_matrix_tensor(mat_type)
197
+
198
+
199
+ def test_dlpack_stream_arg(test, device):
200
+ # test valid range for the stream argument to array.__dlpack__()
201
+
202
+ data = np.arange(10)
203
+
204
+ def check_result(capsule):
205
+ result = wp._src.dlpack._from_dlpack(capsule)
206
+ assert_np_equal(result.numpy(), data)
207
+
208
+ with wp.ScopedDevice(device):
209
+ a = wp.array(data=data)
210
+
211
+ # stream arguments supported for all devices
212
+ check_result(a.__dlpack__())
213
+ check_result(a.__dlpack__(stream=None))
214
+ check_result(a.__dlpack__(stream=-1))
215
+
216
+ # device-specific stream arguments
217
+ if device.is_cuda:
218
+ check_result(a.__dlpack__(stream=0)) # default stream
219
+ check_result(a.__dlpack__(stream=1)) # legacy default stream
220
+ check_result(a.__dlpack__(stream=2)) # per thread default stream
221
+
222
+ # custom stream
223
+ stream = wp.Stream(device)
224
+ check_result(a.__dlpack__(stream=stream.cuda_stream))
225
+
226
+ # unsupported stream arguments
227
+ expected_error = r"DLPack stream must None or an integer >= -1"
228
+ with test.assertRaisesRegex(TypeError, expected_error):
229
+ check_result(a.__dlpack__(stream=-2))
230
+ with test.assertRaisesRegex(TypeError, expected_error):
231
+ check_result(a.__dlpack__(stream="nope"))
232
+ else:
233
+ expected_error = r"DLPack stream must be None or -1 for CPU device"
234
+
235
+ with test.assertRaisesRegex(TypeError, expected_error):
236
+ check_result(a.__dlpack__(stream=0))
237
+ with test.assertRaisesRegex(TypeError, expected_error):
238
+ check_result(a.__dlpack__(stream=1))
239
+ with test.assertRaisesRegex(TypeError, expected_error):
240
+ check_result(a.__dlpack__(stream=2))
241
+ with test.assertRaisesRegex(TypeError, expected_error):
242
+ check_result(a.__dlpack__(stream=1742))
243
+
244
+ with test.assertRaisesRegex(TypeError, expected_error):
245
+ check_result(a.__dlpack__(stream=-2))
246
+ with test.assertRaisesRegex(TypeError, expected_error):
247
+ check_result(a.__dlpack__(stream="nope"))
248
+
249
+
250
+ def test_dlpack_warp_to_torch(test, device):
251
+ import torch.utils.dlpack
252
+
253
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
254
+
255
+ t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
256
+
257
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
258
+
259
+ test.assertEqual(a.ptr, t.data_ptr())
260
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
261
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
262
+ test.assertEqual(a.shape, tuple(t.shape))
263
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
264
+
265
+ assert_np_equal(a.numpy(), t.cpu().numpy())
266
+
267
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
268
+
269
+ assert_np_equal(a.numpy(), t.cpu().numpy())
270
+
271
+ t += 1
272
+
273
+ assert_np_equal(a.numpy(), t.cpu().numpy())
274
+
275
+
276
+ def test_dlpack_warp_to_torch_v2(test, device):
277
+ # same as original test, but uses newer __dlpack__() method
278
+
279
+ import torch.utils.dlpack
280
+
281
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
282
+
283
+ # pass the array directly
284
+ t = torch.utils.dlpack.from_dlpack(a)
285
+
286
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
287
+
288
+ test.assertEqual(a.ptr, t.data_ptr())
289
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
290
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
291
+ test.assertEqual(a.shape, tuple(t.shape))
292
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
293
+
294
+ assert_np_equal(a.numpy(), t.cpu().numpy())
295
+
296
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
297
+
298
+ assert_np_equal(a.numpy(), t.cpu().numpy())
299
+
300
+ t += 1
301
+
302
+ assert_np_equal(a.numpy(), t.cpu().numpy())
303
+
304
+
305
+ def test_dlpack_torch_to_warp(test, device):
306
+ import torch
307
+ import torch.utils.dlpack
308
+
309
+ t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
310
+
311
+ a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
312
+
313
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
314
+
315
+ test.assertEqual(a.ptr, t.data_ptr())
316
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
317
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
318
+ test.assertEqual(a.shape, tuple(t.shape))
319
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
320
+
321
+ assert_np_equal(a.numpy(), t.cpu().numpy())
322
+
323
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
324
+
325
+ assert_np_equal(a.numpy(), t.cpu().numpy())
326
+
327
+ t += 1
328
+
329
+ assert_np_equal(a.numpy(), t.cpu().numpy())
330
+
331
+
332
+ def test_dlpack_torch_to_warp_v2(test, device):
333
+ # same as original test, but uses newer __dlpack__() method
334
+
335
+ import torch
336
+
337
+ with torch.device(wp.device_to_torch(device)):
338
+ t = torch.arange(N, dtype=torch.float32)
339
+
340
+ # pass tensor directly
341
+ a = wp.from_dlpack(t)
342
+
343
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
344
+
345
+ test.assertEqual(a.ptr, t.data_ptr())
346
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
347
+ test.assertEqual(a.dtype, wp.dtype_from_torch(t.dtype))
348
+ test.assertEqual(a.shape, tuple(t.shape))
349
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
350
+
351
+ assert_np_equal(a.numpy(), t.cpu().numpy())
352
+
353
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
354
+
355
+ assert_np_equal(a.numpy(), t.cpu().numpy())
356
+
357
+ t += 1
358
+
359
+ assert_np_equal(a.numpy(), t.cpu().numpy())
360
+
361
+
362
+ def test_dlpack_paddle_to_warp(test, device):
363
+ import paddle
364
+ import paddle.utils.dlpack
365
+
366
+ t = paddle.arange(N, dtype=paddle.float32).to(device=wp.device_to_paddle(device))
367
+
368
+ # paddle do not implement __dlpack__ yet, so only test to_dlpack here
369
+ a = wp.from_dlpack(paddle.utils.dlpack.to_dlpack(t))
370
+
371
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
372
+
373
+ test.assertEqual(a.ptr, t.data_ptr())
374
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
375
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
376
+ test.assertEqual(a.shape, tuple(t.shape))
377
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
378
+
379
+ assert_np_equal(a.numpy(), t.numpy())
380
+
381
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
382
+
383
+ assert_np_equal(a.numpy(), t.numpy())
384
+
385
+ paddle.assign(t + 1, t)
386
+
387
+ assert_np_equal(a.numpy(), t.numpy())
388
+
389
+
390
+ def test_dlpack_warp_to_jax(test, device):
391
+ import jax
392
+ import jax.dlpack
393
+ import jax.numpy as jnp
394
+
395
+ cpu_device = jax.devices("cpu")[0]
396
+
397
+ # Create a numpy array from a JAX array to respect XLA alignment needs
398
+ with jax.default_device(cpu_device):
399
+ x_jax = jnp.arange(N, dtype=jnp.float32)
400
+ x_numpy = np.asarray(x_jax)
401
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
402
+
403
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
404
+
405
+ if device.is_cpu:
406
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
407
+
408
+ # use generic dlpack conversion
409
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
410
+
411
+ # use jax wrapper
412
+ j2 = wp.to_jax(a)
413
+
414
+ test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
415
+ test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
416
+ test.assertEqual(a.device, wp.device_from_jax(next(iter(j1.devices()))))
417
+ test.assertEqual(a.device, wp.device_from_jax(next(iter(j2.devices()))))
418
+ test.assertEqual(a.shape, j1.shape)
419
+ test.assertEqual(a.shape, j2.shape)
420
+
421
+ assert_np_equal(a.numpy(), np.asarray(j1))
422
+ assert_np_equal(a.numpy(), np.asarray(j2))
423
+
424
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
425
+ wp.synchronize_device(device)
426
+
427
+ # HACK? Run a no-op operation so that Jax flags the arrays as dirty
428
+ # and gets the latest values, which were modified by Warp.
429
+ j1 += 0
430
+ j2 += 0
431
+
432
+ assert_np_equal(a.numpy(), np.asarray(j1))
433
+ assert_np_equal(a.numpy(), np.asarray(j2))
434
+
435
+
436
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
437
+ def test_dlpack_warp_to_jax_v2(test, device):
438
+ # same as original test, but uses newer __dlpack__() method
439
+ import jax
440
+ import jax.dlpack
441
+ import jax.numpy as jnp
442
+
443
+ cpu_device = jax.devices("cpu")[0]
444
+
445
+ # Create a numpy array from a JAX array to respect XLA alignment needs
446
+ with jax.default_device(cpu_device):
447
+ x_jax = jnp.arange(N, dtype=jnp.float32)
448
+ x_numpy = np.asarray(x_jax)
449
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
450
+
451
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
452
+
453
+ if device.is_cpu:
454
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
455
+
456
+ # pass warp array directly
457
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
458
+
459
+ # use jax wrapper
460
+ j2 = wp.to_jax(a)
461
+
462
+ test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
463
+ test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
464
+ test.assertEqual(a.device, wp.device_from_jax(next(iter(j1.devices()))))
465
+ test.assertEqual(a.device, wp.device_from_jax(next(iter(j2.devices()))))
466
+ test.assertEqual(a.shape, j1.shape)
467
+ test.assertEqual(a.shape, j2.shape)
468
+
469
+ assert_np_equal(a.numpy(), np.asarray(j1))
470
+ assert_np_equal(a.numpy(), np.asarray(j2))
471
+
472
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
473
+ wp.synchronize_device(device)
474
+
475
+ # HACK? Run a no-op operation so that Jax flags the arrays as dirty
476
+ # and gets the latest values, which were modified by Warp.
477
+ j1 += 0
478
+ j2 += 0
479
+
480
+ assert_np_equal(a.numpy(), np.asarray(j1))
481
+ assert_np_equal(a.numpy(), np.asarray(j2))
482
+
483
+
484
+ def test_dlpack_warp_to_paddle(test, device):
485
+ import paddle.utils.dlpack
486
+
487
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
488
+
489
+ t = paddle.utils.dlpack.from_dlpack(wp.to_dlpack(a))
490
+
491
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
492
+
493
+ test.assertEqual(a.ptr, t.data_ptr())
494
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
495
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
496
+ test.assertEqual(a.shape, tuple(t.shape))
497
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
498
+
499
+ assert_np_equal(a.numpy(), t.cpu().numpy())
500
+
501
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
502
+
503
+ assert_np_equal(a.numpy(), t.cpu().numpy())
504
+
505
+ paddle.assign(t + 1, t)
506
+
507
+ assert_np_equal(a.numpy(), t.cpu().numpy())
508
+
509
+
510
+ def test_dlpack_warp_to_paddle_v2(test, device):
511
+ # same as original test, but uses newer __dlpack__() method
512
+
513
+ import paddle.utils.dlpack
514
+
515
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
516
+
517
+ # pass the array directly
518
+ t = paddle.utils.dlpack.from_dlpack(a)
519
+
520
+ item_size = wp._src.types.type_size_in_bytes(a.dtype)
521
+
522
+ test.assertEqual(a.ptr, t.data_ptr())
523
+ test.assertEqual(a.device, wp.device_from_paddle(t.place))
524
+ test.assertEqual(a.dtype, wp.dtype_from_paddle(t.dtype))
525
+ test.assertEqual(a.shape, tuple(t.shape))
526
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.strides))
527
+
528
+ assert_np_equal(a.numpy(), t.numpy())
529
+
530
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
531
+
532
+ assert_np_equal(a.numpy(), t.numpy())
533
+
534
+ paddle.assign(t + 1, t)
535
+
536
+ assert_np_equal(a.numpy(), t.numpy())
537
+
538
+
539
+ def test_dlpack_jax_to_warp(test, device):
540
+ import jax
541
+ import jax.dlpack
542
+
543
+ with jax.default_device(wp.device_to_jax(device)):
544
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
545
+
546
+ # use generic dlpack conversion
547
+ a1 = wp.from_dlpack(j)
548
+
549
+ # use jax wrapper
550
+ a2 = wp.from_jax(j)
551
+
552
+ test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
553
+ test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
554
+ test.assertEqual(a1.device, wp.device_from_jax(next(iter(j.devices()))))
555
+ test.assertEqual(a2.device, wp.device_from_jax(next(iter(j.devices()))))
556
+ test.assertEqual(a1.shape, j.shape)
557
+ test.assertEqual(a2.shape, j.shape)
558
+
559
+ assert_np_equal(a1.numpy(), np.asarray(j))
560
+ assert_np_equal(a2.numpy(), np.asarray(j))
561
+
562
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
563
+ wp.synchronize_device(device)
564
+
565
+ # HACK? Run a no-op operation so that Jax flags the array as dirty
566
+ # and gets the latest values, which were modified by Warp.
567
+ j += 0
568
+
569
+ assert_np_equal(a1.numpy(), np.asarray(j))
570
+ assert_np_equal(a2.numpy(), np.asarray(j))
571
+
572
+
573
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
574
+ def test_dlpack_jax_to_warp_v2(test, device):
575
+ # same as original test, but uses newer __dlpack__() method
576
+
577
+ import jax
578
+
579
+ with jax.default_device(wp.device_to_jax(device)):
580
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
581
+
582
+ # pass jax array directly
583
+ a1 = wp.from_dlpack(j)
584
+
585
+ # use jax wrapper
586
+ a2 = wp.from_jax(j)
587
+
588
+ test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
589
+ test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
590
+ test.assertEqual(a1.device, wp.device_from_jax(next(iter(j.devices()))))
591
+ test.assertEqual(a2.device, wp.device_from_jax(next(iter(j.devices()))))
592
+ test.assertEqual(a1.shape, j.shape)
593
+ test.assertEqual(a2.shape, j.shape)
594
+
595
+ assert_np_equal(a1.numpy(), np.asarray(j))
596
+ assert_np_equal(a2.numpy(), np.asarray(j))
597
+
598
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
599
+ wp.synchronize_device(device)
600
+
601
+ # HACK? Run a no-op operation so that Jax flags the array as dirty
602
+ # and gets the latest values, which were modified by Warp.
603
+ j += 0
604
+
605
+ assert_np_equal(a1.numpy(), np.asarray(j))
606
+ assert_np_equal(a2.numpy(), np.asarray(j))
607
+
608
+
609
+ class TestDLPack(unittest.TestCase):
610
+ pass
611
+
612
+
613
+ devices = get_test_devices()
614
+
615
+ add_function_test(TestDLPack, "test_dlpack_warp_to_warp", test_dlpack_warp_to_warp, devices=devices)
616
+ add_function_test(TestDLPack, "test_dlpack_dtypes_and_shapes", test_dlpack_dtypes_and_shapes, devices=devices)
617
+ add_function_test(TestDLPack, "test_dlpack_stream_arg", test_dlpack_stream_arg, devices=devices)
618
+
619
+ # torch interop via dlpack
620
+ try:
621
+ import torch
622
+ import torch.utils.dlpack
623
+
624
+ # check which Warp devices work with Torch
625
+ # CUDA devices may fail if Torch was not compiled with CUDA support
626
+ test_devices = get_test_devices()
627
+ torch_compatible_devices = []
628
+ for d in test_devices:
629
+ try:
630
+ t = torch.arange(10, device=wp.device_to_torch(d))
631
+ t += 1
632
+ torch_compatible_devices.append(d)
633
+ except Exception as e:
634
+ print(f"Skipping Torch DLPack tests on device '{d}' due to exception: {e}")
635
+
636
+ if torch_compatible_devices:
637
+ add_function_test(
638
+ TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
639
+ )
640
+ add_function_test(
641
+ TestDLPack, "test_dlpack_warp_to_torch_v2", test_dlpack_warp_to_torch_v2, devices=torch_compatible_devices
642
+ )
643
+ add_function_test(
644
+ TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
645
+ )
646
+ add_function_test(
647
+ TestDLPack, "test_dlpack_torch_to_warp_v2", test_dlpack_torch_to_warp_v2, devices=torch_compatible_devices
648
+ )
649
+
650
+ except Exception as e:
651
+ print(f"Skipping Torch DLPack tests due to exception: {e}")
652
+
653
+ # jax interop via dlpack
654
+ try:
655
+ # prevent Jax from gobbling up GPU memory
656
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
657
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
658
+
659
+ import jax
660
+ import jax.dlpack
661
+
662
+ # check which Warp devices work with Jax
663
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
664
+ test_devices = get_test_devices()
665
+ jax_compatible_devices = []
666
+ for d in test_devices:
667
+ try:
668
+ with jax.default_device(wp.device_to_jax(d)):
669
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
670
+ j += 1
671
+ jax_compatible_devices.append(d)
672
+ except Exception as e:
673
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
674
+
675
+ if jax_compatible_devices:
676
+ add_function_test(
677
+ TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
678
+ )
679
+ add_function_test(
680
+ TestDLPack, "test_dlpack_warp_to_jax_v2", test_dlpack_warp_to_jax_v2, devices=jax_compatible_devices
681
+ )
682
+ add_function_test(
683
+ TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
684
+ )
685
+ add_function_test(
686
+ TestDLPack, "test_dlpack_jax_to_warp_v2", test_dlpack_jax_to_warp_v2, devices=jax_compatible_devices
687
+ )
688
+
689
+ except Exception as e:
690
+ print(f"Skipping Jax DLPack tests due to exception: {e}")
691
+
692
+
693
+ # paddle interop via dlpack
694
+ try:
695
+ import paddle
696
+ import paddle.utils.dlpack
697
+
698
+ # check which Warp devices work with paddle
699
+ # CUDA devices may fail if paddle was not compiled with CUDA support
700
+ test_devices = get_test_devices()
701
+ paddle_compatible_devices = []
702
+ for d in test_devices:
703
+ try:
704
+ t = paddle.arange(10).to(device=wp.device_to_paddle(d))
705
+ paddle.assign(t + 1, t)
706
+ paddle_compatible_devices.append(d)
707
+ except Exception as e:
708
+ print(f"Skipping paddle DLPack tests on device '{d}' due to exception: {e}")
709
+
710
+ if paddle_compatible_devices:
711
+ add_function_test(
712
+ TestDLPack, "test_dlpack_warp_to_paddle", test_dlpack_warp_to_paddle, devices=paddle_compatible_devices
713
+ )
714
+ add_function_test(
715
+ TestDLPack,
716
+ "test_dlpack_warp_to_paddle_v2",
717
+ test_dlpack_warp_to_paddle_v2,
718
+ devices=paddle_compatible_devices,
719
+ )
720
+ add_function_test(
721
+ TestDLPack, "test_dlpack_paddle_to_warp", test_dlpack_paddle_to_warp, devices=paddle_compatible_devices
722
+ )
723
+
724
+ except Exception as e:
725
+ print(f"Skipping Paddle DLPack tests due to exception: {e}")
726
+
727
+
728
+ if __name__ == "__main__":
729
+ wp.clear_kernel_cache()
730
+ unittest.main(verbosity=2)