warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,883 +13,27 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import ctypes
17
- import threading
18
- import traceback
19
- from enum import IntEnum
20
- from typing import Callable, Optional
16
+ # isort: skip_file
21
17
 
22
- import jax
18
+ from warp._src.jax_experimental.ffi import GraphMode as GraphMode
19
+ from warp._src.jax_experimental.ffi import jax_kernel as jax_kernel
20
+ from warp._src.jax_experimental.ffi import jax_callable as jax_callable
21
+ from warp._src.jax_experimental.ffi import register_ffi_callback as register_ffi_callback
23
22
 
24
- import warp as wp
25
- from warp.codegen import get_full_arg_spec, make_full_qualified_name
26
- from warp.jax import get_jax_device
27
- from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
23
+ from warp._src.jax_experimental.ffi import (
24
+ get_jax_callable_default_graph_cache_max as get_jax_callable_default_graph_cache_max,
25
+ )
26
+ from warp._src.jax_experimental.ffi import (
27
+ set_jax_callable_default_graph_cache_max as set_jax_callable_default_graph_cache_max,
28
+ )
29
+ from warp._src.jax_experimental.ffi import clear_jax_callable_graph_cache as clear_jax_callable_graph_cache
28
30
 
29
- from .xla_ffi import *
31
+ # TODO: Remove after cleaning up the public API.
30
32
 
33
+ from warp._src.jax_experimental import ffi as _ffi
31
34
 
32
- def check_jax_version():
33
- # check if JAX version supports this
34
- if jax.__version_info__ < (0, 5, 0):
35
- msg = (
36
- "This version of jax_kernel() requires JAX version 0.5.0 or higher, "
37
- f"but installed JAX version is {jax.__version_info__}."
38
- )
39
- if jax.__version_info__ >= (0, 4, 25):
40
- msg += " Please use warp.jax_experimental.custom_call.jax_kernel instead."
41
- raise RuntimeError(msg)
42
35
 
36
+ def __getattr__(name):
37
+ from warp._src.utils import get_deprecated_api
43
38
 
44
- class GraphMode(IntEnum):
45
- NONE = 0 # don't capture a graph
46
- JAX = 1 # let JAX capture a graph
47
- WARP = 2 # let Warp capture a graph
48
-
49
-
50
- class FfiArg:
51
- def __init__(self, name, type, in_out=False):
52
- self.name = name
53
- self.type = type
54
- self.in_out = in_out
55
- self.is_array = isinstance(type, wp.array)
56
-
57
- if self.is_array:
58
- if hasattr(type.dtype, "_wp_scalar_type_"):
59
- self.dtype_shape = type.dtype._shape_
60
- self.dtype_ndim = len(self.dtype_shape)
61
- self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
62
- self.jax_ndim = type.ndim + self.dtype_ndim
63
- elif type.dtype in wp.types.value_types:
64
- self.dtype_ndim = 0
65
- self.dtype_shape = ()
66
- self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
67
- self.jax_ndim = type.ndim
68
- else:
69
- raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
70
- self.warp_ndim = type.ndim
71
- elif type in wp.types.value_types:
72
- self.dtype_ndim = 0
73
- self.dtype_shape = ()
74
- self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
75
- self.jax_ndim = 0
76
- self.warp_ndim = 0
77
- else:
78
- raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
79
-
80
-
81
- class FfiLaunchDesc:
82
- def __init__(self, static_inputs, launch_dims):
83
- self.static_inputs = static_inputs
84
- self.launch_dims = launch_dims
85
-
86
-
87
- class FfiKernel:
88
- def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
89
- self.kernel = kernel
90
- self.name = generate_unique_name(kernel.func)
91
- self.num_outputs = num_outputs
92
- self.vmap_method = vmap_method
93
- self.launch_dims = launch_dims
94
- self.output_dims = output_dims
95
- self.first_array_arg = None
96
- self.launch_id = 0
97
- self.launch_descriptors = {}
98
-
99
- in_out_argnames_list = in_out_argnames or []
100
- in_out_argnames = set(in_out_argnames_list)
101
- if len(in_out_argnames_list) != len(in_out_argnames):
102
- raise AssertionError("in_out_argnames must not contain duplicate names")
103
-
104
- self.num_kernel_args = len(kernel.adj.args)
105
- self.num_in_out = len(in_out_argnames)
106
- self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
107
- if self.num_outputs < 1:
108
- raise ValueError("At least one output is required")
109
- if self.num_outputs > self.num_kernel_args:
110
- raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
111
- if self.num_outputs < self.num_in_out:
112
- raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
113
-
114
- # process input args
115
- self.input_args = []
116
- for i in range(self.num_inputs):
117
- arg_name = kernel.adj.args[i].label
118
- arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
119
- if arg_name in in_out_argnames:
120
- in_out_argnames.remove(arg_name)
121
- if arg.is_array:
122
- # keep track of the first input array argument
123
- if self.first_array_arg is None:
124
- self.first_array_arg = i
125
- self.input_args.append(arg)
126
-
127
- # process output args
128
- self.output_args = []
129
- for i in range(self.num_inputs, self.num_kernel_args):
130
- arg_name = kernel.adj.args[i].label
131
- if arg_name in in_out_argnames:
132
- raise AssertionError(
133
- f"Expected an output-only argument for argument {arg_name}."
134
- " in_out arguments should be placed before output-only arguments."
135
- )
136
- arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
137
- if not arg.is_array:
138
- raise TypeError("All output arguments must be arrays")
139
- self.output_args.append(arg)
140
-
141
- if in_out_argnames:
142
- raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
143
-
144
- # Build input output aliases.
145
- out_id = 0
146
- input_output_aliases = {}
147
- for in_id, arg in enumerate(self.input_args):
148
- if not arg.in_out:
149
- continue
150
- input_output_aliases[in_id] = out_id
151
- out_id += 1
152
- self.input_output_aliases = input_output_aliases
153
-
154
- # register the callback
155
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
156
- self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
157
- ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
158
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
159
- jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
160
-
161
- def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
162
- num_inputs = len(args)
163
- if num_inputs != self.num_inputs:
164
- raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
165
-
166
- # default argument fallback
167
- if launch_dims is None:
168
- launch_dims = self.launch_dims
169
- if output_dims is None:
170
- output_dims = self.output_dims
171
- if vmap_method is None:
172
- vmap_method = self.vmap_method
173
-
174
- # output types
175
- out_types = []
176
-
177
- # process inputs
178
- static_inputs = {}
179
- for i in range(num_inputs):
180
- input_arg = self.input_args[i]
181
- input_value = args[i]
182
- if input_arg.is_array:
183
- # check dtype
184
- if input_value.dtype != input_arg.jax_scalar_type:
185
- raise TypeError(
186
- f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
187
- )
188
- # check ndim
189
- if input_value.ndim != input_arg.jax_ndim:
190
- raise TypeError(
191
- f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
192
- )
193
- # check inner dims
194
- for d in range(input_arg.dtype_ndim):
195
- if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
196
- raise TypeError(
197
- f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
198
- )
199
- else:
200
- # make sure scalar is not a traced variable, should be static
201
- if isinstance(input_value, jax.core.Tracer):
202
- raise ValueError(f"Argument '{input_arg.name}' must be a static value")
203
- # stash the value to be retrieved by callback
204
- static_inputs[input_arg.name] = input_arg.type(input_value)
205
-
206
- # append in-out arg to output types
207
- if input_arg.in_out:
208
- out_types.append(get_jax_output_type(input_arg, input_value.shape))
209
-
210
- # launch dimensions
211
- if launch_dims is None:
212
- # use the shape of the first input array
213
- if self.first_array_arg is not None:
214
- launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
215
- else:
216
- raise RuntimeError("Failed to determine launch dimensions")
217
- elif isinstance(launch_dims, int):
218
- launch_dims = (launch_dims,)
219
- else:
220
- launch_dims = tuple(launch_dims)
221
-
222
- # output shapes
223
- if isinstance(output_dims, dict):
224
- # assume a dictionary of shapes keyed on argument name
225
- for output_arg in self.output_args:
226
- dims = output_dims.get(output_arg.name)
227
- if dims is None:
228
- raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
229
- out_types.append(get_jax_output_type(output_arg, dims))
230
- else:
231
- if output_dims is None:
232
- # use launch dimensions
233
- output_dims = launch_dims
234
- elif isinstance(output_dims, int):
235
- output_dims = (output_dims,)
236
- # assume same dimensions for all outputs
237
- for output_arg in self.output_args:
238
- out_types.append(get_jax_output_type(output_arg, output_dims))
239
-
240
- call = jax.ffi.ffi_call(
241
- self.name,
242
- out_types,
243
- vmap_method=vmap_method,
244
- input_output_aliases=self.input_output_aliases,
245
- )
246
-
247
- # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
248
- device = wp.device_from_jax(get_jax_device())
249
- self.kernel.module.load(device)
250
-
251
- # save launch data to be retrieved by callback
252
- launch_id = self.launch_id
253
- self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
254
- self.launch_id += 1
255
-
256
- return call(*args, launch_id=launch_id)
257
-
258
- def ffi_callback(self, call_frame):
259
- try:
260
- # On the first call, XLA runtime will query the API version and traits
261
- # metadata using the |extension| field. Let us respond to that query
262
- # if the metadata extension is present.
263
- extension = call_frame.contents.extension_start
264
- if extension:
265
- # Try to set the version metadata.
266
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
267
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
268
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
269
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
270
- # Turn on CUDA graphs for this handler.
271
- metadata_ext.contents.metadata.contents.traits = (
272
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
273
- )
274
- return None
275
-
276
- # retrieve call info
277
- attrs = decode_attrs(call_frame.contents.attrs)
278
- launch_id = int(attrs["launch_id"])
279
- launch_desc = self.launch_descriptors[launch_id]
280
-
281
- num_inputs = call_frame.contents.args.size
282
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
283
-
284
- num_outputs = call_frame.contents.rets.size
285
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
286
-
287
- assert num_inputs == self.num_inputs
288
- assert num_outputs == self.num_outputs
289
-
290
- launch_bounds = launch_bounds_t(launch_desc.launch_dims)
291
-
292
- # first kernel param is the launch bounds
293
- kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
294
- kernel_params[0] = ctypes.addressof(launch_bounds)
295
-
296
- arg_refs = []
297
-
298
- # input and in-out args
299
- for i, input_arg in enumerate(self.input_args):
300
- if input_arg.is_array:
301
- buffer = inputs[i].contents
302
- shape = buffer.dims[: input_arg.type.ndim]
303
- strides = strides_from_shape(shape, input_arg.type.dtype)
304
- arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
305
- kernel_params[i + 1] = ctypes.addressof(arg)
306
- arg_refs.append(arg) # keep a reference
307
- else:
308
- # scalar argument, get stashed value
309
- value = launch_desc.static_inputs[input_arg.name]
310
- arg = input_arg.type._type_(value)
311
- kernel_params[i + 1] = ctypes.addressof(arg)
312
- arg_refs.append(arg) # keep a reference
313
-
314
- # pure output args (skip in-out FFI buffers)
315
- for i, output_arg in enumerate(self.output_args):
316
- buffer = outputs[i + self.num_in_out].contents
317
- shape = buffer.dims[: output_arg.type.ndim]
318
- strides = strides_from_shape(shape, output_arg.type.dtype)
319
- arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
320
- kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
321
- arg_refs.append(arg) # keep a reference
322
-
323
- # get device and stream
324
- device = wp.device_from_jax(get_jax_device())
325
- stream = get_stream_from_callframe(call_frame.contents)
326
-
327
- # get kernel hooks
328
- hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
329
- assert hooks.forward, "Failed to find kernel entry point"
330
-
331
- # launch the kernel
332
- wp.context.runtime.core.wp_cuda_launch_kernel(
333
- device.context,
334
- hooks.forward,
335
- launch_bounds.size,
336
- 0,
337
- 256,
338
- hooks.forward_smem_bytes,
339
- kernel_params,
340
- stream,
341
- )
342
-
343
- except Exception as e:
344
- print(traceback.format_exc())
345
- return create_ffi_error(
346
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
347
- )
348
-
349
-
350
- class FfiCallDesc:
351
- def __init__(self, static_inputs):
352
- self.static_inputs = static_inputs
353
- self.captures = {}
354
-
355
-
356
- class FfiCallable:
357
- def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
358
- self.func = func
359
- self.name = generate_unique_name(func)
360
- self.num_outputs = num_outputs
361
- self.vmap_method = vmap_method
362
- self.graph_mode = graph_mode
363
- self.output_dims = output_dims
364
- self.first_array_arg = None
365
- self.call_id = 0
366
- self.call_descriptors = {}
367
-
368
- in_out_argnames_list = in_out_argnames or []
369
- in_out_argnames = set(in_out_argnames_list)
370
- if len(in_out_argnames_list) != len(in_out_argnames):
371
- raise AssertionError("in_out_argnames must not contain duplicate names")
372
-
373
- # get arguments and annotations
374
- argspec = get_full_arg_spec(func)
375
-
376
- num_args = len(argspec.args)
377
- self.num_in_out = len(in_out_argnames)
378
- self.num_inputs = num_args - num_outputs + self.num_in_out
379
- if self.num_outputs < 1:
380
- raise ValueError("At least one output is required")
381
- if self.num_outputs > num_args:
382
- raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
383
- if self.num_outputs < self.num_in_out:
384
- raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
385
-
386
- if len(argspec.annotations) < num_args:
387
- raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
388
-
389
- # parse type annotations
390
- self.args = []
391
- arg_idx = 0
392
- for arg_name, arg_type in argspec.annotations.items():
393
- if arg_name == "return":
394
- if arg_type is not None:
395
- raise TypeError("Function must not return a value")
396
- continue
397
- else:
398
- arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
399
- if arg_name in in_out_argnames:
400
- in_out_argnames.remove(arg_name)
401
- if arg.is_array:
402
- if arg_idx < self.num_inputs and self.first_array_arg is None:
403
- self.first_array_arg = arg_idx
404
- self.args.append(arg)
405
-
406
- if arg.in_out and arg_idx >= self.num_inputs:
407
- raise AssertionError(
408
- f"Expected an output-only argument for argument {arg_name}."
409
- " in_out arguments should be placed before output-only arguments."
410
- )
411
-
412
- arg_idx += 1
413
-
414
- if in_out_argnames:
415
- raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
416
-
417
- self.input_args = self.args[: self.num_inputs] # includes in-out args
418
- self.output_args = self.args[self.num_inputs :] # pure output args
419
-
420
- # Buffer indices for array arguments in callback.
421
- # In-out buffers are the same pointers in the XLA call frame,
422
- # so we only include them for inputs and skip them for outputs.
423
- self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
424
- self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
425
-
426
- # Build input output aliases.
427
- out_id = 0
428
- input_output_aliases = {}
429
- for in_id, arg in enumerate(self.input_args):
430
- if not arg.in_out:
431
- continue
432
- input_output_aliases[in_id] = out_id
433
- out_id += 1
434
- self.input_output_aliases = input_output_aliases
435
-
436
- # register the callback
437
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
438
- self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
439
- ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
440
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
441
- jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
442
-
443
- def __call__(self, *args, output_dims=None, vmap_method=None):
444
- num_inputs = len(args)
445
- if num_inputs != self.num_inputs:
446
- input_names = ", ".join(arg.name for arg in self.input_args)
447
- s = "" if self.num_inputs == 1 else "s"
448
- raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
449
-
450
- # default argument fallback
451
- if vmap_method is None:
452
- vmap_method = self.vmap_method
453
- if output_dims is None:
454
- output_dims = self.output_dims
455
-
456
- # output types
457
- out_types = []
458
-
459
- # process inputs
460
- static_inputs = {}
461
- for i in range(num_inputs):
462
- input_arg = self.input_args[i]
463
- input_value = args[i]
464
- if input_arg.is_array:
465
- # check dtype
466
- if input_value.dtype != input_arg.jax_scalar_type:
467
- raise TypeError(
468
- f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
469
- )
470
- # check ndim
471
- if input_value.ndim != input_arg.jax_ndim:
472
- raise TypeError(
473
- f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
474
- )
475
- # check inner dims
476
- for d in range(input_arg.dtype_ndim):
477
- if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
478
- raise TypeError(
479
- f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
480
- )
481
- else:
482
- # make sure scalar is not a traced variable, should be static
483
- if isinstance(input_value, jax.core.Tracer):
484
- raise ValueError(f"Argument '{input_arg.name}' must be a static value")
485
- # stash the value to be retrieved by callback
486
- static_inputs[input_arg.name] = input_arg.type(input_value)
487
-
488
- # append in-out arg to output types
489
- if input_arg.in_out:
490
- out_types.append(get_jax_output_type(input_arg, input_value.shape))
491
-
492
- # output shapes
493
- if isinstance(output_dims, dict):
494
- # assume a dictionary of shapes keyed on argument name
495
- for output_arg in self.output_args:
496
- dims = output_dims.get(output_arg.name)
497
- if dims is None:
498
- raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
499
- out_types.append(get_jax_output_type(output_arg, dims))
500
- else:
501
- if output_dims is None:
502
- if self.first_array_arg is None:
503
- raise ValueError("Unable to determine output dimensions")
504
- output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
505
- elif isinstance(output_dims, int):
506
- output_dims = (output_dims,)
507
- # assume same dimensions for all outputs
508
- for output_arg in self.output_args:
509
- out_types.append(get_jax_output_type(output_arg, output_dims))
510
-
511
- call = jax.ffi.ffi_call(
512
- self.name,
513
- out_types,
514
- vmap_method=vmap_method,
515
- input_output_aliases=self.input_output_aliases,
516
- # has_side_effect=True, # force this function to execute even if outputs aren't used
517
- )
518
-
519
- # load the module
520
- # NOTE: if the target function uses kernels from different modules, they will not be loaded here
521
- device = wp.device_from_jax(get_jax_device())
522
- module = wp.get_module(self.func.__module__)
523
- module.load(device)
524
-
525
- # save call data to be retrieved by callback
526
- call_id = self.call_id
527
- self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
528
- self.call_id += 1
529
- return call(*args, call_id=call_id)
530
-
531
- def ffi_callback(self, call_frame):
532
- try:
533
- # On the first call, XLA runtime will query the API version and traits
534
- # metadata using the |extension| field. Let us respond to that query
535
- # if the metadata extension is present.
536
- extension = call_frame.contents.extension_start
537
- if extension:
538
- # Try to set the version metadata.
539
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
540
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
541
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
542
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
543
- # Turn on CUDA graphs for this handler.
544
- if self.graph_mode is GraphMode.JAX:
545
- metadata_ext.contents.metadata.contents.traits = (
546
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
547
- )
548
- return None
549
-
550
- # retrieve call info
551
- # NOTE: this assumes that there's only one attribute - call_id (int64).
552
- # A more general but slower approach is this:
553
- # attrs = decode_attrs(call_frame.contents.attrs)
554
- # call_id = int(attrs["call_id"])
555
- attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
556
- call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
557
- call_desc = self.call_descriptors[call_id]
558
-
559
- num_inputs = call_frame.contents.args.size
560
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
561
-
562
- num_outputs = call_frame.contents.rets.size
563
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
564
-
565
- assert num_inputs == self.num_inputs
566
- assert num_outputs == self.num_outputs
567
-
568
- cuda_stream = get_stream_from_callframe(call_frame.contents)
569
-
570
- if self.graph_mode == GraphMode.WARP:
571
- # check if we already captured an identical call
572
- ip = [inputs[i].contents.data for i in self.array_input_indices]
573
- op = [outputs[i].contents.data for i in self.array_output_indices]
574
- buffer_hash = hash((*ip, *op))
575
- capture = call_desc.captures.get(buffer_hash)
576
-
577
- # launch existing graph
578
- if capture is not None:
579
- # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
580
- # This code should match wp.capture_launch().
581
- graph = capture.graph
582
- if graph.graph_exec is None:
583
- g = ctypes.c_void_p()
584
- if not wp.context.runtime.core.wp_cuda_graph_create_exec(
585
- graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
586
- ):
587
- raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
588
- graph.graph_exec = g
589
-
590
- if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
591
- raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
592
-
593
- # early out
594
- return
595
-
596
- device = wp.device_from_jax(get_jax_device())
597
- stream = wp.Stream(device, cuda_stream=cuda_stream)
598
-
599
- # reconstruct the argument list
600
- arg_list = []
601
-
602
- # input and in-out args
603
- for i, arg in enumerate(self.input_args):
604
- if arg.is_array:
605
- buffer = inputs[i].contents
606
- shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
607
- arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
608
- arg_list.append(arr)
609
- else:
610
- # scalar argument, get stashed value
611
- value = call_desc.static_inputs[arg.name]
612
- arg_list.append(value)
613
-
614
- # pure output args (skip in-out FFI buffers)
615
- for i, arg in enumerate(self.output_args):
616
- buffer = outputs[i + self.num_in_out].contents
617
- shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
618
- arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
619
- arg_list.append(arr)
620
-
621
- # call the Python function with reconstructed arguments
622
- with wp.ScopedStream(stream, sync_enter=False):
623
- if stream.is_capturing:
624
- # capturing with JAX
625
- with wp.ScopedCapture(external=True) as capture:
626
- self.func(*arg_list)
627
- # keep a reference to the capture object to prevent required modules getting unloaded
628
- call_desc.capture = capture
629
- elif self.graph_mode == GraphMode.WARP:
630
- # capturing with WARP
631
- with wp.ScopedCapture() as capture:
632
- self.func(*arg_list)
633
- wp.capture_launch(capture.graph)
634
- # keep a reference to the capture object and reuse it with same buffers
635
- call_desc.captures[buffer_hash] = capture
636
- else:
637
- # not capturing
638
- self.func(*arg_list)
639
-
640
- except Exception as e:
641
- print(traceback.format_exc())
642
- return create_ffi_error(
643
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
644
- )
645
-
646
- return None
647
-
648
-
649
- # Holders for the custom callbacks to keep them alive.
650
- _FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {}
651
- _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
652
- _FFI_REGISTRY_LOCK = threading.Lock()
653
-
654
-
655
- def jax_kernel(
656
- kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
657
- ):
658
- """Create a JAX callback from a Warp kernel.
659
-
660
- NOTE: This is an experimental feature under development.
661
-
662
- Args:
663
- kernel: The Warp kernel to launch.
664
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
665
- This must include the number of ``in_out_arguments``.
666
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
667
- This argument can also be specified for individual calls.
668
- launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
669
- dimensions are inferred from the shape of the first array argument.
670
- This argument can also be specified for individual calls.
671
- output_dims: Optional. Specify the default dimensions of output arrays. If None, output
672
- dimensions are inferred from the launch dimensions.
673
- This argument can also be specified for individual calls.
674
- in_out_argnames: Optional. Names of input-output arguments.
675
-
676
- Limitations:
677
- - All kernel arguments must be contiguous arrays or scalars.
678
- - Scalars must be static arguments in JAX.
679
- - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
680
- - There must be at least one output or input-output argument.
681
- - Only the CUDA backend is supported.
682
- """
683
-
684
- check_jax_version()
685
-
686
- key = (
687
- kernel.func,
688
- kernel.sig,
689
- num_outputs,
690
- vmap_method,
691
- tuple(launch_dims) if launch_dims else launch_dims,
692
- tuple(sorted(output_dims.items())) if output_dims else output_dims,
693
- )
694
-
695
- with _FFI_REGISTRY_LOCK:
696
- if key not in _FFI_KERNEL_REGISTRY:
697
- new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
698
- _FFI_KERNEL_REGISTRY[key] = new_kernel
699
-
700
- return _FFI_KERNEL_REGISTRY[key]
701
-
702
-
703
- def jax_callable(
704
- func: Callable,
705
- num_outputs: int = 1,
706
- graph_compatible: Optional[bool] = None, # deprecated
707
- graph_mode: GraphMode = GraphMode.JAX,
708
- vmap_method: Optional[str] = "broadcast_all",
709
- output_dims=None,
710
- in_out_argnames=None,
711
- ):
712
- """Create a JAX callback from an annotated Python function.
713
-
714
- The Python function arguments must have type annotations like Warp kernels.
715
-
716
- NOTE: This is an experimental feature under development.
717
-
718
- Args:
719
- func: The Python function to call.
720
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
721
- This must include the number of ``in_out_arguments``.
722
- graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
723
- This argument is deprecated, use ``graph_mode`` instead.
724
- graph_mode: Optional. CUDA graph capture mode.
725
- ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
726
- ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
727
- such as when the callable uses conditional graph nodes.
728
- ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
729
- such as host synchronization.
730
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
731
- This argument can also be specified for individual calls.
732
- output_dims: Optional. Specify the default dimensions of output arrays.
733
- If ``None``, output dimensions are inferred from the launch dimensions.
734
- This argument can also be specified for individual calls.
735
- in_out_argnames: Optional. Names of input-output arguments.
736
-
737
- Limitations:
738
- - All kernel arguments must be contiguous arrays or scalars.
739
- - Scalars must be static arguments in JAX.
740
- - Input and input-output arguments must precede the output arguments in the ``func`` definition.
741
- - There must be at least one output or input-output argument.
742
- - Only the CUDA backend is supported.
743
- """
744
-
745
- check_jax_version()
746
-
747
- if graph_compatible is not None:
748
- wp.utils.warn(
749
- "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
750
- DeprecationWarning,
751
- stacklevel=3,
752
- )
753
- if graph_compatible is False:
754
- graph_mode = GraphMode.NONE
755
-
756
- key = (
757
- func,
758
- num_outputs,
759
- graph_mode,
760
- vmap_method,
761
- tuple(sorted(output_dims.items())) if output_dims else output_dims,
762
- )
763
-
764
- with _FFI_REGISTRY_LOCK:
765
- if key not in _FFI_CALLABLE_REGISTRY:
766
- new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
767
- _FFI_CALLABLE_REGISTRY[key] = new_callable
768
-
769
- return _FFI_CALLABLE_REGISTRY[key]
770
-
771
-
772
- ###############################################################################
773
- #
774
- # Generic FFI callbacks for Python functions of the form
775
- # func(inputs, outputs, attrs, ctx)
776
- #
777
- ###############################################################################
778
-
779
-
780
- def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
781
- """Create a JAX callback from a Python function.
782
-
783
- The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
784
-
785
- NOTE: This is an experimental feature under development.
786
-
787
- Args:
788
- name: A unique FFI callback name.
789
- func: The Python function to call.
790
- graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
791
- """
792
-
793
- check_jax_version()
794
-
795
- # TODO check that the name is not already registered
796
-
797
- def ffi_callback(call_frame):
798
- try:
799
- extension = call_frame.contents.extension_start
800
- # On the first call, XLA runtime will query the API version and traits
801
- # metadata using the |extension| field. Let us respond to that query
802
- # if the metadata extension is present.
803
- if extension:
804
- # Try to set the version metadata.
805
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
806
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
807
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
808
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
809
- if graph_compatible:
810
- # Turn on CUDA graphs for this handler.
811
- metadata_ext.contents.metadata.contents.traits = (
812
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
813
- )
814
- return None
815
-
816
- attrs = decode_attrs(call_frame.contents.attrs)
817
-
818
- input_count = call_frame.contents.args.size
819
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
820
- inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
821
-
822
- output_count = call_frame.contents.rets.size
823
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
824
- outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
825
-
826
- ctx = ExecutionContext(call_frame.contents)
827
-
828
- func(inputs, outputs, attrs, ctx)
829
- except Exception as e:
830
- print(traceback.format_exc())
831
- return create_ffi_error(
832
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
833
- )
834
-
835
- return None
836
-
837
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
838
- callback_func = FFI_CCALLFUNC(ffi_callback)
839
- with _FFI_REGISTRY_LOCK:
840
- _FFI_CALLABLE_REGISTRY[name] = callback_func
841
- ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
842
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
843
- jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
844
-
845
-
846
- ###############################################################################
847
- #
848
- # Utilities
849
- #
850
- ###############################################################################
851
-
852
- # ensure unique FFI callback names
853
- ffi_name_counts = {}
854
-
855
-
856
- def generate_unique_name(func) -> str:
857
- key = make_full_qualified_name(func)
858
- unique_id = ffi_name_counts.get(key, 0)
859
- ffi_name_counts[key] = unique_id + 1
860
- return f"{key}_{unique_id}"
861
-
862
-
863
- def get_warp_shape(arg, dims):
864
- if arg.dtype_ndim > 0:
865
- # vector/matrix array
866
- return dims[: arg.warp_ndim]
867
- else:
868
- # scalar array
869
- return dims
870
-
871
-
872
- def get_jax_output_type(arg, dims):
873
- if isinstance(dims, int):
874
- dims = (dims,)
875
-
876
- ndim = len(dims)
877
-
878
- if arg.dtype_ndim > 0:
879
- # vector/matrix array
880
- if ndim == arg.warp_ndim:
881
- return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
882
- elif ndim == arg.jax_ndim:
883
- # make sure inner dimensions match
884
- inner_dims = dims[-arg.dtype_ndim :]
885
- for i in range(arg.dtype_ndim):
886
- if inner_dims[i] != arg.dtype_shape[i]:
887
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
888
- return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
889
- else:
890
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
891
- else:
892
- # scalar array
893
- if ndim != arg.warp_ndim:
894
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
895
- return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
39
+ return get_deprecated_api(_ffi, "wp.jax_experimental", name)