warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1284 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections
17
+ import ctypes
18
+ import inspect
19
+ import threading
20
+ import traceback
21
+ from enum import IntEnum
22
+ from typing import Callable, Optional
23
+
24
+ import jax
25
+
26
+ import warp as wp
27
+ from warp._src.codegen import get_full_arg_spec, make_full_qualified_name
28
+ from warp._src.jax import get_jax_device
29
+ from warp._src.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
30
+
31
+ from .xla_ffi import *
32
+
33
+ # Type alias for differentiable kernel cache key
34
+ DiffKernelCacheKey = tuple[Callable, tuple, int, str, tuple[str, ...]]
35
+
36
+ # Holders for the custom callbacks to keep them alive.
37
+ _FFI_KERNEL_REGISTRY: dict[str, "FfiKernel"] = {}
38
+ _FFI_DIFF_KERNEL_REGISTRY: dict[DiffKernelCacheKey, Callable] = {}
39
+ _FFI_CALLABLE_REGISTRY: dict[str, "FfiCallable"] = {}
40
+ _FFI_CALLBACK_REGISTRY: dict[str, ctypes.CFUNCTYPE] = {}
41
+ _FFI_REGISTRY_LOCK = threading.Lock()
42
+
43
+ # Lock when XLA invokes callbacks from multiple threads.
44
+ _FFI_CALLBACK_LOCK = threading.Lock()
45
+
46
+
47
+ def check_jax_version():
48
+ # check if JAX version supports this
49
+ if jax.__version_info__ < (0, 5, 0):
50
+ msg = (
51
+ "This version of jax_kernel() requires JAX version 0.5.0 or higher, "
52
+ f"but installed JAX version is {jax.__version_info__}."
53
+ )
54
+ if jax.__version_info__ >= (0, 4, 25):
55
+ msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
56
+ raise RuntimeError(msg)
57
+
58
+
59
+ class GraphMode(IntEnum):
60
+ NONE = 0 # don't capture a graph
61
+ JAX = 1 # let JAX capture a graph
62
+ WARP = 2 # let Warp capture a graph
63
+
64
+
65
+ class ModulePreloadMode(IntEnum):
66
+ NONE = 0 # don't preload modules
67
+ CURRENT_DEVICE = 1 # preload on currently active device
68
+ ALL_DEVICES = 2 # preload on all supported devices
69
+
70
+
71
+ class FfiArg:
72
+ def __init__(self, name, type, in_out=False):
73
+ self.name = name
74
+ self.type = type
75
+ self.in_out = in_out
76
+ self.is_array = isinstance(type, wp.array)
77
+
78
+ if self.is_array:
79
+ if hasattr(type.dtype, "_wp_scalar_type_"):
80
+ self.dtype_shape = type.dtype._shape_
81
+ self.dtype_ndim = len(self.dtype_shape)
82
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
83
+ self.jax_ndim = type.ndim + self.dtype_ndim
84
+ elif type.dtype in wp._src.types.value_types:
85
+ self.dtype_ndim = 0
86
+ self.dtype_shape = ()
87
+ self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
88
+ self.jax_ndim = type.ndim
89
+ else:
90
+ raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
91
+ self.warp_ndim = type.ndim
92
+ elif type in wp._src.types.value_types:
93
+ self.dtype_ndim = 0
94
+ self.dtype_shape = ()
95
+ self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
96
+ self.jax_ndim = 0
97
+ self.warp_ndim = 0
98
+ else:
99
+ raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
100
+
101
+
102
+ class FfiLaunchDesc:
103
+ def __init__(self, static_inputs, launch_dims):
104
+ self.static_inputs = static_inputs
105
+ self.launch_dims = launch_dims
106
+
107
+
108
+ class FfiKernel:
109
+ def __init__(
110
+ self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
111
+ ):
112
+ self.kernel = kernel
113
+ self.name = generate_unique_name(kernel.func)
114
+ self.num_outputs = num_outputs
115
+ self.vmap_method = vmap_method
116
+ self.launch_dims = launch_dims
117
+ self.output_dims = output_dims
118
+ self.module_preload_mode = module_preload_mode
119
+ self.first_array_arg = None
120
+ self.launch_id = 0
121
+ self.launch_descriptors = {}
122
+
123
+ in_out_argnames_list = in_out_argnames or []
124
+ in_out_argnames = set(in_out_argnames_list)
125
+ if len(in_out_argnames_list) != len(in_out_argnames):
126
+ raise AssertionError("in_out_argnames must not contain duplicate names")
127
+
128
+ self.num_kernel_args = len(kernel.adj.args)
129
+ self.num_in_out = len(in_out_argnames)
130
+ self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
131
+ if self.num_outputs < 1:
132
+ raise ValueError("At least one output is required")
133
+ if self.num_outputs > self.num_kernel_args:
134
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
135
+ if self.num_outputs < self.num_in_out:
136
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
137
+
138
+ # process input args
139
+ self.input_args = []
140
+ for i in range(self.num_inputs):
141
+ arg_name = kernel.adj.args[i].label
142
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
143
+ if arg_name in in_out_argnames:
144
+ in_out_argnames.remove(arg_name)
145
+ if arg.is_array:
146
+ # keep track of the first input array argument
147
+ if self.first_array_arg is None:
148
+ self.first_array_arg = i
149
+ self.input_args.append(arg)
150
+
151
+ # process output args
152
+ self.output_args = []
153
+ for i in range(self.num_inputs, self.num_kernel_args):
154
+ arg_name = kernel.adj.args[i].label
155
+ if arg_name in in_out_argnames:
156
+ raise AssertionError(
157
+ f"Expected an output-only argument for argument {arg_name}."
158
+ " in_out arguments should be placed before output-only arguments."
159
+ )
160
+ arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
161
+ if not arg.is_array:
162
+ raise TypeError("All output arguments must be arrays")
163
+ self.output_args.append(arg)
164
+
165
+ if in_out_argnames:
166
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
167
+
168
+ # Build input output aliases.
169
+ out_id = 0
170
+ input_output_aliases = {}
171
+ for in_id, arg in enumerate(self.input_args):
172
+ if not arg.in_out:
173
+ continue
174
+ input_output_aliases[in_id] = out_id
175
+ out_id += 1
176
+ self.input_output_aliases = input_output_aliases
177
+
178
+ # register the callback
179
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
180
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
181
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
182
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
183
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
184
+
185
+ def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
186
+ num_inputs = len(args)
187
+ if num_inputs != self.num_inputs:
188
+ raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
189
+
190
+ # default argument fallback
191
+ if launch_dims is None:
192
+ launch_dims = self.launch_dims
193
+ if output_dims is None:
194
+ output_dims = self.output_dims
195
+ if vmap_method is None:
196
+ vmap_method = self.vmap_method
197
+
198
+ # output types
199
+ out_types = []
200
+
201
+ # process inputs
202
+ static_inputs = {}
203
+ for i in range(num_inputs):
204
+ input_arg = self.input_args[i]
205
+ input_value = args[i]
206
+ if input_arg.is_array:
207
+ # check dtype
208
+ if input_value.dtype != input_arg.jax_scalar_type:
209
+ raise TypeError(
210
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
211
+ )
212
+ # check ndim
213
+ if input_value.ndim != input_arg.jax_ndim:
214
+ raise TypeError(
215
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
216
+ )
217
+ # check inner dims
218
+ for d in range(input_arg.dtype_ndim):
219
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
220
+ raise TypeError(
221
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
222
+ )
223
+ else:
224
+ # make sure scalar is not a traced variable, should be static
225
+ if isinstance(input_value, jax.core.Tracer):
226
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
227
+ # stash the value to be retrieved by callback
228
+ static_inputs[input_arg.name] = input_arg.type(input_value)
229
+
230
+ # append in-out arg to output types
231
+ if input_arg.in_out:
232
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
233
+
234
+ # launch dimensions
235
+ if launch_dims is None:
236
+ # use the shape of the first input array
237
+ if self.first_array_arg is not None:
238
+ launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
239
+ else:
240
+ raise RuntimeError("Failed to determine launch dimensions")
241
+ elif isinstance(launch_dims, int):
242
+ launch_dims = (launch_dims,)
243
+ else:
244
+ launch_dims = tuple(launch_dims)
245
+
246
+ # output shapes
247
+ if isinstance(output_dims, dict):
248
+ # assume a dictionary of shapes keyed on argument name
249
+ for output_arg in self.output_args:
250
+ dims = output_dims.get(output_arg.name)
251
+ if dims is None:
252
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
253
+ out_types.append(get_jax_output_type(output_arg, dims))
254
+ else:
255
+ if output_dims is None:
256
+ # use launch dimensions
257
+ output_dims = launch_dims
258
+ elif isinstance(output_dims, int):
259
+ output_dims = (output_dims,)
260
+ # assume same dimensions for all outputs
261
+ for output_arg in self.output_args:
262
+ out_types.append(get_jax_output_type(output_arg, output_dims))
263
+
264
+ call = jax.ffi.ffi_call(
265
+ self.name,
266
+ out_types,
267
+ vmap_method=vmap_method,
268
+ input_output_aliases=self.input_output_aliases,
269
+ )
270
+
271
+ # preload on the specified devices
272
+ if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
273
+ device = wp.device_from_jax(get_jax_device())
274
+ self.kernel.module.load(device)
275
+ elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
276
+ for d in jax.local_devices():
277
+ try:
278
+ dev = wp.device_from_jax(d)
279
+ except Exception:
280
+ # ignore unsupported devices like TPUs
281
+ pass
282
+ # we only support CUDA devices for now
283
+ if dev.is_cuda:
284
+ self.kernel.module.load(dev)
285
+
286
+ # save launch data to be retrieved by callback
287
+ launch_id = self.launch_id
288
+ self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
289
+ self.launch_id += 1
290
+
291
+ return call(*args, launch_id=launch_id)
292
+
293
+ def ffi_callback(self, call_frame):
294
+ try:
295
+ # On the first call, XLA runtime will query the API version and traits
296
+ # metadata using the |extension| field. Let us respond to that query
297
+ # if the metadata extension is present.
298
+ extension = call_frame.contents.extension_start
299
+ if extension:
300
+ # Try to set the version metadata.
301
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
302
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
303
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
304
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
305
+ # Turn on CUDA graphs for this handler.
306
+ metadata_ext.contents.metadata.contents.traits = (
307
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
308
+ )
309
+ return None
310
+
311
+ # Lock is required to prevent race conditions when callback is invoked
312
+ # from multiple threads, like with pmap.
313
+ with _FFI_CALLBACK_LOCK:
314
+ # retrieve call info
315
+ attrs = decode_attrs(call_frame.contents.attrs)
316
+ launch_id = int(attrs["launch_id"])
317
+ launch_desc = self.launch_descriptors[launch_id]
318
+
319
+ num_inputs = call_frame.contents.args.size
320
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
321
+
322
+ num_outputs = call_frame.contents.rets.size
323
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
324
+
325
+ assert num_inputs == self.num_inputs
326
+ assert num_outputs == self.num_outputs
327
+
328
+ launch_bounds = launch_bounds_t(launch_desc.launch_dims)
329
+
330
+ # first kernel param is the launch bounds
331
+ kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
332
+ kernel_params[0] = ctypes.addressof(launch_bounds)
333
+
334
+ arg_refs = []
335
+
336
+ # input and in-out args
337
+ for i, input_arg in enumerate(self.input_args):
338
+ if input_arg.is_array:
339
+ buffer = inputs[i].contents
340
+ shape = buffer.dims[: input_arg.type.ndim]
341
+ strides = strides_from_shape(shape, input_arg.type.dtype)
342
+ arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
343
+ kernel_params[i + 1] = ctypes.addressof(arg)
344
+ arg_refs.append(arg) # keep a reference
345
+ else:
346
+ # scalar argument, get stashed value
347
+ value = launch_desc.static_inputs[input_arg.name]
348
+ arg = input_arg.type._type_(value)
349
+ kernel_params[i + 1] = ctypes.addressof(arg)
350
+ arg_refs.append(arg) # keep a reference
351
+
352
+ # pure output args (skip in-out FFI buffers)
353
+ for i, output_arg in enumerate(self.output_args):
354
+ buffer = outputs[i + self.num_in_out].contents
355
+ shape = buffer.dims[: output_arg.type.ndim]
356
+ strides = strides_from_shape(shape, output_arg.type.dtype)
357
+ arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
358
+ kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
359
+ arg_refs.append(arg) # keep a reference
360
+
361
+ # get device and stream
362
+ device = wp.get_cuda_device(get_device_ordinal_from_callframe(call_frame.contents))
363
+ stream = get_stream_from_callframe(call_frame.contents)
364
+
365
+ # get kernel hooks
366
+ hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
367
+ assert hooks.forward, "Failed to find kernel entry point"
368
+
369
+ # launch the kernel
370
+ wp._src.context.runtime.core.wp_cuda_launch_kernel(
371
+ device.context,
372
+ hooks.forward,
373
+ launch_bounds.size,
374
+ 0,
375
+ 256,
376
+ hooks.forward_smem_bytes,
377
+ kernel_params,
378
+ stream,
379
+ )
380
+
381
+ except Exception as e:
382
+ print(traceback.format_exc())
383
+ return create_ffi_error(
384
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
385
+ )
386
+
387
+
388
+ class FfiCallDesc:
389
+ def __init__(self, static_inputs):
390
+ self.static_inputs = static_inputs
391
+
392
+
393
+ class FfiCallable:
394
+ default_graph_cache_max: int | None = 32
395
+
396
+ def __init__(
397
+ self,
398
+ func,
399
+ num_outputs,
400
+ graph_mode,
401
+ vmap_method,
402
+ output_dims,
403
+ in_out_argnames,
404
+ graph_cache_max,
405
+ module_preload_mode,
406
+ ):
407
+ self.func = func
408
+ self.name = generate_unique_name(func)
409
+ self.num_outputs = num_outputs
410
+ self.vmap_method = vmap_method
411
+ self.graph_mode = graph_mode
412
+ self.output_dims = output_dims
413
+ self.module_preload_mode = module_preload_mode
414
+ self.first_array_arg = None
415
+ self.call_id = 0
416
+ self.call_descriptors = {}
417
+
418
+ # LRU cache of graphs captured by Warp
419
+ self._graph_cache_max = graph_cache_max
420
+ self.captures = collections.OrderedDict()
421
+
422
+ in_out_argnames_list = in_out_argnames or []
423
+ in_out_argnames = set(in_out_argnames_list)
424
+ if len(in_out_argnames_list) != len(in_out_argnames):
425
+ raise AssertionError("in_out_argnames must not contain duplicate names")
426
+
427
+ # get arguments and annotations
428
+ argspec = get_full_arg_spec(func)
429
+
430
+ num_args = len(argspec.args)
431
+ self.num_in_out = len(in_out_argnames)
432
+ self.num_inputs = num_args - num_outputs + self.num_in_out
433
+ if self.num_outputs < 1:
434
+ raise ValueError("At least one output is required")
435
+ if self.num_outputs > num_args:
436
+ raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
437
+ if self.num_outputs < self.num_in_out:
438
+ raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
439
+
440
+ if len(argspec.annotations) < num_args:
441
+ raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
442
+
443
+ # parse type annotations
444
+ self.args = []
445
+ arg_idx = 0
446
+ for arg_name, arg_type in argspec.annotations.items():
447
+ if arg_name == "return":
448
+ if arg_type is not None:
449
+ raise TypeError("Function must not return a value")
450
+ continue
451
+ else:
452
+ arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
453
+ if arg_name in in_out_argnames:
454
+ in_out_argnames.remove(arg_name)
455
+ if arg.is_array:
456
+ if arg_idx < self.num_inputs and self.first_array_arg is None:
457
+ self.first_array_arg = arg_idx
458
+ self.args.append(arg)
459
+
460
+ if arg.in_out and arg_idx >= self.num_inputs:
461
+ raise AssertionError(
462
+ f"Expected an output-only argument for argument {arg_name}."
463
+ " in_out arguments should be placed before output-only arguments."
464
+ )
465
+
466
+ arg_idx += 1
467
+
468
+ if in_out_argnames:
469
+ raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
470
+
471
+ self.input_args = self.args[: self.num_inputs] # includes in-out args
472
+ self.output_args = self.args[self.num_inputs :] # pure output args
473
+
474
+ # Buffer indices for array arguments in callback.
475
+ # In-out buffers are the same pointers in the XLA call frame,
476
+ # so we only include them for inputs and skip them for outputs.
477
+ self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
478
+ self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
479
+
480
+ # Build input output aliases.
481
+ out_id = 0
482
+ input_output_aliases = {}
483
+ for in_id, arg in enumerate(self.input_args):
484
+ if not arg.in_out:
485
+ continue
486
+ input_output_aliases[in_id] = out_id
487
+ out_id += 1
488
+ self.input_output_aliases = input_output_aliases
489
+
490
+ # register the callback
491
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
492
+ self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
493
+ ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
494
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
495
+ jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
496
+
497
+ def __call__(self, *args, output_dims=None, vmap_method=None):
498
+ num_inputs = len(args)
499
+ if num_inputs != self.num_inputs:
500
+ input_names = ", ".join(arg.name for arg in self.input_args)
501
+ s = "" if self.num_inputs == 1 else "s"
502
+ raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
503
+
504
+ # default argument fallback
505
+ if vmap_method is None:
506
+ vmap_method = self.vmap_method
507
+ if output_dims is None:
508
+ output_dims = self.output_dims
509
+
510
+ # output types
511
+ out_types = []
512
+
513
+ # process inputs
514
+ static_inputs = {}
515
+ for i in range(num_inputs):
516
+ input_arg = self.input_args[i]
517
+ input_value = args[i]
518
+ if input_arg.is_array:
519
+ # check dtype
520
+ if input_value.dtype != input_arg.jax_scalar_type:
521
+ raise TypeError(
522
+ f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
523
+ )
524
+ # check ndim
525
+ if input_value.ndim != input_arg.jax_ndim:
526
+ raise TypeError(
527
+ f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
528
+ )
529
+ # check inner dims
530
+ for d in range(input_arg.dtype_ndim):
531
+ if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
532
+ raise TypeError(
533
+ f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
534
+ )
535
+ else:
536
+ # make sure scalar is not a traced variable, should be static
537
+ if isinstance(input_value, jax.core.Tracer):
538
+ raise ValueError(f"Argument '{input_arg.name}' must be a static value")
539
+ # stash the value to be retrieved by callback
540
+ static_inputs[input_arg.name] = input_arg.type(input_value)
541
+
542
+ # append in-out arg to output types
543
+ if input_arg.in_out:
544
+ out_types.append(get_jax_output_type(input_arg, input_value.shape))
545
+
546
+ # output shapes
547
+ if isinstance(output_dims, dict):
548
+ # assume a dictionary of shapes keyed on argument name
549
+ for output_arg in self.output_args:
550
+ dims = output_dims.get(output_arg.name)
551
+ if dims is None:
552
+ raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
553
+ out_types.append(get_jax_output_type(output_arg, dims))
554
+ else:
555
+ if output_dims is None:
556
+ if self.first_array_arg is None:
557
+ raise ValueError("Unable to determine output dimensions")
558
+ output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
559
+ elif isinstance(output_dims, int):
560
+ output_dims = (output_dims,)
561
+ # assume same dimensions for all outputs
562
+ for output_arg in self.output_args:
563
+ out_types.append(get_jax_output_type(output_arg, output_dims))
564
+
565
+ call = jax.ffi.ffi_call(
566
+ self.name,
567
+ out_types,
568
+ vmap_method=vmap_method,
569
+ input_output_aliases=self.input_output_aliases,
570
+ # has_side_effect=True, # force this function to execute even if outputs aren't used
571
+ )
572
+
573
+ # preload on the specified devices
574
+ # NOTE: if the target function uses kernels from different modules, they will not be loaded here
575
+ module = wp.get_module(self.func.__module__)
576
+ if self.module_preload_mode == ModulePreloadMode.CURRENT_DEVICE:
577
+ device = wp.device_from_jax(get_jax_device())
578
+ module.load(device)
579
+ elif self.module_preload_mode == ModulePreloadMode.ALL_DEVICES:
580
+ for d in jax.local_devices():
581
+ try:
582
+ dev = wp.device_from_jax(d)
583
+ except Exception:
584
+ # ignore unsupported devices like TPUs
585
+ pass
586
+ # we only support CUDA devices for now
587
+ if dev.is_cuda:
588
+ module.load(dev)
589
+
590
+ # save call data to be retrieved by callback
591
+ call_id = self.call_id
592
+ self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
593
+ self.call_id += 1
594
+ return call(*args, call_id=call_id)
595
+
596
+ def ffi_callback(self, call_frame):
597
+ try:
598
+ # On the first call, XLA runtime will query the API version and traits
599
+ # metadata using the |extension| field. Let us respond to that query
600
+ # if the metadata extension is present.
601
+ extension = call_frame.contents.extension_start
602
+ if extension:
603
+ # Try to set the version metadata.
604
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
605
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
606
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
607
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
608
+ # Turn on CUDA graphs for this handler.
609
+ if self.graph_mode is GraphMode.JAX:
610
+ metadata_ext.contents.metadata.contents.traits = (
611
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
612
+ )
613
+ return None
614
+
615
+ # Lock is required to prevent race conditions when callback is invoked
616
+ # from multiple threads, like with pmap.
617
+ with _FFI_CALLBACK_LOCK:
618
+ # retrieve call info
619
+ # NOTE: this assumes that there's only one attribute - call_id (int64).
620
+ # A more general but slower approach is this:
621
+ # attrs = decode_attrs(call_frame.contents.attrs)
622
+ # call_id = int(attrs["call_id"])
623
+ attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
624
+ call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
625
+ call_desc = self.call_descriptors[call_id]
626
+
627
+ num_inputs = call_frame.contents.args.size
628
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
629
+
630
+ num_outputs = call_frame.contents.rets.size
631
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
632
+
633
+ assert num_inputs == self.num_inputs
634
+ assert num_outputs == self.num_outputs
635
+
636
+ cuda_stream = get_stream_from_callframe(call_frame.contents)
637
+
638
+ if self.graph_mode == GraphMode.WARP:
639
+ # check if we already captured an identical call
640
+ ip = [inputs[i].contents.data for i in self.array_input_indices]
641
+ op = [outputs[i].contents.data for i in self.array_output_indices]
642
+ capture_key = hash((call_id, *ip, *op))
643
+ capture = self.captures.get(capture_key)
644
+
645
+ # launch existing graph
646
+ if capture is not None:
647
+ # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
648
+ # This code should match wp.capture_launch().
649
+ graph = capture.graph
650
+ if graph.graph_exec is None:
651
+ g = ctypes.c_void_p()
652
+ if not wp._src.context.runtime.core.wp_cuda_graph_create_exec(
653
+ graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
654
+ ):
655
+ raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
656
+ graph.graph_exec = g
657
+
658
+ if not wp._src.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
659
+ raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
660
+
661
+ # update the graph cache to keep recently used graphs alive
662
+ self.captures.move_to_end(capture_key)
663
+
664
+ # early out
665
+ return
666
+
667
+ device_ordinal = get_device_ordinal_from_callframe(call_frame.contents)
668
+ device = wp.get_cuda_device(device_ordinal)
669
+ stream = wp.Stream(device, cuda_stream=cuda_stream)
670
+
671
+ # reconstruct the argument list
672
+ arg_list = []
673
+
674
+ # input and in-out args
675
+ for i, arg in enumerate(self.input_args):
676
+ if arg.is_array:
677
+ buffer = inputs[i].contents
678
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
679
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
680
+ arg_list.append(arr)
681
+ else:
682
+ # scalar argument, get stashed value
683
+ value = call_desc.static_inputs[arg.name]
684
+ arg_list.append(value)
685
+
686
+ # pure output args (skip in-out FFI buffers)
687
+ for i, arg in enumerate(self.output_args):
688
+ buffer = outputs[i + self.num_in_out].contents
689
+ shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
690
+ arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
691
+ arg_list.append(arr)
692
+
693
+ # call the Python function with reconstructed arguments
694
+ with wp.ScopedStream(stream, sync_enter=True):
695
+ if stream.is_capturing:
696
+ # capturing with JAX
697
+ with wp.ScopedCapture(external=True) as capture:
698
+ self.func(*arg_list)
699
+ # keep a reference to the capture object to prevent required modules getting unloaded
700
+ call_desc.capture = capture
701
+ elif self.graph_mode == GraphMode.WARP:
702
+ # capturing with WARP
703
+ with wp.ScopedCapture() as capture:
704
+ self.func(*arg_list)
705
+ wp.capture_launch(capture.graph)
706
+ # keep a reference to the capture object and reuse it with same buffers
707
+ self.captures[capture_key] = capture
708
+ # respect the cache size limit if specified
709
+ if self._graph_cache_max is not None and len(self.captures) > self._graph_cache_max:
710
+ self.captures.popitem(last=False)
711
+ else:
712
+ # not capturing
713
+ self.func(*arg_list)
714
+
715
+ except Exception as e:
716
+ print(traceback.format_exc())
717
+ return create_ffi_error(
718
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
719
+ )
720
+
721
+ return None
722
+
723
+ @property
724
+ def graph_cache_max(self) -> int | None:
725
+ return self._graph_cache_max
726
+
727
+ @graph_cache_max.setter
728
+ def graph_cache_max(self, value: int | None):
729
+ if value != self._graph_cache_max:
730
+ if value is not None and (self._graph_cache_max is None or value < self._graph_cache_max):
731
+ # trim the cache if needed
732
+ while len(self.captures) > value:
733
+ self.captures.popitem(last=False)
734
+ self._graph_cache_max = value
735
+
736
+ @property
737
+ def graph_cache_size(self) -> int:
738
+ return len(self.captures)
739
+
740
+
741
+ def jax_kernel(
742
+ kernel,
743
+ num_outputs=1,
744
+ vmap_method="broadcast_all",
745
+ launch_dims=None,
746
+ output_dims=None,
747
+ in_out_argnames=None,
748
+ module_preload_mode=ModulePreloadMode.CURRENT_DEVICE,
749
+ enable_backward: bool = False,
750
+ ):
751
+ """Create a JAX callback from a Warp kernel.
752
+
753
+ NOTE: This is an experimental feature under development.
754
+
755
+ Args:
756
+ kernel: The Warp kernel to launch.
757
+ num_outputs: Specify the number of output arguments if greater than 1.
758
+ This must include the number of ``in_out_arguments``.
759
+ vmap_method: String specifying how the callback transforms under ``vmap()``.
760
+ This argument can also be specified for individual calls.
761
+ launch_dims: Specify the default kernel launch dimensions. If None, launch
762
+ dimensions are inferred from the shape of the first array argument.
763
+ This argument can also be specified for individual calls.
764
+ output_dims: Specify the default dimensions of output arrays. If None, output
765
+ dimensions are inferred from the launch dimensions.
766
+ This argument can also be specified for individual calls.
767
+ in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
768
+ These must be array arguments that appear before any pure output arguments in the
769
+ kernel signature. The number of in-out arguments is included in ``num_outputs``.
770
+ Not supported when ``enable_backward=True``.
771
+ module_preload_mode: Specify the devices where the module should be preloaded.
772
+ enable_backward: Enable automatic differentiation for this kernel.
773
+
774
+ Limitations:
775
+ - All kernel arguments must be contiguous arrays or scalars.
776
+ - Scalars must be static arguments in JAX.
777
+ - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
778
+ - There must be at least one output or input-output argument.
779
+ - Only the CUDA backend is supported.
780
+ """
781
+
782
+ check_jax_version()
783
+
784
+ if not enable_backward:
785
+ key = (
786
+ kernel.func,
787
+ kernel.sig,
788
+ num_outputs,
789
+ vmap_method,
790
+ tuple(launch_dims) if launch_dims else launch_dims,
791
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
792
+ module_preload_mode,
793
+ )
794
+
795
+ with _FFI_REGISTRY_LOCK:
796
+ if key not in _FFI_KERNEL_REGISTRY:
797
+ new_kernel = FfiKernel(
798
+ kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames, module_preload_mode
799
+ )
800
+ _FFI_KERNEL_REGISTRY[key] = new_kernel
801
+
802
+ return _FFI_KERNEL_REGISTRY[key]
803
+
804
+ # make sure the arguments are compatible with autodiff
805
+ if in_out_argnames:
806
+ raise NotImplementedError(
807
+ "jax_kernel(): Input-output arguments (in_out_argnames) are not supported when enable_backward=True."
808
+ )
809
+
810
+ # TODO: we should support passing these to the forward and backward callables
811
+ if launch_dims is not None or output_dims is not None:
812
+ raise NotImplementedError(
813
+ "jax_kernel(): Custom dimensions (launch_dims, output_dims) are not supported when enable_backward=True."
814
+ )
815
+
816
+ # Differentiable path: build a custom VJP wrapper inline.
817
+ # Infer the original kernel signature (names and annotations)
818
+ signature = inspect.signature(kernel.func)
819
+
820
+ parameters = [p for p in signature.parameters.values() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
821
+ parameter_count = len(parameters)
822
+ num_inputs = parameter_count - num_outputs
823
+
824
+ # determine static argument indices
825
+ static_args = []
826
+ for i, p in enumerate(parameters[:num_inputs]):
827
+ param_type = p.annotation
828
+ if not isinstance(param_type, wp.array):
829
+ if param_type in wp._src.types.value_types:
830
+ static_args.append(i)
831
+ else:
832
+ raise TypeError(f"Invalid type for argument '{p.name}', expected array or scalar, got {type}")
833
+
834
+ def _resolve_launch_dims(call_args):
835
+ # determine launch dimensions from the shape of the first input array
836
+ for i, p in enumerate(parameters[:num_inputs]):
837
+ param_type = p.annotation
838
+ if isinstance(param_type, wp.array):
839
+ arg = call_args[i]
840
+ arg_shape = tuple(arg.shape)
841
+ if hasattr(param_type.dtype, "_wp_scalar_type_"):
842
+ # vector/matrix array, trim trailing dimensions of JAX input array
843
+ return arg_shape[: param_type.ndim]
844
+ else:
845
+ # scalar array
846
+ return arg_shape
847
+ raise RuntimeError("Unable to determine launch dimensions, at least one input array is required")
848
+
849
+ # Forward kernel wrapper: simply launches the kernel
850
+ def fwd_kernel_wrapper(*args):
851
+ wp.launch(kernel, dim=_resolve_launch_dims(args), inputs=args[:num_inputs], outputs=args[num_inputs:])
852
+
853
+ # update forward signature and annotations so jax_callable() sees a fully annotated function
854
+ fwd_kernel_wrapper.__signature__ = signature
855
+ fwd_kernel_wrapper.__annotations__ = {p.name: p.annotation for p in parameters}
856
+ fwd_kernel_wrapper.__annotations__["return"] = None
857
+
858
+ jax_fwd_kernel = jax_callable(fwd_kernel_wrapper, num_outputs=num_outputs, vmap_method=vmap_method)
859
+
860
+ # backward arguments only include static args once
861
+ bwd_arg_count = 2 * parameter_count - len(static_args)
862
+
863
+ # Backward wrapper: launches adjoint with provided output gradients
864
+ def bwd_kernel_wrapper(*args):
865
+ if len(args) != bwd_arg_count:
866
+ raise RuntimeError(f"Invalid backward argument count, expected {bwd_arg_count} but got {len(args)}")
867
+
868
+ inputs = list(args[:num_inputs])
869
+ outputs = list(args[num_inputs:parameter_count])
870
+ grad_out = list(args[parameter_count : parameter_count + num_outputs])
871
+ grad_in = list(args[parameter_count + num_outputs :])
872
+
873
+ for i in static_args:
874
+ grad_in.insert(i, inputs[i])
875
+
876
+ for gi in grad_in:
877
+ if isinstance(gi, wp.array):
878
+ try:
879
+ gi.zero_()
880
+ except Exception as e:
881
+ wp.utils.warn(f"Failed to zero gradient array: {e}", stacklevel=2)
882
+ raise e
883
+
884
+ # NOTE: We cannot use a passed launch_dims here, the backward rule doesn't receive it (and it could be wrong under pmap/vmap).
885
+ # We need to infer from the inputs.
886
+ wp.launch(
887
+ kernel,
888
+ dim=_resolve_launch_dims(inputs),
889
+ inputs=inputs,
890
+ outputs=outputs,
891
+ adj_inputs=grad_in,
892
+ adj_outputs=grad_out,
893
+ adjoint=True,
894
+ )
895
+
896
+ # Build the backward wrapper signature expected by jax_callable
897
+ bwd_input_params = parameters[:num_inputs]
898
+ bwd_output_params = parameters[num_inputs:parameter_count]
899
+ bwd_grad_output_params = [
900
+ inspect.Parameter(
901
+ f"adj_{p.name}",
902
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
903
+ default=p.default,
904
+ annotation=p.annotation,
905
+ )
906
+ for p in bwd_output_params
907
+ ]
908
+
909
+ bwd_grad_input_params = [
910
+ inspect.Parameter(
911
+ f"adj_{p.name}",
912
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
913
+ default=p.default,
914
+ annotation=p.annotation,
915
+ )
916
+ for p in bwd_input_params
917
+ ]
918
+ for i in reversed(static_args):
919
+ del bwd_grad_input_params[i]
920
+
921
+ # update backward signature and annotations so jax_callable() sees a fully annotated function
922
+ bwd_signature = bwd_input_params + bwd_output_params + bwd_grad_output_params + bwd_grad_input_params
923
+ bwd_kernel_wrapper.__signature__ = inspect.Signature(bwd_signature)
924
+ bwd_annotations = {}
925
+ for p in bwd_input_params:
926
+ bwd_annotations[p.name] = p.annotation
927
+ for p in bwd_output_params:
928
+ bwd_annotations[p.name] = p.annotation
929
+ for p in bwd_grad_output_params:
930
+ bwd_annotations[p.name] = p.annotation
931
+ for p in bwd_grad_input_params:
932
+ bwd_annotations[p.name] = p.annotation
933
+ bwd_annotations["return"] = None
934
+ bwd_kernel_wrapper.__annotations__ = bwd_annotations
935
+
936
+ jax_bwd_kernel = jax_callable(
937
+ bwd_kernel_wrapper,
938
+ num_outputs=len(bwd_input_params) - len(static_args),
939
+ vmap_method=vmap_method,
940
+ )
941
+
942
+ differentiable_input_indices = [i for i in range(num_inputs) if i not in static_args]
943
+ differentiable_input_names = [parameters[i].name for i in differentiable_input_indices]
944
+
945
+ def fwd_function(*args):
946
+ outputs = jax_fwd_kernel(*args)
947
+ non_static_inputs = list(args)
948
+ for i in reversed(static_args):
949
+ del non_static_inputs[i]
950
+ # Normalize to tuple for consistent handling
951
+ if num_outputs == 1:
952
+ outputs_tuple = (outputs,) if not isinstance(outputs, (list, tuple)) else (outputs[0],)
953
+ else:
954
+ outputs_tuple = outputs if isinstance(outputs, tuple) else tuple(outputs)
955
+ return outputs, (tuple(non_static_inputs), outputs_tuple)
956
+
957
+ def bwd_function(*bwd_args):
958
+ nondiff_vals = list(bwd_args[: len(static_args)])
959
+ residuals = bwd_args[len(static_args)]
960
+ grad_out_args = bwd_args[len(static_args) + 1 :]
961
+
962
+ non_static_inputs, output_vals_tuple = residuals
963
+
964
+ input_vals = list(non_static_inputs)
965
+ for i, v in zip(static_args, nondiff_vals):
966
+ input_vals.insert(i, v)
967
+
968
+ # Normalize grad outputs and handle nested containers (e.g., single tuple for multi-output)
969
+ if num_outputs == 1:
970
+ go = grad_out_args[0]
971
+ grad_out_tuple = tuple(go) if isinstance(go, (list, tuple)) else (go,)
972
+ else:
973
+ if len(grad_out_args) == 1 and isinstance(grad_out_args[0], (list, tuple)):
974
+ grad_out_tuple = tuple(grad_out_args[0])
975
+ else:
976
+ grad_out_tuple = tuple(grad_out_args)
977
+ bwd_call_args = list(input_vals) + list(output_vals_tuple) + list(grad_out_tuple)
978
+
979
+ out_dims_map = {}
980
+ param_ann = {p.name: p.annotation for p in parameters[:num_inputs]}
981
+ for name, val in zip(differentiable_input_names, non_static_inputs):
982
+ ann = param_ann.get(name)
983
+ if ann is None:
984
+ continue
985
+ # Check if annotation is a warp array type (annotation is an instance of wp.array)
986
+ is_array_ann = isinstance(ann, wp.array)
987
+ if not is_array_ann:
988
+ continue
989
+ dtype_ndim = 0
990
+ # Extract dtype ndim if it's a vector/matrix type
991
+ if hasattr(ann, "dtype") and hasattr(ann.dtype, "_wp_scalar_type_"):
992
+ dtype_ndim = len(ann.dtype._shape_)
993
+ warp_ndim = getattr(ann, "ndim", 0)
994
+ vshape = tuple(val.shape)
995
+ if warp_ndim == 0:
996
+ continue
997
+ if dtype_ndim > 0:
998
+ core_rank = max(0, len(vshape) - dtype_ndim)
999
+ warp_dims = vshape[max(0, core_rank - warp_ndim) : core_rank]
1000
+ else:
1001
+ warp_dims = vshape[-warp_ndim:]
1002
+ out_dims_map[f"adj_{name}"] = tuple(warp_dims)
1003
+
1004
+ non_static_input_grads = jax_bwd_kernel(*bwd_call_args, output_dims=out_dims_map)
1005
+ return tuple(non_static_input_grads)
1006
+
1007
+ jax_func = jax.custom_vjp(jax_fwd_kernel, nondiff_argnums=tuple(static_args))
1008
+ jax_func.defvjp(fwd_function, bwd_function)
1009
+
1010
+ if static_args:
1011
+ static_names = [parameters[i].name for i in static_args]
1012
+
1013
+ def _user_callable(*args):
1014
+ return jax_func(*args)
1015
+
1016
+ _user_callable.__signature__ = signature
1017
+
1018
+ # Cache differentiable wrapper
1019
+ key = (kernel.func, kernel.sig, num_outputs, vmap_method, tuple(sorted(static_names)))
1020
+ with _FFI_REGISTRY_LOCK:
1021
+ cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
1022
+ if cached is None:
1023
+ cached = jax.jit(_user_callable, static_argnames=tuple(static_names))
1024
+ _FFI_DIFF_KERNEL_REGISTRY[key] = cached
1025
+ return _FFI_DIFF_KERNEL_REGISTRY[key]
1026
+
1027
+ # Cache differentiable wrapper (no static args)
1028
+ key = (kernel.func, kernel.sig, num_outputs, vmap_method, ())
1029
+ with _FFI_REGISTRY_LOCK:
1030
+ cached = _FFI_DIFF_KERNEL_REGISTRY.get(key)
1031
+ if cached is None:
1032
+ _FFI_DIFF_KERNEL_REGISTRY[key] = jax_func
1033
+ cached = jax_func
1034
+ return cached
1035
+
1036
+
1037
+ def jax_callable(
1038
+ func: Callable,
1039
+ num_outputs: int = 1,
1040
+ graph_compatible: Optional[bool] = None, # deprecated
1041
+ graph_mode: GraphMode = GraphMode.JAX,
1042
+ vmap_method: Optional[str] = "broadcast_all",
1043
+ output_dims=None,
1044
+ in_out_argnames=None,
1045
+ graph_cache_max: int | None = None,
1046
+ module_preload_mode: ModulePreloadMode = ModulePreloadMode.CURRENT_DEVICE,
1047
+ ):
1048
+ """Create a JAX callback from an annotated Python function.
1049
+
1050
+ The Python function arguments must have type annotations like Warp kernels.
1051
+
1052
+ NOTE: This is an experimental feature under development.
1053
+
1054
+ Args:
1055
+ func: The Python function to call.
1056
+ num_outputs: Specify the number of output arguments if greater than 1.
1057
+ This must include the number of ``in_out_arguments``.
1058
+ graph_compatible: Whether the function can be called during CUDA graph capture.
1059
+ This argument is deprecated, use ``graph_mode`` instead.
1060
+ graph_mode: CUDA graph capture mode.
1061
+ ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing JAX capture.
1062
+ ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subgraph,
1063
+ such as when the callable uses conditional graph nodes.
1064
+ ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
1065
+ such as host synchronization.
1066
+ vmap_method: String specifying how the callback transforms under ``vmap()``.
1067
+ This argument can also be specified for individual calls.
1068
+ output_dims: Specify the default dimensions of output arrays.
1069
+ If ``None``, output dimensions are inferred from the launch dimensions.
1070
+ This argument can also be specified for individual calls.
1071
+ in_out_argnames: Names of arguments that are both inputs and outputs (aliased buffers).
1072
+ These must be array arguments that appear before any pure output arguments in the
1073
+ function signature. The number of in-out arguments is included in ``num_outputs``.
1074
+ graph_cache_max: Maximum number of cached graphs captured using ``GraphMode.WARP``.
1075
+ If ``None``, use ``warp.jax_experimental.get_jax_callable_default_graph_cache_max()``.
1076
+ module_preload_mode: Specify the devices where the module should be preloaded.
1077
+
1078
+ Limitations:
1079
+ - All kernel arguments must be contiguous arrays or scalars.
1080
+ - Scalars must be static arguments in JAX.
1081
+ - Input and input-output arguments must precede the output arguments in the ``func`` definition.
1082
+ - There must be at least one output or input-output argument.
1083
+ - Only the CUDA backend is supported.
1084
+ """
1085
+
1086
+ check_jax_version()
1087
+
1088
+ if graph_compatible is not None:
1089
+ wp._src.utils.warn(
1090
+ "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
1091
+ DeprecationWarning,
1092
+ stacklevel=3,
1093
+ )
1094
+ if graph_compatible is False:
1095
+ graph_mode = GraphMode.NONE
1096
+
1097
+ if graph_cache_max is None:
1098
+ graph_cache_max = FfiCallable.default_graph_cache_max
1099
+
1100
+ # Note: we don't include graph_cache_max in the key, it is applied below.
1101
+ key = (
1102
+ func,
1103
+ num_outputs,
1104
+ graph_mode,
1105
+ vmap_method,
1106
+ tuple(sorted(output_dims.items())) if output_dims else output_dims,
1107
+ module_preload_mode,
1108
+ )
1109
+
1110
+ with _FFI_REGISTRY_LOCK:
1111
+ callable = _FFI_CALLABLE_REGISTRY.get(key)
1112
+ if callable is None:
1113
+ callable = FfiCallable(
1114
+ func,
1115
+ num_outputs,
1116
+ graph_mode,
1117
+ vmap_method,
1118
+ output_dims,
1119
+ in_out_argnames,
1120
+ graph_cache_max,
1121
+ module_preload_mode,
1122
+ )
1123
+ _FFI_CALLABLE_REGISTRY[key] = callable
1124
+ else:
1125
+ # make sure we're using the latest graph cache max
1126
+ callable.graph_cache_max = graph_cache_max
1127
+
1128
+ return callable
1129
+
1130
+
1131
+ def get_jax_callable_default_graph_cache_max():
1132
+ """
1133
+ Get the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
1134
+ """
1135
+ return FfiCallable.default_graph_cache_max
1136
+
1137
+
1138
+ def set_jax_callable_default_graph_cache_max(cache_max: int | None):
1139
+ """
1140
+ Set the maximum size of the graph cache for graphs captured using ``GraphMode.WARP``, unlimited if ``None``.
1141
+ """
1142
+ FfiCallable.default_graph_cache_max = cache_max
1143
+
1144
+
1145
+ def clear_jax_callable_graph_cache(callable: FfiCallable | None = None):
1146
+ """Clear the graph cache of the given callable or all callables if ``None``."""
1147
+
1148
+ if callable is not None:
1149
+ callable.captures.clear()
1150
+ else:
1151
+ # apply to all callables
1152
+ with _FFI_REGISTRY_LOCK:
1153
+ for callable in _FFI_CALLABLE_REGISTRY.values():
1154
+ callable.captures.clear()
1155
+
1156
+
1157
+ ###############################################################################
1158
+ #
1159
+ # Generic FFI callbacks for Python functions of the form
1160
+ # func(inputs, outputs, attrs, ctx)
1161
+ #
1162
+ ###############################################################################
1163
+
1164
+
1165
+ def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
1166
+ """Create a JAX callback from a Python function.
1167
+
1168
+ The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
1169
+
1170
+ NOTE: This is an experimental feature under development.
1171
+
1172
+ Args:
1173
+ name: A unique FFI callback name.
1174
+ func: The Python function to call.
1175
+ graph_compatible: Whether the function can be called during CUDA graph capture.
1176
+ """
1177
+
1178
+ check_jax_version()
1179
+
1180
+ # TODO check that the name is not already registered
1181
+
1182
+ def ffi_callback(call_frame):
1183
+ try:
1184
+ extension = call_frame.contents.extension_start
1185
+ # On the first call, XLA runtime will query the API version and traits
1186
+ # metadata using the |extension| field. Let us respond to that query
1187
+ # if the metadata extension is present.
1188
+ if extension:
1189
+ # Try to set the version metadata.
1190
+ if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
1191
+ metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
1192
+ metadata_ext.contents.metadata.contents.api_version.major_version = 0
1193
+ metadata_ext.contents.metadata.contents.api_version.minor_version = 1
1194
+ if graph_compatible:
1195
+ # Turn on CUDA graphs for this handler.
1196
+ metadata_ext.contents.metadata.contents.traits = (
1197
+ XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
1198
+ )
1199
+ return None
1200
+
1201
+ # Lock is required to prevent race conditions when callback is invoked
1202
+ # from multiple threads, like with pmap.
1203
+ with _FFI_CALLBACK_LOCK:
1204
+ attrs = decode_attrs(call_frame.contents.attrs)
1205
+
1206
+ input_count = call_frame.contents.args.size
1207
+ inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
1208
+ inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
1209
+
1210
+ output_count = call_frame.contents.rets.size
1211
+ outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
1212
+ outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
1213
+
1214
+ ctx = ExecutionContext(call_frame.contents)
1215
+
1216
+ func(inputs, outputs, attrs, ctx)
1217
+
1218
+ except Exception as e:
1219
+ print(traceback.format_exc())
1220
+ return create_ffi_error(
1221
+ call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
1222
+ )
1223
+
1224
+ return None
1225
+
1226
+ FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
1227
+ callback_func = FFI_CCALLFUNC(ffi_callback)
1228
+ with _FFI_REGISTRY_LOCK:
1229
+ _FFI_CALLBACK_REGISTRY[name] = callback_func
1230
+ ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
1231
+ ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
1232
+ jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
1233
+
1234
+
1235
+ ###############################################################################
1236
+ #
1237
+ # Utilities
1238
+ #
1239
+ ###############################################################################
1240
+
1241
+ # ensure unique FFI callback names
1242
+ ffi_name_counts = {}
1243
+
1244
+
1245
+ def generate_unique_name(func) -> str:
1246
+ key = make_full_qualified_name(func)
1247
+ unique_id = ffi_name_counts.get(key, 0)
1248
+ ffi_name_counts[key] = unique_id + 1
1249
+ return f"{key}_{unique_id}"
1250
+
1251
+
1252
+ def get_warp_shape(arg, dims):
1253
+ if arg.dtype_ndim > 0:
1254
+ # vector/matrix array
1255
+ return dims[: arg.warp_ndim]
1256
+ else:
1257
+ # scalar array
1258
+ return dims
1259
+
1260
+
1261
+ def get_jax_output_type(arg, dims):
1262
+ if isinstance(dims, int):
1263
+ dims = (dims,)
1264
+
1265
+ ndim = len(dims)
1266
+
1267
+ if arg.dtype_ndim > 0:
1268
+ # vector/matrix array
1269
+ if ndim == arg.warp_ndim:
1270
+ return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
1271
+ elif ndim == arg.jax_ndim:
1272
+ # make sure inner dimensions match
1273
+ inner_dims = dims[-arg.dtype_ndim :]
1274
+ for i in range(arg.dtype_ndim):
1275
+ if inner_dims[i] != arg.dtype_shape[i]:
1276
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1277
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
1278
+ else:
1279
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1280
+ else:
1281
+ # scalar array
1282
+ if ndim != arg.warp_ndim:
1283
+ raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
1284
+ return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)