warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1673 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import unittest
18
+ from functools import partial
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+
23
+ import warp as wp
24
+ from warp._src.jax import get_jax_device
25
+ from warp.tests.unittest_utils import *
26
+
27
+ # default array size for tests
28
+ ARRAY_SIZE = 1024 * 1024
29
+
30
+
31
+ # basic kernel with one input and output
32
+ @wp.kernel
33
+ def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
34
+ tid = wp.tid()
35
+ output[tid] = 3.0 * input[tid]
36
+
37
+
38
+ # generic kernel with one scalar input and output
39
+ @wp.kernel
40
+ def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
41
+ tid = wp.tid()
42
+ output[tid] = input.dtype(3) * input[tid]
43
+
44
+
45
+ # generic kernel with one vector/matrix input and output
46
+ @wp.kernel
47
+ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
48
+ tid = wp.tid()
49
+ output[tid] = input.dtype.dtype(3) * input[tid]
50
+
51
+
52
+ @wp.kernel
53
+ def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
54
+ tid = wp.tid()
55
+ y[tid] = x[tid] + 1.0
56
+
57
+
58
+ @wp.kernel
59
+ def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
60
+ i, j = wp.tid()
61
+ y[i, j] = x[i, j] + 1.0
62
+
63
+
64
+ # kernel with multiple inputs and outputs
65
+ @wp.kernel
66
+ def multiarg_kernel(
67
+ # inputs
68
+ a: wp.array(dtype=float),
69
+ b: wp.array(dtype=float),
70
+ c: wp.array(dtype=float),
71
+ # outputs
72
+ ab: wp.array(dtype=float),
73
+ bc: wp.array(dtype=float),
74
+ ):
75
+ tid = wp.tid()
76
+ ab[tid] = a[tid] + b[tid]
77
+ bc[tid] = b[tid] + c[tid]
78
+
79
+
80
+ # various types for testing
81
+ scalar_types = wp._src.types.scalar_types
82
+ vector_types = []
83
+ matrix_types = []
84
+ for dim in [2, 3, 4]:
85
+ for T in scalar_types:
86
+ vector_types.append(wp.vec(dim, T))
87
+ matrix_types.append(wp.mat((dim, dim), T))
88
+
89
+ # explicitly overload generic kernels to avoid module reloading during tests
90
+ for T in scalar_types:
91
+ wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
92
+ for T in [*vector_types, *matrix_types]:
93
+ wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
94
+
95
+
96
+ def _jax_version():
97
+ try:
98
+ import jax
99
+
100
+ return jax.__version_info__
101
+ except ImportError:
102
+ return (0, 0, 0)
103
+
104
+
105
+ def test_dtype_from_jax(test, device):
106
+ import jax.numpy as jp
107
+
108
+ def test_conversions(jax_type, warp_type):
109
+ test.assertEqual(wp.dtype_from_jax(jax_type), warp_type)
110
+ test.assertEqual(wp.dtype_from_jax(jp.dtype(jax_type)), warp_type)
111
+
112
+ test_conversions(jp.float16, wp.float16)
113
+ test_conversions(jp.float32, wp.float32)
114
+ test_conversions(jp.float64, wp.float64)
115
+ test_conversions(jp.int8, wp.int8)
116
+ test_conversions(jp.int16, wp.int16)
117
+ test_conversions(jp.int32, wp.int32)
118
+ test_conversions(jp.int64, wp.int64)
119
+ test_conversions(jp.uint8, wp.uint8)
120
+ test_conversions(jp.uint16, wp.uint16)
121
+ test_conversions(jp.uint32, wp.uint32)
122
+ test_conversions(jp.uint64, wp.uint64)
123
+ test_conversions(jp.bool_, wp.bool)
124
+
125
+
126
+ def test_dtype_to_jax(test, device):
127
+ import jax.numpy as jp
128
+
129
+ def test_conversions(warp_type, jax_type):
130
+ test.assertEqual(wp.dtype_to_jax(warp_type), jax_type)
131
+
132
+ test_conversions(wp.float16, jp.float16)
133
+ test_conversions(wp.float32, jp.float32)
134
+ test_conversions(wp.float64, jp.float64)
135
+ test_conversions(wp.int8, jp.int8)
136
+ test_conversions(wp.int16, jp.int16)
137
+ test_conversions(wp.int32, jp.int32)
138
+ test_conversions(wp.int64, jp.int64)
139
+ test_conversions(wp.uint8, jp.uint8)
140
+ test_conversions(wp.uint16, jp.uint16)
141
+ test_conversions(wp.uint32, jp.uint32)
142
+ test_conversions(wp.uint64, jp.uint64)
143
+ test_conversions(wp.bool, jp.bool_)
144
+
145
+
146
+ def test_device_conversion(test, device):
147
+ jax_device = wp.device_to_jax(device)
148
+ warp_device = wp.device_from_jax(jax_device)
149
+ test.assertEqual(warp_device, device)
150
+
151
+
152
+ def test_jax_kernel_basic(test, device, use_ffi=False):
153
+ import jax.numpy as jp
154
+
155
+ if use_ffi:
156
+ from warp.jax_experimental.ffi import jax_kernel
157
+
158
+ jax_triple = jax_kernel(triple_kernel)
159
+ else:
160
+ from warp.jax_experimental.custom_call import jax_kernel
161
+
162
+ jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
163
+
164
+ n = ARRAY_SIZE
165
+
166
+ @jax.jit
167
+ def f():
168
+ x = jp.arange(n, dtype=jp.float32)
169
+ return jax_triple(x)
170
+
171
+ # run on the given device
172
+ with jax.default_device(wp.device_to_jax(device)):
173
+ y = f()
174
+
175
+ wp.synchronize_device(device)
176
+
177
+ result = np.asarray(y).reshape((n,))
178
+ expected = 3 * np.arange(n, dtype=np.float32)
179
+
180
+ assert_np_equal(result, expected)
181
+
182
+
183
+ def test_jax_kernel_scalar(test, device, use_ffi=False):
184
+ import jax.numpy as jp
185
+
186
+ if use_ffi:
187
+ from warp.jax_experimental.ffi import jax_kernel
188
+
189
+ kwargs = {}
190
+ else:
191
+ from warp.jax_experimental.custom_call import jax_kernel
192
+
193
+ kwargs = {"quiet": True}
194
+
195
+ # use a smallish size to ensure arange * 3 doesn't overflow
196
+ n = 64
197
+
198
+ for T in scalar_types:
199
+ jp_dtype = wp.dtype_to_jax(T)
200
+ np_dtype = wp.dtype_to_numpy(T)
201
+
202
+ with test.subTest(msg=T.__name__):
203
+ # get the concrete overload
204
+ kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
205
+
206
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
207
+
208
+ @jax.jit
209
+ def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
210
+ x = jp.arange(n, dtype=jp_dtype)
211
+ return jax_triple(x)
212
+
213
+ # run on the given device
214
+ with jax.default_device(wp.device_to_jax(device)):
215
+ y = f()
216
+
217
+ wp.synchronize_device(device)
218
+
219
+ result = np.asarray(y).reshape((n,))
220
+ expected = 3 * np.arange(n, dtype=np_dtype)
221
+
222
+ assert_np_equal(result, expected)
223
+
224
+
225
+ def test_jax_kernel_vecmat(test, device, use_ffi=False):
226
+ import jax.numpy as jp
227
+
228
+ if use_ffi:
229
+ from warp.jax_experimental.ffi import jax_kernel
230
+
231
+ kwargs = {}
232
+ else:
233
+ from warp.jax_experimental.custom_call import jax_kernel
234
+
235
+ kwargs = {"quiet": True}
236
+
237
+ for T in [*vector_types, *matrix_types]:
238
+ jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
239
+ np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
240
+
241
+ # use a smallish size to ensure arange * 3 doesn't overflow
242
+ n = 64 // T._length_
243
+ scalar_shape = (n, *T._shape_)
244
+ scalar_len = n * T._length_
245
+
246
+ with test.subTest(msg=T.__name__):
247
+ # get the concrete overload
248
+ kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
249
+
250
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
251
+
252
+ @jax.jit
253
+ def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
254
+ x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
255
+ return jax_triple(x)
256
+
257
+ # run on the given device
258
+ with jax.default_device(wp.device_to_jax(device)):
259
+ y = f()
260
+
261
+ wp.synchronize_device(device)
262
+
263
+ result = np.asarray(y).reshape(scalar_shape)
264
+ expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
265
+
266
+ assert_np_equal(result, expected)
267
+
268
+
269
+ def test_jax_kernel_multiarg(test, device, use_ffi=False):
270
+ import jax.numpy as jp
271
+
272
+ if use_ffi:
273
+ from warp.jax_experimental.ffi import jax_kernel
274
+
275
+ jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
276
+ else:
277
+ from warp.jax_experimental.custom_call import jax_kernel
278
+
279
+ jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
280
+
281
+ n = ARRAY_SIZE
282
+
283
+ @jax.jit
284
+ def f():
285
+ a = jp.full(n, 1, dtype=jp.float32)
286
+ b = jp.full(n, 2, dtype=jp.float32)
287
+ c = jp.full(n, 3, dtype=jp.float32)
288
+ return jax_multiarg(a, b, c)
289
+
290
+ # run on the given device
291
+ with jax.default_device(wp.device_to_jax(device)):
292
+ x, y = f()
293
+
294
+ wp.synchronize_device(device)
295
+
296
+ result_x, result_y = np.asarray(x), np.asarray(y)
297
+ expected_x = np.full(n, 3, dtype=np.float32)
298
+ expected_y = np.full(n, 5, dtype=np.float32)
299
+
300
+ assert_np_equal(result_x, expected_x)
301
+ assert_np_equal(result_y, expected_y)
302
+
303
+
304
+ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
305
+ import jax.numpy as jp
306
+
307
+ if use_ffi:
308
+ from warp.jax_experimental.ffi import jax_kernel
309
+
310
+ kwargs = {}
311
+ else:
312
+ from warp.jax_experimental.custom_call import jax_kernel
313
+
314
+ kwargs = {"quiet": True}
315
+
316
+ n = 64
317
+ m = 32
318
+
319
+ # Test with 1D launch dims
320
+ jax_inc_1d = jax_kernel(
321
+ inc_1d_kernel, launch_dims=(n - 2,), **kwargs
322
+ ) # Intentionally not the same as the first dimension of the input
323
+
324
+ @jax.jit
325
+ def f_1d():
326
+ x = jp.arange(n, dtype=jp.float32)
327
+ return jax_inc_1d(x)
328
+
329
+ # Test with 2D launch dims
330
+ jax_inc_2d = jax_kernel(
331
+ inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
332
+ ) # Intentionally not the same as the first dimension of the input
333
+
334
+ @jax.jit
335
+ def f_2d():
336
+ x = jp.zeros((n, m), dtype=jp.float32) + 3.0
337
+ return jax_inc_2d(x)
338
+
339
+ # run on the given device
340
+ with jax.default_device(wp.device_to_jax(device)):
341
+ y_1d = f_1d()
342
+ y_2d = f_2d()
343
+
344
+ wp.synchronize_device(device)
345
+
346
+ result_1d = np.asarray(y_1d).reshape((n - 2,))
347
+ expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
348
+
349
+ result_2d = np.asarray(y_2d).reshape((n - 2, m - 2))
350
+ expected_2d = np.full((n - 2, m - 2), 4.0, dtype=np.float32)
351
+
352
+ assert_np_equal(result_1d, expected_1d)
353
+ assert_np_equal(result_2d, expected_2d)
354
+
355
+
356
+ # =========================================================================================================
357
+ # JAX FFI
358
+ # =========================================================================================================
359
+
360
+
361
+ @wp.kernel
362
+ def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
363
+ tid = wp.tid()
364
+ output[tid] = a[tid] + b[tid]
365
+
366
+
367
+ @wp.kernel
368
+ def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
369
+ tid = wp.tid()
370
+ out[tid] = alpha * x[tid] + y[tid]
371
+
372
+
373
+ @wp.kernel
374
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
375
+ tid = wp.tid()
376
+ sin_out[tid] = wp.sin(angle[tid])
377
+ cos_out[tid] = wp.cos(angle[tid])
378
+
379
+
380
+ @wp.kernel
381
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
382
+ tid = wp.tid()
383
+ d = float(tid + 1)
384
+ output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
385
+
386
+
387
+ @wp.kernel
388
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
389
+ tid = wp.tid()
390
+ output[tid] = a[tid] * s
391
+
392
+
393
+ @wp.kernel
394
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
395
+ tid = wp.tid()
396
+ output[tid] = a[tid] * s
397
+
398
+
399
+ @wp.kernel
400
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
401
+ tid = wp.tid()
402
+ b[tid] += a[tid]
403
+
404
+
405
+ @wp.kernel
406
+ def matmul_kernel(
407
+ a: wp.array2d(dtype=float), # NxK
408
+ b: wp.array2d(dtype=float), # KxM
409
+ c: wp.array2d(dtype=float), # NxM
410
+ ):
411
+ # launch dims should be (N, M)
412
+ i, j = wp.tid()
413
+ N = a.shape[0]
414
+ K = a.shape[1]
415
+ M = b.shape[1]
416
+ if i < N and j < M:
417
+ s = wp.float32(0)
418
+ for k in range(K):
419
+ s += a[i, k] * b[k, j]
420
+ c[i, j] = s
421
+
422
+
423
+ @wp.kernel
424
+ def in_out_kernel(
425
+ a: wp.array(dtype=float), # input only
426
+ b: wp.array(dtype=float), # input and output
427
+ c: wp.array(dtype=float), # output only
428
+ ):
429
+ tid = wp.tid()
430
+ b[tid] += a[tid]
431
+ c[tid] = 2.0 * a[tid]
432
+
433
+
434
+ @wp.kernel
435
+ def multi_out_kernel(
436
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
437
+ ):
438
+ tid = wp.tid()
439
+ c[tid] = a[tid] + b[tid]
440
+ d[tid] = s * a[tid]
441
+
442
+
443
+ @wp.kernel
444
+ def multi_out_kernel_v2(
445
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
446
+ ):
447
+ tid = wp.tid()
448
+ c[tid] = a[tid] * a[tid]
449
+ d[tid] = a[tid] * b[tid] * s
450
+
451
+
452
+ @wp.kernel
453
+ def multi_out_kernel_v3(
454
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
455
+ ):
456
+ tid = wp.tid()
457
+ c[tid] = a[tid] ** 2.0
458
+ d[tid] = a[tid] * b[tid] * s
459
+
460
+
461
+ @wp.kernel
462
+ def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
463
+ tid = wp.tid()
464
+ c[tid] = (a[tid] * s + b[tid]) ** 2.0
465
+
466
+
467
+ # The Python function to call.
468
+ # Note the argument annotations, just like Warp kernels.
469
+ def scale_func(
470
+ # inputs
471
+ a: wp.array(dtype=float),
472
+ b: wp.array(dtype=wp.vec2),
473
+ s: float,
474
+ # outputs
475
+ c: wp.array(dtype=float),
476
+ d: wp.array(dtype=wp.vec2),
477
+ ):
478
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
479
+ wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
480
+
481
+
482
+ def in_out_func(
483
+ a: wp.array(dtype=float), # input only
484
+ b: wp.array(dtype=float), # input and output
485
+ c: wp.array(dtype=float), # output only
486
+ ):
487
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
488
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
489
+
490
+
491
+ def double_func(
492
+ # inputs
493
+ a: wp.array(dtype=float),
494
+ # outputs
495
+ b: wp.array(dtype=float),
496
+ ):
497
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
498
+
499
+
500
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
501
+ def test_ffi_jax_kernel_add(test, device):
502
+ # two inputs and one output
503
+ import jax.numpy as jp
504
+
505
+ from warp.jax_experimental.ffi import jax_kernel
506
+
507
+ jax_add = jax_kernel(add_kernel)
508
+
509
+ @jax.jit
510
+ def f():
511
+ n = ARRAY_SIZE
512
+ a = jp.arange(n, dtype=jp.float32)
513
+ b = jp.ones(n, dtype=jp.float32)
514
+ return jax_add(a, b)
515
+
516
+ with jax.default_device(wp.device_to_jax(device)):
517
+ (y,) = f()
518
+
519
+ wp.synchronize_device(device)
520
+
521
+ result = np.asarray(y)
522
+ expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
523
+
524
+ assert_np_equal(result, expected)
525
+
526
+
527
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
528
+ def test_ffi_jax_kernel_sincos(test, device):
529
+ # one input and two outputs
530
+ import jax.numpy as jp
531
+
532
+ from warp.jax_experimental.ffi import jax_kernel
533
+
534
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
535
+
536
+ n = ARRAY_SIZE
537
+
538
+ @jax.jit
539
+ def f():
540
+ a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
541
+ return jax_sincos(a)
542
+
543
+ with jax.default_device(wp.device_to_jax(device)):
544
+ s, c = f()
545
+
546
+ wp.synchronize_device(device)
547
+
548
+ result_s = np.asarray(s)
549
+ result_c = np.asarray(c)
550
+
551
+ a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
552
+ expected_s = np.sin(a)
553
+ expected_c = np.cos(a)
554
+
555
+ assert_np_equal(result_s, expected_s, tol=1e-4)
556
+ assert_np_equal(result_c, expected_c, tol=1e-4)
557
+
558
+
559
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
560
+ def test_ffi_jax_kernel_diagonal(test, device):
561
+ # no inputs and one output
562
+ from warp.jax_experimental.ffi import jax_kernel
563
+
564
+ jax_diagonal = jax_kernel(diagonal_kernel)
565
+
566
+ @jax.jit
567
+ def f():
568
+ # launch dimensions determine output size
569
+ return jax_diagonal(launch_dims=4)
570
+
571
+ wp.synchronize_device(device)
572
+
573
+ with jax.default_device(wp.device_to_jax(device)):
574
+ (d,) = f()
575
+
576
+ result = np.asarray(d)
577
+ expected = np.array(
578
+ [
579
+ [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
580
+ [[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
581
+ [[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
582
+ [[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
583
+ ],
584
+ dtype=np.float32,
585
+ )
586
+
587
+ assert_np_equal(result, expected)
588
+
589
+
590
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
591
+ def test_ffi_jax_kernel_in_out(test, device):
592
+ # in-out args
593
+ import jax.numpy as jp
594
+
595
+ from warp.jax_experimental.ffi import jax_kernel
596
+
597
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
598
+
599
+ f = jax.jit(jax_func)
600
+
601
+ with jax.default_device(wp.device_to_jax(device)):
602
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
603
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
604
+ b, c = f(a, b)
605
+
606
+ wp.synchronize_device(device)
607
+
608
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
609
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
610
+
611
+
612
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
613
+ def test_ffi_jax_kernel_scale_vec_constant(test, device):
614
+ # multiply vectors by scalar (constant)
615
+ import jax.numpy as jp
616
+
617
+ from warp.jax_experimental.ffi import jax_kernel
618
+
619
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
620
+
621
+ @jax.jit
622
+ def f():
623
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
624
+ s = 2.0
625
+ return jax_scale_vec(a, s)
626
+
627
+ with jax.default_device(wp.device_to_jax(device)):
628
+ (b,) = f()
629
+
630
+ wp.synchronize_device(device)
631
+
632
+ expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
633
+
634
+ assert_np_equal(b, expected)
635
+
636
+
637
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
638
+ def test_ffi_jax_kernel_scale_vec_static(test, device):
639
+ # multiply vectors by scalar (static arg)
640
+ import jax.numpy as jp
641
+
642
+ from warp.jax_experimental.ffi import jax_kernel
643
+
644
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
645
+
646
+ # NOTE: scalar arguments must be static compile-time constants
647
+ @partial(jax.jit, static_argnames=["s"])
648
+ def f(a, s):
649
+ return jax_scale_vec(a, s)
650
+
651
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
652
+ s = 3.0
653
+
654
+ with jax.default_device(wp.device_to_jax(device)):
655
+ (b,) = f(a, s)
656
+
657
+ wp.synchronize_device(device)
658
+
659
+ expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
660
+
661
+ assert_np_equal(b, expected)
662
+
663
+
664
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
665
+ def test_ffi_jax_kernel_launch_dims_default(test, device):
666
+ # specify default launch dims
667
+ import jax.numpy as jp
668
+
669
+ from warp.jax_experimental.ffi import jax_kernel
670
+
671
+ N, M, K = 3, 4, 2
672
+
673
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
674
+
675
+ @jax.jit
676
+ def f():
677
+ a = jp.full((N, K), 2, dtype=jp.float32)
678
+ b = jp.full((K, M), 3, dtype=jp.float32)
679
+
680
+ # use default launch dims
681
+ return jax_matmul(a, b)
682
+
683
+ with jax.default_device(wp.device_to_jax(device)):
684
+ (result,) = f()
685
+
686
+ wp.synchronize_device(device)
687
+
688
+ expected = np.full((3, 4), 12, dtype=np.float32)
689
+
690
+ test.assertEqual(result.shape, expected.shape)
691
+ assert_np_equal(result, expected)
692
+
693
+
694
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
695
+ def test_ffi_jax_kernel_launch_dims_custom(test, device):
696
+ # specify custom launch dims per call
697
+ import jax.numpy as jp
698
+
699
+ from warp.jax_experimental.ffi import jax_kernel
700
+
701
+ jax_matmul = jax_kernel(matmul_kernel)
702
+
703
+ @jax.jit
704
+ def f():
705
+ N1, M1, K1 = 3, 4, 2
706
+ a1 = jp.full((N1, K1), 2, dtype=jp.float32)
707
+ b1 = jp.full((K1, M1), 3, dtype=jp.float32)
708
+
709
+ # use custom launch dims
710
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
711
+
712
+ N2, M2, K2 = 4, 3, 2
713
+ a2 = jp.full((N2, K2), 2, dtype=jp.float32)
714
+ b2 = jp.full((K2, M2), 3, dtype=jp.float32)
715
+
716
+ # use different custom launch dims
717
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
718
+
719
+ return result1[0], result2[0]
720
+
721
+ with jax.default_device(wp.device_to_jax(device)):
722
+ result1, result2 = f()
723
+
724
+ wp.synchronize_device(device)
725
+
726
+ expected1 = np.full((3, 4), 12, dtype=np.float32)
727
+ expected2 = np.full((4, 3), 12, dtype=np.float32)
728
+
729
+ test.assertEqual(result1.shape, expected1.shape)
730
+ test.assertEqual(result2.shape, expected2.shape)
731
+ assert_np_equal(result1, expected1)
732
+ assert_np_equal(result2, expected2)
733
+
734
+
735
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
736
+ def test_ffi_jax_callable_scale_constant(test, device):
737
+ # scale two arrays using a constant
738
+ import jax.numpy as jp
739
+
740
+ from warp.jax_experimental.ffi import jax_callable
741
+
742
+ jax_func = jax_callable(scale_func, num_outputs=2)
743
+
744
+ @jax.jit
745
+ def f():
746
+ # inputs
747
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
748
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
749
+ s = 2.0
750
+
751
+ # output shapes
752
+ output_dims = {"c": a.shape, "d": b.shape}
753
+
754
+ c, d = jax_func(a, b, s, output_dims=output_dims)
755
+
756
+ return c, d
757
+
758
+ with jax.default_device(wp.device_to_jax(device)):
759
+ result1, result2 = f()
760
+
761
+ wp.synchronize_device(device)
762
+
763
+ expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
764
+ expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
765
+
766
+ assert_np_equal(result1, expected1)
767
+ assert_np_equal(result2, expected2)
768
+
769
+
770
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
771
+ def test_ffi_jax_callable_scale_static(test, device):
772
+ # scale two arrays using a static arg
773
+ import jax.numpy as jp
774
+
775
+ from warp.jax_experimental.ffi import jax_callable
776
+
777
+ jax_func = jax_callable(scale_func, num_outputs=2)
778
+
779
+ # NOTE: scalar arguments must be static compile-time constants
780
+ @partial(jax.jit, static_argnames=["s"])
781
+ def f(a, b, s):
782
+ # output shapes
783
+ output_dims = {"c": a.shape, "d": b.shape}
784
+
785
+ c, d = jax_func(a, b, s, output_dims=output_dims)
786
+
787
+ return c, d
788
+
789
+ with jax.default_device(wp.device_to_jax(device)):
790
+ # inputs
791
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
792
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
793
+ s = 3.0
794
+ result1, result2 = f(a, b, s)
795
+
796
+ wp.synchronize_device(device)
797
+
798
+ expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
799
+ expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
800
+
801
+ assert_np_equal(result1, expected1)
802
+ assert_np_equal(result2, expected2)
803
+
804
+
805
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
806
+ def test_ffi_jax_callable_in_out(test, device):
807
+ # in-out arguments
808
+ import jax.numpy as jp
809
+
810
+ from warp.jax_experimental.ffi import jax_callable
811
+
812
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
813
+
814
+ f = jax.jit(jax_func)
815
+
816
+ with jax.default_device(wp.device_to_jax(device)):
817
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
818
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
819
+ b, c = f(a, b)
820
+
821
+ wp.synchronize_device(device)
822
+
823
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
824
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
825
+
826
+
827
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
828
+ def test_ffi_jax_callable_graph_cache(test, device):
829
+ # test graph caching limits
830
+ import jax
831
+ import jax.numpy as jp
832
+
833
+ from warp.jax_experimental.ffi import (
834
+ GraphMode,
835
+ clear_jax_callable_graph_cache,
836
+ get_jax_callable_default_graph_cache_max,
837
+ jax_callable,
838
+ set_jax_callable_default_graph_cache_max,
839
+ )
840
+
841
+ # --- test with default cache settings ---
842
+
843
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
844
+ f = jax.jit(jax_double)
845
+ arrays = []
846
+
847
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
848
+
849
+ with jax.default_device(wp.device_to_jax(device)):
850
+ for i in range(10):
851
+ n = 10 + i
852
+ a = jp.arange(n, dtype=jp.float32)
853
+ (b,) = f(a)
854
+
855
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
856
+
857
+ # ensure graph cache is always growing
858
+ test.assertEqual(jax_double.graph_cache_size, i + 1)
859
+
860
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
861
+ arrays.append(a)
862
+
863
+ # --- test clearing one callable's cache ---
864
+
865
+ clear_jax_callable_graph_cache(jax_double)
866
+
867
+ test.assertEqual(jax_double.graph_cache_size, 0)
868
+
869
+ # --- test with a custom cache limit ---
870
+
871
+ graph_cache_max = 5
872
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
873
+ f = jax.jit(jax_double)
874
+ arrays = []
875
+
876
+ test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
877
+
878
+ with jax.default_device(wp.device_to_jax(device)):
879
+ for i in range(10):
880
+ n = 10 + i
881
+ a = jp.arange(n, dtype=jp.float32)
882
+ (b,) = f(a)
883
+
884
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
885
+
886
+ # ensure graph cache size is capped
887
+ test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
888
+
889
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
890
+ arrays.append(a)
891
+
892
+ # --- test clearing all callables' caches ---
893
+
894
+ clear_jax_callable_graph_cache()
895
+
896
+ with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
897
+ for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
898
+ test.assertEqual(c.graph_cache_size, 0)
899
+
900
+ # --- test with a custom default cache limit ---
901
+
902
+ saved_max = get_jax_callable_default_graph_cache_max()
903
+ try:
904
+ set_jax_callable_default_graph_cache_max(5)
905
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
906
+ f = jax.jit(jax_double)
907
+ arrays = []
908
+
909
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
910
+
911
+ with jax.default_device(wp.device_to_jax(device)):
912
+ for i in range(10):
913
+ n = 10 + i
914
+ a = jp.arange(n, dtype=jp.float32)
915
+ (b,) = f(a)
916
+
917
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
918
+
919
+ # ensure graph cache size is capped
920
+ test.assertEqual(
921
+ jax_double.graph_cache_size,
922
+ min(i + 1, get_jax_callable_default_graph_cache_max()),
923
+ )
924
+
925
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
926
+ arrays.append(a)
927
+
928
+ clear_jax_callable_graph_cache()
929
+
930
+ finally:
931
+ set_jax_callable_default_graph_cache_max(saved_max)
932
+
933
+
934
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
935
+ def test_ffi_jax_callable_pmap_mul(test, device):
936
+ import jax
937
+ import jax.numpy as jp
938
+
939
+ from warp.jax_experimental.ffi import jax_callable
940
+
941
+ j = jax_callable(double_func, num_outputs=1)
942
+
943
+ ndev = jax.local_device_count()
944
+ per_device = max(ARRAY_SIZE // ndev, 64)
945
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
946
+
947
+ def per_device_func(v):
948
+ (y,) = j(v)
949
+ return y
950
+
951
+ y = jax.pmap(per_device_func)(x)
952
+
953
+ wp.synchronize()
954
+
955
+ assert_np_equal(np.asarray(y), 2 * np.asarray(x))
956
+
957
+
958
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
959
+ def test_ffi_jax_callable_pmap_multi_output(test, device):
960
+ import jax
961
+ import jax.numpy as jp
962
+
963
+ from warp.jax_experimental.ffi import jax_callable
964
+
965
+ def multi_out_py(
966
+ a: wp.array(dtype=float),
967
+ b: wp.array(dtype=float),
968
+ s: float,
969
+ c: wp.array(dtype=float),
970
+ d: wp.array(dtype=float),
971
+ ):
972
+ wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
973
+
974
+ j = jax_callable(multi_out_py, num_outputs=2)
975
+
976
+ ndev = jax.local_device_count()
977
+ per_device = max(ARRAY_SIZE // ndev, 64)
978
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
979
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
980
+ s = 3.0
981
+
982
+ def per_device_func(aa, bb):
983
+ c, d = j(aa, bb, s)
984
+ return c + d # simple combine to exercise both outputs
985
+
986
+ out = jax.pmap(per_device_func)(a, b)
987
+
988
+ wp.synchronize()
989
+
990
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
991
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
992
+ ref = (a_np + b_np) + s * a_np
993
+ assert_np_equal(np.asarray(out), ref)
994
+
995
+
996
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
997
+ def test_ffi_jax_callable_pmap_multi_stage(test, device):
998
+ import jax
999
+ import jax.numpy as jp
1000
+
1001
+ from warp.jax_experimental.ffi import jax_callable
1002
+
1003
+ def multi_stage_py(
1004
+ a: wp.array(dtype=float),
1005
+ b: wp.array(dtype=float),
1006
+ alpha: float,
1007
+ tmp: wp.array(dtype=float),
1008
+ out: wp.array(dtype=float),
1009
+ ):
1010
+ wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
1011
+ wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
1012
+
1013
+ j = jax_callable(multi_stage_py, num_outputs=2)
1014
+
1015
+ ndev = jax.local_device_count()
1016
+ per_device = max(ARRAY_SIZE // ndev, 64)
1017
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1018
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
1019
+ alpha = 2.5
1020
+
1021
+ def per_device_func(aa, bb):
1022
+ tmp, out = j(aa, bb, alpha)
1023
+ return tmp + out
1024
+
1025
+ combined = jax.pmap(per_device_func)(a, b)
1026
+
1027
+ wp.synchronize()
1028
+
1029
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1030
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
1031
+ tmp_ref = a_np + b_np
1032
+ out_ref = alpha * (a_np + b_np) + b_np
1033
+ ref = tmp_ref + out_ref
1034
+ assert_np_equal(np.asarray(combined), ref)
1035
+
1036
+
1037
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1038
+ def test_ffi_callback(test, device):
1039
+ # in-out arguments
1040
+ import jax.numpy as jp
1041
+
1042
+ from warp.jax_experimental.ffi import register_ffi_callback
1043
+
1044
+ # the Python function to call
1045
+ def warp_func(inputs, outputs, attrs, ctx):
1046
+ # input arrays
1047
+ a = inputs[0]
1048
+ b = inputs[1]
1049
+
1050
+ # scalar attributes
1051
+ s = attrs["scale"]
1052
+
1053
+ # output arrays
1054
+ c = outputs[0]
1055
+ d = outputs[1]
1056
+
1057
+ device = wp.device_from_jax(get_jax_device())
1058
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
1059
+
1060
+ with wp.ScopedStream(stream):
1061
+ # launch with arrays of scalars
1062
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
1063
+
1064
+ # launch with arrays of vec2
1065
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
1066
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
1067
+
1068
+ # register callback
1069
+ register_ffi_callback("warp_func", warp_func)
1070
+
1071
+ n = ARRAY_SIZE
1072
+
1073
+ with jax.default_device(wp.device_to_jax(device)):
1074
+ # inputs
1075
+ a = jp.arange(n, dtype=jp.float32)
1076
+ b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
1077
+ s = 2.0
1078
+
1079
+ # set up call
1080
+ out_types = [
1081
+ jax.ShapeDtypeStruct(a.shape, jp.float32),
1082
+ jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
1083
+ ]
1084
+ call = jax.ffi.ffi_call("warp_func", out_types)
1085
+
1086
+ # call it
1087
+ c, d = call(a, b, scale=s)
1088
+
1089
+ wp.synchronize_device(device)
1090
+
1091
+ assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
1092
+ assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
1093
+
1094
+
1095
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1096
+ def test_ffi_jax_kernel_autodiff_simple(test, device):
1097
+ import jax
1098
+ import jax.numpy as jp
1099
+
1100
+ from warp.jax_experimental.ffi import jax_kernel
1101
+
1102
+ jax_func = jax_kernel(
1103
+ scale_sum_square_kernel,
1104
+ num_outputs=1,
1105
+ enable_backward=True,
1106
+ )
1107
+
1108
+ from functools import partial
1109
+
1110
+ @partial(jax.jit, static_argnames=["s"])
1111
+ def loss(a, b, s):
1112
+ out = jax_func(a, b, s)[0]
1113
+ return jp.sum(out)
1114
+
1115
+ n = ARRAY_SIZE
1116
+ a = jp.arange(n, dtype=jp.float32)
1117
+ b = jp.ones(n, dtype=jp.float32)
1118
+ s = 2.0
1119
+
1120
+ with jax.default_device(wp.device_to_jax(device)):
1121
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1122
+
1123
+ wp.synchronize_device(device)
1124
+
1125
+ # reference gradients
1126
+ # d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
1127
+ # d/db sum((a*s + b)^2) = sum(2*(a*s + b))
1128
+ a_np = np.arange(n, dtype=np.float32)
1129
+ b_np = np.ones(n, dtype=np.float32)
1130
+ ref_da = 2.0 * (a_np * s + b_np) * s
1131
+ ref_db = 2.0 * (a_np * s + b_np)
1132
+
1133
+ assert_np_equal(np.asarray(da), ref_da)
1134
+ assert_np_equal(np.asarray(db), ref_db)
1135
+
1136
+
1137
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1138
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
1139
+ import jax
1140
+ import jax.numpy as jp
1141
+
1142
+ from warp.jax_experimental.ffi import jax_kernel
1143
+
1144
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1145
+
1146
+ def loss(a, b, s):
1147
+ out = jax_func(a, b, s)[0]
1148
+ return jp.sum(out)
1149
+
1150
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1151
+
1152
+ # more typical: jit(grad(...)) with static scalar
1153
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1154
+
1155
+ n = ARRAY_SIZE
1156
+ a = jp.arange(n, dtype=jp.float32)
1157
+ b = jp.ones(n, dtype=jp.float32)
1158
+ s = 2.0
1159
+
1160
+ with jax.default_device(wp.device_to_jax(device)):
1161
+ da, db = jitted_grad(a, b, s)
1162
+
1163
+ wp.synchronize_device(device)
1164
+
1165
+ a_np = np.arange(n, dtype=np.float32)
1166
+ b_np = np.ones(n, dtype=np.float32)
1167
+ ref_da = 2.0 * (a_np * s + b_np) * s
1168
+ ref_db = 2.0 * (a_np * s + b_np)
1169
+
1170
+ assert_np_equal(np.asarray(da), ref_da)
1171
+ assert_np_equal(np.asarray(db), ref_db)
1172
+
1173
+
1174
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1175
+ def test_ffi_jax_kernel_autodiff_multi_output(test, device):
1176
+ import jax
1177
+ import jax.numpy as jp
1178
+
1179
+ from warp.jax_experimental.ffi import jax_kernel
1180
+
1181
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1182
+
1183
+ def caller(fn, a, b, s):
1184
+ c, d = fn(a, b, s)
1185
+ return jp.sum(c + d)
1186
+
1187
+ @jax.jit
1188
+ def grads(a, b, s):
1189
+ # mark s as static in the inner call via partial to avoid hashing
1190
+ def _inner(a, b, s):
1191
+ return caller(jax_func, a, b, s)
1192
+
1193
+ return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
1194
+
1195
+ n = ARRAY_SIZE
1196
+ a = jp.arange(n, dtype=jp.float32)
1197
+ b = jp.ones(n, dtype=jp.float32)
1198
+ s = 2.0
1199
+
1200
+ with jax.default_device(wp.device_to_jax(device)):
1201
+ da, db = grads(a, b, s)
1202
+
1203
+ wp.synchronize_device(device)
1204
+
1205
+ a_np = np.arange(n, dtype=np.float32)
1206
+ b_np = np.ones(n, dtype=np.float32)
1207
+ # d/da sum(c+d) = 2*a + b*s
1208
+ ref_da = 2.0 * a_np + b_np * s
1209
+ # d/db sum(c+d) = a*s
1210
+ ref_db = a_np * s
1211
+
1212
+ assert_np_equal(np.asarray(da), ref_da)
1213
+ assert_np_equal(np.asarray(db), ref_db)
1214
+
1215
+
1216
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1217
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
1218
+ import jax
1219
+ import jax.numpy as jp
1220
+
1221
+ from warp.jax_experimental.ffi import jax_kernel
1222
+
1223
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1224
+
1225
+ def loss(a, b, s):
1226
+ c, d = jax_func(a, b, s)
1227
+ return jp.sum(c + d)
1228
+
1229
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1230
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1231
+
1232
+ n = ARRAY_SIZE
1233
+ a = jp.arange(n, dtype=jp.float32)
1234
+ b = jp.ones(n, dtype=jp.float32)
1235
+ s = 2.0
1236
+
1237
+ with jax.default_device(wp.device_to_jax(device)):
1238
+ da, db = jitted_grad(a, b, s)
1239
+
1240
+ wp.synchronize_device(device)
1241
+
1242
+ a_np = np.arange(n, dtype=np.float32)
1243
+ b_np = np.ones(n, dtype=np.float32)
1244
+ ref_da = 2.0 * a_np + b_np * s
1245
+ ref_db = a_np * s
1246
+
1247
+ assert_np_equal(np.asarray(da), ref_da)
1248
+ assert_np_equal(np.asarray(db), ref_db)
1249
+
1250
+
1251
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1252
+ def test_ffi_jax_kernel_autodiff_2d(test, device):
1253
+ import jax
1254
+ import jax.numpy as jp
1255
+
1256
+ from warp.jax_experimental.ffi import jax_kernel
1257
+
1258
+ jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
1259
+
1260
+ @jax.jit
1261
+ def loss(a):
1262
+ out = jax_func(a)[0]
1263
+ return jp.sum(out)
1264
+
1265
+ n, m = 8, 6
1266
+ a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
1267
+
1268
+ with jax.default_device(wp.device_to_jax(device)):
1269
+ (da,) = jax.grad(loss, argnums=(0,))(a)
1270
+
1271
+ wp.synchronize_device(device)
1272
+
1273
+ ref = np.ones((n, m), dtype=np.float32)
1274
+ assert_np_equal(np.asarray(da), ref)
1275
+
1276
+
1277
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1278
+ def test_ffi_jax_kernel_autodiff_vec2(test, device):
1279
+ import jax
1280
+ import jax.numpy as jp
1281
+
1282
+ from warp.jax_experimental.ffi import jax_kernel
1283
+
1284
+ jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
1285
+
1286
+ from functools import partial
1287
+
1288
+ @partial(jax.jit, static_argnames=("s",))
1289
+ def loss(a, s):
1290
+ out = jax_func(a, s)[0]
1291
+ return jp.sum(out)
1292
+
1293
+ n = ARRAY_SIZE
1294
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
1295
+ s = 3.0
1296
+
1297
+ with jax.default_device(wp.device_to_jax(device)):
1298
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1299
+
1300
+ wp.synchronize_device(device)
1301
+
1302
+ # d/da sum(a*s) = s
1303
+ ref = np.full_like(np.asarray(a), s)
1304
+ assert_np_equal(np.asarray(da), ref)
1305
+
1306
+
1307
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1308
+ def test_ffi_jax_kernel_autodiff_mat22(test, device):
1309
+ import jax
1310
+ import jax.numpy as jp
1311
+
1312
+ from warp.jax_experimental.ffi import jax_kernel
1313
+
1314
+ @wp.kernel
1315
+ def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
1316
+ tid = wp.tid()
1317
+ out[tid] = a[tid] * s
1318
+
1319
+ jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
1320
+
1321
+ from functools import partial
1322
+
1323
+ @partial(jax.jit, static_argnames=("s",))
1324
+ def loss(a, s):
1325
+ out = jax_func(a, s)[0]
1326
+ return jp.sum(out)
1327
+
1328
+ n = 12 # must be divisible by 4 for 2x2 matrices
1329
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
1330
+ s = 2.5
1331
+
1332
+ with jax.default_device(wp.device_to_jax(device)):
1333
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1334
+
1335
+ wp.synchronize_device(device)
1336
+
1337
+ ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
1338
+ assert_np_equal(np.asarray(da), ref)
1339
+
1340
+
1341
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1342
+ def test_ffi_jax_kernel_autodiff_static_required(test, device):
1343
+ import jax
1344
+ import jax.numpy as jp
1345
+
1346
+ from warp.jax_experimental.ffi import jax_kernel
1347
+
1348
+ # Require explicit static_argnames for scalar s
1349
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1350
+
1351
+ def loss(a, b, s):
1352
+ out = jax_func(a, b, s)[0]
1353
+ return jp.sum(out)
1354
+
1355
+ n = ARRAY_SIZE
1356
+ a = jp.arange(n, dtype=jp.float32)
1357
+ b = jp.ones(n, dtype=jp.float32)
1358
+ s = 1.5
1359
+
1360
+ with jax.default_device(wp.device_to_jax(device)):
1361
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1362
+
1363
+ wp.synchronize_device(device)
1364
+
1365
+ a_np = np.arange(n, dtype=np.float32)
1366
+ b_np = np.ones(n, dtype=np.float32)
1367
+ ref_da = 2.0 * (a_np * s + b_np) * s
1368
+ ref_db = 2.0 * (a_np * s + b_np)
1369
+
1370
+ assert_np_equal(np.asarray(da), ref_da)
1371
+ assert_np_equal(np.asarray(db), ref_db)
1372
+
1373
+
1374
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1375
+ def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
1376
+ import jax
1377
+ import jax.numpy as jp
1378
+
1379
+ from warp.jax_experimental.ffi import jax_kernel
1380
+
1381
+ jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
1382
+
1383
+ ndev = jax.local_device_count()
1384
+ per_device = ARRAY_SIZE // ndev
1385
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1386
+
1387
+ def per_device_loss(x):
1388
+ y = jax_mul(x)[0]
1389
+ return jp.sum(y)
1390
+
1391
+ grads = jax.pmap(jax.grad(per_device_loss))(x)
1392
+
1393
+ wp.synchronize()
1394
+
1395
+ assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
1396
+
1397
+
1398
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1399
+ def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
1400
+ import jax
1401
+ import jax.numpy as jp
1402
+
1403
+ from warp.jax_experimental.ffi import jax_kernel
1404
+
1405
+ jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
1406
+
1407
+ ndev = jax.local_device_count()
1408
+ per_device = ARRAY_SIZE // ndev
1409
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1410
+ b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1411
+ s = 2.0
1412
+
1413
+ def per_dev_loss(aa, bb):
1414
+ c, d = jax_mo(aa, bb, s)
1415
+ return jp.sum(c + d)
1416
+
1417
+ da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
1418
+
1419
+ wp.synchronize()
1420
+
1421
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1422
+ b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1423
+ ref_da = 2.0 * a_np + b_np * s
1424
+ ref_db = a_np * s
1425
+ assert_np_equal(np.asarray(da), ref_da)
1426
+ assert_np_equal(np.asarray(db), ref_db)
1427
+
1428
+
1429
+ class TestJax(unittest.TestCase):
1430
+ pass
1431
+
1432
+
1433
+ # try adding Jax tests if Jax is installed correctly
1434
+ try:
1435
+ # prevent Jax from gobbling up GPU memory
1436
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
1437
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
1438
+
1439
+ import jax
1440
+
1441
+ # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
1442
+ jax.config.update("jax_enable_x64", True)
1443
+
1444
+ # check which Warp devices work with Jax
1445
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
1446
+ test_devices = get_test_devices()
1447
+ jax_compatible_devices = []
1448
+ jax_compatible_cuda_devices = []
1449
+ for d in test_devices:
1450
+ try:
1451
+ with jax.default_device(wp.device_to_jax(d)):
1452
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
1453
+ j += 1
1454
+ jax_compatible_devices.append(d)
1455
+ if d.is_cuda:
1456
+ jax_compatible_cuda_devices.append(d)
1457
+ except Exception as e:
1458
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
1459
+
1460
+ add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
1461
+ add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
1462
+
1463
+ if jax_compatible_devices:
1464
+ add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
1465
+
1466
+ if jax_compatible_cuda_devices:
1467
+ # tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
1468
+ if jax.__version_info__ < (0, 4, 25):
1469
+ # no interop supported
1470
+ ffi_opts = []
1471
+ elif jax.__version_info__ < (0, 5, 0):
1472
+ # only custom_call supported
1473
+ ffi_opts = [False]
1474
+ elif jax.__version_info__ < (0, 8, 0):
1475
+ # both custom_call and ffi supported
1476
+ ffi_opts = [False, True]
1477
+ else:
1478
+ # only ffi supported
1479
+ ffi_opts = [True]
1480
+
1481
+ for use_ffi in ffi_opts:
1482
+ suffix = "ffi" if use_ffi else "cc"
1483
+ add_function_test(
1484
+ TestJax,
1485
+ f"test_jax_kernel_basic_{suffix}",
1486
+ test_jax_kernel_basic,
1487
+ devices=jax_compatible_cuda_devices,
1488
+ use_ffi=use_ffi,
1489
+ )
1490
+ add_function_test(
1491
+ TestJax,
1492
+ f"test_jax_kernel_scalar_{suffix}",
1493
+ test_jax_kernel_scalar,
1494
+ devices=jax_compatible_cuda_devices,
1495
+ use_ffi=use_ffi,
1496
+ )
1497
+ add_function_test(
1498
+ TestJax,
1499
+ f"test_jax_kernel_vecmat_{suffix}",
1500
+ test_jax_kernel_vecmat,
1501
+ devices=jax_compatible_cuda_devices,
1502
+ use_ffi=use_ffi,
1503
+ )
1504
+ add_function_test(
1505
+ TestJax,
1506
+ f"test_jax_kernel_multiarg_{suffix}",
1507
+ test_jax_kernel_multiarg,
1508
+ devices=jax_compatible_cuda_devices,
1509
+ use_ffi=use_ffi,
1510
+ )
1511
+ add_function_test(
1512
+ TestJax,
1513
+ f"test_jax_kernel_launch_dims_{suffix}",
1514
+ test_jax_kernel_launch_dims,
1515
+ devices=jax_compatible_cuda_devices,
1516
+ use_ffi=use_ffi,
1517
+ )
1518
+
1519
+ # ffi.jax_kernel() tests
1520
+ add_function_test(
1521
+ TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
1522
+ )
1523
+ add_function_test(
1524
+ TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
1525
+ )
1526
+ add_function_test(
1527
+ TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
1528
+ )
1529
+ add_function_test(
1530
+ TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
1531
+ )
1532
+ add_function_test(
1533
+ TestJax,
1534
+ "test_ffi_jax_kernel_scale_vec_constant",
1535
+ test_ffi_jax_kernel_scale_vec_constant,
1536
+ devices=jax_compatible_cuda_devices,
1537
+ )
1538
+ add_function_test(
1539
+ TestJax,
1540
+ "test_ffi_jax_kernel_scale_vec_static",
1541
+ test_ffi_jax_kernel_scale_vec_static,
1542
+ devices=jax_compatible_cuda_devices,
1543
+ )
1544
+ add_function_test(
1545
+ TestJax,
1546
+ "test_ffi_jax_kernel_launch_dims_default",
1547
+ test_ffi_jax_kernel_launch_dims_default,
1548
+ devices=jax_compatible_cuda_devices,
1549
+ )
1550
+ add_function_test(
1551
+ TestJax,
1552
+ "test_ffi_jax_kernel_launch_dims_custom",
1553
+ test_ffi_jax_kernel_launch_dims_custom,
1554
+ devices=jax_compatible_cuda_devices,
1555
+ )
1556
+
1557
+ # ffi.jax_callable() tests
1558
+ add_function_test(
1559
+ TestJax,
1560
+ "test_ffi_jax_callable_scale_constant",
1561
+ test_ffi_jax_callable_scale_constant,
1562
+ devices=jax_compatible_cuda_devices,
1563
+ )
1564
+ add_function_test(
1565
+ TestJax,
1566
+ "test_ffi_jax_callable_scale_static",
1567
+ test_ffi_jax_callable_scale_static,
1568
+ devices=jax_compatible_cuda_devices,
1569
+ )
1570
+ add_function_test(
1571
+ TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
1572
+ )
1573
+ add_function_test(
1574
+ TestJax,
1575
+ "test_ffi_jax_callable_graph_cache",
1576
+ test_ffi_jax_callable_graph_cache,
1577
+ devices=jax_compatible_cuda_devices,
1578
+ )
1579
+
1580
+ # pmap tests
1581
+ add_function_test(
1582
+ TestJax,
1583
+ "test_ffi_jax_callable_pmap_multi_output",
1584
+ test_ffi_jax_callable_pmap_multi_output,
1585
+ devices=None,
1586
+ )
1587
+ add_function_test(
1588
+ TestJax,
1589
+ "test_ffi_jax_callable_pmap_mul",
1590
+ test_ffi_jax_callable_pmap_mul,
1591
+ devices=None,
1592
+ )
1593
+ add_function_test(
1594
+ TestJax,
1595
+ "test_ffi_jax_callable_pmap_multi_stage",
1596
+ test_ffi_jax_callable_pmap_multi_stage,
1597
+ devices=None,
1598
+ )
1599
+
1600
+ # ffi callback tests
1601
+ add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
1602
+
1603
+ # autodiff tests
1604
+ add_function_test(
1605
+ TestJax,
1606
+ "test_ffi_jax_kernel_autodiff_simple",
1607
+ test_ffi_jax_kernel_autodiff_simple,
1608
+ devices=jax_compatible_cuda_devices,
1609
+ )
1610
+ add_function_test(
1611
+ TestJax,
1612
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
1613
+ test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
1614
+ devices=jax_compatible_cuda_devices,
1615
+ )
1616
+ add_function_test(
1617
+ TestJax,
1618
+ "test_ffi_jax_kernel_autodiff_multi_output",
1619
+ test_ffi_jax_kernel_autodiff_multi_output,
1620
+ devices=jax_compatible_cuda_devices,
1621
+ )
1622
+ add_function_test(
1623
+ TestJax,
1624
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
1625
+ test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
1626
+ devices=jax_compatible_cuda_devices,
1627
+ )
1628
+ add_function_test(
1629
+ TestJax,
1630
+ "test_ffi_jax_kernel_autodiff_2d",
1631
+ test_ffi_jax_kernel_autodiff_2d,
1632
+ devices=jax_compatible_cuda_devices,
1633
+ )
1634
+ add_function_test(
1635
+ TestJax,
1636
+ "test_ffi_jax_kernel_autodiff_vec2",
1637
+ test_ffi_jax_kernel_autodiff_vec2,
1638
+ devices=jax_compatible_cuda_devices,
1639
+ )
1640
+ add_function_test(
1641
+ TestJax,
1642
+ "test_ffi_jax_kernel_autodiff_mat22",
1643
+ test_ffi_jax_kernel_autodiff_mat22,
1644
+ devices=jax_compatible_cuda_devices,
1645
+ )
1646
+ add_function_test(
1647
+ TestJax,
1648
+ "test_ffi_jax_kernel_autodiff_static_required",
1649
+ test_ffi_jax_kernel_autodiff_static_required,
1650
+ devices=jax_compatible_cuda_devices,
1651
+ )
1652
+
1653
+ # autodiff with pmap tests
1654
+ add_function_test(
1655
+ TestJax,
1656
+ "test_ffi_jax_kernel_autodiff_pmap_triple",
1657
+ test_ffi_jax_kernel_autodiff_pmap_triple,
1658
+ devices=None,
1659
+ )
1660
+ add_function_test(
1661
+ TestJax,
1662
+ "test_ffi_jax_kernel_autodiff_pmap_multi_output",
1663
+ test_ffi_jax_kernel_autodiff_pmap_multi_output,
1664
+ devices=None,
1665
+ )
1666
+
1667
+ except Exception as e:
1668
+ print(f"Skipping Jax tests due to exception: {e}")
1669
+
1670
+
1671
+ if __name__ == "__main__":
1672
+ wp.clear_kernel_cache()
1673
+ unittest.main(verbosity=2)