warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1286 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections
17
+ import ctypes
18
+ import inspect
19
+ import threading
20
+ import traceback
21
+ from enum import IntEnum
22
+ from typing import Callable, Optional
23
+
24
+ import jax
25
+
26
+ import warp as wp
27
+ from warp._src.codegen import get_full_arg_spec, make_full_qualified_name
28
+ from warp._src.jax import get_jax_device
29
+ from warp._src.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
30
+
31
+ from .xla_ffi import *
32
+
33
+ _wp_module_name_ = "warp.jax_experimental.ffi"
34
+
35
+ # Type alias for differentiable kernel cache key
36
+ DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]]
37
+
38
+ # Holders for the custom callbacks to keep them alive.
39
+ _FFI_KERNEL_REGISTRY: dict[str, "FfiKernel"] = {}
40
+ _FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {}
41
+ _FFI_CALLABLE_REGISTRY: dict[str, "FfiCallable"] = {}
42
+ _FFI_CALLBACK_REGISTRY: dict[str, ctypes.CFUNCTYPE] = {}
43
+ _FFI_REGISTRY_LOCK = threading.Lock()
44
+
45
+ # Lock when XLA invokes callbacks from multiple threads.
46
+ _FFI_CALLBACK_LOCK = threading.Lock()
47
+
48
+
49
+ def check_jax_version():
50
+ # check if JAX version supports this
51
+ if jax.__version_info__ < (0, 5, 0):
52
+ msg = (
53
+ "This version of jax_kernel() requires JAX version 0.5.0 or higher, "
54
+ f"but installed JAX version is {jax.__version_info__}."
55
+ )
56
+ if jax.__version_info__ >= (0, 4, 25):
57
+ msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
58
+ raise RuntimeError(msg)
59
+
60
+
61
+ class GraphMode(IntEnum):
62
+ NONE = 0 # don't capture a graph
63
+ JAX = 1 # let JAX capture a graph
64
+ WARP = 2 # let Warp capture a graph
65
+
66
+
67
+ class ModulePreloadMode(IntEnum):
68
+ NONE = 0 # don't preload modules
69
+ CURRENT_DEVICE = 1 # preload on currently active device
70
+ ALL_DEVICES = 2 # preload on all supported devices
71
+
72
+
73
+ class FfiArg:
74
+ def __init__(self, name, type, in_out=False):
75
+ self.name = name
76
+ self.type = type
77
+ self.in_out = in_out
78
+ self.is_array = isinstance(type, wp.array)
79
+
80
+ if self.is_array:
81
+ if hasattr(type.dtype, "_wp_scalar_type_"):
82
+ self.dtype_shape = type.dtype._shape_
83
+ self.dtype_ndim = len(self.dtype_shape)
84
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
85
+ self.jax_ndim = type.ndim + self.dtype_ndim
86
+ elif type.dtype in wp._src.types.value_types:
87
+ self.dtype_ndim = 0
88
+ self.dtype_shape = ()
89
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
90
+ self.jax_ndim = type.ndim
91
+ else:
92
+ raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
93
+ self.warp_ndim = type.ndim
94
+ elif type in wp._src.types.value_types:
95
+ self.dtype_ndim = 0
96
+ self.dtype_shape = ()
97
+ self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
98
+ self.jax_ndim = 0
99
+ self.warp_ndim = 0
100
+ else:
101
+ raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
102
+
103
+
104
+ class FfiLaunchDesc:
105
+ def __init__(self, static_inputs, launch_dims):
106
+ self.static_inputs = static_inputs
107
+ self.launch_dims = launch_dims
108
+
109
+
110
+ class FfiKernel:
111
+ def __init__(
112
+ self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
113
+ ):
114
+ self.kernel = kernel
115
+ self.name = generate_unique_name(kernel.func)
116
+ self.num_outputs = num_outputs
117
+ self.vmap_method = vmap_method
118
+ self.launch_dims = launch_dims
119
+ self.output_dims = output_dims
120
+ self.module_preload_mode = module_preload_mode
121
+ self.first_array_arg = None
122
+ self.launch_id = 0
123
+ self.launch_descriptors = {}
124
+
125
+ in_out_argnames_list = in_out_argnames or []
126
+ in_out_argnames = set(in_out_argnames_list)
127
+ if len(in_out_argnames_list) != len(in_out_argnames):
128
+ raise AssertionError("in_out_argnames must not contain duplicate names")
129
+
130
+ self.num_kernel_args = len(kernel.adj.args)
131
+ self.num_in_out = len(in_out_argnames)
132
+ self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
133
+ if self.num_outputs < 1:
134
+ raise ValueError("At least one output is required")
135
+ if self.num_outputs > self.num_kernel_args:
136
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
137
+ if self.num_outputs < self.num_in_out:
138
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
139
+
140
+ # process input args
141
+ self.input_args = []
142
+ for i in range(self.num_inputs):
143
+ arg_name = kernel.adj.args[i].label
144
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
145
+ if arg_name in in_out_argnames:
146
+ in_out_argnames.remove(arg_name)
147
+ if arg.is_array:
148
+ # keep track of the first input array argument
149
+ if self.first_array_arg is None:
150
+ self.first_array_arg = i
151
+ self.input_args.append(arg)
152
+
153
+ # process output args
154
+ self.output_args = []
155
+ for i in range(self.num_inputs, self.num_kernel_args):
156
+ arg_name = kernel.adj.args[i].label
157
+ if arg_name in in_out_argnames:
158
+ raise AssertionError(
159
+ f"Expected an output-only argument for argument {arg_name}."
160
+ " in_out arguments should be placed before output-only arguments."
161
+ )
162
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
163
+ if not arg.is_array:
164
+ raise TypeError("All output arguments must be arrays")
165
+ self.output_args.append(arg)
166
+
167
+ if in_out_argnames:
168
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
169
+
170
+ # Build input output aliases.
171
+ out_id = 0
172
+ input_output_aliases = {}
173
+ for in_id, arg in enumerate(self.input_args):
174
+ if not arg.in_out:
175
+ continue
176
+ input_output_aliases[in_id] = out_id
177
+ out_id += 1
178
+ self.input_output_aliases = input_output_aliases
179
+
180
+ # register the callback
181
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
182
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
183
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
184
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
185
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
186
+
187
+ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
188
+ num_inputs = len(args)
189
+ if num_inputs != self.num_inputs:
190
+ raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
191
+
192
+ # default argument fallback
193
+ if launch_dims is None:
194
+ launch_dims = self.launch_dims
195
+ if output_dims is None:
196
+ output_dims = self.output_dims
197
+ if vmap_method is None:
198
+ vmap_method = self.vmap_method
199
+
200
+ # output types
201
+ out_types = []
202
+
203
+ # process inputs
204
+ static_inputs = {}
205
+ for i in range(num_inputs):
206
+ input_arg = self.input_args[i]
207
+ input_value = args[i]
208
+ if input_arg.is_array:
209
+ # check dtype
210
+ if input_value.dtype != input_arg.jax_scalar_type:
211
+ raise TypeError(
212
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
213
+ )
214
+ # check ndim
215
+ if input_value.ndim != input_arg.jax_ndim:
216
+ raise TypeError(
217
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
218
+ )
219
+ # check inner dims
220
+ for d in range(input_arg.dtype_ndim):
221
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
222
+ raise TypeError(
223
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
224
+ )
225
+ else:
226
+ # make sure scalar is not a traced variable, should be static
227
+ if isinstance(input_value, jax.core.Tracer):
228
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
229
+ # stash the value to be retrieved by callback
230
+ static_inputs[input_arg.name] = input_arg.type(input_value)
231
+
232
+ # append in-out arg to output types
233
+ if input_arg.in_out:
234
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
235
+
236
+ # launch dimensions
237
+ if launch_dims is None:
238
+ # use the shape of the first input array
239
+ if self.first_array_arg is not None:
240
+ launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
241
+ else:
242
+ raise RuntimeError("Failed to determine launch dimensions")
243
+ elif isinstance(launch_dims, int):
244
+ launch_dims = (launch_dims,)
245
+ else:
246
+ launch_dims = tuple(launch_dims)
247
+
248
+ # output shapes
249
+ if isinstance(output_dims, dict):
250
+ # assume a dictionary of shapes keyed on argument name
251
+ for output_arg in self.output_args:
252
+ dims = output_dims.get(output_arg.name)
253
+ if dims is None:
254
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
255
+ out_types.append(get_jax_output_type(output_arg, dims))
256
+ else:
257
+ if output_dims is None:
258
+ # use launch dimensions
259
+ output_dims = launch_dims
260
+ elif isinstance(output_dims, int):
261
+ output_dims = (output_dims,)
262
+ # assume same dimensions for all outputs
263
+ for output_arg in self.output_args:
264
+ out_types.append(get_jax_output_type(output_arg, output_dims))
265
+
266
+ call = jax.ffi.ffi_call(
267
+ self.name,
268
+ out_types,
269
+ vmap_method=vmap_method,
270
+ input_output_aliases=self.input_output_aliases,
271
+ )
272
+
273
+ # preload on the specified devices
274
+ if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
275
+ device = wp.device_from_jax(get_jax_device())
276
+ self.kernel.module.load(device)
277
+ elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
278
+ for d in jax.local_devices():
279
+ try:
280
+ dev = wp.device_from_jax(d)
281
+ except Exception:
282
+ # ignore unsupported devices like TPUs
283
+ pass
284
+ # we only support CUDA devices for now
285
+ if dev.is_cuda:
286
+ self.kernel.module.load(dev)
287
+
288
+ # save launch data to be retrieved by callback
289
+ launch_id = self.launch_id
290
+ self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
291
+ self.launch_id += 1
292
+
293
+ return call(*args, launch_id=launch_id)
294
+
295
+ def ffi_callback(self, call_frame):
296
+ try:
297
+ # On the first call, XLA runtime will query the API version and traits
298
+ # metadata using the |extension| field. Let us respond to that query
299
+ # if the metadata extension is present.
300
+ extension = call_frame.contents.extension_start
301
+ if extension:
302
+ # Try to set the version metadata.
303
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
304
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
305
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
306
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
307
+ # Turn on CUDA graphs for this handler.
308
+ metadata_ext.contents.metadata.contents.traits = (
309
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
310
+ )
311
+ return None
312
+
313
+ # Lock is required to prevent race conditions when callback is invoked
314
+ # from multiple threads, like with pmap.
315
+ with _FFI_CALLBACK_LOCK:
316
+ # retrieve call info
317
+ attrs = decode_attrs(call_frame.contents.attrs)
318
+ launch_id = int(attrs["launch_id"])
319
+ launch_desc = self.launch_descriptors[launch_id]
320
+
321
+ num_inputs = call_frame.contents.args.size
322
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
323
+
324
+ num_outputs = call_frame.contents.rets.size
325
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
326
+
327
+ assert num_inputs == self.num_inputs
328
+ assert num_outputs == self.num_outputs
329
+
330
+ launch_bounds = launch_bounds_t(launch_desc.launch_dims)
331
+
332
+ # first kernel param is the launch bounds
333
+ kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
334
+ kernel_params[0] = ctypes.addressof(launch_bounds)
335
+
336
+ arg_refs = []
337
+
338
+ # input and in-out args
339
+ for i, input_arg in enumerate(self.input_args):
340
+ if input_arg.is_array:
341
+ buffer = inputs[i].contents
342
+ shape = buffer.dims[: input_arg.type.ndim]
343
+ strides = strides_from_shape(shape, input_arg.type.dtype)
344
+ arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
345
+ kernel_params[i + 1] = ctypes.addressof(arg)
346
+ arg_refs.append(arg) # keep a reference
347
+ else:
348
+ # scalar argument, get stashed value
349
+ value = launch_desc.static_inputs[input_arg.name]
350
+ arg = input_arg.type._type_(value)
351
+ kernel_params[i + 1] = ctypes.addressof(arg)
352
+ arg_refs.append(arg) # keep a reference
353
+
354
+ # pure output args (skip in-out FFI buffers)
355
+ for i, output_arg in enumerate(self.output_args):
356
+ buffer = outputs[i + self.num_in_out].contents
357
+ shape = buffer.dims[: output_arg.type.ndim]
358
+ strides = strides_from_shape(shape, output_arg.type.dtype)
359
+ arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
360
+ kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
361
+ arg_refs.append(arg) # keep a reference
362
+
363
+ # get device and stream
364
+ device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
365
+ stream = get_stream_from_callframe(call_frame.contents)
366
+
367
+ # get kernel hooks
368
+ hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
369
+ assert hooks.forward, "Failed to find kernel entry point"
370
+
371
+ # launch the kernel
372
+ wp._src.context.runtime.core.wp_cuda_launch_kernel(
373
+ device.context,
374
+ hooks.forward,
375
+ launch_bounds.size,
376
+ 0,
377
+ 256,
378
+ hooks.forward_smem_bytes,
379
+ kernel_params,
380
+ stream,
381
+ )
382
+
383
+ except Exception as e:
384
+ print(traceback.format_exc())
385
+ return create_ffi_error(
386
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
387
+ )
388
+
389
+
390
+ class FfiCallDesc:
391
+ def __init__(self, static_inputs):
392
+ self.static_inputs = static_inputs
393
+
394
+
395
+ class FfiCallable:
396
+ default_graph_cache_max: int | None = 32
397
+
398
+ def __init__(
399
+ self,
400
+ func,
401
+ num_outputs,
402
+ graph_mode,
403
+ vmap_method,
404
+ output_dims,
405
+ in_out_argnames,
406
+ graph_cache_max,
407
+ module_preload_mode,
408
+ ):
409
+ self.func = func
410
+ self.name = generate_unique_name(func)
411
+ self.num_outputs = num_outputs
412
+ self.vmap_method = vmap_method
413
+ self.graph_mode = graph_mode
414
+ self.output_dims = output_dims
415
+ self.module_preload_mode = module_preload_mode
416
+ self.first_array_arg = None
417
+ self.call_id = 0
418
+ self.call_descriptors = {}
419
+
420
+ # LRU cache of graphs captured by Warp
421
+ self._graph_cache_max = graph_cache_max
422
+ self.captures = collections.OrderedDict()
423
+
424
+ in_out_argnames_list = in_out_argnames or []
425
+ in_out_argnames = set(in_out_argnames_list)
426
+ if len(in_out_argnames_list) != len(in_out_argnames):
427
+ raise AssertionError("in_out_argnames must not contain duplicate names")
428
+
429
+ # get arguments and annotations
430
+ argspec = get_full_arg_spec(func)
431
+
432
+ num_args = len(argspec.args)
433
+ self.num_in_out = len(in_out_argnames)
434
+ self.num_inputs = num_args - num_outputs + self.num_in_out
435
+ if self.num_outputs < 1:
436
+ raise ValueError("At least one output is required")
437
+ if self.num_outputs > num_args:
438
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
439
+ if self.num_outputs < self.num_in_out:
440
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
441
+
442
+ if len(argspec.annotations) < num_args:
443
+ raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
444
+
445
+ # parse type annotations
446
+ self.args = []
447
+ arg_idx = 0
448
+ for arg_name, arg_type in argspec.annotations.items():
449
+ if arg_name == "return":
450
+ if arg_type is not None:
451
+ raise TypeError("Function must not return a value")
452
+ continue
453
+ else:
454
+ arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
455
+ if arg_name in in_out_argnames:
456
+ in_out_argnames.remove(arg_name)
457
+ if arg.is_array:
458
+ if arg_idx < self.num_inputs and self.first_array_arg is None:
459
+ self.first_array_arg = arg_idx
460
+ self.args.append(arg)
461
+
462
+ if arg.in_out and arg_idx >= self.num_inputs:
463
+ raise AssertionError(
464
+ f"Expected an output-only argument for argument {arg_name}."
465
+ " in_out arguments should be placed before output-only arguments."
466
+ )
467
+
468
+ arg_idx += 1
469
+
470
+ if in_out_argnames:
471
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
472
+
473
+ self.input_args = self.args[: self.num_inputs] # includes in-out args
474
+ self.output_args = self.args[self.num_inputs :] # pure output args
475
+
476
+ # Buffer indices for array arguments in callback.
477
+ # In-out buffers are the same pointers in the XLA call frame,
478
+ # so we only include them for inputs and skip them for outputs.
479
+ self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
480
+ self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
481
+
482
+ # Build input output aliases.
483
+ out_id = 0
484
+ input_output_aliases = {}
485
+ for in_id, arg in enumerate(self.input_args):
486
+ if not arg.in_out:
487
+ continue
488
+ input_output_aliases[in_id] = out_id
489
+ out_id += 1
490
+ self.input_output_aliases = input_output_aliases
491
+
492
+ # register the callback
493
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
494
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
495
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
496
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
497
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
498
+
499
+ def __call__(self, *args, output_dims=None, vmap_method=None):
500
+ num_inputs = len(args)
501
+ if num_inputs != self.num_inputs:
502
+ input_names = ", ".join(arg.name for arg in self.input_args)
503
+ s = "" if self.num_inputs == 1 else "s"
504
+ raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
505
+
506
+ # default argument fallback
507
+ if vmap_method is None:
508
+ vmap_method = self.vmap_method
509
+ if output_dims is None:
510
+ output_dims = self.output_dims
511
+
512
+ # output types
513
+ out_types = []
514
+
515
+ # process inputs
516
+ static_inputs = {}
517
+ for i in range(num_inputs):
518
+ input_arg = self.input_args[i]
519
+ input_value = args[i]
520
+ if input_arg.is_array:
521
+ # check dtype
522
+ if input_value.dtype != input_arg.jax_scalar_type:
523
+ raise TypeError(
524
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
525
+ )
526
+ # check ndim
527
+ if input_value.ndim != input_arg.jax_ndim:
528
+ raise TypeError(
529
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
530
+ )
531
+ # check inner dims
532
+ for d in range(input_arg.dtype_ndim):
533
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
534
+ raise TypeError(
535
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
536
+ )
537
+ else:
538
+ # make sure scalar is not a traced variable, should be static
539
+ if isinstance(input_value, jax.core.Tracer):
540
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
541
+ # stash the value to be retrieved by callback
542
+ static_inputs[input_arg.name] = input_arg.type(input_value)
543
+
544
+ # append in-out arg to output types
545
+ if input_arg.in_out:
546
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
547
+
548
+ # output shapes
549
+ if isinstance(output_dims, dict):
550
+ # assume a dictionary of shapes keyed on argument name
551
+ for output_arg in self.output_args:
552
+ dims = output_dims.get(output_arg.name)
553
+ if dims is None:
554
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
555
+ out_types.append(get_jax_output_type(output_arg, dims))
556
+ else:
557
+ if output_dims is None:
558
+ if self.first_array_arg is None:
559
+ raise ValueError("Unable to determine output dimensions")
560
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
561
+ elif isinstance(output_dims, int):
562
+ output_dims = (output_dims,)
563
+ # assume same dimensions for all outputs
564
+ for output_arg in self.output_args:
565
+ out_types.append(get_jax_output_type(output_arg, output_dims))
566
+
567
+ call = jax.ffi.ffi_call(
568
+ self.name,
569
+ out_types,
570
+ vmap_method=vmap_method,
571
+ input_output_aliases=self.input_output_aliases,
572
+ # has_side_effect=True, # force this function to execute even if outputs aren't used
573
+ )
574
+
575
+ # preload on the specified devices
576
+ # NOTE: if the target function uses kernels from different modules, they will not be loaded here
577
+ module = wp.get_module(self.func.__module__)
578
+ if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
579
+ device = wp.device_from_jax(get_jax_device())
580
+ module.load(device)
581
+ elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
582
+ for d in jax.local_devices():
583
+ try:
584
+ dev = wp.device_from_jax(d)
585
+ except Exception:
586
+ # ignore unsupported devices like TPUs
587
+ pass
588
+ # we only support CUDA devices for now
589
+ if dev.is_cuda:
590
+ module.load(dev)
591
+
592
+ # save call data to be retrieved by callback
593
+ call_id = self.call_id
594
+ self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
595
+ self.call_id += 1
596
+ return call(*args, call_id=call_id)
597
+
598
+ def ffi_callback(self, call_frame):
599
+ try:
600
+ # On the first call, XLA runtime will query the API version and traits
601
+ # metadata using the |extension| field. Let us respond to that query
602
+ # if the metadata extension is present.
603
+ extension = call_frame.contents.extension_start
604
+ if extension:
605
+ # Try to set the version metadata.
606
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
607
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
608
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
609
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
610
+ # Turn on CUDA graphs for this handler.
611
+ if self.graph_mode is GraphMode.JAX:
612
+ metadata_ext.contents.metadata.contents.traits = (
613
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
614
+ )
615
+ return None
616
+
617
+ # Lock is required to prevent race conditions when callback is invoked
618
+ # from multiple threads, like with pmap.
619
+ with _FFI_CALLBACK_LOCK:
620
+ # retrieve call info
621
+ # NOTE: this assumes that there's only one attribute - call_id (int64).
622
+ # A more general but slower approach is this:
623
+ # attrs = decode_attrs(call_frame.contents.attrs)
624
+ # call_id = int(attrs["call_id"])
625
+ attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
626
+ call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
627
+ call_desc = self.call_descriptors[call_id]
628
+
629
+ num_inputs = call_frame.contents.args.size
630
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
631
+
632
+ num_outputs = call_frame.contents.rets.size
633
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
634
+
635
+ assert num_inputs == self.num_inputs
636
+ assert num_outputs == self.num_outputs
637
+
638
+ cuda_stream = get_stream_from_callframe(call_frame.contents)
639
+
640
+ if self.graph_mode == GraphMode.WARP:
641
+ # check if we already captured an identical call
642
+ ip = [inputs[i].contents.data for i in self.array_input_indices]
643
+ op = [outputs[i].contents.data for i in self.array_output_indices]
644
+ capture_key = hash((call_id, *ip, *op))
645
+ capture = self.captures.get(capture_key)
646
+
647
+ # launch existing graph
648
+ if capture is not None:
649
+ # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
650
+ # This code should match wp.capture_launch().
651
+ graph = capture.graph
652
+ if graph.graph_exec is None:
653
+ g = ctypes.c_void_p()
654
+ if not wp._src.context.runtime.core.wp_cuda_graph_create_exec(
655
+ graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
656
+ ):
657
+ raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
658
+ graph.graph_exec = g
659
+
660
+ if not wp._src.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
661
+ raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
662
+
663
+ # update the graph cache to keep recently used graphs alive
664
+ self.captures.move_to_end(capture_key)
665
+
666
+ # early out
667
+ return
668
+
669
+ device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
670
+ device = wp.get_cuda_device(device_ordinal)
671
+ stream = wp.Stream(device, cuda_stream=cuda_stream)
672
+
673
+ # reconstruct the argument list
674
+ arg_list = []
675
+
676
+ # input and in-out args
677
+ for i, arg in enumerate(self.input_args):
678
+ if arg.is_array:
679
+ buffer = inputs[i].contents
680
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
681
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
682
+ arg_list.append(arr)
683
+ else:
684
+ # scalar argument, get stashed value
685
+ value = call_desc.static_inputs[arg.name]
686
+ arg_list.append(value)
687
+
688
+ # pure output args (skip in-out FFI buffers)
689
+ for i, arg in enumerate(self.output_args):
690
+ buffer = outputs[i + self.num_in_out].contents
691
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
692
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
693
+ arg_list.append(arr)
694
+
695
+ # call the Python function with reconstructed arguments
696
+ with wp.ScopedStream(stream, sync_enter=False):
697
+ if stream.is_capturing:
698
+ # capturing with JAX
699
+ with wp.ScopedCapture(external=True) as capture:
700
+ self.func(*arg_list)
701
+ # keep a reference to the capture object to prevent required modules getting unloaded
702
+ call_desc.capture = capture
703
+ elif self.graph_mode == GraphMode.WARP:
704
+ # capturing with WARP
705
+ with wp.ScopedCapture() as capture:
706
+ self.func(*arg_list)
707
+ wp.capture_launch(capture.graph)
708
+ # keep a reference to the capture object and reuse it with same buffers
709
+ self.captures[capture_key] = capture
710
+ # respect the cache size limit if specified
711
+ if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max:
712
+ self.captures.popitem(last=False)
713
+ else:
714
+ # not capturing
715
+ self.func(*arg_list)
716
+
717
+ except Exception as e:
718
+ print(traceback.format_exc())
719
+ return create_ffi_error(
720
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
721
+ )
722
+
723
+ return None
724
+
725
+ @property
726
+ def graph_cache_max(self) -> int | None:
727
+ return self._graph_cache_max
728
+
729
+ @graph_cache_max.setter
730
+ def graph_cache_max(self, value: int | None):
731
+ if value != self._graph_cache_max:
732
+ if value is not None and (self._graph_cache_max is None or value < self._graph_cache_max):
733
+ # trim the cache if needed
734
+ while len(self.captures) > value:
735
+ self.captures.popitem(last=False)
736
+ self._graph_cache_max = value
737
+
738
+ @property
739
+ def graph_cache_size(self) -> int:
740
+ return len(self.captures)
741
+
742
+
743
+ def jax_kernel(
744
+ kernel,
745
+ num_outputs=1,
746
+ vmap_method="broadcast_all",
747
+ launch_dims=None,
748
+ output_dims=None,
749
+ in_out_argnames=None,
750
+ module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
751
+ enable_backward: bool = False,
752
+ ):
753
+ """Create a JAX callback from a Warp kernel.
754
+
755
+ NOTE: This is an experimental feature under development.
756
+
757
+ Args:
758
+ kernel: The Warp kernel to launch.
759
+ num_outputs: Specify the number of output arguments if greater than 1.
760
+ This must include the number of ``in_out_arguments``.
761
+ vmap_method: String specifying how the callback transforms under ``vmap()``.
762
+ This argument can also be specified for individual calls.
763
+ launch_dims: Specify the default kernel launch dimensions. If None, launch
764
+ dimensions are inferred from the shape of the first array argument.
765
+ This argument can also be specified for individual calls.
766
+ output_dims: Specify the default dimensions of output arrays. If None, output
767
+ dimensions are inferred from the launch dimensions.
768
+ This argument can also be specified for individual calls.
769
+ in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
770
+ These must be array arguments that appear before any pure output arguments in the
771
+ kernel signature. The number of in-out arguments is included in ``num_outputs``.
772
+ Not supported when ``enable_backward=True``.
773
+ module_preload_mode: Specify the devices where the module should be preloaded.
774
+ enable_backward: Enable automatic differentiation for this kernel.
775
+
776
+ Limitations:
777
+ - All kernel arguments must be contiguous arrays or scalars.
778
+ - Scalars must be static arguments in JAX.
779
+ - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
780
+ - There must be at least one output or input-output argument.
781
+ - Only the CUDA backend is supported.
782
+ """
783
+
784
+ check_jax_version()
785
+
786
+ if not enable_backward:
787
+ key = (
788
+ kernel.func,
789
+ kernel.sig,
790
+ num_outputs,
791
+ vmap_method,
792
+ tuple(launch_dims) if launch_dims else launch_dims,
793
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
794
+ module_preload_mode,
795
+ )
796
+
797
+ with _FFI_REGISTRY_LOCK:
798
+ if key not in _FFI_KERNEL_REGISTRY:
799
+ new_kernel = FfiKernel(
800
+ kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
801
+ )
802
+ _FFI_KERNEL_REGISTRY[key] = new_kernel
803
+
804
+ return _FFI_KERNEL_REGISTRY[key]
805
+
806
+ # make sure the arguments are compatible with autodiff
807
+ if in_out_argnames:
808
+ raise NotImplementedError(
809
+ "jax_kernel(): Input-output arguments (in_out_argnames) are not supported when enable_backward=True."
810
+ )
811
+
812
+ # TODO: we should support passing these to the forward and backward callables
813
+ if launch_dims is not None or output_dims is not None:
814
+ raise NotImplementedError(
815
+ "jax_kernel(): Custom dimensions (launch_dims, output_dims) are not supported when enable_backward=True."
816
+ )
817
+
818
+ # Differentiable path: build a custom VJP wrapper inline.
819
+ # Infer the original kernel signature (names and annotations)
820
+ signature = inspect.signature(kernel.func)
821
+
822
+ parameters = [p for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
823
+ parameter_count = len(parameters)
824
+ num_inputs = parameter_count - num_outputs
825
+
826
+ # determine static argument indices
827
+ static_args = []
828
+ for i, p in enumerate(parameters[:num_inputs]):
829
+ param_type = p.annotation
830
+ if not isinstance(param_type, wp.array):
831
+ if param_type in wp._src.types.value_types:
832
+ static_args.append(i)
833
+ else:
834
+ raise TypeError(f"Invalid type for argument '{p.name}', expected array or scalar, got {type}")
835
+
836
+ def _resolve_launch_dims(call_args):
837
+ # determine launch dimensions from the shape of the first input array
838
+ for i, p in enumerate(parameters[:num_inputs]):
839
+ param_type = p.annotation
840
+ if isinstance(param_type, wp.array):
841
+ arg = call_args[i]
842
+ arg_shape = tuple(arg.shape)
843
+ if hasattr(param_type.dtype, "_wp_scalar_type_"):
844
+ # vector/matrix array, trim trailing dimensions of JAX input array
845
+ return arg_shape[: param_type.ndim]
846
+ else:
847
+ # scalar array
848
+ return arg_shape
849
+ raise RuntimeError("Unable to determine launch dimensions, at least one input array is required")
850
+
851
+ # Forward kernel wrapper: simply launches the kernel
852
+ def fwd_kernel_wrapper(*args):
853
+ wp.launch(kernel, dim=_resolve_launch_dims(args), inputs=args[:num_inputs], outputs=args[num_inputs:])
854
+
855
+ # update forward signature and annotations so jax_callable() sees a fully annotated function
856
+ fwd_kernel_wrapper.__signature__ = signature
857
+ fwd_kernel_wrapper.__annotations__ = {p.name: p.annotation for p in parameters}
858
+ fwd_kernel_wrapper.__annotations__["return"] = None
859
+
860
+ jax_fwd_kernel = jax_callable(fwd_kernel_wrapper, num_outputs=num_outputs, vmap_method=vmap_method)
861
+
862
+ # backward arguments only include static args once
863
+ bwd_arg_count = 2 * parameter_count - len(static_args)
864
+
865
+ # Backward wrapper: launches adjoint with provided output gradients
866
+ def bwd_kernel_wrapper(*args):
867
+ if len(args) != bwd_arg_count:
868
+ raise RuntimeError(f"Invalid backward argument count, expected {bwd_arg_count} but got {len(args)}")
869
+
870
+ inputs = list(args[:num_inputs])
871
+ outputs = list(args[num_inputs:parameter_count])
872
+ grad_out = list(args[parameter_count : parameter_count + num_outputs])
873
+ grad_in = list(args[parameter_count + num_outputs :])
874
+
875
+ for i in static_args:
876
+ grad_in.insert(i, inputs[i])
877
+
878
+ for gi in grad_in:
879
+ if isinstance(gi, wp.array):
880
+ try:
881
+ gi.zero_()
882
+ except Exception as e:
883
+ wp.utils.warn(f"Failed to zero gradient array: {e}", stacklevel=2)
884
+ raise e
885
+
886
+ # NOTE: We cannot use a passed launch_dims here, the backward rule doesn't receive it (and it could be wrong under pmap/vmap).
887
+ # We need to infer from the inputs.
888
+ wp.launch(
889
+ kernel,
890
+ dim=_resolve_launch_dims(inputs),
891
+ inputs=inputs,
892
+ outputs=outputs,
893
+ adj_inputs=grad_in,
894
+ adj_outputs=grad_out,
895
+ adjoint=True,
896
+ )
897
+
898
+ # Build the backward wrapper signature expected by jax_callable
899
+ bwd_input_params = parameters[:num_inputs]
900
+ bwd_output_params = parameters[num_inputs:parameter_count]
901
+ bwd_grad_output_params = [
902
+ inspect.Parameter(
903
+ f"adj_{p.name}",
904
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
905
+ default=p.default,
906
+ annotation=p.annotation,
907
+ )
908
+ for p in bwd_output_params
909
+ ]
910
+
911
+ bwd_grad_input_params = [
912
+ inspect.Parameter(
913
+ f"adj_{p.name}",
914
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
915
+ default=p.default,
916
+ annotation=p.annotation,
917
+ )
918
+ for p in bwd_input_params
919
+ ]
920
+ for i in reversed(static_args):
921
+ del bwd_grad_input_params[i]
922
+
923
+ # update backward signature and annotations so jax_callable() sees a fully annotated function
924
+ bwd_signature = bwd_input_params + bwd_output_params + bwd_grad_output_params + bwd_grad_input_params
925
+ bwd_kernel_wrapper.__signature__ = inspect.Signature(bwd_signature)
926
+ bwd_annotations = {}
927
+ for p in bwd_input_params:
928
+ bwd_annotations[p.name] = p.annotation
929
+ for p in bwd_output_params:
930
+ bwd_annotations[p.name] = p.annotation
931
+ for p in bwd_grad_output_params:
932
+ bwd_annotations[p.name] = p.annotation
933
+ for p in bwd_grad_input_params:
934
+ bwd_annotations[p.name] = p.annotation
935
+ bwd_annotations["return"] = None
936
+ bwd_kernel_wrapper.__annotations__ = bwd_annotations
937
+
938
+ jax_bwd_kernel = jax_callable(
939
+ bwd_kernel_wrapper,
940
+ num_outputs=len(bwd_input_params) - len(static_args),
941
+ vmap_method=vmap_method,
942
+ )
943
+
944
+ differentiable_input_indices = [i for i in range(num_inputs) if i not in static_args]
945
+ differentiable_input_names = [parameters[i].name for i in differentiable_input_indices]
946
+
947
+ def fwd_function(*args):
948
+ outputs = jax_fwd_kernel(*args)
949
+ non_static_inputs = list(args)
950
+ for i in reversed(static_args):
951
+ del non_static_inputs[i]
952
+ # Normalize to tuple for consistent handling
953
+ if num_outputs == 1:
954
+ outputs_tuple = (outputs,) if not isinstance(outputs, (list, tuple)) else (outputs[0],)
955
+ else:
956
+ outputs_tuple = outputs if isinstance(outputs, tuple) else tuple(outputs)
957
+ return outputs, (tuple(non_static_inputs), outputs_tuple)
958
+
959
+ def bwd_function(*bwd_args):
960
+ nondiff_vals = list(bwd_args[: len(static_args)])
961
+ residuals = bwd_args[len(static_args)]
962
+ grad_out_args = bwd_args[len(static_args) + 1 :]
963
+
964
+ non_static_inputs, output_vals_tuple = residuals
965
+
966
+ input_vals = list(non_static_inputs)
967
+ for i, v in zip(static_args, nondiff_vals):
968
+ input_vals.insert(i, v)
969
+
970
+ # Normalize grad outputs and handle nested containers (e.g., single tuple for multi-output)
971
+ if num_outputs == 1:
972
+ go = grad_out_args[0]
973
+ grad_out_tuple = tuple(go) if isinstance(go, (list, tuple)) else (go,)
974
+ else:
975
+ if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)):
976
+ grad_out_tuple = tuple(grad_out_args[0])
977
+ else:
978
+ grad_out_tuple = tuple(grad_out_args)
979
+ bwd_call_args = list(input_vals) + list(output_vals_tuple) + list(grad_out_tuple)
980
+
981
+ out_dims_map = {}
982
+ param_ann = {p.name: p.annotation for p in parameters[:num_inputs]}
983
+ for name, val in zip(differentiable_input_names, non_static_inputs):
984
+ ann = param_ann.get(name)
985
+ if ann is None:
986
+ continue
987
+ # Check if annotation is a warp array type (annotation is an instance of wp.array)
988
+ is_array_ann = isinstance(ann, wp.array)
989
+ if not is_array_ann:
990
+ continue
991
+ dtype_ndim = 0
992
+ # Extract dtype ndim if it's a vector/matrix type
993
+ if hasattr(ann, "dtype") and hasattr(ann.dtype, "_wp_scalar_type_"):
994
+ dtype_ndim = len(ann.dtype._shape_)
995
+ warp_ndim = getattr(ann, "ndim", 0)
996
+ vshape = tuple(val.shape)
997
+ if warp_ndim == 0:
998
+ continue
999
+ if dtype_ndim > 0:
1000
+ core_rank = max(0, len(vshape) - dtype_ndim)
1001
+ warp_dims = vshape[max(0, core_rank - warp_ndim) : core_rank]
1002
+ else:
1003
+ warp_dims = vshape[-warp_ndim:]
1004
+ out_dims_map[f"adj_{name}"] = tuple(warp_dims)
1005
+
1006
+ non_static_input_grads = jax_bwd_kernel(*bwd_call_args, output_dims=out_dims_map)
1007
+ return tuple(non_static_input_grads)
1008
+
1009
+ jax_func = jax.custom_vjp(jax_fwd_kernel, nondiff_argnums=tuple(static_args))
1010
+ jax_func.defvjp(fwd_function, bwd_function)
1011
+
1012
+ if static_args:
1013
+ static_names = [parameters[i].name for i in static_args]
1014
+
1015
+ def _user_callable(*args):
1016
+ return jax_func(*args)
1017
+
1018
+ _user_callable.__signature__ = signature
1019
+
1020
+ # Cache differentiable wrapper
1021
+ key = (kernel.func, kernel.sig, num_outputs, vmap_method, tuple(sorted(static_names)))
1022
+ with _FFI_REGISTRY_LOCK:
1023
+ cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
1024
+ if cached is None:
1025
+ cached = jax.jit(_user_callable, static_argnames=tuple(static_names))
1026
+ _FFI_DIFF_KERNEL_REGISTRY[key] = cached
1027
+ return _FFI_DIFF_KERNEL_REGISTRY[key]
1028
+
1029
+ # Cache differentiable wrapper (no static args)
1030
+ key = (kernel.func, kernel.sig, num_outputs, vmap_method, ())
1031
+ with _FFI_REGISTRY_LOCK:
1032
+ cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
1033
+ if cached is None:
1034
+ _FFI_DIFF_KERNEL_REGISTRY[key] = jax_func
1035
+ cached = jax_func
1036
+ return cached
1037
+
1038
+
1039
+ def jax_callable(
1040
+ func: Callable,
1041
+ num_outputs: int = 1,
1042
+ graph_compatible: Optional[bool] = None, # deprecated
1043
+ graph_mode: GraphMode = GraphMode.JAX,
1044
+ vmap_method: Optional[str] = "broadcast_all",
1045
+ output_dims=None,
1046
+ in_out_argnames=None,
1047
+ graph_cache_max: int | None = None,
1048
+ module_preload_mode: ModulePreloadMode = ModulePreloadMode.CURRENT_DEVICE,
1049
+ ):
1050
+ """Create a JAX callback from an annotated Python function.
1051
+
1052
+ The Python function arguments must have type annotations like Warp kernels.
1053
+
1054
+ NOTE: This is an experimental feature under development.
1055
+
1056
+ Args:
1057
+ func: The Python function to call.
1058
+ num_outputs: Specify the number of output arguments if greater than 1.
1059
+ This must include the number of ``in_out_arguments``.
1060
+ graph_compatible: Whether the function can be called during CUDA graph capture.
1061
+ This argument is deprecated, use ``graph_mode`` instead.
1062
+ graph_mode: CUDA graph capture mode.
1063
+ ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing JAX capture.
1064
+ ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subgraph,
1065
+ such as when the callable uses conditional graph nodes.
1066
+ ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
1067
+ such as host synchronization.
1068
+ vmap_method: String specifying how the callback transforms under ``vmap()``.
1069
+ This argument can also be specified for individual calls.
1070
+ output_dims: Specify the default dimensions of output arrays.
1071
+ If ``None``, output dimensions are inferred from the launch dimensions.
1072
+ This argument can also be specified for individual calls.
1073
+ in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
1074
+ These must be array arguments that appear before any pure output arguments in the
1075
+ function signature. The number of in-out arguments is included in ``num_outputs``.
1076
+ graph_cache_max: Maximum number of cached graphs captured using ``GraphMode.WARP``.
1077
+ If ``None``, use ``warp.jax_experimental.get_jax_callable_default_graph_cache_max()``.
1078
+ module_preload_mode: Specify the devices where the module should be preloaded.
1079
+
1080
+ Limitations:
1081
+ - All kernel arguments must be contiguous arrays or scalars.
1082
+ - Scalars must be static arguments in JAX.
1083
+ - Input and input-output arguments must precede the output arguments in the ``func`` definition.
1084
+ - There must be at least one output or input-output argument.
1085
+ - Only the CUDA backend is supported.
1086
+ """
1087
+
1088
+ check_jax_version()
1089
+
1090
+ if graph_compatible is not None:
1091
+ wp._src.utils.warn(
1092
+ "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
1093
+ DeprecationWarning,
1094
+ stacklevel=3,
1095
+ )
1096
+ if graph_compatible is False:
1097
+ graph_mode = GraphMode.NONE
1098
+
1099
+ if graph_cache_max is None:
1100
+ graph_cache_max = FfiCallable.default_graph_cache_max
1101
+
1102
+ # Note: we don't include graph_cache_max in the key, it is applied below.
1103
+ key = (
1104
+ func,
1105
+ num_outputs,
1106
+ graph_mode,
1107
+ vmap_method,
1108
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
1109
+ module_preload_mode,
1110
+ )
1111
+
1112
+ with _FFI_REGISTRY_LOCK:
1113
+ callable = _FFI_CALLABLE_REGISTRY.get(key)
1114
+ if callable is None:
1115
+ callable = FfiCallable(
1116
+ func,
1117
+ num_outputs,
1118
+ graph_mode,
1119
+ vmap_method,
1120
+ output_dims,
1121
+ in_out_argnames,
1122
+ graph_cache_max,
1123
+ module_preload_mode,
1124
+ )
1125
+ _FFI_CALLABLE_REGISTRY[key] = callable
1126
+ else:
1127
+ # make sure we're using the latest graph cache max
1128
+ callable.graph_cache_max = graph_cache_max
1129
+
1130
+ return callable
1131
+
1132
+
1133
+ def get_jax_callable_default_graph_cache_max():
1134
+ """
1135
+ Get the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
1136
+ """
1137
+ return FfiCallable.default_graph_cache_max
1138
+
1139
+
1140
+ def set_jax_callable_default_graph_cache_max(cache_max: int | None):
1141
+ """
1142
+ Set the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
1143
+ """
1144
+ FfiCallable.default_graph_cache_max = cache_max
1145
+
1146
+
1147
+ def clear_jax_callable_graph_cache(callable: FfiCallable | None = None):
1148
+ """Clear the graph cache of the given callable or all callables if ``None``."""
1149
+
1150
+ if callable is not None:
1151
+ callable.captures.clear()
1152
+ else:
1153
+ # apply to all callables
1154
+ with _FFI_REGISTRY_LOCK:
1155
+ for callable in _FFI_CALLABLE_REGISTRY.values():
1156
+ callable.captures.clear()
1157
+
1158
+
1159
+ ###############################################################################
1160
+ #
1161
+ # Generic FFI callbacks for Python functions of the form
1162
+ # func(inputs, outputs, attrs, ctx)
1163
+ #
1164
+ ###############################################################################
1165
+
1166
+
1167
+ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
1168
+ """Create a JAX callback from a Python function.
1169
+
1170
+ The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
1171
+
1172
+ NOTE: This is an experimental feature under development.
1173
+
1174
+ Args:
1175
+ name: A unique FFI callback name.
1176
+ func: The Python function to call.
1177
+ graph_compatible: Whether the function can be called during CUDA graph capture.
1178
+ """
1179
+
1180
+ check_jax_version()
1181
+
1182
+ # TODO check that the name is not already registered
1183
+
1184
+ def ffi_callback(call_frame):
1185
+ try:
1186
+ extension = call_frame.contents.extension_start
1187
+ # On the first call, XLA runtime will query the API version and traits
1188
+ # metadata using the |extension| field. Let us respond to that query
1189
+ # if the metadata extension is present.
1190
+ if extension:
1191
+ # Try to set the version metadata.
1192
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
1193
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
1194
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
1195
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
1196
+ if graph_compatible:
1197
+ # Turn on CUDA graphs for this handler.
1198
+ metadata_ext.contents.metadata.contents.traits = (
1199
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
1200
+ )
1201
+ return None
1202
+
1203
+ # Lock is required to prevent race conditions when callback is invoked
1204
+ # from multiple threads, like with pmap.
1205
+ with _FFI_CALLBACK_LOCK:
1206
+ attrs = decode_attrs(call_frame.contents.attrs)
1207
+
1208
+ input_count = call_frame.contents.args.size
1209
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
1210
+ inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
1211
+
1212
+ output_count = call_frame.contents.rets.size
1213
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
1214
+ outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
1215
+
1216
+ ctx = ExecutionContext(call_frame.contents)
1217
+
1218
+ func(inputs, outputs, attrs, ctx)
1219
+
1220
+ except Exception as e:
1221
+ print(traceback.format_exc())
1222
+ return create_ffi_error(
1223
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
1224
+ )
1225
+
1226
+ return None
1227
+
1228
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
1229
+ callback_func = FFI_CCALLFUNC(ffi_callback)
1230
+ with _FFI_REGISTRY_LOCK:
1231
+ _FFI_CALLBACK_REGISTRY[name] = callback_func
1232
+ ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
1233
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
1234
+ jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
1235
+
1236
+
1237
+ ###############################################################################
1238
+ #
1239
+ # Utilities
1240
+ #
1241
+ ###############################################################################
1242
+
1243
+ # ensure unique FFI callback names
1244
+ ffi_name_counts = {}
1245
+
1246
+
1247
+ def generate_unique_name(func) -> str:
1248
+ key = make_full_qualified_name(func)
1249
+ unique_id = ffi_name_counts.get(key, 0)
1250
+ ffi_name_counts[key] = unique_id + 1
1251
+ return f"{key}_{unique_id}"
1252
+
1253
+
1254
+ def get_warp_shape(arg, dims):
1255
+ if arg.dtype_ndim > 0:
1256
+ # vector/matrix array
1257
+ return dims[: arg.warp_ndim]
1258
+ else:
1259
+ # scalar array
1260
+ return dims
1261
+
1262
+
1263
+ def get_jax_output_type(arg, dims):
1264
+ if isinstance(dims, int):
1265
+ dims = (dims,)
1266
+
1267
+ ndim = len(dims)
1268
+
1269
+ if arg.dtype_ndim > 0:
1270
+ # vector/matrix array
1271
+ if ndim == arg.warp_ndim:
1272
+ return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
1273
+ elif ndim == arg.jax_ndim:
1274
+ # make sure inner dimensions match
1275
+ inner_dims = dims[-arg.dtype_ndim :]
1276
+ for i in range(arg.dtype_ndim):
1277
+ if inner_dims[i] != arg.dtype_shape[i]:
1278
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1279
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
1280
+ else:
1281
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1282
+ else:
1283
+ # scalar array
1284
+ if ndim != arg.warp_ndim:
1285
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1286
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)