warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2302 -307
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1546 -224
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +3 -3
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +581 -280
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +18 -17
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +580 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,658 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import enum
18
+
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+
24
+ _wp_module_name_ = "warp.jax_experimental.xla_ffi"
25
+
26
+ #######################################################################
27
+ # ctypes structures and enums for XLA's FFI API:
28
+ # https://github.com/openxla/xla/blob/a1a5e62fbffa3a3b6c409d72607456cf5b353a22/xla/ffi/api/c_api.h
29
+ #######################################################################
30
+
31
+
32
+ # typedef enum {
33
+ # XLA_FFI_Extension_Metadata = 1,
34
+ # } XLA_FFI_Extension_Type;
35
+ class XLA_FFI_Extension_Type(enum.IntEnum):
36
+ Metadata = 1
37
+
38
+
39
+ # typedef struct XLA_FFI_Extension_Base {
40
+ # size_t struct_size;
41
+ # XLA_FFI_Extension_Type type;
42
+ # struct XLA_FFI_Extension_Base* next;
43
+ # } XLA_FFI_Extension_Base;
44
+ class XLA_FFI_Extension_Base(ctypes.Structure):
45
+ pass
46
+
47
+
48
+ XLA_FFI_Extension_Base._fields_ = [
49
+ ("struct_size", ctypes.c_size_t),
50
+ ("type", ctypes.c_int), # XLA_FFI_Extension_Type
51
+ ("next", ctypes.POINTER(XLA_FFI_Extension_Base)),
52
+ ]
53
+
54
+
55
+ # typedef enum {
56
+ # XLA_FFI_ExecutionStage_INSTANTIATE = 0,
57
+ # XLA_FFI_ExecutionStage_PREPARE = 1,
58
+ # XLA_FFI_ExecutionStage_INITIALIZE = 2,
59
+ # XLA_FFI_ExecutionStage_EXECUTE = 3,
60
+ # } XLA_FFI_ExecutionStage;
61
+ class XLA_FFI_ExecutionStage(enum.IntEnum):
62
+ INSTANTIATE = 0
63
+ PREPARE = 1
64
+ INITIALIZE = 2
65
+ EXECUTE = 3
66
+
67
+
68
+ # typedef enum {
69
+ # XLA_FFI_DataType_INVALID = 0,
70
+ # XLA_FFI_DataType_PRED = 1,
71
+ # XLA_FFI_DataType_S8 = 2,
72
+ # XLA_FFI_DataType_S16 = 3,
73
+ # XLA_FFI_DataType_S32 = 4,
74
+ # XLA_FFI_DataType_S64 = 5,
75
+ # XLA_FFI_DataType_U8 = 6,
76
+ # XLA_FFI_DataType_U16 = 7,
77
+ # XLA_FFI_DataType_U32 = 8,
78
+ # XLA_FFI_DataType_U64 = 9,
79
+ # XLA_FFI_DataType_F16 = 10,
80
+ # XLA_FFI_DataType_F32 = 11,
81
+ # XLA_FFI_DataType_F64 = 12,
82
+ # XLA_FFI_DataType_BF16 = 16,
83
+ # XLA_FFI_DataType_C64 = 15,
84
+ # XLA_FFI_DataType_C128 = 18,
85
+ # XLA_FFI_DataType_TOKEN = 17,
86
+ # XLA_FFI_DataType_F8E5M2 = 19,
87
+ # XLA_FFI_DataType_F8E3M4 = 29,
88
+ # XLA_FFI_DataType_F8E4M3 = 28,
89
+ # XLA_FFI_DataType_F8E4M3FN = 20,
90
+ # XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
91
+ # XLA_FFI_DataType_F8E5M2FNUZ = 24,
92
+ # XLA_FFI_DataType_F8E4M3FNUZ = 25,
93
+ # XLA_FFI_DataType_F4E2M1FN = 32,
94
+ # XLA_FFI_DataType_F8E8M0FNU = 33,
95
+ # } XLA_FFI_DataType;
96
+ class XLA_FFI_DataType(enum.IntEnum):
97
+ INVALID = 0
98
+ PRED = 1
99
+ S8 = 2
100
+ S16 = 3
101
+ S32 = 4
102
+ S64 = 5
103
+ U8 = 6
104
+ U16 = 7
105
+ U32 = 8
106
+ U64 = 9
107
+ F16 = 10
108
+ F32 = 11
109
+ F64 = 12
110
+ BF16 = 16
111
+ C64 = 15
112
+ C128 = 18
113
+ TOKEN = 17
114
+ F8E5M2 = 19
115
+ F8E3M4 = 29
116
+ F8E4M3 = 28
117
+ F8E4M3FN = 20
118
+ F8E4M3B11FNUZ = 23
119
+ F8E5M2FNUZ = 24
120
+ F8E4M3FNUZ = 25
121
+ F4E2M1FN = 32
122
+ F8E8M0FNU = 33
123
+
124
+
125
+ # struct XLA_FFI_Buffer {
126
+ # size_t struct_size;
127
+ # XLA_FFI_Extension_Base* extension_start;
128
+ #
129
+ # XLA_FFI_DataType dtype;
130
+ # void* data;
131
+ # int64_t rank;
132
+ # int64_t* dims; // length == rank
133
+ # };
134
+ class XLA_FFI_Buffer(ctypes.Structure):
135
+ _fields_ = (
136
+ ("struct_size", ctypes.c_size_t),
137
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
138
+ ("dtype", ctypes.c_int), # XLA_FFI_DataType
139
+ ("data", ctypes.c_void_p),
140
+ ("rank", ctypes.c_int64),
141
+ ("dims", ctypes.POINTER(ctypes.c_int64)),
142
+ )
143
+
144
+
145
+ # typedef enum {
146
+ # XLA_FFI_ArgType_BUFFER = 1,
147
+ # } XLA_FFI_ArgType;
148
+ class XLA_FFI_ArgType(enum.IntEnum):
149
+ BUFFER = 1
150
+
151
+
152
+ # typedef enum {
153
+ # XLA_FFI_RetType_BUFFER = 1,
154
+ # } XLA_FFI_RetType;
155
+ class XLA_FFI_RetType(enum.IntEnum):
156
+ BUFFER = 1
157
+
158
+
159
+ # struct XLA_FFI_Args {
160
+ # size_t struct_size;
161
+ # XLA_FFI_Extension_Base* extension_start;
162
+ # int64_t size;
163
+ # XLA_FFI_ArgType* types; // length == size
164
+ # void** args; // length == size
165
+ # };
166
+ class XLA_FFI_Args(ctypes.Structure):
167
+ _fields_ = (
168
+ ("struct_size", ctypes.c_size_t),
169
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
170
+ ("size", ctypes.c_int64),
171
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
172
+ ("args", ctypes.POINTER(ctypes.c_void_p)),
173
+ )
174
+
175
+
176
+ # struct XLA_FFI_Rets {
177
+ # size_t struct_size;
178
+ # XLA_FFI_Extension_Base* extension_start;
179
+ # int64_t size;
180
+ # XLA_FFI_RetType* types; // length == size
181
+ # void** rets; // length == size
182
+ # };
183
+ class XLA_FFI_Rets(ctypes.Structure):
184
+ _fields_ = (
185
+ ("struct_size", ctypes.c_size_t),
186
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
187
+ ("size", ctypes.c_int64),
188
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
189
+ ("rets", ctypes.POINTER(ctypes.c_void_p)),
190
+ )
191
+
192
+
193
+ # typedef struct XLA_FFI_ByteSpan {
194
+ # const char* ptr;
195
+ # size_t len;
196
+ # } XLA_FFI_ByteSpan;
197
+ class XLA_FFI_ByteSpan(ctypes.Structure):
198
+ _fields_ = (
199
+ ("ptr", ctypes.POINTER(ctypes.c_char)),
200
+ ("len", ctypes.c_size_t),
201
+ )
202
+
203
+
204
+ # typedef struct XLA_FFI_Scalar {
205
+ # XLA_FFI_DataType dtype;
206
+ # void* value;
207
+ # } XLA_FFI_Scalar;
208
+ class XLA_FFI_Scalar(ctypes.Structure):
209
+ _fields_ = (
210
+ ("dtype", ctypes.c_int),
211
+ ("value", ctypes.c_void_p),
212
+ )
213
+
214
+
215
+ # typedef struct XLA_FFI_Array {
216
+ # XLA_FFI_DataType dtype;
217
+ # size_t size;
218
+ # void* data;
219
+ # } XLA_FFI_Array;
220
+ class XLA_FFI_Array(ctypes.Structure):
221
+ _fields_ = (
222
+ ("dtype", ctypes.c_int),
223
+ ("size", ctypes.c_size_t),
224
+ ("data", ctypes.c_void_p),
225
+ )
226
+
227
+
228
+ # typedef enum {
229
+ # XLA_FFI_AttrType_ARRAY = 1,
230
+ # XLA_FFI_AttrType_DICTIONARY = 2,
231
+ # XLA_FFI_AttrType_SCALAR = 3,
232
+ # XLA_FFI_AttrType_STRING = 4,
233
+ # } XLA_FFI_AttrType;
234
+ class XLA_FFI_AttrType(enum.IntEnum):
235
+ ARRAY = 1
236
+ DICTIONARY = 2
237
+ SCALAR = 3
238
+ STRING = 4
239
+
240
+
241
+ # struct XLA_FFI_Attrs {
242
+ # size_t struct_size;
243
+ # XLA_FFI_Extension_Base* extension_start;
244
+ # int64_t size;
245
+ # XLA_FFI_AttrType* types; // length == size
246
+ # XLA_FFI_ByteSpan** names; // length == size
247
+ # void** attrs; // length == size
248
+ # };
249
+ class XLA_FFI_Attrs(ctypes.Structure):
250
+ _fields_ = (
251
+ ("struct_size", ctypes.c_size_t),
252
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
253
+ ("size", ctypes.c_int64),
254
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
255
+ ("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
256
+ ("attrs", ctypes.POINTER(ctypes.c_void_p)),
257
+ )
258
+
259
+
260
+ # struct XLA_FFI_Api_Version {
261
+ # size_t struct_size;
262
+ # XLA_FFI_Extension_Base* extension_start;
263
+ # int major_version; // out
264
+ # int minor_version; // out
265
+ # };
266
+ class XLA_FFI_Api_Version(ctypes.Structure):
267
+ _fields_ = (
268
+ ("struct_size", ctypes.c_size_t),
269
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
270
+ ("major_version", ctypes.c_int),
271
+ ("minor_version", ctypes.c_int),
272
+ )
273
+
274
+
275
+ # enum XLA_FFI_Handler_TraitsBits {
276
+ # // Calls to FFI handler are safe to trace into the command buffer. It means
277
+ # // that calls to FFI handler always launch exactly the same device operations
278
+ # // (can depend on attribute values) that can be captured and then replayed.
279
+ # XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
280
+ # };
281
+ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
282
+ COMMAND_BUFFER_COMPATIBLE = 1 << 0
283
+
284
+
285
+ # struct XLA_FFI_Metadata {
286
+ # size_t struct_size;
287
+ # XLA_FFI_Api_Version api_version;
288
+ # XLA_FFI_Handler_Traits traits;
289
+ # };
290
+ class XLA_FFI_Metadata(ctypes.Structure):
291
+ _fields_ = (
292
+ ("struct_size", ctypes.c_size_t),
293
+ ("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
294
+ ("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
295
+ )
296
+
297
+
298
+ # struct XLA_FFI_Metadata_Extension {
299
+ # XLA_FFI_Extension_Base extension_base;
300
+ # XLA_FFI_Metadata* metadata;
301
+ # };
302
+ class XLA_FFI_Metadata_Extension(ctypes.Structure):
303
+ _fields_ = (
304
+ ("extension_base", XLA_FFI_Extension_Base),
305
+ ("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
306
+ )
307
+
308
+
309
+ # typedef enum {
310
+ # XLA_FFI_Error_Code_OK = 0,
311
+ # XLA_FFI_Error_Code_CANCELLED = 1,
312
+ # XLA_FFI_Error_Code_UNKNOWN = 2,
313
+ # XLA_FFI_Error_Code_INVALID_ARGUMENT = 3,
314
+ # XLA_FFI_Error_Code_DEADLINE_EXCEEDED = 4,
315
+ # XLA_FFI_Error_Code_NOT_FOUND = 5,
316
+ # XLA_FFI_Error_Code_ALREADY_EXISTS = 6,
317
+ # XLA_FFI_Error_Code_PERMISSION_DENIED = 7,
318
+ # XLA_FFI_Error_Code_RESOURCE_EXHAUSTED = 8,
319
+ # XLA_FFI_Error_Code_FAILED_PRECONDITION = 9,
320
+ # XLA_FFI_Error_Code_ABORTED = 10,
321
+ # XLA_FFI_Error_Code_OUT_OF_RANGE = 11,
322
+ # XLA_FFI_Error_Code_UNIMPLEMENTED = 12,
323
+ # XLA_FFI_Error_Code_INTERNAL = 13,
324
+ # XLA_FFI_Error_Code_UNAVAILABLE = 14,
325
+ # XLA_FFI_Error_Code_DATA_LOSS = 15,
326
+ # XLA_FFI_Error_Code_UNAUTHENTICATED = 16
327
+ # } XLA_FFI_Error_Code;
328
+ class XLA_FFI_Error_Code(enum.IntEnum):
329
+ OK = 0
330
+ CANCELLED = 1
331
+ UNKNOWN = 2
332
+ INVALID_ARGUMENT = 3
333
+ DEADLINE_EXCEEDED = 4
334
+ NOT_FOUND = 5
335
+ ALREADY_EXISTS = 6
336
+ PERMISSION_DENIED = 7
337
+ RESOURCE_EXHAUSTED = 8
338
+ FAILED_PRECONDITION = 9
339
+ ABORTED = 10
340
+ OUT_OF_RANGE = 11
341
+ UNIMPLEMENTED = 12
342
+ INTERNAL = 13
343
+ UNAVAILABLE = 14
344
+ DATA_LOSS = 15
345
+ UNAUTHENTICATED = 16
346
+
347
+
348
+ # struct XLA_FFI_Error_Create_Args {
349
+ # size_t struct_size;
350
+ # XLA_FFI_Extension_Base* extension_start;
351
+ # const char* message;
352
+ # XLA_FFI_Error_Code errc;
353
+ # };
354
+ class XLA_FFI_Error_Create_Args(ctypes.Structure):
355
+ _fields_ = (
356
+ ("struct_size", ctypes.c_size_t),
357
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
358
+ ("message", ctypes.c_char_p),
359
+ ("errc", ctypes.c_int),
360
+ ) # XLA_FFI_Error_Code
361
+
362
+
363
+ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
364
+
365
+
366
+ # struct XLA_FFI_Stream_Get_Args {
367
+ # size_t struct_size;
368
+ # XLA_FFI_Extension_Base* extension_start;
369
+ # XLA_FFI_ExecutionContext* ctx;
370
+ # void* stream; // out
371
+ # };
372
+ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
373
+ _fields_ = (
374
+ ("struct_size", ctypes.c_size_t),
375
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
376
+ ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
377
+ ("stream", ctypes.c_void_p),
378
+ ) # // out
379
+
380
+
381
+ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
382
+
383
+
384
+ # struct XLA_FFI_DeviceOrdinal_Get {
385
+ # size_t struct_size;
386
+ # XLA_FFI_Extension_Base* extension_start;
387
+ # XLA_FFI_ExecutionContext* ctx;
388
+ # int32_t device_ordinal; // out
389
+ # };
390
+ class XLA_FFI_DeviceOrdinal_Get_Args(ctypes.Structure):
391
+ _fields_ = (
392
+ ("struct_size", ctypes.c_size_t),
393
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
394
+ ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
395
+ ("device_ordinal", ctypes.c_int32),
396
+ ) # // out
397
+
398
+
399
+ XLA_FFI_DeviceOrdinal_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_DeviceOrdinal_Get_Args))
400
+
401
+
402
+ # struct XLA_FFI_Api {
403
+ # size_t struct_size;
404
+ # XLA_FFI_Extension_Base* extension_start;
405
+ #
406
+ # XLA_FFI_Api_Version api_version;
407
+ # XLA_FFI_InternalApi* internal_api;
408
+ #
409
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create);
410
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage);
411
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy);
412
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
413
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get);
414
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_TypeId_Register);
415
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ExecutionContext_Get);
416
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Set);
417
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get);
418
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate);
419
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free);
420
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule);
421
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_NumThreads);
422
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
423
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
424
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
425
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_RunId_Get);
426
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceOrdinal_Get);
427
+ # };
428
+ class XLA_FFI_Api(ctypes.Structure):
429
+ _fields_ = (
430
+ ("struct_size", ctypes.c_size_t),
431
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
432
+ ("api_version", XLA_FFI_Api_Version),
433
+ ("internal_api", ctypes.c_void_p), # XLA_FFI_InternalApi*
434
+ ("XLA_FFI_Error_Create", XLA_FFI_Error_Create), # XLA_FFI_Error_Create
435
+ ("XLA_FFI_Error_GetMessage", ctypes.c_void_p), # XLA_FFI_Error_GetMessage
436
+ ("XLA_FFI_Error_Destroy", ctypes.c_void_p), # XLA_FFI_Error_Destroy
437
+ ("XLA_FFI_Handler_Register", ctypes.c_void_p), # XLA_FFI_Handler_Register
438
+ ("XLA_FFI_Stream_Get", XLA_FFI_Stream_Get), # XLA_FFI_Stream_Get
439
+ ("XLA_FFI_TypeId_Register", ctypes.c_void_p), # XLA_FFI_TypeId_Register
440
+ ("XLA_FFI_ExecutionContext_Get", ctypes.c_void_p), # XLA_FFI_ExecutionContext_Get
441
+ ("XLA_FFI_State_Set", ctypes.c_void_p), # XLA_FFI_State_Set
442
+ ("XLA_FFI_State_Get", ctypes.c_void_p), # XLA_FFI_State_Get
443
+ ("XLA_FFI_DeviceMemory_Allocate", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Allocate
444
+ ("XLA_FFI_DeviceMemory_Free", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Free
445
+ ("XLA_FFI_ThreadPool_Schedule", ctypes.c_void_p), # XLA_FFI_ThreadPool_Schedule
446
+ ("XLA_FFI_ThreadPool_NumThreads", ctypes.c_void_p), # XLA_FFI_ThreadPool_NumThreads
447
+ ("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
448
+ ("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
449
+ ("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
450
+ # TODO(chaserileyroberts): Make this return the correct value and not a c_void_p.
451
+ ("XLA_FFI_RunId_Get", ctypes.c_void_p), # XLA_FFI_RunId_Get
452
+ ("XLA_FFI_DeviceOrdinal_Get", XLA_FFI_DeviceOrdinal_Get), # XLA_FFI_DeviceOrdinal_Get
453
+ )
454
+
455
+
456
+ # struct XLA_FFI_CallFrame {
457
+ # size_t struct_size;
458
+ # XLA_FFI_Extension_Base* extension_start;
459
+ # const XLA_FFI_Api* api;
460
+ # XLA_FFI_ExecutionContext* ctx;
461
+ # XLA_FFI_ExecutionStage stage;
462
+ # XLA_FFI_Args args;
463
+ # XLA_FFI_Rets rets;
464
+ # XLA_FFI_Attrs attrs;
465
+ #
466
+ # // XLA FFI handler implementation can use `future` to signal a result of
467
+ # // asynchronous computation to the XLA runtime. XLA runtime will keep all
468
+ # // arguments, results and attributes alive until `future` is completed.
469
+ # XLA_FFI_Future* future; // out
470
+ # };
471
+ class XLA_FFI_CallFrame(ctypes.Structure):
472
+ _fields_ = (
473
+ ("struct_size", ctypes.c_size_t),
474
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
475
+ ("api", ctypes.POINTER(XLA_FFI_Api)),
476
+ ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
477
+ ("stage", ctypes.c_int), # XLA_FFI_ExecutionStage
478
+ ("args", XLA_FFI_Args),
479
+ ("rets", XLA_FFI_Rets),
480
+ ("attrs", XLA_FFI_Attrs),
481
+ ("future", ctypes.c_void_p), # XLA_FFI_Future* // out
482
+ )
483
+
484
+
485
+ _xla_data_type_to_constructor = {
486
+ # XLA_FFI_DataType.INVALID
487
+ XLA_FFI_DataType.PRED: jnp.bool,
488
+ XLA_FFI_DataType.S8: jnp.int8,
489
+ XLA_FFI_DataType.S16: jnp.int16,
490
+ XLA_FFI_DataType.S32: jnp.int32,
491
+ XLA_FFI_DataType.S64: jnp.int64,
492
+ XLA_FFI_DataType.U8: jnp.uint8,
493
+ XLA_FFI_DataType.U16: jnp.uint16,
494
+ XLA_FFI_DataType.U32: jnp.uint32,
495
+ XLA_FFI_DataType.U64: jnp.uint64,
496
+ XLA_FFI_DataType.F16: jnp.float16,
497
+ XLA_FFI_DataType.F32: jnp.float32,
498
+ XLA_FFI_DataType.F64: jnp.float64,
499
+ XLA_FFI_DataType.BF16: jnp.bfloat16,
500
+ XLA_FFI_DataType.C64: jnp.complex64,
501
+ XLA_FFI_DataType.C128: jnp.complex128,
502
+ # XLA_FFI_DataType.TOKEN
503
+ # XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
504
+ # XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
505
+ }
506
+
507
+ # newer types not supported by older versions
508
+ if hasattr(jnp, "float8_e5m2"):
509
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2] = jnp.float8_e5m2
510
+ if hasattr(jnp, "float8_e3m4"):
511
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E3M4] = jnp.float8_e3m4
512
+ if hasattr(jnp, "float8_e4m3"):
513
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3] = jnp.float8_e4m3
514
+ if hasattr(jnp, "float8_e4m3fn"):
515
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FN] = jnp.float8_e4m3fn
516
+ if hasattr(jnp, "float8_e4m3b11fnuz"):
517
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3B11FNUZ] = jnp.float8_e4m3b11fnuz
518
+ if hasattr(jnp, "float8_e5m2fnuz"):
519
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2FNUZ] = jnp.float8_e5m2fnuz
520
+ if hasattr(jnp, "float8_e4m3fnuz"):
521
+ _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FNUZ] = jnp.float8_e4m3fnuz
522
+
523
+
524
+ ########################################################################
525
+ # Helpers for translating between ctypes and python types
526
+ #######################################################################
527
+
528
+
529
+ def decode_bytespan(span: XLA_FFI_ByteSpan):
530
+ len = span.len
531
+ chars = ctypes.cast(span.ptr, ctypes.POINTER(ctypes.c_char * len))
532
+ return chars.contents.value.decode("utf-8")
533
+
534
+
535
+ def decode_scalar(scalar: XLA_FFI_Scalar):
536
+ # TODO validate if dtype supported
537
+ dtype = jnp.dtype(_xla_data_type_to_constructor[scalar.dtype])
538
+ bytes = ctypes.string_at(scalar.value, dtype.itemsize)
539
+ return np.frombuffer(bytes, dtype=dtype).reshape(())
540
+
541
+
542
+ def decode_array(array: XLA_FFI_Array):
543
+ # TODO validate if dtype supported
544
+ dtype = jnp.dtype(_xla_data_type_to_constructor[array.dtype])
545
+ bytes = ctypes.string_at(array.data, dtype.itemsize * array.size)
546
+ return np.frombuffer(bytes, dtype=dtype)
547
+
548
+
549
+ def decode_attrs(attrs: XLA_FFI_Attrs):
550
+ result = {}
551
+ for i in range(attrs.size):
552
+ attr_name = decode_bytespan(attrs.names[i].contents)
553
+ attr_type = attrs.types[i]
554
+ if attr_type == XLA_FFI_AttrType.STRING:
555
+ bytespan = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_ByteSpan))
556
+ attr_value = decode_bytespan(bytespan.contents)
557
+ elif attr_type == XLA_FFI_AttrType.SCALAR:
558
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Scalar))
559
+ attr_value = decode_scalar(attr_value.contents)
560
+ elif attr_type == XLA_FFI_AttrType.ARRAY:
561
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Array))
562
+ attr_value = decode_array(attr_value.contents)
563
+ elif attr_type == XLA_FFI_AttrType.DICTIONARY:
564
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Attrs))
565
+ attr_value = decode_attrs(attr_value.contents)
566
+ else:
567
+ raise Exception("Unexpected attr type")
568
+ result[attr_name] = attr_value
569
+ return result
570
+
571
+
572
+ # error-string to XLA_FFI_Error
573
+ def create_ffi_error(api, errc, message):
574
+ create_args = XLA_FFI_Error_Create_Args(
575
+ ctypes.sizeof(XLA_FFI_Error_Create_Args),
576
+ ctypes.POINTER(XLA_FFI_Extension_Base)(),
577
+ ctypes.c_char_p(message.encode("utf-8")),
578
+ errc,
579
+ )
580
+ return api.contents.XLA_FFI_Error_Create(create_args)
581
+
582
+
583
+ def create_invalid_argument_ffi_error(api, message):
584
+ return create_ffi_error(api, XLA_FFI_Error_Code.INVALID_ARGUMENT, message)
585
+
586
+
587
+ # Extract CUDA stream from XLA_FFI_CallFrame.
588
+ def get_stream_from_callframe(call_frame):
589
+ api = call_frame.api
590
+ get_stream_args = XLA_FFI_Stream_Get_Args(
591
+ ctypes.sizeof(XLA_FFI_Stream_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, None
592
+ )
593
+ api.contents.XLA_FFI_Stream_Get(get_stream_args)
594
+ # TODO check result
595
+ return get_stream_args.stream
596
+
597
+
598
+ def get_device_ordinal_from_callframe(call_frame):
599
+ api = call_frame.api
600
+ get_device_args = XLA_FFI_DeviceOrdinal_Get_Args(
601
+ ctypes.sizeof(XLA_FFI_DeviceOrdinal_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, 0
602
+ )
603
+ api.contents.XLA_FFI_DeviceOrdinal_Get(get_device_args)
604
+ return get_device_args.device_ordinal
605
+
606
+
607
+ _dtype_from_ffi = {
608
+ XLA_FFI_DataType.S8: wp.int8,
609
+ XLA_FFI_DataType.S16: wp.int16,
610
+ XLA_FFI_DataType.S32: wp.int32,
611
+ XLA_FFI_DataType.S64: wp.int64,
612
+ XLA_FFI_DataType.U8: wp.uint8,
613
+ XLA_FFI_DataType.U16: wp.uint16,
614
+ XLA_FFI_DataType.U32: wp.uint32,
615
+ XLA_FFI_DataType.U64: wp.uint64,
616
+ XLA_FFI_DataType.F16: wp.float16,
617
+ XLA_FFI_DataType.F32: wp.float32,
618
+ XLA_FFI_DataType.F64: wp.float64,
619
+ }
620
+
621
+
622
+ def dtype_from_ffi(ffi_dtype):
623
+ return _dtype_from_ffi.get(ffi_dtype)
624
+
625
+
626
+ def jax_dtype_from_ffi(ffi_dtype):
627
+ return _xla_data_type_to_constructor.get(ffi_dtype)
628
+
629
+
630
+ # Execution context (stream, stage)
631
+ class ExecutionContext:
632
+ stage: XLA_FFI_ExecutionStage
633
+ stream: int
634
+
635
+ def __init__(self, callframe: XLA_FFI_CallFrame):
636
+ self.stage = XLA_FFI_ExecutionStage(callframe.stage)
637
+ self.stream = get_stream_from_callframe(callframe)
638
+
639
+
640
+ class FfiBuffer:
641
+ dtype: str
642
+ data: int
643
+ shape: tuple[int]
644
+
645
+ def __init__(self, xla_buffer):
646
+ # TODO check if valid
647
+ self.dtype = jnp.dtype(_xla_data_type_to_constructor[xla_buffer.dtype])
648
+ self.shape = tuple(xla_buffer.dims[i] for i in range(xla_buffer.rank))
649
+ self.data = xla_buffer.data
650
+
651
+ @property
652
+ def __cuda_array_interface__(self):
653
+ return {
654
+ "shape": self.shape,
655
+ "typestr": self.dtype.char,
656
+ "data": (self.data, False),
657
+ "version": 2,
658
+ }