warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,863 +13,27 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import ctypes
17
- import threading
18
- import traceback
19
- from enum import IntEnum
20
- from typing import Callable, Optional
16
+ # isort: skip_file
21
17
 
22
- import jax
18
+ from warp._src.jax_experimental.ffi import GraphMode as GraphMode
19
+ from warp._src.jax_experimental.ffi import jax_kernel as jax_kernel
20
+ from warp._src.jax_experimental.ffi import jax_callable as jax_callable
21
+ from warp._src.jax_experimental.ffi import register_ffi_callback as register_ffi_callback
23
22
 
24
- import warp as wp
25
- from warp.codegen import get_full_arg_spec, make_full_qualified_name
26
- from warp.jax import get_jax_device
27
- from warp.types import array_t, launch_bounds_t, strides_from_shape, type_to_warp
23
+ from warp._src.jax_experimental.ffi import (
24
+ get_jax_callable_default_graph_cache_max as get_jax_callable_default_graph_cache_max,
25
+ )
26
+ from warp._src.jax_experimental.ffi import (
27
+ set_jax_callable_default_graph_cache_max as set_jax_callable_default_graph_cache_max,
28
+ )
29
+ from warp._src.jax_experimental.ffi import clear_jax_callable_graph_cache as clear_jax_callable_graph_cache
28
30
 
29
- from .xla_ffi import *
31
+ # TODO: Remove after cleaning up the public API.
30
32
 
33
+ from warp._src.jax_experimental import ffi as _ffi
31
34
 
32
- class GraphMode(IntEnum):
33
- NONE = 0 # don't capture a graph
34
- JAX = 1 # let JAX capture a graph
35
- WARP = 2 # let Warp capture a graph
36
35
 
36
+ def __getattr__(name):
37
+ from warp._src.utils import get_deprecated_api
37
38
 
38
- class FfiArg:
39
- def __init__(self, name, type, in_out=False):
40
- self.name = name
41
- self.type = type
42
- self.in_out = in_out
43
- self.is_array = isinstance(type, wp.array)
44
-
45
- if self.is_array:
46
- if hasattr(type.dtype, "_wp_scalar_type_"):
47
- self.dtype_shape = type.dtype._shape_
48
- self.dtype_ndim = len(self.dtype_shape)
49
- self.jax_scalar_type = wp.dtype_to_jax(type.dtype._wp_scalar_type_)
50
- self.jax_ndim = type.ndim + self.dtype_ndim
51
- elif type.dtype in wp.types.value_types:
52
- self.dtype_ndim = 0
53
- self.dtype_shape = ()
54
- self.jax_scalar_type = wp.dtype_to_jax(type.dtype)
55
- self.jax_ndim = type.ndim
56
- else:
57
- raise TypeError(f"Invalid data type for array argument '{name}', expected scalar, vector, or matrix")
58
- self.warp_ndim = type.ndim
59
- elif type in wp.types.value_types:
60
- self.dtype_ndim = 0
61
- self.dtype_shape = ()
62
- self.jax_scalar_type = wp.dtype_to_jax(type_to_warp(type))
63
- self.jax_ndim = 0
64
- self.warp_ndim = 0
65
- else:
66
- raise TypeError(f"Invalid type for argument '{name}', expected array or scalar, got {type}")
67
-
68
-
69
- class FfiLaunchDesc:
70
- def __init__(self, static_inputs, launch_dims):
71
- self.static_inputs = static_inputs
72
- self.launch_dims = launch_dims
73
-
74
-
75
- class FfiKernel:
76
- def __init__(self, kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames):
77
- self.kernel = kernel
78
- self.name = generate_unique_name(kernel.func)
79
- self.num_outputs = num_outputs
80
- self.vmap_method = vmap_method
81
- self.launch_dims = launch_dims
82
- self.output_dims = output_dims
83
- self.first_array_arg = None
84
- self.launch_id = 0
85
- self.launch_descriptors = {}
86
-
87
- in_out_argnames_list = in_out_argnames or []
88
- in_out_argnames = set(in_out_argnames_list)
89
- if len(in_out_argnames_list) != len(in_out_argnames):
90
- raise AssertionError("in_out_argnames must not contain duplicate names")
91
-
92
- self.num_kernel_args = len(kernel.adj.args)
93
- self.num_in_out = len(in_out_argnames)
94
- self.num_inputs = self.num_kernel_args - num_outputs + self.num_in_out
95
- if self.num_outputs < 1:
96
- raise ValueError("At least one output is required")
97
- if self.num_outputs > self.num_kernel_args:
98
- raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
99
- if self.num_outputs < self.num_in_out:
100
- raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
101
-
102
- # process input args
103
- self.input_args = []
104
- for i in range(self.num_inputs):
105
- arg_name = kernel.adj.args[i].label
106
- arg = FfiArg(arg_name, kernel.adj.args[i].type, arg_name in in_out_argnames)
107
- if arg_name in in_out_argnames:
108
- in_out_argnames.remove(arg_name)
109
- if arg.is_array:
110
- # keep track of the first input array argument
111
- if self.first_array_arg is None:
112
- self.first_array_arg = i
113
- self.input_args.append(arg)
114
-
115
- # process output args
116
- self.output_args = []
117
- for i in range(self.num_inputs, self.num_kernel_args):
118
- arg_name = kernel.adj.args[i].label
119
- if arg_name in in_out_argnames:
120
- raise AssertionError(
121
- f"Expected an output-only argument for argument {arg_name}."
122
- " in_out arguments should be placed before output-only arguments."
123
- )
124
- arg = FfiArg(arg_name, kernel.adj.args[i].type, False)
125
- if not arg.is_array:
126
- raise TypeError("All output arguments must be arrays")
127
- self.output_args.append(arg)
128
-
129
- if in_out_argnames:
130
- raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
131
-
132
- # Build input output aliases.
133
- out_id = 0
134
- input_output_aliases = {}
135
- for in_id, arg in enumerate(self.input_args):
136
- if not arg.in_out:
137
- continue
138
- input_output_aliases[in_id] = out_id
139
- out_id += 1
140
- self.input_output_aliases = input_output_aliases
141
-
142
- # register the callback
143
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
144
- self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
145
- ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
146
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
147
- jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
148
-
149
- def __call__(self, *args, output_dims=None, launch_dims=None, vmap_method=None):
150
- num_inputs = len(args)
151
- if num_inputs != self.num_inputs:
152
- raise ValueError(f"Expected {self.num_inputs} inputs, but got {num_inputs}")
153
-
154
- # default argument fallback
155
- if launch_dims is None:
156
- launch_dims = self.launch_dims
157
- if output_dims is None:
158
- output_dims = self.output_dims
159
- if vmap_method is None:
160
- vmap_method = self.vmap_method
161
-
162
- # output types
163
- out_types = []
164
-
165
- # process inputs
166
- static_inputs = {}
167
- for i in range(num_inputs):
168
- input_arg = self.input_args[i]
169
- input_value = args[i]
170
- if input_arg.is_array:
171
- # check dtype
172
- if input_value.dtype != input_arg.jax_scalar_type:
173
- raise TypeError(
174
- f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
175
- )
176
- # check ndim
177
- if input_value.ndim != input_arg.jax_ndim:
178
- raise TypeError(
179
- f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
180
- )
181
- # check inner dims
182
- for d in range(input_arg.dtype_ndim):
183
- if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
184
- raise TypeError(
185
- f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
186
- )
187
- else:
188
- # make sure scalar is not a traced variable, should be static
189
- if isinstance(input_value, jax.core.Tracer):
190
- raise ValueError(f"Argument '{input_arg.name}' must be a static value")
191
- # stash the value to be retrieved by callback
192
- static_inputs[input_arg.name] = input_arg.type(input_value)
193
-
194
- # append in-out arg to output types
195
- if input_arg.in_out:
196
- out_types.append(get_jax_output_type(input_arg, input_value.shape))
197
-
198
- # launch dimensions
199
- if launch_dims is None:
200
- # use the shape of the first input array
201
- if self.first_array_arg is not None:
202
- launch_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
203
- else:
204
- raise RuntimeError("Failed to determine launch dimensions")
205
- elif isinstance(launch_dims, int):
206
- launch_dims = (launch_dims,)
207
- else:
208
- launch_dims = tuple(launch_dims)
209
-
210
- # output shapes
211
- if isinstance(output_dims, dict):
212
- # assume a dictionary of shapes keyed on argument name
213
- for output_arg in self.output_args:
214
- dims = output_dims.get(output_arg.name)
215
- if dims is None:
216
- raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
217
- out_types.append(get_jax_output_type(output_arg, dims))
218
- else:
219
- if output_dims is None:
220
- # use launch dimensions
221
- output_dims = launch_dims
222
- elif isinstance(output_dims, int):
223
- output_dims = (output_dims,)
224
- # assume same dimensions for all outputs
225
- for output_arg in self.output_args:
226
- out_types.append(get_jax_output_type(output_arg, output_dims))
227
-
228
- call = jax.ffi.ffi_call(
229
- self.name,
230
- out_types,
231
- vmap_method=vmap_method,
232
- input_output_aliases=self.input_output_aliases,
233
- )
234
-
235
- # ensure the kernel module is loaded before the callback, otherwise graph capture may fail
236
- device = wp.device_from_jax(get_jax_device())
237
- self.kernel.module.load(device)
238
-
239
- # save launch data to be retrieved by callback
240
- launch_id = self.launch_id
241
- self.launch_descriptors[launch_id] = FfiLaunchDesc(static_inputs, launch_dims)
242
- self.launch_id += 1
243
-
244
- return call(*args, launch_id=launch_id)
245
-
246
- def ffi_callback(self, call_frame):
247
- try:
248
- # On the first call, XLA runtime will query the API version and traits
249
- # metadata using the |extension| field. Let us respond to that query
250
- # if the metadata extension is present.
251
- extension = call_frame.contents.extension_start
252
- if extension:
253
- # Try to set the version metadata.
254
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
255
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
256
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
257
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
258
- # Turn on CUDA graphs for this handler.
259
- metadata_ext.contents.metadata.contents.traits = (
260
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
261
- )
262
- return None
263
-
264
- # retrieve call info
265
- attrs = decode_attrs(call_frame.contents.attrs)
266
- launch_id = int(attrs["launch_id"])
267
- launch_desc = self.launch_descriptors[launch_id]
268
-
269
- num_inputs = call_frame.contents.args.size
270
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
271
-
272
- num_outputs = call_frame.contents.rets.size
273
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
274
-
275
- assert num_inputs == self.num_inputs
276
- assert num_outputs == self.num_outputs
277
-
278
- launch_bounds = launch_bounds_t(launch_desc.launch_dims)
279
-
280
- # first kernel param is the launch bounds
281
- kernel_params = (ctypes.c_void_p * (1 + self.num_kernel_args))()
282
- kernel_params[0] = ctypes.addressof(launch_bounds)
283
-
284
- arg_refs = []
285
-
286
- # input and in-out args
287
- for i, input_arg in enumerate(self.input_args):
288
- if input_arg.is_array:
289
- buffer = inputs[i].contents
290
- shape = buffer.dims[: input_arg.type.ndim]
291
- strides = strides_from_shape(shape, input_arg.type.dtype)
292
- arg = array_t(buffer.data, 0, input_arg.type.ndim, shape, strides)
293
- kernel_params[i + 1] = ctypes.addressof(arg)
294
- arg_refs.append(arg) # keep a reference
295
- else:
296
- # scalar argument, get stashed value
297
- value = launch_desc.static_inputs[input_arg.name]
298
- arg = input_arg.type._type_(value)
299
- kernel_params[i + 1] = ctypes.addressof(arg)
300
- arg_refs.append(arg) # keep a reference
301
-
302
- # pure output args (skip in-out FFI buffers)
303
- for i, output_arg in enumerate(self.output_args):
304
- buffer = outputs[i + self.num_in_out].contents
305
- shape = buffer.dims[: output_arg.type.ndim]
306
- strides = strides_from_shape(shape, output_arg.type.dtype)
307
- arg = array_t(buffer.data, 0, output_arg.type.ndim, shape, strides)
308
- kernel_params[num_inputs + i + 1] = ctypes.addressof(arg)
309
- arg_refs.append(arg) # keep a reference
310
-
311
- # get device and stream
312
- device = wp.device_from_jax(get_jax_device())
313
- stream = get_stream_from_callframe(call_frame.contents)
314
-
315
- # get kernel hooks
316
- hooks = self.kernel.module.get_kernel_hooks(self.kernel, device)
317
- assert hooks.forward, "Failed to find kernel entry point"
318
-
319
- # launch the kernel
320
- wp.context.runtime.core.wp_cuda_launch_kernel(
321
- device.context,
322
- hooks.forward,
323
- launch_bounds.size,
324
- 0,
325
- 256,
326
- hooks.forward_smem_bytes,
327
- kernel_params,
328
- stream,
329
- )
330
-
331
- except Exception as e:
332
- print(traceback.format_exc())
333
- return create_ffi_error(
334
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
335
- )
336
-
337
-
338
- class FfiCallDesc:
339
- def __init__(self, static_inputs):
340
- self.static_inputs = static_inputs
341
- self.captures = {}
342
-
343
-
344
- class FfiCallable:
345
- def __init__(self, func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames):
346
- self.func = func
347
- self.name = generate_unique_name(func)
348
- self.num_outputs = num_outputs
349
- self.vmap_method = vmap_method
350
- self.graph_mode = graph_mode
351
- self.output_dims = output_dims
352
- self.first_array_arg = None
353
- self.call_id = 0
354
- self.call_descriptors = {}
355
-
356
- in_out_argnames_list = in_out_argnames or []
357
- in_out_argnames = set(in_out_argnames_list)
358
- if len(in_out_argnames_list) != len(in_out_argnames):
359
- raise AssertionError("in_out_argnames must not contain duplicate names")
360
-
361
- # get arguments and annotations
362
- argspec = get_full_arg_spec(func)
363
-
364
- num_args = len(argspec.args)
365
- self.num_in_out = len(in_out_argnames)
366
- self.num_inputs = num_args - num_outputs + self.num_in_out
367
- if self.num_outputs < 1:
368
- raise ValueError("At least one output is required")
369
- if self.num_outputs > num_args:
370
- raise ValueError("Number of outputs cannot be greater than the number of kernel arguments")
371
- if self.num_outputs < self.num_in_out:
372
- raise ValueError("Number of outputs cannot be smaller than the number of in_out_argnames")
373
-
374
- if len(argspec.annotations) < num_args:
375
- raise RuntimeError(f"Incomplete argument annotations on function {self.name}")
376
-
377
- # parse type annotations
378
- self.args = []
379
- arg_idx = 0
380
- for arg_name, arg_type in argspec.annotations.items():
381
- if arg_name == "return":
382
- if arg_type is not None:
383
- raise TypeError("Function must not return a value")
384
- continue
385
- else:
386
- arg = FfiArg(arg_name, arg_type, arg_name in in_out_argnames)
387
- if arg_name in in_out_argnames:
388
- in_out_argnames.remove(arg_name)
389
- if arg.is_array:
390
- if arg_idx < self.num_inputs and self.first_array_arg is None:
391
- self.first_array_arg = arg_idx
392
- self.args.append(arg)
393
-
394
- if arg.in_out and arg_idx >= self.num_inputs:
395
- raise AssertionError(
396
- f"Expected an output-only argument for argument {arg_name}."
397
- " in_out arguments should be placed before output-only arguments."
398
- )
399
-
400
- arg_idx += 1
401
-
402
- if in_out_argnames:
403
- raise ValueError(f"in_out_argnames: '{in_out_argnames}' did not match any function argument names.")
404
-
405
- self.input_args = self.args[: self.num_inputs] # includes in-out args
406
- self.output_args = self.args[self.num_inputs :] # pure output args
407
-
408
- # Buffer indices for array arguments in callback.
409
- # In-out buffers are the same pointers in the XLA call frame,
410
- # so we only include them for inputs and skip them for outputs.
411
- self.array_input_indices = [i for i, arg in enumerate(self.input_args) if arg.is_array]
412
- self.array_output_indices = list(range(self.num_in_out, self.num_outputs))
413
-
414
- # Build input output aliases.
415
- out_id = 0
416
- input_output_aliases = {}
417
- for in_id, arg in enumerate(self.input_args):
418
- if not arg.in_out:
419
- continue
420
- input_output_aliases[in_id] = out_id
421
- out_id += 1
422
- self.input_output_aliases = input_output_aliases
423
-
424
- # register the callback
425
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
426
- self.callback_func = FFI_CCALLFUNC(lambda call_frame: self.ffi_callback(call_frame))
427
- ffi_ccall_address = ctypes.cast(self.callback_func, ctypes.c_void_p)
428
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
429
- jax.ffi.register_ffi_target(self.name, ffi_capsule, platform="CUDA")
430
-
431
- def __call__(self, *args, output_dims=None, vmap_method=None):
432
- num_inputs = len(args)
433
- if num_inputs != self.num_inputs:
434
- input_names = ", ".join(arg.name for arg in self.input_args)
435
- s = "" if self.num_inputs == 1 else "s"
436
- raise ValueError(f"Expected {self.num_inputs} input{s} ({input_names}), but got {num_inputs}")
437
-
438
- # default argument fallback
439
- if vmap_method is None:
440
- vmap_method = self.vmap_method
441
- if output_dims is None:
442
- output_dims = self.output_dims
443
-
444
- # output types
445
- out_types = []
446
-
447
- # process inputs
448
- static_inputs = {}
449
- for i in range(num_inputs):
450
- input_arg = self.input_args[i]
451
- input_value = args[i]
452
- if input_arg.is_array:
453
- # check dtype
454
- if input_value.dtype != input_arg.jax_scalar_type:
455
- raise TypeError(
456
- f"Invalid data type for array argument '{input_arg.name}', expected {input_arg.jax_scalar_type}, got {input_value.dtype}"
457
- )
458
- # check ndim
459
- if input_value.ndim != input_arg.jax_ndim:
460
- raise TypeError(
461
- f"Invalid dimensionality for array argument '{input_arg.name}', expected {input_arg.jax_ndim} dimensions, got {input_value.ndim}"
462
- )
463
- # check inner dims
464
- for d in range(input_arg.dtype_ndim):
465
- if input_value.shape[input_arg.type.ndim + d] != input_arg.dtype_shape[d]:
466
- raise TypeError(
467
- f"Invalid inner dimensions for array argument '{input_arg.name}', expected {input_arg.dtype_shape}, got {input_value.shape[-input_arg.dtype_ndim :]}"
468
- )
469
- else:
470
- # make sure scalar is not a traced variable, should be static
471
- if isinstance(input_value, jax.core.Tracer):
472
- raise ValueError(f"Argument '{input_arg.name}' must be a static value")
473
- # stash the value to be retrieved by callback
474
- static_inputs[input_arg.name] = input_arg.type(input_value)
475
-
476
- # append in-out arg to output types
477
- if input_arg.in_out:
478
- out_types.append(get_jax_output_type(input_arg, input_value.shape))
479
-
480
- # output shapes
481
- if isinstance(output_dims, dict):
482
- # assume a dictionary of shapes keyed on argument name
483
- for output_arg in self.output_args:
484
- dims = output_dims.get(output_arg.name)
485
- if dims is None:
486
- raise ValueError(f"Missing output dimensions for argument '{output_arg.name}'")
487
- out_types.append(get_jax_output_type(output_arg, dims))
488
- else:
489
- if output_dims is None:
490
- if self.first_array_arg is None:
491
- raise ValueError("Unable to determine output dimensions")
492
- output_dims = get_warp_shape(self.input_args[self.first_array_arg], args[self.first_array_arg].shape)
493
- elif isinstance(output_dims, int):
494
- output_dims = (output_dims,)
495
- # assume same dimensions for all outputs
496
- for output_arg in self.output_args:
497
- out_types.append(get_jax_output_type(output_arg, output_dims))
498
-
499
- call = jax.ffi.ffi_call(
500
- self.name,
501
- out_types,
502
- vmap_method=vmap_method,
503
- input_output_aliases=self.input_output_aliases,
504
- # has_side_effect=True, # force this function to execute even if outputs aren't used
505
- )
506
-
507
- # load the module
508
- # NOTE: if the target function uses kernels from different modules, they will not be loaded here
509
- device = wp.device_from_jax(get_jax_device())
510
- module = wp.get_module(self.func.__module__)
511
- module.load(device)
512
-
513
- # save call data to be retrieved by callback
514
- call_id = self.call_id
515
- self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
516
- self.call_id += 1
517
- return call(*args, call_id=call_id)
518
-
519
- def ffi_callback(self, call_frame):
520
- try:
521
- # On the first call, XLA runtime will query the API version and traits
522
- # metadata using the |extension| field. Let us respond to that query
523
- # if the metadata extension is present.
524
- extension = call_frame.contents.extension_start
525
- if extension:
526
- # Try to set the version metadata.
527
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
528
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
529
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
530
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
531
- # Turn on CUDA graphs for this handler.
532
- if self.graph_mode is GraphMode.JAX:
533
- metadata_ext.contents.metadata.contents.traits = (
534
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
535
- )
536
- return None
537
-
538
- # retrieve call info
539
- # NOTE: this assumes that there's only one attribute - call_id (int64).
540
- # A more general but slower approach is this:
541
- # attrs = decode_attrs(call_frame.contents.attrs)
542
- # call_id = int(attrs["call_id"])
543
- attr = ctypes.cast(call_frame.contents.attrs.attrs[0], ctypes.POINTER(XLA_FFI_Scalar)).contents
544
- call_id = ctypes.cast(attr.value, ctypes.POINTER(ctypes.c_int64)).contents.value
545
- call_desc = self.call_descriptors[call_id]
546
-
547
- num_inputs = call_frame.contents.args.size
548
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
549
-
550
- num_outputs = call_frame.contents.rets.size
551
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
552
-
553
- assert num_inputs == self.num_inputs
554
- assert num_outputs == self.num_outputs
555
-
556
- cuda_stream = get_stream_from_callframe(call_frame.contents)
557
-
558
- if self.graph_mode == GraphMode.WARP:
559
- # check if we already captured an identical call
560
- ip = [inputs[i].contents.data for i in self.array_input_indices]
561
- op = [outputs[i].contents.data for i in self.array_output_indices]
562
- buffer_hash = hash((*ip, *op))
563
- capture = call_desc.captures.get(buffer_hash)
564
-
565
- # launch existing graph
566
- if capture is not None:
567
- # NOTE: We use the native graph API to avoid overhead with obtaining Stream and Device objects in Python.
568
- # This code should match wp.capture_launch().
569
- graph = capture.graph
570
- if graph.graph_exec is None:
571
- g = ctypes.c_void_p()
572
- if not wp.context.runtime.core.wp_cuda_graph_create_exec(
573
- graph.device.context, cuda_stream, graph.graph, ctypes.byref(g)
574
- ):
575
- raise RuntimeError(f"Graph creation error: {wp.context.runtime.get_error_string()}")
576
- graph.graph_exec = g
577
-
578
- if not wp.context.runtime.core.wp_cuda_graph_launch(graph.graph_exec, cuda_stream):
579
- raise RuntimeError(f"Graph launch error: {wp.context.runtime.get_error_string()}")
580
-
581
- # early out
582
- return
583
-
584
- device = wp.device_from_jax(get_jax_device())
585
- stream = wp.Stream(device, cuda_stream=cuda_stream)
586
-
587
- # reconstruct the argument list
588
- arg_list = []
589
-
590
- # input and in-out args
591
- for i, arg in enumerate(self.input_args):
592
- if arg.is_array:
593
- buffer = inputs[i].contents
594
- shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
595
- arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
596
- arg_list.append(arr)
597
- else:
598
- # scalar argument, get stashed value
599
- value = call_desc.static_inputs[arg.name]
600
- arg_list.append(value)
601
-
602
- # pure output args (skip in-out FFI buffers)
603
- for i, arg in enumerate(self.output_args):
604
- buffer = outputs[i + self.num_in_out].contents
605
- shape = buffer.dims[: buffer.rank - arg.dtype_ndim]
606
- arr = wp.array(ptr=buffer.data, dtype=arg.type.dtype, shape=shape, device=device)
607
- arg_list.append(arr)
608
-
609
- # call the Python function with reconstructed arguments
610
- with wp.ScopedStream(stream, sync_enter=False):
611
- if stream.is_capturing:
612
- # capturing with JAX
613
- with wp.ScopedCapture(external=True) as capture:
614
- self.func(*arg_list)
615
- # keep a reference to the capture object to prevent required modules getting unloaded
616
- call_desc.capture = capture
617
- elif self.graph_mode == GraphMode.WARP:
618
- # capturing with WARP
619
- with wp.ScopedCapture() as capture:
620
- self.func(*arg_list)
621
- wp.capture_launch(capture.graph)
622
- # keep a reference to the capture object and reuse it with same buffers
623
- call_desc.captures[buffer_hash] = capture
624
- else:
625
- # not capturing
626
- self.func(*arg_list)
627
-
628
- except Exception as e:
629
- print(traceback.format_exc())
630
- return create_ffi_error(
631
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
632
- )
633
-
634
- return None
635
-
636
-
637
- # Holders for the custom callbacks to keep them alive.
638
- _FFI_CALLABLE_REGISTRY: dict[str, FfiCallable] = {}
639
- _FFI_KERNEL_REGISTRY: dict[str, FfiKernel] = {}
640
- _FFI_REGISTRY_LOCK = threading.Lock()
641
-
642
-
643
- def jax_kernel(
644
- kernel, num_outputs=1, vmap_method="broadcast_all", launch_dims=None, output_dims=None, in_out_argnames=None
645
- ):
646
- """Create a JAX callback from a Warp kernel.
647
-
648
- NOTE: This is an experimental feature under development.
649
-
650
- Args:
651
- kernel: The Warp kernel to launch.
652
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
653
- This must include the number of ``in_out_arguments``.
654
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
655
- This argument can also be specified for individual calls.
656
- launch_dims: Optional. Specify the default kernel launch dimensions. If None, launch
657
- dimensions are inferred from the shape of the first array argument.
658
- This argument can also be specified for individual calls.
659
- output_dims: Optional. Specify the default dimensions of output arrays. If None, output
660
- dimensions are inferred from the launch dimensions.
661
- This argument can also be specified for individual calls.
662
- in_out_argnames: Optional. Names of input-output arguments.
663
-
664
- Limitations:
665
- - All kernel arguments must be contiguous arrays or scalars.
666
- - Scalars must be static arguments in JAX.
667
- - Input and input-output arguments must precede the output arguments in the ``kernel`` definition.
668
- - There must be at least one output or input-output argument.
669
- - Only the CUDA backend is supported.
670
- """
671
- key = (
672
- kernel.func,
673
- num_outputs,
674
- vmap_method,
675
- tuple(launch_dims) if launch_dims else launch_dims,
676
- tuple(sorted(output_dims.items())) if output_dims else output_dims,
677
- )
678
-
679
- with _FFI_REGISTRY_LOCK:
680
- if key not in _FFI_KERNEL_REGISTRY:
681
- new_kernel = FfiKernel(kernel, num_outputs, vmap_method, launch_dims, output_dims, in_out_argnames)
682
- _FFI_KERNEL_REGISTRY[key] = new_kernel
683
-
684
- return _FFI_KERNEL_REGISTRY[key]
685
-
686
-
687
- def jax_callable(
688
- func: Callable,
689
- num_outputs: int = 1,
690
- graph_compatible: Optional[bool] = None, # deprecated
691
- graph_mode: GraphMode = GraphMode.JAX,
692
- vmap_method: Optional[str] = "broadcast_all",
693
- output_dims=None,
694
- in_out_argnames=None,
695
- ):
696
- """Create a JAX callback from an annotated Python function.
697
-
698
- The Python function arguments must have type annotations like Warp kernels.
699
-
700
- NOTE: This is an experimental feature under development.
701
-
702
- Args:
703
- func: The Python function to call.
704
- num_outputs: Optional. Specify the number of output arguments if greater than 1.
705
- This must include the number of ``in_out_arguments``.
706
- graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
707
- This argument is deprecated, use ``graph_mode`` instead.
708
- graph_mode: Optional. CUDA graph capture mode.
709
- ``GraphMode.JAX`` (default): Let JAX capture the graph, which may be used as a subgraph in an enclosing capture.
710
- ``GraphMode.WARP``: Let Warp capture the graph. Use this mode when the callable cannot be used as a subraph,
711
- such as when the callable uses conditional graph nodes.
712
- ``GraphMode.NONE``: Disable graph capture. Use when the callable performs operations that are not legal in a graph,
713
- such as host synchronization.
714
- vmap_method: Optional. String specifying how the callback transforms under ``vmap()``.
715
- This argument can also be specified for individual calls.
716
- output_dims: Optional. Specify the default dimensions of output arrays.
717
- If ``None``, output dimensions are inferred from the launch dimensions.
718
- This argument can also be specified for individual calls.
719
- in_out_argnames: Optional. Names of input-output arguments.
720
-
721
- Limitations:
722
- - All kernel arguments must be contiguous arrays or scalars.
723
- - Scalars must be static arguments in JAX.
724
- - Input and input-output arguments must precede the output arguments in the ``func`` definition.
725
- - There must be at least one output or input-output argument.
726
- - Only the CUDA backend is supported.
727
- """
728
-
729
- if graph_compatible is not None:
730
- wp.utils.warn(
731
- "The `graph_compatible` argument is deprecated, use `graph_mode` instead.",
732
- DeprecationWarning,
733
- stacklevel=3,
734
- )
735
- if graph_compatible is False:
736
- graph_mode = GraphMode.NONE
737
-
738
- key = (
739
- func,
740
- num_outputs,
741
- graph_mode,
742
- vmap_method,
743
- tuple(sorted(output_dims.items())) if output_dims else output_dims,
744
- )
745
-
746
- with _FFI_REGISTRY_LOCK:
747
- if key not in _FFI_CALLABLE_REGISTRY:
748
- new_callable = FfiCallable(func, num_outputs, graph_mode, vmap_method, output_dims, in_out_argnames)
749
- _FFI_CALLABLE_REGISTRY[key] = new_callable
750
-
751
- return _FFI_CALLABLE_REGISTRY[key]
752
-
753
-
754
- ###############################################################################
755
- #
756
- # Generic FFI callbacks for Python functions of the form
757
- # func(inputs, outputs, attrs, ctx)
758
- #
759
- ###############################################################################
760
-
761
-
762
- def register_ffi_callback(name: str, func: Callable, graph_compatible: bool = True) -> None:
763
- """Create a JAX callback from a Python function.
764
-
765
- The Python function must have the form ``func(inputs, outputs, attrs, ctx)``.
766
-
767
- NOTE: This is an experimental feature under development.
768
-
769
- Args:
770
- name: A unique FFI callback name.
771
- func: The Python function to call.
772
- graph_compatible: Optional. Whether the function can be called during CUDA graph capture.
773
- """
774
-
775
- # TODO check that the name is not already registered
776
-
777
- def ffi_callback(call_frame):
778
- try:
779
- extension = call_frame.contents.extension_start
780
- # On the first call, XLA runtime will query the API version and traits
781
- # metadata using the |extension| field. Let us respond to that query
782
- # if the metadata extension is present.
783
- if extension:
784
- # Try to set the version metadata.
785
- if extension.contents.type == XLA_FFI_Extension_Type.Metadata:
786
- metadata_ext = ctypes.cast(extension, ctypes.POINTER(XLA_FFI_Metadata_Extension))
787
- metadata_ext.contents.metadata.contents.api_version.major_version = 0
788
- metadata_ext.contents.metadata.contents.api_version.minor_version = 1
789
- if graph_compatible:
790
- # Turn on CUDA graphs for this handler.
791
- metadata_ext.contents.metadata.contents.traits = (
792
- XLA_FFI_Handler_TraitsBits.COMMAND_BUFFER_COMPATIBLE
793
- )
794
- return None
795
-
796
- attrs = decode_attrs(call_frame.contents.attrs)
797
-
798
- input_count = call_frame.contents.args.size
799
- inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
800
- inputs = [FfiBuffer(inputs[i].contents) for i in range(input_count)]
801
-
802
- output_count = call_frame.contents.rets.size
803
- outputs = ctypes.cast(call_frame.contents.rets.rets, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
804
- outputs = [FfiBuffer(outputs[i].contents) for i in range(output_count)]
805
-
806
- ctx = ExecutionContext(call_frame.contents)
807
-
808
- func(inputs, outputs, attrs, ctx)
809
- except Exception as e:
810
- print(traceback.format_exc())
811
- return create_ffi_error(
812
- call_frame.contents.api, XLA_FFI_Error_Code.UNKNOWN, f"FFI callback error: {type(e).__name__}: {e}"
813
- )
814
-
815
- return None
816
-
817
- FFI_CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_CallFrame))
818
- callback_func = FFI_CCALLFUNC(ffi_callback)
819
- with _FFI_REGISTRY_LOCK:
820
- _FFI_CALLABLE_REGISTRY[name] = callback_func
821
- ffi_ccall_address = ctypes.cast(callback_func, ctypes.c_void_p)
822
- ffi_capsule = jax.ffi.pycapsule(ffi_ccall_address.value)
823
- jax.ffi.register_ffi_target(name, ffi_capsule, platform="CUDA")
824
-
825
-
826
- ###############################################################################
827
- #
828
- # Utilities
829
- #
830
- ###############################################################################
831
-
832
- # ensure unique FFI callback names
833
- ffi_name_counts = {}
834
-
835
-
836
- def generate_unique_name(func) -> str:
837
- key = make_full_qualified_name(func)
838
- unique_id = ffi_name_counts.get(key, 0)
839
- ffi_name_counts[key] = unique_id + 1
840
- return f"{key}_{unique_id}"
841
-
842
-
843
- def get_warp_shape(arg, dims):
844
- if arg.dtype_ndim > 0:
845
- # vector/matrix array
846
- return dims[: arg.warp_ndim]
847
- else:
848
- # scalar array
849
- return dims
850
-
851
-
852
- def get_jax_output_type(arg, dims):
853
- if isinstance(dims, int):
854
- dims = (dims,)
855
-
856
- ndim = len(dims)
857
-
858
- if arg.dtype_ndim > 0:
859
- # vector/matrix array
860
- if ndim == arg.warp_ndim:
861
- return jax.ShapeDtypeStruct((*dims, *arg.dtype_shape), arg.jax_scalar_type)
862
- elif ndim == arg.jax_ndim:
863
- # make sure inner dimensions match
864
- inner_dims = dims[-arg.dtype_ndim :]
865
- for i in range(arg.dtype_ndim):
866
- if inner_dims[i] != arg.dtype_shape[i]:
867
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
868
- return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
869
- else:
870
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
871
- else:
872
- # scalar array
873
- if ndim != arg.warp_ndim:
874
- raise ValueError(f"Invalid output dimensions for argument '{arg.name}': {dims}")
875
- return jax.ShapeDtypeStruct(dims, arg.jax_scalar_type)
39
+ return get_deprecated_api(_ffi, "wp.jax_experimental", name)