warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -13,612 +13,12 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import ctypes
17
- import enum
16
+ # TODO: Remove after cleaning up the public API.
18
17
 
19
- import jax.numpy as jnp
20
- import numpy as np
18
+ from warp._src.jax_experimental import xla_ffi as _xla_ffi
21
19
 
22
- import warp as wp
23
20
 
24
- #######################################################################
25
- # ctypes structures and enums for XLA's FFI API:
26
- # https://github.com/openxla/xla/blob/a1a5e62fbffa3a3b6c409d72607456cf5b353a22/xla/ffi/api/c_api.h
27
- #######################################################################
21
+ def __getattr__(name):
22
+ from warp._src.utils import get_deprecated_api
28
23
 
29
-
30
- # typedef enum {
31
- # XLA_FFI_Extension_Metadata = 1,
32
- # } XLA_FFI_Extension_Type;
33
- class XLA_FFI_Extension_Type(enum.IntEnum):
34
- Metadata = 1
35
-
36
-
37
- # typedef struct XLA_FFI_Extension_Base {
38
- # size_t struct_size;
39
- # XLA_FFI_Extension_Type type;
40
- # struct XLA_FFI_Extension_Base* next;
41
- # } XLA_FFI_Extension_Base;
42
- class XLA_FFI_Extension_Base(ctypes.Structure):
43
- pass
44
-
45
-
46
- XLA_FFI_Extension_Base._fields_ = [
47
- ("struct_size", ctypes.c_size_t),
48
- ("type", ctypes.c_int), # XLA_FFI_Extension_Type
49
- ("next", ctypes.POINTER(XLA_FFI_Extension_Base)),
50
- ]
51
-
52
-
53
- # typedef enum {
54
- # XLA_FFI_ExecutionStage_INSTANTIATE = 0,
55
- # XLA_FFI_ExecutionStage_PREPARE = 1,
56
- # XLA_FFI_ExecutionStage_INITIALIZE = 2,
57
- # XLA_FFI_ExecutionStage_EXECUTE = 3,
58
- # } XLA_FFI_ExecutionStage;
59
- class XLA_FFI_ExecutionStage(enum.IntEnum):
60
- INSTANTIATE = 0
61
- PREPARE = 1
62
- INITIALIZE = 2
63
- EXECUTE = 3
64
-
65
-
66
- # typedef enum {
67
- # XLA_FFI_DataType_INVALID = 0,
68
- # XLA_FFI_DataType_PRED = 1,
69
- # XLA_FFI_DataType_S8 = 2,
70
- # XLA_FFI_DataType_S16 = 3,
71
- # XLA_FFI_DataType_S32 = 4,
72
- # XLA_FFI_DataType_S64 = 5,
73
- # XLA_FFI_DataType_U8 = 6,
74
- # XLA_FFI_DataType_U16 = 7,
75
- # XLA_FFI_DataType_U32 = 8,
76
- # XLA_FFI_DataType_U64 = 9,
77
- # XLA_FFI_DataType_F16 = 10,
78
- # XLA_FFI_DataType_F32 = 11,
79
- # XLA_FFI_DataType_F64 = 12,
80
- # XLA_FFI_DataType_BF16 = 16,
81
- # XLA_FFI_DataType_C64 = 15,
82
- # XLA_FFI_DataType_C128 = 18,
83
- # XLA_FFI_DataType_TOKEN = 17,
84
- # XLA_FFI_DataType_F8E5M2 = 19,
85
- # XLA_FFI_DataType_F8E3M4 = 29,
86
- # XLA_FFI_DataType_F8E4M3 = 28,
87
- # XLA_FFI_DataType_F8E4M3FN = 20,
88
- # XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
89
- # XLA_FFI_DataType_F8E5M2FNUZ = 24,
90
- # XLA_FFI_DataType_F8E4M3FNUZ = 25,
91
- # XLA_FFI_DataType_F4E2M1FN = 32,
92
- # XLA_FFI_DataType_F8E8M0FNU = 33,
93
- # } XLA_FFI_DataType;
94
- class XLA_FFI_DataType(enum.IntEnum):
95
- INVALID = 0
96
- PRED = 1
97
- S8 = 2
98
- S16 = 3
99
- S32 = 4
100
- S64 = 5
101
- U8 = 6
102
- U16 = 7
103
- U32 = 8
104
- U64 = 9
105
- F16 = 10
106
- F32 = 11
107
- F64 = 12
108
- BF16 = 16
109
- C64 = 15
110
- C128 = 18
111
- TOKEN = 17
112
- F8E5M2 = 19
113
- F8E3M4 = 29
114
- F8E4M3 = 28
115
- F8E4M3FN = 20
116
- F8E4M3B11FNUZ = 23
117
- F8E5M2FNUZ = 24
118
- F8E4M3FNUZ = 25
119
- F4E2M1FN = 32
120
- F8E8M0FNU = 33
121
-
122
-
123
- # struct XLA_FFI_Buffer {
124
- # size_t struct_size;
125
- # XLA_FFI_Extension_Base* extension_start;
126
- #
127
- # XLA_FFI_DataType dtype;
128
- # void* data;
129
- # int64_t rank;
130
- # int64_t* dims; // length == rank
131
- # };
132
- class XLA_FFI_Buffer(ctypes.Structure):
133
- _fields_ = (
134
- ("struct_size", ctypes.c_size_t),
135
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
136
- ("dtype", ctypes.c_int), # XLA_FFI_DataType
137
- ("data", ctypes.c_void_p),
138
- ("rank", ctypes.c_int64),
139
- ("dims", ctypes.POINTER(ctypes.c_int64)),
140
- )
141
-
142
-
143
- # typedef enum {
144
- # XLA_FFI_ArgType_BUFFER = 1,
145
- # } XLA_FFI_ArgType;
146
- class XLA_FFI_ArgType(enum.IntEnum):
147
- BUFFER = 1
148
-
149
-
150
- # typedef enum {
151
- # XLA_FFI_RetType_BUFFER = 1,
152
- # } XLA_FFI_RetType;
153
- class XLA_FFI_RetType(enum.IntEnum):
154
- BUFFER = 1
155
-
156
-
157
- # struct XLA_FFI_Args {
158
- # size_t struct_size;
159
- # XLA_FFI_Extension_Base* extension_start;
160
- # int64_t size;
161
- # XLA_FFI_ArgType* types; // length == size
162
- # void** args; // length == size
163
- # };
164
- class XLA_FFI_Args(ctypes.Structure):
165
- _fields_ = (
166
- ("struct_size", ctypes.c_size_t),
167
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
168
- ("size", ctypes.c_int64),
169
- ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
170
- ("args", ctypes.POINTER(ctypes.c_void_p)),
171
- )
172
-
173
-
174
- # struct XLA_FFI_Rets {
175
- # size_t struct_size;
176
- # XLA_FFI_Extension_Base* extension_start;
177
- # int64_t size;
178
- # XLA_FFI_RetType* types; // length == size
179
- # void** rets; // length == size
180
- # };
181
- class XLA_FFI_Rets(ctypes.Structure):
182
- _fields_ = (
183
- ("struct_size", ctypes.c_size_t),
184
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
185
- ("size", ctypes.c_int64),
186
- ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
187
- ("rets", ctypes.POINTER(ctypes.c_void_p)),
188
- )
189
-
190
-
191
- # typedef struct XLA_FFI_ByteSpan {
192
- # const char* ptr;
193
- # size_t len;
194
- # } XLA_FFI_ByteSpan;
195
- class XLA_FFI_ByteSpan(ctypes.Structure):
196
- _fields_ = (
197
- ("ptr", ctypes.POINTER(ctypes.c_char)),
198
- ("len", ctypes.c_size_t),
199
- )
200
-
201
-
202
- # typedef struct XLA_FFI_Scalar {
203
- # XLA_FFI_DataType dtype;
204
- # void* value;
205
- # } XLA_FFI_Scalar;
206
- class XLA_FFI_Scalar(ctypes.Structure):
207
- _fields_ = (
208
- ("dtype", ctypes.c_int),
209
- ("value", ctypes.c_void_p),
210
- )
211
-
212
-
213
- # typedef struct XLA_FFI_Array {
214
- # XLA_FFI_DataType dtype;
215
- # size_t size;
216
- # void* data;
217
- # } XLA_FFI_Array;
218
- class XLA_FFI_Array(ctypes.Structure):
219
- _fields_ = (
220
- ("dtype", ctypes.c_int),
221
- ("size", ctypes.c_size_t),
222
- ("data", ctypes.c_void_p),
223
- )
224
-
225
-
226
- # typedef enum {
227
- # XLA_FFI_AttrType_ARRAY = 1,
228
- # XLA_FFI_AttrType_DICTIONARY = 2,
229
- # XLA_FFI_AttrType_SCALAR = 3,
230
- # XLA_FFI_AttrType_STRING = 4,
231
- # } XLA_FFI_AttrType;
232
- class XLA_FFI_AttrType(enum.IntEnum):
233
- ARRAY = 1
234
- DICTIONARY = 2
235
- SCALAR = 3
236
- STRING = 4
237
-
238
-
239
- # struct XLA_FFI_Attrs {
240
- # size_t struct_size;
241
- # XLA_FFI_Extension_Base* extension_start;
242
- # int64_t size;
243
- # XLA_FFI_AttrType* types; // length == size
244
- # XLA_FFI_ByteSpan** names; // length == size
245
- # void** attrs; // length == size
246
- # };
247
- class XLA_FFI_Attrs(ctypes.Structure):
248
- _fields_ = (
249
- ("struct_size", ctypes.c_size_t),
250
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
251
- ("size", ctypes.c_int64),
252
- ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
253
- ("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
254
- ("attrs", ctypes.POINTER(ctypes.c_void_p)),
255
- )
256
-
257
-
258
- # struct XLA_FFI_Api_Version {
259
- # size_t struct_size;
260
- # XLA_FFI_Extension_Base* extension_start;
261
- # int major_version; // out
262
- # int minor_version; // out
263
- # };
264
- class XLA_FFI_Api_Version(ctypes.Structure):
265
- _fields_ = (
266
- ("struct_size", ctypes.c_size_t),
267
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
268
- ("major_version", ctypes.c_int),
269
- ("minor_version", ctypes.c_int),
270
- )
271
-
272
-
273
- # enum XLA_FFI_Handler_TraitsBits {
274
- # // Calls to FFI handler are safe to trace into the command buffer. It means
275
- # // that calls to FFI handler always launch exactly the same device operations
276
- # // (can depend on attribute values) that can be captured and then replayed.
277
- # XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
278
- # };
279
- class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
280
- COMMAND_BUFFER_COMPATIBLE = 1 << 0
281
-
282
-
283
- # struct XLA_FFI_Metadata {
284
- # size_t struct_size;
285
- # XLA_FFI_Api_Version api_version;
286
- # XLA_FFI_Handler_Traits traits;
287
- # };
288
- class XLA_FFI_Metadata(ctypes.Structure):
289
- _fields_ = (
290
- ("struct_size", ctypes.c_size_t),
291
- ("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
292
- ("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
293
- )
294
-
295
-
296
- # struct XLA_FFI_Metadata_Extension {
297
- # XLA_FFI_Extension_Base extension_base;
298
- # XLA_FFI_Metadata* metadata;
299
- # };
300
- class XLA_FFI_Metadata_Extension(ctypes.Structure):
301
- _fields_ = (
302
- ("extension_base", XLA_FFI_Extension_Base),
303
- ("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
304
- )
305
-
306
-
307
- # typedef enum {
308
- # XLA_FFI_Error_Code_OK = 0,
309
- # XLA_FFI_Error_Code_CANCELLED = 1,
310
- # XLA_FFI_Error_Code_UNKNOWN = 2,
311
- # XLA_FFI_Error_Code_INVALID_ARGUMENT = 3,
312
- # XLA_FFI_Error_Code_DEADLINE_EXCEEDED = 4,
313
- # XLA_FFI_Error_Code_NOT_FOUND = 5,
314
- # XLA_FFI_Error_Code_ALREADY_EXISTS = 6,
315
- # XLA_FFI_Error_Code_PERMISSION_DENIED = 7,
316
- # XLA_FFI_Error_Code_RESOURCE_EXHAUSTED = 8,
317
- # XLA_FFI_Error_Code_FAILED_PRECONDITION = 9,
318
- # XLA_FFI_Error_Code_ABORTED = 10,
319
- # XLA_FFI_Error_Code_OUT_OF_RANGE = 11,
320
- # XLA_FFI_Error_Code_UNIMPLEMENTED = 12,
321
- # XLA_FFI_Error_Code_INTERNAL = 13,
322
- # XLA_FFI_Error_Code_UNAVAILABLE = 14,
323
- # XLA_FFI_Error_Code_DATA_LOSS = 15,
324
- # XLA_FFI_Error_Code_UNAUTHENTICATED = 16
325
- # } XLA_FFI_Error_Code;
326
- class XLA_FFI_Error_Code(enum.IntEnum):
327
- OK = 0
328
- CANCELLED = 1
329
- UNKNOWN = 2
330
- INVALID_ARGUMENT = 3
331
- DEADLINE_EXCEEDED = 4
332
- NOT_FOUND = 5
333
- ALREADY_EXISTS = 6
334
- PERMISSION_DENIED = 7
335
- RESOURCE_EXHAUSTED = 8
336
- FAILED_PRECONDITION = 9
337
- ABORTED = 10
338
- OUT_OF_RANGE = 11
339
- UNIMPLEMENTED = 12
340
- INTERNAL = 13
341
- UNAVAILABLE = 14
342
- DATA_LOSS = 15
343
- UNAUTHENTICATED = 16
344
-
345
-
346
- # struct XLA_FFI_Error_Create_Args {
347
- # size_t struct_size;
348
- # XLA_FFI_Extension_Base* extension_start;
349
- # const char* message;
350
- # XLA_FFI_Error_Code errc;
351
- # };
352
- class XLA_FFI_Error_Create_Args(ctypes.Structure):
353
- _fields_ = (
354
- ("struct_size", ctypes.c_size_t),
355
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
356
- ("message", ctypes.c_char_p),
357
- ("errc", ctypes.c_int),
358
- ) # XLA_FFI_Error_Code
359
-
360
-
361
- XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
362
-
363
-
364
- # struct XLA_FFI_Stream_Get_Args {
365
- # size_t struct_size;
366
- # XLA_FFI_Extension_Base* extension_start;
367
- # XLA_FFI_ExecutionContext* ctx;
368
- # void* stream; // out
369
- # };
370
- class XLA_FFI_Stream_Get_Args(ctypes.Structure):
371
- _fields_ = (
372
- ("struct_size", ctypes.c_size_t),
373
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
374
- ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
375
- ("stream", ctypes.c_void_p),
376
- ) # // out
377
-
378
-
379
- XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
380
-
381
-
382
- # struct XLA_FFI_Api {
383
- # size_t struct_size;
384
- # XLA_FFI_Extension_Base* extension_start;
385
- #
386
- # XLA_FFI_Api_Version api_version;
387
- # XLA_FFI_InternalApi* internal_api;
388
- #
389
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create);
390
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage);
391
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy);
392
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
393
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get);
394
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_TypeId_Register);
395
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ExecutionContext_Get);
396
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Set);
397
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get);
398
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate);
399
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free);
400
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule);
401
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_NumThreads);
402
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
403
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
404
- # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
405
- # };
406
- class XLA_FFI_Api(ctypes.Structure):
407
- _fields_ = (
408
- ("struct_size", ctypes.c_size_t),
409
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
410
- ("api_version", XLA_FFI_Api_Version),
411
- ("internal_api", ctypes.c_void_p), # XLA_FFI_InternalApi*
412
- ("XLA_FFI_Error_Create", XLA_FFI_Error_Create), # XLA_FFI_Error_Create
413
- ("XLA_FFI_Error_GetMessage", ctypes.c_void_p), # XLA_FFI_Error_GetMessage
414
- ("XLA_FFI_Error_Destroy", ctypes.c_void_p), # XLA_FFI_Error_Destroy
415
- ("XLA_FFI_Handler_Register", ctypes.c_void_p), # XLA_FFI_Handler_Register
416
- ("XLA_FFI_Stream_Get", XLA_FFI_Stream_Get), # XLA_FFI_Stream_Get
417
- ("XLA_FFI_TypeId_Register", ctypes.c_void_p), # XLA_FFI_TypeId_Register
418
- ("XLA_FFI_ExecutionContext_Get", ctypes.c_void_p), # XLA_FFI_ExecutionContext_Get
419
- ("XLA_FFI_State_Set", ctypes.c_void_p), # XLA_FFI_State_Set
420
- ("XLA_FFI_State_Get", ctypes.c_void_p), # XLA_FFI_State_Get
421
- ("XLA_FFI_DeviceMemory_Allocate", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Allocate
422
- ("XLA_FFI_DeviceMemory_Free", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Free
423
- ("XLA_FFI_ThreadPool_Schedule", ctypes.c_void_p), # XLA_FFI_ThreadPool_Schedule
424
- ("XLA_FFI_ThreadPool_NumThreads", ctypes.c_void_p), # XLA_FFI_ThreadPool_NumThreads
425
- ("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
426
- ("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
427
- ("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
428
- )
429
-
430
-
431
- # struct XLA_FFI_CallFrame {
432
- # size_t struct_size;
433
- # XLA_FFI_Extension_Base* extension_start;
434
- # const XLA_FFI_Api* api;
435
- # XLA_FFI_ExecutionContext* ctx;
436
- # XLA_FFI_ExecutionStage stage;
437
- # XLA_FFI_Args args;
438
- # XLA_FFI_Rets rets;
439
- # XLA_FFI_Attrs attrs;
440
- #
441
- # // XLA FFI handler implementation can use `future` to signal a result of
442
- # // asynchronous computation to the XLA runtime. XLA runtime will keep all
443
- # // arguments, results and attributes alive until `future` is completed.
444
- # XLA_FFI_Future* future; // out
445
- # };
446
- class XLA_FFI_CallFrame(ctypes.Structure):
447
- _fields_ = (
448
- ("struct_size", ctypes.c_size_t),
449
- ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
450
- ("api", ctypes.POINTER(XLA_FFI_Api)),
451
- ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
452
- ("stage", ctypes.c_int), # XLA_FFI_ExecutionStage
453
- ("args", XLA_FFI_Args),
454
- ("rets", XLA_FFI_Rets),
455
- ("attrs", XLA_FFI_Attrs),
456
- ("future", ctypes.c_void_p), # XLA_FFI_Future* // out
457
- )
458
-
459
-
460
- _xla_data_type_to_constructor = {
461
- # XLA_FFI_DataType.INVALID
462
- XLA_FFI_DataType.PRED: jnp.bool,
463
- XLA_FFI_DataType.S8: jnp.int8,
464
- XLA_FFI_DataType.S16: jnp.int16,
465
- XLA_FFI_DataType.S32: jnp.int32,
466
- XLA_FFI_DataType.S64: jnp.int64,
467
- XLA_FFI_DataType.U8: jnp.uint8,
468
- XLA_FFI_DataType.U16: jnp.uint16,
469
- XLA_FFI_DataType.U32: jnp.uint32,
470
- XLA_FFI_DataType.U64: jnp.uint64,
471
- XLA_FFI_DataType.F16: jnp.float16,
472
- XLA_FFI_DataType.F32: jnp.float32,
473
- XLA_FFI_DataType.F64: jnp.float64,
474
- XLA_FFI_DataType.BF16: jnp.bfloat16,
475
- XLA_FFI_DataType.C64: jnp.complex64,
476
- XLA_FFI_DataType.C128: jnp.complex128,
477
- # XLA_FFI_DataType.TOKEN
478
- # XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
479
- # XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
480
- }
481
-
482
- # newer types not supported by older versions
483
- if hasattr(jnp, "float8_e5m2"):
484
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2] = jnp.float8_e5m2
485
- if hasattr(jnp, "float8_e3m4"):
486
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E3M4] = jnp.float8_e3m4
487
- if hasattr(jnp, "float8_e4m3"):
488
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3] = jnp.float8_e4m3
489
- if hasattr(jnp, "float8_e4m3fn"):
490
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FN] = jnp.float8_e4m3fn
491
- if hasattr(jnp, "float8_e4m3b11fnuz"):
492
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3B11FNUZ] = jnp.float8_e4m3b11fnuz
493
- if hasattr(jnp, "float8_e5m2fnuz"):
494
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E5M2FNUZ] = jnp.float8_e5m2fnuz
495
- if hasattr(jnp, "float8_e4m3fnuz"):
496
- _xla_data_type_to_constructor[XLA_FFI_DataType.F8E4M3FNUZ] = jnp.float8_e4m3fnuz
497
-
498
-
499
- ########################################################################
500
- # Helpers for translating between ctypes and python types
501
- #######################################################################
502
-
503
-
504
- def decode_bytespan(span: XLA_FFI_ByteSpan):
505
- len = span.len
506
- chars = ctypes.cast(span.ptr, ctypes.POINTER(ctypes.c_char * len))
507
- return chars.contents.value.decode("utf-8")
508
-
509
-
510
- def decode_scalar(scalar: XLA_FFI_Scalar):
511
- # TODO validate if dtype supported
512
- dtype = jnp.dtype(_xla_data_type_to_constructor[scalar.dtype])
513
- bytes = ctypes.string_at(scalar.value, dtype.itemsize)
514
- return np.frombuffer(bytes, dtype=dtype).reshape(())
515
-
516
-
517
- def decode_array(array: XLA_FFI_Array):
518
- # TODO validate if dtype supported
519
- dtype = jnp.dtype(_xla_data_type_to_constructor[array.dtype])
520
- bytes = ctypes.string_at(array.data, dtype.itemsize * array.size)
521
- return np.frombuffer(bytes, dtype=dtype)
522
-
523
-
524
- def decode_attrs(attrs: XLA_FFI_Attrs):
525
- result = {}
526
- for i in range(attrs.size):
527
- attr_name = decode_bytespan(attrs.names[i].contents)
528
- attr_type = attrs.types[i]
529
- if attr_type == XLA_FFI_AttrType.STRING:
530
- bytespan = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_ByteSpan))
531
- attr_value = decode_bytespan(bytespan.contents)
532
- elif attr_type == XLA_FFI_AttrType.SCALAR:
533
- attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Scalar))
534
- attr_value = decode_scalar(attr_value.contents)
535
- elif attr_type == XLA_FFI_AttrType.ARRAY:
536
- attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Array))
537
- attr_value = decode_array(attr_value.contents)
538
- elif attr_type == XLA_FFI_AttrType.DICTIONARY:
539
- attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Attrs))
540
- attr_value = decode_attrs(attr_value.contents)
541
- else:
542
- raise Exception("Unexpected attr type")
543
- result[attr_name] = attr_value
544
- return result
545
-
546
-
547
- # error-string to XLA_FFI_Error
548
- def create_ffi_error(api, errc, message):
549
- create_args = XLA_FFI_Error_Create_Args(
550
- ctypes.sizeof(XLA_FFI_Error_Create_Args),
551
- ctypes.POINTER(XLA_FFI_Extension_Base)(),
552
- ctypes.c_char_p(message.encode("utf-8")),
553
- errc,
554
- )
555
- return api.contents.XLA_FFI_Error_Create(create_args)
556
-
557
-
558
- def create_invalid_argument_ffi_error(api, message):
559
- return create_ffi_error(api, XLA_FFI_Error_Code.INVALID_ARGUMENT, message)
560
-
561
-
562
- # Extract CUDA stream from XLA_FFI_CallFrame.
563
- def get_stream_from_callframe(call_frame):
564
- api = call_frame.api
565
- get_stream_args = XLA_FFI_Stream_Get_Args(
566
- ctypes.sizeof(XLA_FFI_Stream_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, None
567
- )
568
- api.contents.XLA_FFI_Stream_Get(get_stream_args)
569
- # TODO check result
570
- return get_stream_args.stream
571
-
572
-
573
- _dtype_from_ffi = {
574
- XLA_FFI_DataType.S8: wp.int8,
575
- XLA_FFI_DataType.S16: wp.int16,
576
- XLA_FFI_DataType.S32: wp.int32,
577
- XLA_FFI_DataType.S64: wp.int64,
578
- XLA_FFI_DataType.U8: wp.uint8,
579
- XLA_FFI_DataType.U16: wp.uint16,
580
- XLA_FFI_DataType.U32: wp.uint32,
581
- XLA_FFI_DataType.U64: wp.uint64,
582
- XLA_FFI_DataType.F16: wp.float16,
583
- XLA_FFI_DataType.F32: wp.float32,
584
- XLA_FFI_DataType.F64: wp.float64,
585
- }
586
-
587
-
588
- def dtype_from_ffi(ffi_dtype):
589
- return _dtype_from_ffi.get(ffi_dtype)
590
-
591
-
592
- def jax_dtype_from_ffi(ffi_dtype):
593
- return _xla_data_type_to_constructor.get(ffi_dtype)
594
-
595
-
596
- # Execution context (stream, stage)
597
- class ExecutionContext:
598
- stage: XLA_FFI_ExecutionStage
599
- stream: int
600
-
601
- def __init__(self, callframe: XLA_FFI_CallFrame):
602
- self.stage = XLA_FFI_ExecutionStage(callframe.stage)
603
- self.stream = get_stream_from_callframe(callframe)
604
-
605
-
606
- class FfiBuffer:
607
- dtype: str
608
- data: int
609
- shape: tuple[int]
610
-
611
- def __init__(self, xla_buffer):
612
- # TODO check if valid
613
- self.dtype = jnp.dtype(_xla_data_type_to_constructor[xla_buffer.dtype])
614
- self.shape = tuple(xla_buffer.dims[i] for i in range(xla_buffer.rank))
615
- self.data = xla_buffer.data
616
-
617
- @property
618
- def __cuda_array_interface__(self):
619
- return {
620
- "shape": self.shape,
621
- "typestr": self.dtype.char,
622
- "data": (self.data, False),
623
- "version": 2,
624
- }
24
+ return get_deprecated_api(_xla_ffi, "wp.jax_experimental", name)