warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,698 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import traceback
18
+ from typing import Callable
19
+
20
+ import jax
21
+
22
+ import warp as wp
23
+ from warp.codegen import get_full_arg_spec, make_full_qualified_name
24
+ from warp.jax import get_jax_device
25
+ from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
26
+
27
+ from .xla_ffi import *
28
+
29
+
30
+ def jax_kernel(kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None):
31
+ """Create a JAX callback from a Warp kernel.
32
+
33
+ NOTE: This is an experimental feature under development.
34
+
35
+ Args:
36
+ kernel: The Warp kernel to launch.
37
+ num_outputs: Optional. Specify the number of output arguments if greater than 1.
38
+ vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
39
+ This argument can also be specified for individual calls.
40
+ launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
41
+ dimensions are inferred from the shape of the first array argument.
42
+ This argument can also be specified for individual calls.
43
+ output_dims: Optional. Specify the default dimensions of output arrays. If None, output
44
+ dimensions are inferred from the launch dimensions.
45
+ This argument can also be specified for individual calls.
46
+
47
+ Limitations:
48
+ - All kernel arguments must be contiguous arrays or scalars.
49
+ - Scalars must be static arguments in JAX.
50
+ - Input arguments are followed by output arguments in the Warp kernel definition.
51
+ - There must be at least one output argument.
52
+ - Only the CUDA backend is supported.
53
+ """
54
+
55
+ return FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims)
56
+
57
+
58
+ def jax_callable(
59
+ func: Callable,
60
+ num_outputs: int = 1,
61
+ graph_compatible: bool = True,
62
+ vmap_method: str = "broadcast_all",
63
+ output_dims=None,
64
+ ):
65
+ """Create a JAX callback from an annotated Python function.
66
+
67
+ The Python function arguments must have type annotations like Warp kernels.
68
+
69
+ NOTE: This is an experimental feature under development.
70
+
71
+ Args:
72
+ func: The Python function to call.
73
+ num_outputs: Optional. Specify the number of output arguments if greater than 1.
74
+ graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
75
+ vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
76
+ This argument can also be specified for individual calls.
77
+ output_dims: Optional. Specify the default dimensions of output arrays.
78
+ If ``None``, output dimensions are inferred from the launch dimensions.
79
+ This argument can also be specified for individual calls.
80
+
81
+ Limitations:
82
+ - All kernel arguments must be contiguous arrays or scalars.
83
+ - Scalars must be static arguments in JAX.
84
+ - Input arguments are followed by output arguments in the Warp kernel definition.
85
+ - There must be at least one output argument.
86
+ - Only the CUDA backend is supported.
87
+ """
88
+
89
+ return FfiCallable(func, num_outputs, graph_compatible, vmap_method, output_dims)
90
+
91
+
92
+ class FfiArg:
93
+ def __init__(self, name, type):
94
+ self.name = name
95
+ self.type = type
96
+ self.is_array = isinstance(type, wp.array)
97
+
98
+ if self.is_array:
99
+ if hasattr(type.dtype, "_wp_scalar_type_"):
100
+ self.dtype_shape = type.dtype._shape_
101
+ self.dtype_ndim = len(self.dtype_shape)
102
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
103
+ self.jax_ndim = type.ndim + self.dtype_ndim
104
+ elif type.dtype in wp.types.value_types:
105
+ self.dtype_ndim = 0
106
+ self.dtype_shape = ()
107
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
108
+ self.jax_ndim = type.ndim
109
+ else:
110
+ raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
111
+ self.warp_ndim = type.ndim
112
+ elif type in wp.types.value_types:
113
+ self.dtype_ndim = 0
114
+ self.dtype_shape = ()
115
+ self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
116
+ self.jax_ndim = 0
117
+ self.warp_ndim = 0
118
+ else:
119
+ raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
120
+
121
+
122
+ class FfiLaunchDesc:
123
+ def __init__(self, static_inputs, launch_dims):
124
+ self.static_inputs = static_inputs
125
+ self.launch_dims = launch_dims
126
+
127
+
128
+ class FfiKernel:
129
+ def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims):
130
+ self.kernel = kernel
131
+ self.name = generate_unique_name(kernel.func)
132
+ self.num_outputs = num_outputs
133
+ self.vmap_method = vmap_method
134
+ self.launch_dims = launch_dims
135
+ self.output_dims = output_dims
136
+ self.first_array_arg = None
137
+ self.launch_id = 0
138
+ self.launch_descriptors = {}
139
+
140
+ self.num_kernel_args = len(kernel.adj.args)
141
+ self.num_inputs = self.num_kernel_args - num_outputs
142
+ if self.num_outputs < 1:
143
+ raise ValueError("At least one output is required")
144
+ if self.num_outputs > self.num_kernel_args:
145
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
146
+
147
+ # process input args
148
+ self.input_args = []
149
+ for i in range(self.num_inputs):
150
+ arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
151
+ if arg.is_array:
152
+ # keep track of the first input array argument
153
+ if self.first_array_arg is None:
154
+ self.first_array_arg = i
155
+ self.input_args.append(arg)
156
+
157
+ # process output args
158
+ self.output_args = []
159
+ for i in range(self.num_inputs, self.num_kernel_args):
160
+ arg = FfiArg(kernel.adj.args[i].label, kernel.adj.args[i].type)
161
+ if not arg.is_array:
162
+ raise TypeError("All output arguments must be arrays")
163
+ self.output_args.append(arg)
164
+
165
+ # register the callback
166
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
167
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
168
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
169
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
170
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
171
+
172
+ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
173
+ num_inputs = len(args)
174
+ if num_inputs != self.num_inputs:
175
+ raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
176
+
177
+ # default argument fallback
178
+ if launch_dims is None:
179
+ launch_dims = self.launch_dims
180
+ if output_dims is None:
181
+ output_dims = self.output_dims
182
+ if vmap_method is None:
183
+ vmap_method = self.vmap_method
184
+
185
+ # process inputs
186
+ static_inputs = {}
187
+ for i in range(num_inputs):
188
+ input_arg = self.input_args[i]
189
+ input_value = args[i]
190
+ if input_arg.is_array:
191
+ # check dtype
192
+ if input_value.dtype != input_arg.jax_scalar_type:
193
+ raise TypeError(
194
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
195
+ )
196
+ # check ndim
197
+ if input_value.ndim != input_arg.jax_ndim:
198
+ raise TypeError(
199
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
200
+ )
201
+ # check inner dims
202
+ for d in range(input_arg.dtype_ndim):
203
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
204
+ raise TypeError(
205
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
206
+ )
207
+ else:
208
+ # make sure scalar is not a traced variable, should be static
209
+ if isinstance(input_value, jax.core.Tracer):
210
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
211
+ # stash the value to be retrieved by callback
212
+ static_inputs[input_arg.name] = input_arg.type(input_value)
213
+
214
+ # launch dimensions
215
+ if launch_dims is None:
216
+ # use the shape of the first input array
217
+ if self.first_array_arg is not None:
218
+ launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
219
+ else:
220
+ raise RuntimeError("Failed to determine launch dimensions")
221
+ elif isinstance(launch_dims, int):
222
+ launch_dims = (launch_dims,)
223
+ else:
224
+ launch_dims = tuple(launch_dims)
225
+
226
+ # output types
227
+ out_types = []
228
+ if isinstance(output_dims, dict):
229
+ # assume a dictionary of shapes keyed on argument name
230
+ for output_arg in self.output_args:
231
+ dims = output_dims.get(output_arg.name)
232
+ if dims is None:
233
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
234
+ out_types.append(get_jax_output_type(output_arg, dims))
235
+ else:
236
+ if output_dims is None:
237
+ # use launch dimensions
238
+ output_dims = launch_dims
239
+ elif isinstance(output_dims, int):
240
+ output_dims = (output_dims,)
241
+ # assume same dimensions for all outputs
242
+ for output_arg in self.output_args:
243
+ out_types.append(get_jax_output_type(output_arg, output_dims))
244
+
245
+ call = jax.ffi.ffi_call(
246
+ self.name,
247
+ out_types,
248
+ vmap_method=vmap_method,
249
+ )
250
+
251
+ # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
252
+ device = wp.device_from_jax(get_jax_device())
253
+ self.kernel.module.load(device)
254
+
255
+ # save launch data to be retrieved by callback
256
+ launch_id = self.launch_id
257
+ self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
258
+ self.launch_id += 1
259
+
260
+ return call(*args, launch_id=launch_id)
261
+
262
+ def ffi_callback(self, call_frame):
263
+ try:
264
+ # On the first call, XLA runtime will query the API version and traits
265
+ # metadata using the |extension| field. Let us respond to that query
266
+ # if the metadata extension is present.
267
+ extension = call_frame.contents.extension_start
268
+ if extension:
269
+ # Try to set the version metadata.
270
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
271
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
272
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
273
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
274
+ # Turn on CUDA graphs for this handler.
275
+ metadata_ext.contents.metadata.contents.traits = (
276
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
277
+ )
278
+ return None
279
+
280
+ # retrieve call info
281
+ attrs = decode_attrs(call_frame.contents.attrs)
282
+ launch_id = int(attrs["launch_id"])
283
+ launch_desc = self.launch_descriptors[launch_id]
284
+
285
+ num_inputs = call_frame.contents.args.size
286
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
287
+
288
+ num_outputs = call_frame.contents.rets.size
289
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
290
+
291
+ assert num_inputs == self.num_inputs
292
+ assert num_outputs == self.num_outputs
293
+
294
+ launch_bounds = launch_bounds_t(launch_desc.launch_dims)
295
+
296
+ # first kernel param is the launch bounds
297
+ kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
298
+ kernel_params[0] = ctypes.addressof(launch_bounds)
299
+
300
+ arg_refs = []
301
+
302
+ # inputs
303
+ for i in range(num_inputs):
304
+ input_arg = self.input_args[i]
305
+ if input_arg.is_array:
306
+ buffer = inputs[i].contents
307
+ shape = buffer.dims[: input_arg.type.ndim]
308
+ strides = strides_from_shape(shape, input_arg.type.dtype)
309
+ arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
310
+ kernel_params[i + 1] = ctypes.addressof(arg)
311
+ arg_refs.append(arg) # keep a reference
312
+ else:
313
+ # scalar argument, get stashed value
314
+ value = launch_desc.static_inputs[input_arg.name]
315
+ arg = input_arg.type._type_(value)
316
+ kernel_params[i + 1] = ctypes.addressof(arg)
317
+ arg_refs.append(arg) # keep a reference
318
+
319
+ # outputs
320
+ for i in range(num_outputs):
321
+ output_arg = self.output_args[i]
322
+ buffer = outputs[i].contents
323
+ shape = buffer.dims[: output_arg.type.ndim]
324
+ strides = strides_from_shape(shape, output_arg.type.dtype)
325
+ arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
326
+ kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
327
+ arg_refs.append(arg) # keep a reference
328
+
329
+ # get device and stream
330
+ device = wp.device_from_jax(get_jax_device())
331
+ stream = get_stream_from_callframe(call_frame.contents)
332
+
333
+ # get kernel hooks
334
+ hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
335
+ assert hooks.forward, "Failed to find kernel entry point"
336
+
337
+ # launch the kernel
338
+ wp.context.runtime.core.cuda_launch_kernel(
339
+ device.context,
340
+ hooks.forward,
341
+ launch_bounds.size,
342
+ 0,
343
+ 256,
344
+ hooks.forward_smem_bytes,
345
+ kernel_params,
346
+ stream,
347
+ )
348
+
349
+ except Exception as e:
350
+ print(traceback.format_exc())
351
+ return create_ffi_error(
352
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
353
+ )
354
+
355
+
356
+ class FfiCallDesc:
357
+ def __init__(self, static_inputs):
358
+ self.static_inputs = static_inputs
359
+
360
+
361
+ class FfiCallable:
362
+ def __init__(self, func, num_outputs, graph_compatible, vmap_method, output_dims):
363
+ self.func = func
364
+ self.name = generate_unique_name(func)
365
+ self.num_outputs = num_outputs
366
+ self.vmap_method = vmap_method
367
+ self.graph_compatible = graph_compatible
368
+ self.output_dims = output_dims
369
+ self.first_array_arg = None
370
+ self.has_static_args = False
371
+ self.call_id = 0
372
+ self.call_descriptors = {}
373
+
374
+ # get arguments and annotations
375
+ argspec = get_full_arg_spec(func)
376
+
377
+ num_args = len(argspec.args)
378
+ self.num_inputs = num_args - num_outputs
379
+ if self.num_outputs < 1:
380
+ raise ValueError("At least one output is required")
381
+ if self.num_outputs > num_args:
382
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
383
+
384
+ if len(argspec.annotations) < num_args:
385
+ raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
386
+
387
+ # parse type annotations
388
+ self.args = []
389
+ arg_idx = 0
390
+ for arg_name, arg_type in argspec.annotations.items():
391
+ if arg_name == "return":
392
+ if arg_type is not None:
393
+ raise TypeError("Function must not return a value")
394
+ else:
395
+ arg = FfiArg(arg_name, arg_type)
396
+ if arg.is_array:
397
+ if arg_idx < self.num_inputs and self.first_array_arg is None:
398
+ self.first_array_arg = arg_idx
399
+ else:
400
+ self.has_static_args = True
401
+ self.args.append(arg)
402
+ arg_idx += 1
403
+
404
+ self.input_args = self.args[: self.num_inputs]
405
+ self.output_args = self.args[self.num_inputs :]
406
+
407
+ # register the callback
408
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
409
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
410
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
411
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
412
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
413
+
414
+ def __call__(self, *args, output_dims=None, vmap_method=None):
415
+ num_inputs = len(args)
416
+ if num_inputs != self.num_inputs:
417
+ raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
418
+
419
+ # default argument fallback
420
+ if vmap_method is None:
421
+ vmap_method = self.vmap_method
422
+ if output_dims is None:
423
+ output_dims = self.output_dims
424
+
425
+ # process inputs
426
+ static_inputs = {}
427
+ for i in range(num_inputs):
428
+ input_arg = self.input_args[i]
429
+ input_value = args[i]
430
+ if input_arg.is_array:
431
+ # check dtype
432
+ if input_value.dtype != input_arg.jax_scalar_type:
433
+ raise TypeError(
434
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
435
+ )
436
+ # check ndim
437
+ if input_value.ndim != input_arg.jax_ndim:
438
+ raise TypeError(
439
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
440
+ )
441
+ # check inner dims
442
+ for d in range(input_arg.dtype_ndim):
443
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
444
+ raise TypeError(
445
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
446
+ )
447
+ else:
448
+ # make sure scalar is not a traced variable, should be static
449
+ if isinstance(input_value, jax.core.Tracer):
450
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
451
+ # stash the value to be retrieved by callback
452
+ static_inputs[input_arg.name] = input_arg.type(input_value)
453
+
454
+ if output_dims is None and self.first_array_arg is not None:
455
+ # use the shape of the first input array
456
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
457
+
458
+ # output types
459
+ out_types = []
460
+ if isinstance(output_dims, dict):
461
+ # assume a dictionary of shapes keyed on argument name
462
+ for output_arg in self.output_args:
463
+ dims = output_dims.get(output_arg.name)
464
+ if dims is None:
465
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
466
+ out_types.append(get_jax_output_type(output_arg, dims))
467
+ else:
468
+ if output_dims is None:
469
+ raise ValueError("Unable to determine output dimensions")
470
+ elif isinstance(output_dims, int):
471
+ output_dims = (output_dims,)
472
+ # assume same dimensions for all outputs
473
+ for output_arg in self.output_args:
474
+ out_types.append(get_jax_output_type(output_arg, output_dims))
475
+
476
+ call = jax.ffi.ffi_call(
477
+ self.name,
478
+ out_types,
479
+ vmap_method=vmap_method,
480
+ # has_side_effect=True, # force this function to execute even if outputs aren't used
481
+ )
482
+
483
+ # load the module
484
+ # NOTE: if the target function uses kernels from different modules, they will not be loaded here
485
+ device = wp.device_from_jax(get_jax_device())
486
+ module = wp.get_module(self.func.__module__)
487
+ module.load(device)
488
+
489
+ if self.has_static_args:
490
+ # save call data to be retrieved by callback
491
+ call_id = self.call_id
492
+ self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
493
+ self.call_id += 1
494
+ return call(*args, call_id=call_id)
495
+ else:
496
+ return call(*args)
497
+
498
+ def ffi_callback(self, call_frame):
499
+ try:
500
+ # TODO Try-catch around the body and return XLA_FFI_Error on error.
501
+ extension = call_frame.contents.extension_start
502
+ # On the first call, XLA runtime will query the API version and traits
503
+ # metadata using the |extension| field. Let us respond to that query
504
+ # if the metadata extension is present.
505
+ if extension:
506
+ # Try to set the version metadata.
507
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
508
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
509
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
510
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
511
+ # Turn on CUDA graphs for this handler.
512
+ if self.graph_compatible:
513
+ metadata_ext.contents.metadata.contents.traits = (
514
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
515
+ )
516
+ return None
517
+
518
+ if self.has_static_args:
519
+ # retrieve call info
520
+ attrs = decode_attrs(call_frame.contents.attrs)
521
+ call_id = int(attrs["call_id"])
522
+ call_desc = self.call_descriptors[call_id]
523
+
524
+ num_inputs = call_frame.contents.args.size
525
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
526
+
527
+ num_outputs = call_frame.contents.rets.size
528
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
529
+
530
+ assert num_inputs == self.num_inputs
531
+ assert num_outputs == self.num_outputs
532
+
533
+ device = wp.device_from_jax(get_jax_device())
534
+ cuda_stream = get_stream_from_callframe(call_frame.contents)
535
+ stream = wp.Stream(device, cuda_stream=cuda_stream)
536
+
537
+ # reconstruct the argument list
538
+ arg_list = []
539
+
540
+ # inputs
541
+ for i in range(num_inputs):
542
+ arg = self.input_args[i]
543
+ if arg.is_array:
544
+ buffer = inputs[i].contents
545
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
546
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
547
+ arg_list.append(arr)
548
+ else:
549
+ # scalar argument, get stashed value
550
+ value = call_desc.static_inputs[arg.name]
551
+ arg_list.append(value)
552
+
553
+ # outputs
554
+ for i in range(num_outputs):
555
+ arg = self.output_args[i]
556
+ buffer = outputs[i].contents
557
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
558
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
559
+ arg_list.append(arr)
560
+
561
+ # call the Python function with reconstructed arguments
562
+ with wp.ScopedStream(stream, sync_enter=False):
563
+ self.func(*arg_list)
564
+
565
+ except Exception as e:
566
+ print(traceback.format_exc())
567
+ return create_ffi_error(
568
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
569
+ )
570
+
571
+ return None
572
+
573
+
574
+ ###############################################################################
575
+ #
576
+ # Generic FFI callbacks for Python functions of the form
577
+ # func(inputs, outputs, attrs, ctx)
578
+ #
579
+ ###############################################################################
580
+
581
+ # Holder for the custom callbacks to keep them alive.
582
+ ffi_callbacks = {}
583
+
584
+
585
+ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
586
+ """Create a JAX callback from a Python function.
587
+
588
+ The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
589
+
590
+ NOTE: This is an experimental feature under development.
591
+
592
+ Args:
593
+ name: A unique FFI callback name.
594
+ func: The Python function to call.
595
+ graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
596
+ """
597
+
598
+ # TODO check that the name is not already registered
599
+
600
+ def ffi_callback(call_frame):
601
+ try:
602
+ # TODO Try-catch around the body and return XLA_FFI_Error on error.
603
+ extension = call_frame.contents.extension_start
604
+ # On the first call, XLA runtime will query the API version and traits
605
+ # metadata using the |extension| field. Let us respond to that query
606
+ # if the metadata extension is present.
607
+ if extension:
608
+ # Try to set the version metadata.
609
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
610
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
611
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
612
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
613
+ if graph_compatible:
614
+ # Turn on CUDA graphs for this handler.
615
+ metadata_ext.contents.metadata.contents.traits = (
616
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
617
+ )
618
+ return None
619
+
620
+ attrs = decode_attrs(call_frame.contents.attrs)
621
+
622
+ input_count = call_frame.contents.args.size
623
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
624
+ inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
625
+
626
+ output_count = call_frame.contents.rets.size
627
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
628
+ outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
629
+
630
+ ctx = ExecutionContext(call_frame.contents)
631
+
632
+ func(inputs, outputs, attrs, ctx)
633
+ except Exception as e:
634
+ print(traceback.format_exc())
635
+ return create_ffi_error(
636
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
637
+ )
638
+
639
+ return None
640
+
641
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
642
+ callback_func = FFI_CCALLFUNC(ffi_callback)
643
+ ffi_callbacks[name] = callback_func
644
+ ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
645
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
646
+ jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
647
+
648
+
649
+ ###############################################################################
650
+ #
651
+ # Utilities
652
+ #
653
+ ###############################################################################
654
+
655
+ # ensure unique FFI callback names
656
+ ffi_name_counts = {}
657
+
658
+
659
+ def generate_unique_name(func) -> str:
660
+ key = make_full_qualified_name(func)
661
+ unique_id = ffi_name_counts.get(key, 0)
662
+ ffi_name_counts[key] = unique_id + 1
663
+ return f"{key}_{unique_id}"
664
+
665
+
666
+ def get_warp_shape(arg, dims):
667
+ if arg.dtype_ndim > 0:
668
+ # vector/matrix array
669
+ return dims[: arg.warp_ndim]
670
+ else:
671
+ # scalar array
672
+ return dims
673
+
674
+
675
+ def get_jax_output_type(arg, dims):
676
+ if isinstance(dims, int):
677
+ dims = (dims,)
678
+
679
+ ndim = len(dims)
680
+
681
+ if arg.dtype_ndim > 0:
682
+ # vector/matrix array
683
+ if ndim == arg.warp_ndim:
684
+ return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
685
+ elif ndim == arg.jax_ndim:
686
+ # make sure inner dimensions match
687
+ inner_dims = dims[-arg.dtype_ndim :]
688
+ for i in range(arg.dtype_ndim):
689
+ if inner_dims[i] != arg.dtype_shape[i]:
690
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
691
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
692
+ else:
693
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
694
+ else:
695
+ # scalar array
696
+ if ndim != arg.warp_ndim:
697
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
698
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)