warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,602 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ import enum
18
+
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+
24
+ #######################################################################
25
+ # ctypes structures and enums for XLA's FFI API:
26
+ # https://github.com/openxla/xla/blob/a1a5e62fbffa3a3b6c409d72607456cf5b353a22/xla/ffi/api/c_api.h
27
+ #######################################################################
28
+
29
+
30
+ # typedef enum {
31
+ # XLA_FFI_Extension_Metadata = 1,
32
+ # } XLA_FFI_Extension_Type;
33
+ class XLA_FFI_Extension_Type(enum.IntEnum):
34
+ Metadata = 1
35
+
36
+
37
+ # typedef struct XLA_FFI_Extension_Base {
38
+ # size_t struct_size;
39
+ # XLA_FFI_Extension_Type type;
40
+ # struct XLA_FFI_Extension_Base* next;
41
+ # } XLA_FFI_Extension_Base;
42
+ class XLA_FFI_Extension_Base(ctypes.Structure):
43
+ pass
44
+
45
+
46
+ XLA_FFI_Extension_Base._fields_ = [
47
+ ("struct_size", ctypes.c_size_t),
48
+ ("type", ctypes.c_int), # XLA_FFI_Extension_Type
49
+ ("next", ctypes.POINTER(XLA_FFI_Extension_Base)),
50
+ ]
51
+
52
+
53
+ # typedef enum {
54
+ # XLA_FFI_ExecutionStage_INSTANTIATE = 0,
55
+ # XLA_FFI_ExecutionStage_PREPARE = 1,
56
+ # XLA_FFI_ExecutionStage_INITIALIZE = 2,
57
+ # XLA_FFI_ExecutionStage_EXECUTE = 3,
58
+ # } XLA_FFI_ExecutionStage;
59
+ class XLA_FFI_ExecutionStage(enum.IntEnum):
60
+ INSTANTIATE = 0
61
+ PREPARE = 1
62
+ INITIALIZE = 2
63
+ EXECUTE = 3
64
+
65
+
66
+ # typedef enum {
67
+ # XLA_FFI_DataType_INVALID = 0,
68
+ # XLA_FFI_DataType_PRED = 1,
69
+ # XLA_FFI_DataType_S8 = 2,
70
+ # XLA_FFI_DataType_S16 = 3,
71
+ # XLA_FFI_DataType_S32 = 4,
72
+ # XLA_FFI_DataType_S64 = 5,
73
+ # XLA_FFI_DataType_U8 = 6,
74
+ # XLA_FFI_DataType_U16 = 7,
75
+ # XLA_FFI_DataType_U32 = 8,
76
+ # XLA_FFI_DataType_U64 = 9,
77
+ # XLA_FFI_DataType_F16 = 10,
78
+ # XLA_FFI_DataType_F32 = 11,
79
+ # XLA_FFI_DataType_F64 = 12,
80
+ # XLA_FFI_DataType_BF16 = 16,
81
+ # XLA_FFI_DataType_C64 = 15,
82
+ # XLA_FFI_DataType_C128 = 18,
83
+ # XLA_FFI_DataType_TOKEN = 17,
84
+ # XLA_FFI_DataType_F8E5M2 = 19,
85
+ # XLA_FFI_DataType_F8E3M4 = 29,
86
+ # XLA_FFI_DataType_F8E4M3 = 28,
87
+ # XLA_FFI_DataType_F8E4M3FN = 20,
88
+ # XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
89
+ # XLA_FFI_DataType_F8E5M2FNUZ = 24,
90
+ # XLA_FFI_DataType_F8E4M3FNUZ = 25,
91
+ # XLA_FFI_DataType_F4E2M1FN = 32,
92
+ # XLA_FFI_DataType_F8E8M0FNU = 33,
93
+ # } XLA_FFI_DataType;
94
+ class XLA_FFI_DataType(enum.IntEnum):
95
+ INVALID = 0
96
+ PRED = 1
97
+ S8 = 2
98
+ S16 = 3
99
+ S32 = 4
100
+ S64 = 5
101
+ U8 = 6
102
+ U16 = 7
103
+ U32 = 8
104
+ U64 = 9
105
+ F16 = 10
106
+ F32 = 11
107
+ F64 = 12
108
+ BF16 = 16
109
+ C64 = 15
110
+ C128 = 18
111
+ TOKEN = 17
112
+ F8E5M2 = 19
113
+ F8E3M4 = 29
114
+ F8E4M3 = 28
115
+ F8E4M3FN = 20
116
+ F8E4M3B11FNUZ = 23
117
+ F8E5M2FNUZ = 24
118
+ F8E4M3FNUZ = 25
119
+ F4E2M1FN = 32
120
+ F8E8M0FNU = 33
121
+
122
+
123
+ # struct XLA_FFI_Buffer {
124
+ # size_t struct_size;
125
+ # XLA_FFI_Extension_Base* extension_start;
126
+ #
127
+ # XLA_FFI_DataType dtype;
128
+ # void* data;
129
+ # int64_t rank;
130
+ # int64_t* dims; // length == rank
131
+ # };
132
+ class XLA_FFI_Buffer(ctypes.Structure):
133
+ _fields_ = [
134
+ ("struct_size", ctypes.c_size_t),
135
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
136
+ ("dtype", ctypes.c_int), # XLA_FFI_DataType
137
+ ("data", ctypes.c_void_p),
138
+ ("rank", ctypes.c_int64),
139
+ ("dims", ctypes.POINTER(ctypes.c_int64)),
140
+ ]
141
+
142
+
143
+ # typedef enum {
144
+ # XLA_FFI_ArgType_BUFFER = 1,
145
+ # } XLA_FFI_ArgType;
146
+ class XLA_FFI_ArgType(enum.IntEnum):
147
+ BUFFER = 1
148
+
149
+
150
+ # typedef enum {
151
+ # XLA_FFI_RetType_BUFFER = 1,
152
+ # } XLA_FFI_RetType;
153
+ class XLA_FFI_RetType(enum.IntEnum):
154
+ BUFFER = 1
155
+
156
+
157
+ # struct XLA_FFI_Args {
158
+ # size_t struct_size;
159
+ # XLA_FFI_Extension_Base* extension_start;
160
+ # int64_t size;
161
+ # XLA_FFI_ArgType* types; // length == size
162
+ # void** args; // length == size
163
+ # };
164
+ class XLA_FFI_Args(ctypes.Structure):
165
+ _fields_ = [
166
+ ("struct_size", ctypes.c_size_t),
167
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
168
+ ("size", ctypes.c_int64),
169
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
170
+ ("args", ctypes.POINTER(ctypes.c_void_p)),
171
+ ]
172
+
173
+
174
+ # struct XLA_FFI_Rets {
175
+ # size_t struct_size;
176
+ # XLA_FFI_Extension_Base* extension_start;
177
+ # int64_t size;
178
+ # XLA_FFI_RetType* types; // length == size
179
+ # void** rets; // length == size
180
+ # };
181
+ class XLA_FFI_Rets(ctypes.Structure):
182
+ _fields_ = [
183
+ ("struct_size", ctypes.c_size_t),
184
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
185
+ ("size", ctypes.c_int64),
186
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
187
+ ("rets", ctypes.POINTER(ctypes.c_void_p)),
188
+ ]
189
+
190
+
191
+ # typedef struct XLA_FFI_ByteSpan {
192
+ # const char* ptr;
193
+ # size_t len;
194
+ # } XLA_FFI_ByteSpan;
195
+ class XLA_FFI_ByteSpan(ctypes.Structure):
196
+ _fields_ = [("ptr", ctypes.POINTER(ctypes.c_char)), ("len", ctypes.c_size_t)]
197
+
198
+
199
+ # typedef struct XLA_FFI_Scalar {
200
+ # XLA_FFI_DataType dtype;
201
+ # void* value;
202
+ # } XLA_FFI_Scalar;
203
+ class XLA_FFI_Scalar(ctypes.Structure):
204
+ _fields_ = [("dtype", ctypes.c_int), ("value", ctypes.c_void_p)]
205
+
206
+
207
+ # typedef struct XLA_FFI_Array {
208
+ # XLA_FFI_DataType dtype;
209
+ # size_t size;
210
+ # void* data;
211
+ # } XLA_FFI_Array;
212
+ class XLA_FFI_Array(ctypes.Structure):
213
+ _fields_ = [("dtype", ctypes.c_int), ("size", ctypes.c_size_t), ("data", ctypes.c_void_p)]
214
+
215
+
216
+ # typedef enum {
217
+ # XLA_FFI_AttrType_ARRAY = 1,
218
+ # XLA_FFI_AttrType_DICTIONARY = 2,
219
+ # XLA_FFI_AttrType_SCALAR = 3,
220
+ # XLA_FFI_AttrType_STRING = 4,
221
+ # } XLA_FFI_AttrType;
222
+ class XLA_FFI_AttrType(enum.IntEnum):
223
+ ARRAY = 1
224
+ DICTIONARY = 2
225
+ SCALAR = 3
226
+ STRING = 4
227
+
228
+
229
+ # struct XLA_FFI_Attrs {
230
+ # size_t struct_size;
231
+ # XLA_FFI_Extension_Base* extension_start;
232
+ # int64_t size;
233
+ # XLA_FFI_AttrType* types; // length == size
234
+ # XLA_FFI_ByteSpan** names; // length == size
235
+ # void** attrs; // length == size
236
+ # };
237
+ class XLA_FFI_Attrs(ctypes.Structure):
238
+ _fields_ = [
239
+ ("struct_size", ctypes.c_size_t),
240
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
241
+ ("size", ctypes.c_int64),
242
+ ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
243
+ ("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
244
+ ("attrs", ctypes.POINTER(ctypes.c_void_p)),
245
+ ]
246
+
247
+
248
+ # struct XLA_FFI_Api_Version {
249
+ # size_t struct_size;
250
+ # XLA_FFI_Extension_Base* extension_start;
251
+ # int major_version; // out
252
+ # int minor_version; // out
253
+ # };
254
+ class XLA_FFI_Api_Version(ctypes.Structure):
255
+ _fields_ = [
256
+ ("struct_size", ctypes.c_size_t),
257
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
258
+ ("major_version", ctypes.c_int),
259
+ ("minor_version", ctypes.c_int),
260
+ ]
261
+
262
+
263
+ # enum XLA_FFI_Handler_TraitsBits {
264
+ # // Calls to FFI handler are safe to trace into the command buffer. It means
265
+ # // that calls to FFI handler always launch exactly the same device operations
266
+ # // (can depend on attribute values) that can be captured and then replayed.
267
+ # XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
268
+ # };
269
+ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
270
+ COMMAND_BUFFER_COMPATIBLE = 1 << 0
271
+
272
+
273
+ # struct XLA_FFI_Metadata {
274
+ # size_t struct_size;
275
+ # XLA_FFI_Api_Version api_version;
276
+ # XLA_FFI_Handler_Traits traits;
277
+ # };
278
+ class XLA_FFI_Metadata(ctypes.Structure):
279
+ _fields_ = [
280
+ ("struct_size", ctypes.c_size_t),
281
+ ("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
282
+ ("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
283
+ ]
284
+
285
+
286
+ # struct XLA_FFI_Metadata_Extension {
287
+ # XLA_FFI_Extension_Base extension_base;
288
+ # XLA_FFI_Metadata* metadata;
289
+ # };
290
+ class XLA_FFI_Metadata_Extension(ctypes.Structure):
291
+ _fields_ = [("extension_base", XLA_FFI_Extension_Base), ("metadata", ctypes.POINTER(XLA_FFI_Metadata))]
292
+
293
+
294
+ # typedef enum {
295
+ # XLA_FFI_Error_Code_OK = 0,
296
+ # XLA_FFI_Error_Code_CANCELLED = 1,
297
+ # XLA_FFI_Error_Code_UNKNOWN = 2,
298
+ # XLA_FFI_Error_Code_INVALID_ARGUMENT = 3,
299
+ # XLA_FFI_Error_Code_DEADLINE_EXCEEDED = 4,
300
+ # XLA_FFI_Error_Code_NOT_FOUND = 5,
301
+ # XLA_FFI_Error_Code_ALREADY_EXISTS = 6,
302
+ # XLA_FFI_Error_Code_PERMISSION_DENIED = 7,
303
+ # XLA_FFI_Error_Code_RESOURCE_EXHAUSTED = 8,
304
+ # XLA_FFI_Error_Code_FAILED_PRECONDITION = 9,
305
+ # XLA_FFI_Error_Code_ABORTED = 10,
306
+ # XLA_FFI_Error_Code_OUT_OF_RANGE = 11,
307
+ # XLA_FFI_Error_Code_UNIMPLEMENTED = 12,
308
+ # XLA_FFI_Error_Code_INTERNAL = 13,
309
+ # XLA_FFI_Error_Code_UNAVAILABLE = 14,
310
+ # XLA_FFI_Error_Code_DATA_LOSS = 15,
311
+ # XLA_FFI_Error_Code_UNAUTHENTICATED = 16
312
+ # } XLA_FFI_Error_Code;
313
+ class XLA_FFI_Error_Code(enum.IntEnum):
314
+ OK = 0
315
+ CANCELLED = 1
316
+ UNKNOWN = 2
317
+ INVALID_ARGUMENT = 3
318
+ DEADLINE_EXCEEDED = 4
319
+ NOT_FOUND = 5
320
+ ALREADY_EXISTS = 6
321
+ PERMISSION_DENIED = 7
322
+ RESOURCE_EXHAUSTED = 8
323
+ FAILED_PRECONDITION = 9
324
+ ABORTED = 10
325
+ OUT_OF_RANGE = 11
326
+ UNIMPLEMENTED = 12
327
+ INTERNAL = 13
328
+ UNAVAILABLE = 14
329
+ DATA_LOSS = 15
330
+ UNAUTHENTICATED = 16
331
+
332
+
333
+ # struct XLA_FFI_Error_Create_Args {
334
+ # size_t struct_size;
335
+ # XLA_FFI_Extension_Base* extension_start;
336
+ # const char* message;
337
+ # XLA_FFI_Error_Code errc;
338
+ # };
339
+ class XLA_FFI_Error_Create_Args(ctypes.Structure):
340
+ _fields_ = [
341
+ ("struct_size", ctypes.c_size_t),
342
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
343
+ ("message", ctypes.c_char_p),
344
+ ("errc", ctypes.c_int),
345
+ ] # XLA_FFI_Error_Code
346
+
347
+
348
+ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
349
+
350
+
351
+ # struct XLA_FFI_Stream_Get_Args {
352
+ # size_t struct_size;
353
+ # XLA_FFI_Extension_Base* extension_start;
354
+ # XLA_FFI_ExecutionContext* ctx;
355
+ # void* stream; // out
356
+ # };
357
+ class XLA_FFI_Stream_Get_Args(ctypes.Structure):
358
+ _fields_ = [
359
+ ("struct_size", ctypes.c_size_t),
360
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
361
+ ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
362
+ ("stream", ctypes.c_void_p),
363
+ ] # // out
364
+
365
+
366
+ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
367
+
368
+
369
+ # struct XLA_FFI_Api {
370
+ # size_t struct_size;
371
+ # XLA_FFI_Extension_Base* extension_start;
372
+ #
373
+ # XLA_FFI_Api_Version api_version;
374
+ # XLA_FFI_InternalApi* internal_api;
375
+ #
376
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create);
377
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage);
378
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy);
379
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
380
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get);
381
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_TypeId_Register);
382
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ExecutionContext_Get);
383
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Set);
384
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get);
385
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate);
386
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free);
387
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule);
388
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_NumThreads);
389
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
390
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
391
+ # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
392
+ # };
393
+ class XLA_FFI_Api(ctypes.Structure):
394
+ _fields_ = [
395
+ ("struct_size", ctypes.c_size_t),
396
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
397
+ ("api_version", XLA_FFI_Api_Version),
398
+ ("internal_api", ctypes.c_void_p), # XLA_FFI_InternalApi*
399
+ ("XLA_FFI_Error_Create", XLA_FFI_Error_Create), # XLA_FFI_Error_Create
400
+ ("XLA_FFI_Error_GetMessage", ctypes.c_void_p), # XLA_FFI_Error_GetMessage
401
+ ("XLA_FFI_Error_Destroy", ctypes.c_void_p), # XLA_FFI_Error_Destroy
402
+ ("XLA_FFI_Handler_Register", ctypes.c_void_p), # XLA_FFI_Handler_Register
403
+ ("XLA_FFI_Stream_Get", XLA_FFI_Stream_Get), # XLA_FFI_Stream_Get
404
+ ("XLA_FFI_TypeId_Register", ctypes.c_void_p), # XLA_FFI_TypeId_Register
405
+ ("XLA_FFI_ExecutionContext_Get", ctypes.c_void_p), # XLA_FFI_ExecutionContext_Get
406
+ ("XLA_FFI_State_Set", ctypes.c_void_p), # XLA_FFI_State_Set
407
+ ("XLA_FFI_State_Get", ctypes.c_void_p), # XLA_FFI_State_Get
408
+ ("XLA_FFI_DeviceMemory_Allocate", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Allocate
409
+ ("XLA_FFI_DeviceMemory_Free", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Free
410
+ ("XLA_FFI_ThreadPool_Schedule", ctypes.c_void_p), # XLA_FFI_ThreadPool_Schedule
411
+ ("XLA_FFI_ThreadPool_NumThreads", ctypes.c_void_p), # XLA_FFI_ThreadPool_NumThreads
412
+ ("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
413
+ ("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
414
+ ("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
415
+ ]
416
+
417
+
418
+ # struct XLA_FFI_CallFrame {
419
+ # size_t struct_size;
420
+ # XLA_FFI_Extension_Base* extension_start;
421
+ # const XLA_FFI_Api* api;
422
+ # XLA_FFI_ExecutionContext* ctx;
423
+ # XLA_FFI_ExecutionStage stage;
424
+ # XLA_FFI_Args args;
425
+ # XLA_FFI_Rets rets;
426
+ # XLA_FFI_Attrs attrs;
427
+ #
428
+ # // XLA FFI handler implementation can use `future` to signal a result of
429
+ # // asynchronous computation to the XLA runtime. XLA runtime will keep all
430
+ # // arguments, results and attributes alive until `future` is completed.
431
+ # XLA_FFI_Future* future; // out
432
+ # };
433
+ class XLA_FFI_CallFrame(ctypes.Structure):
434
+ _fields_ = [
435
+ ("struct_size", ctypes.c_size_t),
436
+ ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
437
+ ("api", ctypes.POINTER(XLA_FFI_Api)),
438
+ ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
439
+ ("stage", ctypes.c_int), # XLA_FFI_ExecutionStage
440
+ ("args", XLA_FFI_Args),
441
+ ("rets", XLA_FFI_Rets),
442
+ ("attrs", XLA_FFI_Attrs),
443
+ ("future", ctypes.c_void_p), # XLA_FFI_Future* // out
444
+ ]
445
+
446
+
447
+ _xla_data_type_to_constructor = {
448
+ # XLA_FFI_DataType.INVALID
449
+ XLA_FFI_DataType.PRED: jnp.bool,
450
+ XLA_FFI_DataType.S8: jnp.int8,
451
+ XLA_FFI_DataType.S16: jnp.int16,
452
+ XLA_FFI_DataType.S32: jnp.int32,
453
+ XLA_FFI_DataType.S64: jnp.int64,
454
+ XLA_FFI_DataType.U8: jnp.uint8,
455
+ XLA_FFI_DataType.U16: jnp.uint16,
456
+ XLA_FFI_DataType.U32: jnp.uint32,
457
+ XLA_FFI_DataType.U64: jnp.uint64,
458
+ XLA_FFI_DataType.F16: jnp.float16,
459
+ XLA_FFI_DataType.F32: jnp.float32,
460
+ XLA_FFI_DataType.F64: jnp.float64,
461
+ XLA_FFI_DataType.BF16: jnp.bfloat16,
462
+ XLA_FFI_DataType.C64: jnp.complex64,
463
+ XLA_FFI_DataType.C128: jnp.complex128,
464
+ # XLA_FFI_DataType.TOKEN
465
+ XLA_FFI_DataType.F8E5M2: jnp.float8_e5m2,
466
+ XLA_FFI_DataType.F8E3M4: jnp.float8_e3m4,
467
+ XLA_FFI_DataType.F8E4M3: jnp.float8_e4m3,
468
+ XLA_FFI_DataType.F8E4M3FN: jnp.float8_e4m3fn,
469
+ XLA_FFI_DataType.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,
470
+ XLA_FFI_DataType.F8E5M2FNUZ: jnp.float8_e5m2fnuz,
471
+ XLA_FFI_DataType.F8E4M3FNUZ: jnp.float8_e4m3fnuz,
472
+ # XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
473
+ # XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
474
+ }
475
+
476
+
477
+ ########################################################################
478
+ # Helpers for translating between ctypes and python types
479
+ #######################################################################
480
+
481
+
482
+ def decode_bytespan(span: XLA_FFI_ByteSpan):
483
+ len = span.len
484
+ chars = ctypes.cast(span.ptr, ctypes.POINTER(ctypes.c_char * len))
485
+ return chars.contents.value.decode("utf-8")
486
+
487
+
488
+ def decode_scalar(scalar: XLA_FFI_Scalar):
489
+ # TODO validate if dtype supported
490
+ dtype = jnp.dtype(_xla_data_type_to_constructor[scalar.dtype])
491
+ bytes = ctypes.string_at(scalar.value, dtype.itemsize)
492
+ return np.frombuffer(bytes, dtype=dtype).reshape(())
493
+
494
+
495
+ def decode_array(array: XLA_FFI_Array):
496
+ # TODO validate if dtype supported
497
+ dtype = jnp.dtype(_xla_data_type_to_constructor[array.dtype])
498
+ bytes = ctypes.string_at(array.data, dtype.itemsize * array.size)
499
+ return np.frombuffer(bytes, dtype=dtype)
500
+
501
+
502
+ def decode_attrs(attrs: XLA_FFI_Attrs):
503
+ result = {}
504
+ for i in range(attrs.size):
505
+ attr_name = decode_bytespan(attrs.names[i].contents)
506
+ attr_type = attrs.types[i]
507
+ if attr_type == XLA_FFI_AttrType.STRING:
508
+ bytespan = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_ByteSpan))
509
+ attr_value = decode_bytespan(bytespan.contents)
510
+ elif attr_type == XLA_FFI_AttrType.SCALAR:
511
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Scalar))
512
+ attr_value = decode_scalar(attr_value.contents)
513
+ elif attr_type == XLA_FFI_AttrType.ARRAY:
514
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Array))
515
+ attr_value = decode_array(attr_value.contents)
516
+ elif attr_type == XLA_FFI_AttrType.DICTIONARY:
517
+ attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Attrs))
518
+ attr_value = decode_attrs(attr_value.contents)
519
+ else:
520
+ raise Exception("Unexpected attr type")
521
+ result[attr_name] = attr_value
522
+ return result
523
+
524
+
525
+ # error-string to XLA_FFI_Error
526
+ def create_ffi_error(api, errc, message):
527
+ create_args = XLA_FFI_Error_Create_Args(
528
+ ctypes.sizeof(XLA_FFI_Error_Create_Args),
529
+ ctypes.POINTER(XLA_FFI_Extension_Base)(),
530
+ ctypes.c_char_p(message.encode("utf-8")),
531
+ errc,
532
+ )
533
+ return api.contents.XLA_FFI_Error_Create(create_args)
534
+
535
+
536
+ def create_invalid_argument_ffi_error(api, message):
537
+ return create_ffi_error(api, XLA_FFI_Error_Code.INVALID_ARGUMENT, message)
538
+
539
+
540
+ # Extract CUDA stream from XLA_FFI_CallFrame.
541
+ def get_stream_from_callframe(call_frame):
542
+ api = call_frame.api
543
+ get_stream_args = XLA_FFI_Stream_Get_Args(
544
+ ctypes.sizeof(XLA_FFI_Stream_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, None
545
+ )
546
+ api.contents.XLA_FFI_Stream_Get(get_stream_args)
547
+ # TODO check result
548
+ return get_stream_args.stream
549
+
550
+
551
+ _dtype_from_ffi = {
552
+ XLA_FFI_DataType.S8: wp.int8,
553
+ XLA_FFI_DataType.S16: wp.int16,
554
+ XLA_FFI_DataType.S32: wp.int32,
555
+ XLA_FFI_DataType.S64: wp.int64,
556
+ XLA_FFI_DataType.U8: wp.uint8,
557
+ XLA_FFI_DataType.U16: wp.uint16,
558
+ XLA_FFI_DataType.U32: wp.uint32,
559
+ XLA_FFI_DataType.U64: wp.uint64,
560
+ XLA_FFI_DataType.F16: wp.float16,
561
+ XLA_FFI_DataType.F32: wp.float32,
562
+ XLA_FFI_DataType.F64: wp.float64,
563
+ }
564
+
565
+
566
+ def dtype_from_ffi(ffi_dtype):
567
+ return _dtype_from_ffi.get(ffi_dtype)
568
+
569
+
570
+ def jax_dtype_from_ffi(ffi_dtype):
571
+ return _xla_data_type_to_constructor.get(ffi_dtype)
572
+
573
+
574
+ # Execution context (stream, stage)
575
+ class ExecutionContext:
576
+ stage: XLA_FFI_ExecutionStage
577
+ stream: int
578
+
579
+ def __init__(self, callframe: XLA_FFI_CallFrame):
580
+ self.stage = XLA_FFI_ExecutionStage(callframe.stage)
581
+ self.stream = get_stream_from_callframe(callframe)
582
+
583
+
584
+ class FfiBuffer:
585
+ dtype: str
586
+ data: int
587
+ shape: tuple[int]
588
+
589
+ def __init__(self, xla_buffer):
590
+ # TODO check if valid
591
+ self.dtype = jnp.dtype(_xla_data_type_to_constructor[xla_buffer.dtype])
592
+ self.shape = tuple(xla_buffer.dims[i] for i in range(xla_buffer.rank))
593
+ self.data = xla_buffer.data
594
+
595
+ @property
596
+ def __cuda_array_interface__(self):
597
+ return {
598
+ "shape": self.shape,
599
+ "typestr": self.dtype.char,
600
+ "data": (self.data, False),
601
+ "version": 2,
602
+ }