warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,18 @@
15
15
 
16
16
  import os
17
17
  import unittest
18
+ from functools import partial
18
19
  from typing import Any
19
20
 
20
21
  import numpy as np
21
22
 
22
23
  import warp as wp
24
+ from warp._src.jax import get_jax_device
23
25
  from warp.tests.unittest_utils import *
24
26
 
27
+ # default array size for tests
28
+ ARRAY_SIZE = 1024 * 1024
29
+
25
30
 
26
31
  # basic kernel with one input and output
27
32
  @wp.kernel
@@ -44,6 +49,18 @@ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)
44
49
  output[tid] = input.dtype.dtype(3) * input[tid]
45
50
 
46
51
 
52
+ @wp.kernel
53
+ def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
54
+ tid = wp.tid()
55
+ y[tid] = x[tid] + 1.0
56
+
57
+
58
+ @wp.kernel
59
+ def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
60
+ i, j = wp.tid()
61
+ y[i, j] = x[i, j] + 1.0
62
+
63
+
47
64
  # kernel with multiple inputs and outputs
48
65
  @wp.kernel
49
66
  def multiarg_kernel(
@@ -61,7 +78,7 @@ def multiarg_kernel(
61
78
 
62
79
 
63
80
  # various types for testing
64
- scalar_types = wp.types.scalar_types
81
+ scalar_types = wp._src.types.scalar_types
65
82
  vector_types = []
66
83
  matrix_types = []
67
84
  for dim in [2, 3, 4]:
@@ -132,15 +149,19 @@ def test_device_conversion(test, device):
132
149
  test.assertEqual(warp_device, device)
133
150
 
134
151
 
135
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
136
- def test_jax_kernel_basic(test, device):
152
+ def test_jax_kernel_basic(test, device, use_ffi=False):
137
153
  import jax.numpy as jp
138
154
 
139
- from warp.jax_experimental import jax_kernel
155
+ if use_ffi:
156
+ from warp.jax_experimental.ffi import jax_kernel
140
157
 
141
- n = 64
158
+ jax_triple = jax_kernel(triple_kernel)
159
+ else:
160
+ from warp.jax_experimental.custom_call import jax_kernel
161
+
162
+ jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
142
163
 
143
- jax_triple = jax_kernel(triple_kernel)
164
+ n = ARRAY_SIZE
144
165
 
145
166
  @jax.jit
146
167
  def f():
@@ -151,18 +172,27 @@ def test_jax_kernel_basic(test, device):
151
172
  with jax.default_device(wp.device_to_jax(device)):
152
173
  y = f()
153
174
 
175
+ wp.synchronize_device(device)
176
+
154
177
  result = np.asarray(y).reshape((n,))
155
178
  expected = 3 * np.arange(n, dtype=np.float32)
156
179
 
157
180
  assert_np_equal(result, expected)
158
181
 
159
182
 
160
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
161
- def test_jax_kernel_scalar(test, device):
183
+ def test_jax_kernel_scalar(test, device, use_ffi=False):
162
184
  import jax.numpy as jp
163
185
 
164
- from warp.jax_experimental import jax_kernel
186
+ if use_ffi:
187
+ from warp.jax_experimental.ffi import jax_kernel
188
+
189
+ kwargs = {}
190
+ else:
191
+ from warp.jax_experimental.custom_call import jax_kernel
165
192
 
193
+ kwargs = {"quiet": True}
194
+
195
+ # use a smallish size to ensure arange * 3 doesn't overflow
166
196
  n = 64
167
197
 
168
198
  for T in scalar_types:
@@ -173,7 +203,7 @@ def test_jax_kernel_scalar(test, device):
173
203
  # get the concrete overload
174
204
  kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
175
205
 
176
- jax_triple = jax_kernel(kernel_instance)
206
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
177
207
 
178
208
  @jax.jit
179
209
  def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
@@ -184,22 +214,31 @@ def test_jax_kernel_scalar(test, device):
184
214
  with jax.default_device(wp.device_to_jax(device)):
185
215
  y = f()
186
216
 
217
+ wp.synchronize_device(device)
218
+
187
219
  result = np.asarray(y).reshape((n,))
188
220
  expected = 3 * np.arange(n, dtype=np_dtype)
189
221
 
190
222
  assert_np_equal(result, expected)
191
223
 
192
224
 
193
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
194
- def test_jax_kernel_vecmat(test, device):
225
+ def test_jax_kernel_vecmat(test, device, use_ffi=False):
195
226
  import jax.numpy as jp
196
227
 
197
- from warp.jax_experimental import jax_kernel
228
+ if use_ffi:
229
+ from warp.jax_experimental.ffi import jax_kernel
230
+
231
+ kwargs = {}
232
+ else:
233
+ from warp.jax_experimental.custom_call import jax_kernel
234
+
235
+ kwargs = {"quiet": True}
198
236
 
199
237
  for T in [*vector_types, *matrix_types]:
200
238
  jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
201
239
  np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
202
240
 
241
+ # use a smallish size to ensure arange * 3 doesn't overflow
203
242
  n = 64 // T._length_
204
243
  scalar_shape = (n, *T._shape_)
205
244
  scalar_len = n * T._length_
@@ -208,7 +247,7 @@ def test_jax_kernel_vecmat(test, device):
208
247
  # get the concrete overload
209
248
  kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
210
249
 
211
- jax_triple = jax_kernel(kernel_instance)
250
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
212
251
 
213
252
  @jax.jit
214
253
  def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
@@ -219,21 +258,27 @@ def test_jax_kernel_vecmat(test, device):
219
258
  with jax.default_device(wp.device_to_jax(device)):
220
259
  y = f()
221
260
 
261
+ wp.synchronize_device(device)
262
+
222
263
  result = np.asarray(y).reshape(scalar_shape)
223
264
  expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
224
265
 
225
266
  assert_np_equal(result, expected)
226
267
 
227
268
 
228
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
229
- def test_jax_kernel_multiarg(test, device):
269
+ def test_jax_kernel_multiarg(test, device, use_ffi=False):
230
270
  import jax.numpy as jp
231
271
 
232
- from warp.jax_experimental import jax_kernel
272
+ if use_ffi:
273
+ from warp.jax_experimental.ffi import jax_kernel
233
274
 
234
- n = 64
275
+ jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
276
+ else:
277
+ from warp.jax_experimental.custom_call import jax_kernel
235
278
 
236
- jax_multiarg = jax_kernel(multiarg_kernel)
279
+ jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
280
+
281
+ n = ARRAY_SIZE
237
282
 
238
283
  @jax.jit
239
284
  def f():
@@ -246,6 +291,8 @@ def test_jax_kernel_multiarg(test, device):
246
291
  with jax.default_device(wp.device_to_jax(device)):
247
292
  x, y = f()
248
293
 
294
+ wp.synchronize_device(device)
295
+
249
296
  result_x, result_y = np.asarray(x), np.asarray(y)
250
297
  expected_x = np.full(n, 3, dtype=np.float32)
251
298
  expected_y = np.full(n, 5, dtype=np.float32)
@@ -254,50 +301,48 @@ def test_jax_kernel_multiarg(test, device):
254
301
  assert_np_equal(result_y, expected_y)
255
302
 
256
303
 
257
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
258
- def test_jax_kernel_launch_dims(test, device):
304
+ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
259
305
  import jax.numpy as jp
260
306
 
261
- from warp.jax_experimental import jax_kernel
307
+ if use_ffi:
308
+ from warp.jax_experimental.ffi import jax_kernel
309
+
310
+ kwargs = {}
311
+ else:
312
+ from warp.jax_experimental.custom_call import jax_kernel
313
+
314
+ kwargs = {"quiet": True}
262
315
 
263
316
  n = 64
264
317
  m = 32
265
318
 
266
319
  # Test with 1D launch dims
267
- @wp.kernel
268
- def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
269
- tid = wp.tid()
270
- y[tid] = x[tid] + 1.0
271
-
272
- jax_add_one = jax_kernel(
273
- add_one_kernel, launch_dims=(n - 2,)
320
+ jax_inc_1d = jax_kernel(
321
+ inc_1d_kernel, launch_dims=(n - 2,), **kwargs
274
322
  ) # Intentionally not the same as the first dimension of the input
275
323
 
276
324
  @jax.jit
277
325
  def f_1d():
278
326
  x = jp.arange(n, dtype=jp.float32)
279
- return jax_add_one(x)
327
+ return jax_inc_1d(x)
280
328
 
281
329
  # Test with 2D launch dims
282
- @wp.kernel
283
- def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
284
- i, j = wp.tid()
285
- y[i, j] = x[i, j] + 1.0
286
-
287
- jax_add_one_2d = jax_kernel(
288
- add_one_2d_kernel, launch_dims=(n - 2, m - 2)
330
+ jax_inc_2d = jax_kernel(
331
+ inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
289
332
  ) # Intentionally not the same as the first dimension of the input
290
333
 
291
334
  @jax.jit
292
335
  def f_2d():
293
336
  x = jp.zeros((n, m), dtype=jp.float32) + 3.0
294
- return jax_add_one_2d(x)
337
+ return jax_inc_2d(x)
295
338
 
296
339
  # run on the given device
297
340
  with jax.default_device(wp.device_to_jax(device)):
298
341
  y_1d = f_1d()
299
342
  y_2d = f_2d()
300
343
 
344
+ wp.synchronize_device(device)
345
+
301
346
  result_1d = np.asarray(y_1d).reshape((n - 2,))
302
347
  expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
303
348
 
@@ -308,57 +353,1315 @@ def test_jax_kernel_launch_dims(test, device):
308
353
  assert_np_equal(result_2d, expected_2d)
309
354
 
310
355
 
311
- class TestJax(unittest.TestCase):
312
- pass
356
+ # =========================================================================================================
357
+ # JAX FFI
358
+ # =========================================================================================================
313
359
 
314
360
 
315
- # try adding Jax tests if Jax is installed correctly
316
- try:
317
- # prevent Jax from gobbling up GPU memory
318
- os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
319
- os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
361
+ @wp.kernel
362
+ def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
363
+ tid = wp.tid()
364
+ output[tid] = a[tid] + b[tid]
320
365
 
321
- import jax
322
366
 
323
- # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
324
- jax.config.update("jax_enable_x64", True)
367
+ @wp.kernel
368
+ def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
369
+ tid = wp.tid()
370
+ out[tid] = alpha * x[tid] + y[tid]
325
371
 
326
- # check which Warp devices work with Jax
327
- # CUDA devices may fail if Jax cannot find a CUDA Toolkit
328
- test_devices = get_test_devices()
329
- jax_compatible_devices = []
330
- jax_compatible_cuda_devices = []
331
- for d in test_devices:
332
- try:
333
- with jax.default_device(wp.device_to_jax(d)):
334
- j = jax.numpy.arange(10, dtype=jax.numpy.float32)
335
- j += 1
336
- jax_compatible_devices.append(d)
337
- if d.is_cuda:
338
- jax_compatible_cuda_devices.append(d)
339
- except Exception as e:
340
- print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
341
372
 
342
- add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
343
- add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
373
+ @wp.kernel
374
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
375
+ tid = wp.tid()
376
+ sin_out[tid] = wp.sin(angle[tid])
377
+ cos_out[tid] = wp.cos(angle[tid])
344
378
 
345
- if jax_compatible_devices:
346
- add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
347
379
 
348
- if jax_compatible_cuda_devices:
349
- add_function_test(TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices)
350
- add_function_test(
351
- TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
352
- )
353
- add_function_test(
354
- TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
355
- )
356
- add_function_test(
357
- TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
358
- )
380
+ @wp.kernel
381
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
382
+ tid = wp.tid()
383
+ d = float(tid + 1)
384
+ output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
385
+
386
+
387
+ @wp.kernel
388
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
389
+ tid = wp.tid()
390
+ output[tid] = a[tid] * s
391
+
392
+
393
+ @wp.kernel
394
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
395
+ tid = wp.tid()
396
+ output[tid] = a[tid] * s
397
+
398
+
399
+ @wp.kernel
400
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
401
+ tid = wp.tid()
402
+ b[tid] += a[tid]
403
+
404
+
405
+ @wp.kernel
406
+ def matmul_kernel(
407
+ a: wp.array2d(dtype=float), # NxK
408
+ b: wp.array2d(dtype=float), # KxM
409
+ c: wp.array2d(dtype=float), # NxM
410
+ ):
411
+ # launch dims should be (N, M)
412
+ i, j = wp.tid()
413
+ N = a.shape[0]
414
+ K = a.shape[1]
415
+ M = b.shape[1]
416
+ if i < N and j < M:
417
+ s = wp.float32(0)
418
+ for k in range(K):
419
+ s += a[i, k] * b[k, j]
420
+ c[i, j] = s
421
+
422
+
423
+ @wp.kernel
424
+ def in_out_kernel(
425
+ a: wp.array(dtype=float), # input only
426
+ b: wp.array(dtype=float), # input and output
427
+ c: wp.array(dtype=float), # output only
428
+ ):
429
+ tid = wp.tid()
430
+ b[tid] += a[tid]
431
+ c[tid] = 2.0 * a[tid]
432
+
433
+
434
+ @wp.kernel
435
+ def multi_out_kernel(
436
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
437
+ ):
438
+ tid = wp.tid()
439
+ c[tid] = a[tid] + b[tid]
440
+ d[tid] = s * a[tid]
441
+
442
+
443
+ @wp.kernel
444
+ def multi_out_kernel_v2(
445
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
446
+ ):
447
+ tid = wp.tid()
448
+ c[tid] = a[tid] * a[tid]
449
+ d[tid] = a[tid] * b[tid] * s
450
+
451
+
452
+ @wp.kernel
453
+ def multi_out_kernel_v3(
454
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
455
+ ):
456
+ tid = wp.tid()
457
+ c[tid] = a[tid] ** 2.0
458
+ d[tid] = a[tid] * b[tid] * s
459
+
460
+
461
+ @wp.kernel
462
+ def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
463
+ tid = wp.tid()
464
+ c[tid] = (a[tid] * s + b[tid]) ** 2.0
465
+
466
+
467
+ # The Python function to call.
468
+ # Note the argument annotations, just like Warp kernels.
469
+ def scale_func(
470
+ # inputs
471
+ a: wp.array(dtype=float),
472
+ b: wp.array(dtype=wp.vec2),
473
+ s: float,
474
+ # outputs
475
+ c: wp.array(dtype=float),
476
+ d: wp.array(dtype=wp.vec2),
477
+ ):
478
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
479
+ wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
480
+
481
+
482
+ def in_out_func(
483
+ a: wp.array(dtype=float), # input only
484
+ b: wp.array(dtype=float), # input and output
485
+ c: wp.array(dtype=float), # output only
486
+ ):
487
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
488
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
489
+
490
+
491
+ def double_func(
492
+ # inputs
493
+ a: wp.array(dtype=float),
494
+ # outputs
495
+ b: wp.array(dtype=float),
496
+ ):
497
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
498
+
499
+
500
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
501
+ def test_ffi_jax_kernel_add(test, device):
502
+ # two inputs and one output
503
+ import jax.numpy as jp
504
+
505
+ from warp.jax_experimental.ffi import jax_kernel
506
+
507
+ jax_add = jax_kernel(add_kernel)
508
+
509
+ @jax.jit
510
+ def f():
511
+ n = ARRAY_SIZE
512
+ a = jp.arange(n, dtype=jp.float32)
513
+ b = jp.ones(n, dtype=jp.float32)
514
+ return jax_add(a, b)
515
+
516
+ with jax.default_device(wp.device_to_jax(device)):
517
+ (y,) = f()
518
+
519
+ wp.synchronize_device(device)
520
+
521
+ result = np.asarray(y)
522
+ expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
523
+
524
+ assert_np_equal(result, expected)
525
+
526
+
527
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
528
+ def test_ffi_jax_kernel_sincos(test, device):
529
+ # one input and two outputs
530
+ import jax.numpy as jp
531
+
532
+ from warp.jax_experimental.ffi import jax_kernel
533
+
534
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
535
+
536
+ n = ARRAY_SIZE
537
+
538
+ @jax.jit
539
+ def f():
540
+ a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
541
+ return jax_sincos(a)
542
+
543
+ with jax.default_device(wp.device_to_jax(device)):
544
+ s, c = f()
545
+
546
+ wp.synchronize_device(device)
547
+
548
+ result_s = np.asarray(s)
549
+ result_c = np.asarray(c)
550
+
551
+ a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
552
+ expected_s = np.sin(a)
553
+ expected_c = np.cos(a)
554
+
555
+ assert_np_equal(result_s, expected_s, tol=1e-4)
556
+ assert_np_equal(result_c, expected_c, tol=1e-4)
557
+
558
+
559
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
560
+ def test_ffi_jax_kernel_diagonal(test, device):
561
+ # no inputs and one output
562
+ from warp.jax_experimental.ffi import jax_kernel
563
+
564
+ jax_diagonal = jax_kernel(diagonal_kernel)
565
+
566
+ @jax.jit
567
+ def f():
568
+ # launch dimensions determine output size
569
+ return jax_diagonal(launch_dims=4)
570
+
571
+ wp.synchronize_device(device)
572
+
573
+ with jax.default_device(wp.device_to_jax(device)):
574
+ (d,) = f()
575
+
576
+ result = np.asarray(d)
577
+ expected = np.array(
578
+ [
579
+ [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
580
+ [[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
581
+ [[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
582
+ [[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
583
+ ],
584
+ dtype=np.float32,
585
+ )
586
+
587
+ assert_np_equal(result, expected)
588
+
589
+
590
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
591
+ def test_ffi_jax_kernel_in_out(test, device):
592
+ # in-out args
593
+ import jax.numpy as jp
594
+
595
+ from warp.jax_experimental.ffi import jax_kernel
596
+
597
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
598
+
599
+ f = jax.jit(jax_func)
600
+
601
+ with jax.default_device(wp.device_to_jax(device)):
602
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
603
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
604
+ b, c = f(a, b)
605
+
606
+ wp.synchronize_device(device)
607
+
608
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
609
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
610
+
611
+
612
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
613
+ def test_ffi_jax_kernel_scale_vec_constant(test, device):
614
+ # multiply vectors by scalar (constant)
615
+ import jax.numpy as jp
616
+
617
+ from warp.jax_experimental.ffi import jax_kernel
618
+
619
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
620
+
621
+ @jax.jit
622
+ def f():
623
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
624
+ s = 2.0
625
+ return jax_scale_vec(a, s)
626
+
627
+ with jax.default_device(wp.device_to_jax(device)):
628
+ (b,) = f()
629
+
630
+ wp.synchronize_device(device)
631
+
632
+ expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
633
+
634
+ assert_np_equal(b, expected)
635
+
636
+
637
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
638
+ def test_ffi_jax_kernel_scale_vec_static(test, device):
639
+ # multiply vectors by scalar (static arg)
640
+ import jax.numpy as jp
641
+
642
+ from warp.jax_experimental.ffi import jax_kernel
643
+
644
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
645
+
646
+ # NOTE: scalar arguments must be static compile-time constants
647
+ @partial(jax.jit, static_argnames=["s"])
648
+ def f(a, s):
649
+ return jax_scale_vec(a, s)
650
+
651
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
652
+ s = 3.0
653
+
654
+ with jax.default_device(wp.device_to_jax(device)):
655
+ (b,) = f(a, s)
656
+
657
+ wp.synchronize_device(device)
658
+
659
+ expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
660
+
661
+ assert_np_equal(b, expected)
662
+
663
+
664
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
665
+ def test_ffi_jax_kernel_launch_dims_default(test, device):
666
+ # specify default launch dims
667
+ import jax.numpy as jp
668
+
669
+ from warp.jax_experimental.ffi import jax_kernel
670
+
671
+ N, M, K = 3, 4, 2
672
+
673
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
674
+
675
+ @jax.jit
676
+ def f():
677
+ a = jp.full((N, K), 2, dtype=jp.float32)
678
+ b = jp.full((K, M), 3, dtype=jp.float32)
679
+
680
+ # use default launch dims
681
+ return jax_matmul(a, b)
682
+
683
+ with jax.default_device(wp.device_to_jax(device)):
684
+ (result,) = f()
685
+
686
+ wp.synchronize_device(device)
687
+
688
+ expected = np.full((3, 4), 12, dtype=np.float32)
689
+
690
+ test.assertEqual(result.shape, expected.shape)
691
+ assert_np_equal(result, expected)
692
+
693
+
694
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
695
+ def test_ffi_jax_kernel_launch_dims_custom(test, device):
696
+ # specify custom launch dims per call
697
+ import jax.numpy as jp
698
+
699
+ from warp.jax_experimental.ffi import jax_kernel
700
+
701
+ jax_matmul = jax_kernel(matmul_kernel)
702
+
703
+ @jax.jit
704
+ def f():
705
+ N1, M1, K1 = 3, 4, 2
706
+ a1 = jp.full((N1, K1), 2, dtype=jp.float32)
707
+ b1 = jp.full((K1, M1), 3, dtype=jp.float32)
708
+
709
+ # use custom launch dims
710
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
711
+
712
+ N2, M2, K2 = 4, 3, 2
713
+ a2 = jp.full((N2, K2), 2, dtype=jp.float32)
714
+ b2 = jp.full((K2, M2), 3, dtype=jp.float32)
715
+
716
+ # use different custom launch dims
717
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
718
+
719
+ return result1[0], result2[0]
720
+
721
+ with jax.default_device(wp.device_to_jax(device)):
722
+ result1, result2 = f()
723
+
724
+ wp.synchronize_device(device)
725
+
726
+ expected1 = np.full((3, 4), 12, dtype=np.float32)
727
+ expected2 = np.full((4, 3), 12, dtype=np.float32)
728
+
729
+ test.assertEqual(result1.shape, expected1.shape)
730
+ test.assertEqual(result2.shape, expected2.shape)
731
+ assert_np_equal(result1, expected1)
732
+ assert_np_equal(result2, expected2)
733
+
734
+
735
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
736
+ def test_ffi_jax_callable_scale_constant(test, device):
737
+ # scale two arrays using a constant
738
+ import jax.numpy as jp
739
+
740
+ from warp.jax_experimental.ffi import jax_callable
741
+
742
+ jax_func = jax_callable(scale_func, num_outputs=2)
743
+
744
+ @jax.jit
745
+ def f():
746
+ # inputs
747
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
748
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
749
+ s = 2.0
750
+
751
+ # output shapes
752
+ output_dims = {"c": a.shape, "d": b.shape}
753
+
754
+ c, d = jax_func(a, b, s, output_dims=output_dims)
755
+
756
+ return c, d
757
+
758
+ with jax.default_device(wp.device_to_jax(device)):
759
+ result1, result2 = f()
760
+
761
+ wp.synchronize_device(device)
762
+
763
+ expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
764
+ expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
765
+
766
+ assert_np_equal(result1, expected1)
767
+ assert_np_equal(result2, expected2)
768
+
769
+
770
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
771
+ def test_ffi_jax_callable_scale_static(test, device):
772
+ # scale two arrays using a static arg
773
+ import jax.numpy as jp
359
774
 
775
+ from warp.jax_experimental.ffi import jax_callable
776
+
777
+ jax_func = jax_callable(scale_func, num_outputs=2)
778
+
779
+ # NOTE: scalar arguments must be static compile-time constants
780
+ @partial(jax.jit, static_argnames=["s"])
781
+ def f(a, b, s):
782
+ # output shapes
783
+ output_dims = {"c": a.shape, "d": b.shape}
784
+
785
+ c, d = jax_func(a, b, s, output_dims=output_dims)
786
+
787
+ return c, d
788
+
789
+ with jax.default_device(wp.device_to_jax(device)):
790
+ # inputs
791
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
792
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
793
+ s = 3.0
794
+ result1, result2 = f(a, b, s)
795
+
796
+ wp.synchronize_device(device)
797
+
798
+ expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
799
+ expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
800
+
801
+ assert_np_equal(result1, expected1)
802
+ assert_np_equal(result2, expected2)
803
+
804
+
805
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
806
+ def test_ffi_jax_callable_in_out(test, device):
807
+ # in-out arguments
808
+ import jax.numpy as jp
809
+
810
+ from warp.jax_experimental.ffi import jax_callable
811
+
812
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
813
+
814
+ f = jax.jit(jax_func)
815
+
816
+ with jax.default_device(wp.device_to_jax(device)):
817
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
818
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
819
+ b, c = f(a, b)
820
+
821
+ wp.synchronize_device(device)
822
+
823
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
824
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
825
+
826
+
827
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
828
+ def test_ffi_jax_callable_graph_cache(test, device):
829
+ # test graph caching limits
830
+ import jax
831
+ import jax.numpy as jp
832
+
833
+ from warp.jax_experimental.ffi import (
834
+ GraphMode,
835
+ clear_jax_callable_graph_cache,
836
+ get_jax_callable_default_graph_cache_max,
837
+ jax_callable,
838
+ set_jax_callable_default_graph_cache_max,
839
+ )
840
+
841
+ # --- test with default cache settings ---
842
+
843
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
844
+ f = jax.jit(jax_double)
845
+ arrays = []
846
+
847
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
848
+
849
+ with jax.default_device(wp.device_to_jax(device)):
850
+ for i in range(10):
851
+ n = 10 + i
852
+ a = jp.arange(n, dtype=jp.float32)
853
+ (b,) = f(a)
854
+
855
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
856
+
857
+ # ensure graph cache is always growing
858
+ test.assertEqual(jax_double.graph_cache_size, i + 1)
859
+
860
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
861
+ arrays.append(a)
862
+
863
+ # --- test clearing one callable's cache ---
864
+
865
+ clear_jax_callable_graph_cache(jax_double)
866
+
867
+ test.assertEqual(jax_double.graph_cache_size, 0)
868
+
869
+ # --- test with a custom cache limit ---
870
+
871
+ graph_cache_max = 5
872
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
873
+ f = jax.jit(jax_double)
874
+ arrays = []
875
+
876
+ test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
877
+
878
+ with jax.default_device(wp.device_to_jax(device)):
879
+ for i in range(10):
880
+ n = 10 + i
881
+ a = jp.arange(n, dtype=jp.float32)
882
+ (b,) = f(a)
883
+
884
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
885
+
886
+ # ensure graph cache size is capped
887
+ test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
888
+
889
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
890
+ arrays.append(a)
891
+
892
+ # --- test clearing all callables' caches ---
893
+
894
+ clear_jax_callable_graph_cache()
895
+
896
+ with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
897
+ for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
898
+ test.assertEqual(c.graph_cache_size, 0)
899
+
900
+ # --- test with a custom default cache limit ---
901
+
902
+ saved_max = get_jax_callable_default_graph_cache_max()
903
+ try:
904
+ set_jax_callable_default_graph_cache_max(5)
905
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
906
+ f = jax.jit(jax_double)
907
+ arrays = []
908
+
909
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
910
+
911
+ with jax.default_device(wp.device_to_jax(device)):
912
+ for i in range(10):
913
+ n = 10 + i
914
+ a = jp.arange(n, dtype=jp.float32)
915
+ (b,) = f(a)
916
+
917
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
918
+
919
+ # ensure graph cache size is capped
920
+ test.assertEqual(
921
+ jax_double.graph_cache_size,
922
+ min(i + 1, get_jax_callable_default_graph_cache_max()),
923
+ )
924
+
925
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
926
+ arrays.append(a)
927
+
928
+ clear_jax_callable_graph_cache()
929
+
930
+ finally:
931
+ set_jax_callable_default_graph_cache_max(saved_max)
932
+
933
+
934
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
935
+ def test_ffi_jax_callable_pmap_mul(test, device):
936
+ import jax
937
+ import jax.numpy as jp
938
+
939
+ from warp.jax_experimental.ffi import jax_callable
940
+
941
+ j = jax_callable(double_func, num_outputs=1)
942
+
943
+ ndev = jax.local_device_count()
944
+ per_device = max(ARRAY_SIZE // ndev, 64)
945
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
946
+
947
+ def per_device_func(v):
948
+ (y,) = j(v)
949
+ return y
950
+
951
+ y = jax.pmap(per_device_func)(x)
952
+
953
+ wp.synchronize()
954
+
955
+ assert_np_equal(np.asarray(y), 2 * np.asarray(x))
956
+
957
+
958
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
959
+ def test_ffi_jax_callable_pmap_multi_output(test, device):
960
+ import jax
961
+ import jax.numpy as jp
962
+
963
+ from warp.jax_experimental.ffi import jax_callable
964
+
965
+ def multi_out_py(
966
+ a: wp.array(dtype=float),
967
+ b: wp.array(dtype=float),
968
+ s: float,
969
+ c: wp.array(dtype=float),
970
+ d: wp.array(dtype=float),
971
+ ):
972
+ wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
973
+
974
+ j = jax_callable(multi_out_py, num_outputs=2)
975
+
976
+ ndev = jax.local_device_count()
977
+ per_device = max(ARRAY_SIZE // ndev, 64)
978
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
979
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
980
+ s = 3.0
981
+
982
+ def per_device_func(aa, bb):
983
+ c, d = j(aa, bb, s)
984
+ return c + d # simple combine to exercise both outputs
985
+
986
+ out = jax.pmap(per_device_func)(a, b)
987
+
988
+ wp.synchronize()
989
+
990
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
991
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
992
+ ref = (a_np + b_np) + s * a_np
993
+ assert_np_equal(np.asarray(out), ref)
994
+
995
+
996
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
997
+ def test_ffi_jax_callable_pmap_multi_stage(test, device):
998
+ import jax
999
+ import jax.numpy as jp
1000
+
1001
+ from warp.jax_experimental.ffi import jax_callable
1002
+
1003
+ def multi_stage_py(
1004
+ a: wp.array(dtype=float),
1005
+ b: wp.array(dtype=float),
1006
+ alpha: float,
1007
+ tmp: wp.array(dtype=float),
1008
+ out: wp.array(dtype=float),
1009
+ ):
1010
+ wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
1011
+ wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
1012
+
1013
+ j = jax_callable(multi_stage_py, num_outputs=2)
1014
+
1015
+ ndev = jax.local_device_count()
1016
+ per_device = max(ARRAY_SIZE // ndev, 64)
1017
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1018
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
1019
+ alpha = 2.5
1020
+
1021
+ def per_device_func(aa, bb):
1022
+ tmp, out = j(aa, bb, alpha)
1023
+ return tmp + out
1024
+
1025
+ combined = jax.pmap(per_device_func)(a, b)
1026
+
1027
+ wp.synchronize()
1028
+
1029
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1030
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
1031
+ tmp_ref = a_np + b_np
1032
+ out_ref = alpha * (a_np + b_np) + b_np
1033
+ ref = tmp_ref + out_ref
1034
+ assert_np_equal(np.asarray(combined), ref)
1035
+
1036
+
1037
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1038
+ def test_ffi_callback(test, device):
1039
+ # in-out arguments
1040
+ import jax.numpy as jp
1041
+
1042
+ from warp.jax_experimental.ffi import register_ffi_callback
1043
+
1044
+ # the Python function to call
1045
+ def warp_func(inputs, outputs, attrs, ctx):
1046
+ # input arrays
1047
+ a = inputs[0]
1048
+ b = inputs[1]
1049
+
1050
+ # scalar attributes
1051
+ s = attrs["scale"]
1052
+
1053
+ # output arrays
1054
+ c = outputs[0]
1055
+ d = outputs[1]
1056
+
1057
+ device = wp.device_from_jax(get_jax_device())
1058
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
1059
+
1060
+ with wp.ScopedStream(stream):
1061
+ # launch with arrays of scalars
1062
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
1063
+
1064
+ # launch with arrays of vec2
1065
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
1066
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
1067
+
1068
+ # register callback
1069
+ register_ffi_callback("warp_func", warp_func)
1070
+
1071
+ n = ARRAY_SIZE
1072
+
1073
+ with jax.default_device(wp.device_to_jax(device)):
1074
+ # inputs
1075
+ a = jp.arange(n, dtype=jp.float32)
1076
+ b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
1077
+ s = 2.0
1078
+
1079
+ # set up call
1080
+ out_types = [
1081
+ jax.ShapeDtypeStruct(a.shape, jp.float32),
1082
+ jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
1083
+ ]
1084
+ call = jax.ffi.ffi_call("warp_func", out_types)
1085
+
1086
+ # call it
1087
+ c, d = call(a, b, scale=s)
1088
+
1089
+ wp.synchronize_device(device)
1090
+
1091
+ assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
1092
+ assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
1093
+
1094
+
1095
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1096
+ def test_ffi_jax_kernel_autodiff_simple(test, device):
1097
+ import jax
1098
+ import jax.numpy as jp
1099
+
1100
+ from warp.jax_experimental.ffi import jax_kernel
1101
+
1102
+ jax_func = jax_kernel(
1103
+ scale_sum_square_kernel,
1104
+ num_outputs=1,
1105
+ enable_backward=True,
1106
+ )
1107
+
1108
+ from functools import partial
1109
+
1110
+ @partial(jax.jit, static_argnames=["s"])
1111
+ def loss(a, b, s):
1112
+ out = jax_func(a, b, s)[0]
1113
+ return jp.sum(out)
1114
+
1115
+ n = ARRAY_SIZE
1116
+ a = jp.arange(n, dtype=jp.float32)
1117
+ b = jp.ones(n, dtype=jp.float32)
1118
+ s = 2.0
1119
+
1120
+ with jax.default_device(wp.device_to_jax(device)):
1121
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1122
+
1123
+ wp.synchronize_device(device)
1124
+
1125
+ # reference gradients
1126
+ # d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
1127
+ # d/db sum((a*s + b)^2) = sum(2*(a*s + b))
1128
+ a_np = np.arange(n, dtype=np.float32)
1129
+ b_np = np.ones(n, dtype=np.float32)
1130
+ ref_da = 2.0 * (a_np * s + b_np) * s
1131
+ ref_db = 2.0 * (a_np * s + b_np)
1132
+
1133
+ assert_np_equal(np.asarray(da), ref_da)
1134
+ assert_np_equal(np.asarray(db), ref_db)
1135
+
1136
+
1137
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1138
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
1139
+ import jax
1140
+ import jax.numpy as jp
1141
+
1142
+ from warp.jax_experimental.ffi import jax_kernel
1143
+
1144
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1145
+
1146
+ def loss(a, b, s):
1147
+ out = jax_func(a, b, s)[0]
1148
+ return jp.sum(out)
1149
+
1150
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1151
+
1152
+ # more typical: jit(grad(...)) with static scalar
1153
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1154
+
1155
+ n = ARRAY_SIZE
1156
+ a = jp.arange(n, dtype=jp.float32)
1157
+ b = jp.ones(n, dtype=jp.float32)
1158
+ s = 2.0
1159
+
1160
+ with jax.default_device(wp.device_to_jax(device)):
1161
+ da, db = jitted_grad(a, b, s)
1162
+
1163
+ wp.synchronize_device(device)
1164
+
1165
+ a_np = np.arange(n, dtype=np.float32)
1166
+ b_np = np.ones(n, dtype=np.float32)
1167
+ ref_da = 2.0 * (a_np * s + b_np) * s
1168
+ ref_db = 2.0 * (a_np * s + b_np)
1169
+
1170
+ assert_np_equal(np.asarray(da), ref_da)
1171
+ assert_np_equal(np.asarray(db), ref_db)
1172
+
1173
+
1174
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1175
+ def test_ffi_jax_kernel_autodiff_multi_output(test, device):
1176
+ import jax
1177
+ import jax.numpy as jp
1178
+
1179
+ from warp.jax_experimental.ffi import jax_kernel
1180
+
1181
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1182
+
1183
+ def caller(fn, a, b, s):
1184
+ c, d = fn(a, b, s)
1185
+ return jp.sum(c + d)
1186
+
1187
+ @jax.jit
1188
+ def grads(a, b, s):
1189
+ # mark s as static in the inner call via partial to avoid hashing
1190
+ def _inner(a, b, s):
1191
+ return caller(jax_func, a, b, s)
1192
+
1193
+ return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
1194
+
1195
+ n = ARRAY_SIZE
1196
+ a = jp.arange(n, dtype=jp.float32)
1197
+ b = jp.ones(n, dtype=jp.float32)
1198
+ s = 2.0
1199
+
1200
+ with jax.default_device(wp.device_to_jax(device)):
1201
+ da, db = grads(a, b, s)
1202
+
1203
+ wp.synchronize_device(device)
1204
+
1205
+ a_np = np.arange(n, dtype=np.float32)
1206
+ b_np = np.ones(n, dtype=np.float32)
1207
+ # d/da sum(c+d) = 2*a + b*s
1208
+ ref_da = 2.0 * a_np + b_np * s
1209
+ # d/db sum(c+d) = a*s
1210
+ ref_db = a_np * s
1211
+
1212
+ assert_np_equal(np.asarray(da), ref_da)
1213
+ assert_np_equal(np.asarray(db), ref_db)
1214
+
1215
+
1216
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1217
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
1218
+ import jax
1219
+ import jax.numpy as jp
1220
+
1221
+ from warp.jax_experimental.ffi import jax_kernel
1222
+
1223
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1224
+
1225
+ def loss(a, b, s):
1226
+ c, d = jax_func(a, b, s)
1227
+ return jp.sum(c + d)
1228
+
1229
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1230
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1231
+
1232
+ n = ARRAY_SIZE
1233
+ a = jp.arange(n, dtype=jp.float32)
1234
+ b = jp.ones(n, dtype=jp.float32)
1235
+ s = 2.0
1236
+
1237
+ with jax.default_device(wp.device_to_jax(device)):
1238
+ da, db = jitted_grad(a, b, s)
1239
+
1240
+ wp.synchronize_device(device)
1241
+
1242
+ a_np = np.arange(n, dtype=np.float32)
1243
+ b_np = np.ones(n, dtype=np.float32)
1244
+ ref_da = 2.0 * a_np + b_np * s
1245
+ ref_db = a_np * s
1246
+
1247
+ assert_np_equal(np.asarray(da), ref_da)
1248
+ assert_np_equal(np.asarray(db), ref_db)
1249
+
1250
+
1251
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1252
+ def test_ffi_jax_kernel_autodiff_2d(test, device):
1253
+ import jax
1254
+ import jax.numpy as jp
1255
+
1256
+ from warp.jax_experimental.ffi import jax_kernel
1257
+
1258
+ jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
1259
+
1260
+ @jax.jit
1261
+ def loss(a):
1262
+ out = jax_func(a)[0]
1263
+ return jp.sum(out)
1264
+
1265
+ n, m = 8, 6
1266
+ a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
1267
+
1268
+ with jax.default_device(wp.device_to_jax(device)):
1269
+ (da,) = jax.grad(loss, argnums=(0,))(a)
1270
+
1271
+ wp.synchronize_device(device)
1272
+
1273
+ ref = np.ones((n, m), dtype=np.float32)
1274
+ assert_np_equal(np.asarray(da), ref)
1275
+
1276
+
1277
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1278
+ def test_ffi_jax_kernel_autodiff_vec2(test, device):
1279
+ import jax
1280
+ import jax.numpy as jp
1281
+
1282
+ from warp.jax_experimental.ffi import jax_kernel
1283
+
1284
+ jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
1285
+
1286
+ from functools import partial
1287
+
1288
+ @partial(jax.jit, static_argnames=("s",))
1289
+ def loss(a, s):
1290
+ out = jax_func(a, s)[0]
1291
+ return jp.sum(out)
1292
+
1293
+ n = ARRAY_SIZE
1294
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
1295
+ s = 3.0
1296
+
1297
+ with jax.default_device(wp.device_to_jax(device)):
1298
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1299
+
1300
+ wp.synchronize_device(device)
1301
+
1302
+ # d/da sum(a*s) = s
1303
+ ref = np.full_like(np.asarray(a), s)
1304
+ assert_np_equal(np.asarray(da), ref)
1305
+
1306
+
1307
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1308
+ def test_ffi_jax_kernel_autodiff_mat22(test, device):
1309
+ import jax
1310
+ import jax.numpy as jp
1311
+
1312
+ from warp.jax_experimental.ffi import jax_kernel
1313
+
1314
+ @wp.kernel
1315
+ def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
1316
+ tid = wp.tid()
1317
+ out[tid] = a[tid] * s
1318
+
1319
+ jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
1320
+
1321
+ from functools import partial
1322
+
1323
+ @partial(jax.jit, static_argnames=("s",))
1324
+ def loss(a, s):
1325
+ out = jax_func(a, s)[0]
1326
+ return jp.sum(out)
1327
+
1328
+ n = 12 # must be divisible by 4 for 2x2 matrices
1329
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
1330
+ s = 2.5
1331
+
1332
+ with jax.default_device(wp.device_to_jax(device)):
1333
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1334
+
1335
+ wp.synchronize_device(device)
1336
+
1337
+ ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
1338
+ assert_np_equal(np.asarray(da), ref)
1339
+
1340
+
1341
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1342
+ def test_ffi_jax_kernel_autodiff_static_required(test, device):
1343
+ import jax
1344
+ import jax.numpy as jp
1345
+
1346
+ from warp.jax_experimental.ffi import jax_kernel
1347
+
1348
+ # Require explicit static_argnames for scalar s
1349
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1350
+
1351
+ def loss(a, b, s):
1352
+ out = jax_func(a, b, s)[0]
1353
+ return jp.sum(out)
1354
+
1355
+ n = ARRAY_SIZE
1356
+ a = jp.arange(n, dtype=jp.float32)
1357
+ b = jp.ones(n, dtype=jp.float32)
1358
+ s = 1.5
1359
+
1360
+ with jax.default_device(wp.device_to_jax(device)):
1361
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1362
+
1363
+ wp.synchronize_device(device)
1364
+
1365
+ a_np = np.arange(n, dtype=np.float32)
1366
+ b_np = np.ones(n, dtype=np.float32)
1367
+ ref_da = 2.0 * (a_np * s + b_np) * s
1368
+ ref_db = 2.0 * (a_np * s + b_np)
1369
+
1370
+ assert_np_equal(np.asarray(da), ref_da)
1371
+ assert_np_equal(np.asarray(db), ref_db)
1372
+
1373
+
1374
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1375
+ def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
1376
+ import jax
1377
+ import jax.numpy as jp
1378
+
1379
+ from warp.jax_experimental.ffi import jax_kernel
1380
+
1381
+ jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
1382
+
1383
+ ndev = jax.local_device_count()
1384
+ per_device = ARRAY_SIZE // ndev
1385
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1386
+
1387
+ def per_device_loss(x):
1388
+ y = jax_mul(x)[0]
1389
+ return jp.sum(y)
1390
+
1391
+ grads = jax.pmap(jax.grad(per_device_loss))(x)
1392
+
1393
+ wp.synchronize()
1394
+
1395
+ assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
1396
+
1397
+
1398
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1399
+ def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
1400
+ import jax
1401
+ import jax.numpy as jp
1402
+
1403
+ from warp.jax_experimental.ffi import jax_kernel
1404
+
1405
+ jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
1406
+
1407
+ ndev = jax.local_device_count()
1408
+ per_device = ARRAY_SIZE // ndev
1409
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1410
+ b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1411
+ s = 2.0
1412
+
1413
+ def per_dev_loss(aa, bb):
1414
+ c, d = jax_mo(aa, bb, s)
1415
+ return jp.sum(c + d)
1416
+
1417
+ da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
1418
+
1419
+ wp.synchronize()
1420
+
1421
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1422
+ b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1423
+ ref_da = 2.0 * a_np + b_np * s
1424
+ ref_db = a_np * s
1425
+ assert_np_equal(np.asarray(da), ref_da)
1426
+ assert_np_equal(np.asarray(db), ref_db)
1427
+
1428
+
1429
+ class TestJax(unittest.TestCase):
1430
+ pass
1431
+
1432
+
1433
+ # try adding Jax tests if Jax is installed correctly
1434
+ try:
1435
+ # prevent Jax from gobbling up GPU memory
1436
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
1437
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
1438
+
1439
+ import jax
1440
+
1441
+ # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
1442
+ jax.config.update("jax_enable_x64", True)
1443
+
1444
+ # check which Warp devices work with Jax
1445
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
1446
+ test_devices = get_test_devices()
1447
+ jax_compatible_devices = []
1448
+ jax_compatible_cuda_devices = []
1449
+ for d in test_devices:
1450
+ try:
1451
+ with jax.default_device(wp.device_to_jax(d)):
1452
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
1453
+ j += 1
1454
+ jax_compatible_devices.append(d)
1455
+ if d.is_cuda:
1456
+ jax_compatible_cuda_devices.append(d)
1457
+ except Exception as e:
1458
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
1459
+
1460
+ add_function_test(TestJax, "test_dtype_from_jax", test_dtype_from_jax, devices=None)
1461
+ add_function_test(TestJax, "test_dtype_to_jax", test_dtype_to_jax, devices=None)
1462
+
1463
+ if jax_compatible_devices:
1464
+ add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
1465
+
1466
+ if jax_compatible_cuda_devices:
1467
+ # tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
1468
+ if jax.__version_info__ < (0, 4, 25):
1469
+ # no interop supported
1470
+ ffi_opts = []
1471
+ elif jax.__version_info__ < (0, 5, 0):
1472
+ # only custom_call supported
1473
+ ffi_opts = [False]
1474
+ elif jax.__version_info__ < (0, 8, 0):
1475
+ # both custom_call and ffi supported
1476
+ ffi_opts = [False, True]
1477
+ else:
1478
+ # only ffi supported
1479
+ ffi_opts = [True]
1480
+
1481
+ for use_ffi in ffi_opts:
1482
+ suffix = "ffi" if use_ffi else "cc"
1483
+ add_function_test(
1484
+ TestJax,
1485
+ f"test_jax_kernel_basic_{suffix}",
1486
+ test_jax_kernel_basic,
1487
+ devices=jax_compatible_cuda_devices,
1488
+ use_ffi=use_ffi,
1489
+ )
1490
+ add_function_test(
1491
+ TestJax,
1492
+ f"test_jax_kernel_scalar_{suffix}",
1493
+ test_jax_kernel_scalar,
1494
+ devices=jax_compatible_cuda_devices,
1495
+ use_ffi=use_ffi,
1496
+ )
1497
+ add_function_test(
1498
+ TestJax,
1499
+ f"test_jax_kernel_vecmat_{suffix}",
1500
+ test_jax_kernel_vecmat,
1501
+ devices=jax_compatible_cuda_devices,
1502
+ use_ffi=use_ffi,
1503
+ )
1504
+ add_function_test(
1505
+ TestJax,
1506
+ f"test_jax_kernel_multiarg_{suffix}",
1507
+ test_jax_kernel_multiarg,
1508
+ devices=jax_compatible_cuda_devices,
1509
+ use_ffi=use_ffi,
1510
+ )
1511
+ add_function_test(
1512
+ TestJax,
1513
+ f"test_jax_kernel_launch_dims_{suffix}",
1514
+ test_jax_kernel_launch_dims,
1515
+ devices=jax_compatible_cuda_devices,
1516
+ use_ffi=use_ffi,
1517
+ )
1518
+
1519
+ # ffi.jax_kernel() tests
1520
+ add_function_test(
1521
+ TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
1522
+ )
1523
+ add_function_test(
1524
+ TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
1525
+ )
1526
+ add_function_test(
1527
+ TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
1528
+ )
1529
+ add_function_test(
1530
+ TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
1531
+ )
1532
+ add_function_test(
1533
+ TestJax,
1534
+ "test_ffi_jax_kernel_scale_vec_constant",
1535
+ test_ffi_jax_kernel_scale_vec_constant,
1536
+ devices=jax_compatible_cuda_devices,
1537
+ )
1538
+ add_function_test(
1539
+ TestJax,
1540
+ "test_ffi_jax_kernel_scale_vec_static",
1541
+ test_ffi_jax_kernel_scale_vec_static,
1542
+ devices=jax_compatible_cuda_devices,
1543
+ )
1544
+ add_function_test(
1545
+ TestJax,
1546
+ "test_ffi_jax_kernel_launch_dims_default",
1547
+ test_ffi_jax_kernel_launch_dims_default,
1548
+ devices=jax_compatible_cuda_devices,
1549
+ )
1550
+ add_function_test(
1551
+ TestJax,
1552
+ "test_ffi_jax_kernel_launch_dims_custom",
1553
+ test_ffi_jax_kernel_launch_dims_custom,
1554
+ devices=jax_compatible_cuda_devices,
1555
+ )
1556
+
1557
+ # ffi.jax_callable() tests
1558
+ add_function_test(
1559
+ TestJax,
1560
+ "test_ffi_jax_callable_scale_constant",
1561
+ test_ffi_jax_callable_scale_constant,
1562
+ devices=jax_compatible_cuda_devices,
1563
+ )
1564
+ add_function_test(
1565
+ TestJax,
1566
+ "test_ffi_jax_callable_scale_static",
1567
+ test_ffi_jax_callable_scale_static,
1568
+ devices=jax_compatible_cuda_devices,
1569
+ )
1570
+ add_function_test(
1571
+ TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
1572
+ )
1573
+ add_function_test(
1574
+ TestJax,
1575
+ "test_ffi_jax_callable_graph_cache",
1576
+ test_ffi_jax_callable_graph_cache,
1577
+ devices=jax_compatible_cuda_devices,
1578
+ )
1579
+
1580
+ # pmap tests
1581
+ add_function_test(
1582
+ TestJax,
1583
+ "test_ffi_jax_callable_pmap_multi_output",
1584
+ test_ffi_jax_callable_pmap_multi_output,
1585
+ devices=None,
1586
+ )
1587
+ add_function_test(
1588
+ TestJax,
1589
+ "test_ffi_jax_callable_pmap_mul",
1590
+ test_ffi_jax_callable_pmap_mul,
1591
+ devices=None,
1592
+ )
1593
+ add_function_test(
1594
+ TestJax,
1595
+ "test_ffi_jax_callable_pmap_multi_stage",
1596
+ test_ffi_jax_callable_pmap_multi_stage,
1597
+ devices=None,
1598
+ )
1599
+
1600
+ # ffi callback tests
1601
+ add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
1602
+
1603
+ # autodiff tests
1604
+ add_function_test(
1605
+ TestJax,
1606
+ "test_ffi_jax_kernel_autodiff_simple",
1607
+ test_ffi_jax_kernel_autodiff_simple,
1608
+ devices=jax_compatible_cuda_devices,
1609
+ )
1610
+ add_function_test(
1611
+ TestJax,
1612
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
1613
+ test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
1614
+ devices=jax_compatible_cuda_devices,
1615
+ )
1616
+ add_function_test(
1617
+ TestJax,
1618
+ "test_ffi_jax_kernel_autodiff_multi_output",
1619
+ test_ffi_jax_kernel_autodiff_multi_output,
1620
+ devices=jax_compatible_cuda_devices,
1621
+ )
1622
+ add_function_test(
1623
+ TestJax,
1624
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
1625
+ test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
1626
+ devices=jax_compatible_cuda_devices,
1627
+ )
1628
+ add_function_test(
1629
+ TestJax,
1630
+ "test_ffi_jax_kernel_autodiff_2d",
1631
+ test_ffi_jax_kernel_autodiff_2d,
1632
+ devices=jax_compatible_cuda_devices,
1633
+ )
1634
+ add_function_test(
1635
+ TestJax,
1636
+ "test_ffi_jax_kernel_autodiff_vec2",
1637
+ test_ffi_jax_kernel_autodiff_vec2,
1638
+ devices=jax_compatible_cuda_devices,
1639
+ )
1640
+ add_function_test(
1641
+ TestJax,
1642
+ "test_ffi_jax_kernel_autodiff_mat22",
1643
+ test_ffi_jax_kernel_autodiff_mat22,
1644
+ devices=jax_compatible_cuda_devices,
1645
+ )
1646
+ add_function_test(
1647
+ TestJax,
1648
+ "test_ffi_jax_kernel_autodiff_static_required",
1649
+ test_ffi_jax_kernel_autodiff_static_required,
1650
+ devices=jax_compatible_cuda_devices,
1651
+ )
1652
+
1653
+ # autodiff with pmap tests
1654
+ add_function_test(
1655
+ TestJax,
1656
+ "test_ffi_jax_kernel_autodiff_pmap_triple",
1657
+ test_ffi_jax_kernel_autodiff_pmap_triple,
1658
+ devices=None,
1659
+ )
360
1660
  add_function_test(
361
- TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices
1661
+ TestJax,
1662
+ "test_ffi_jax_kernel_autodiff_pmap_multi_output",
1663
+ test_ffi_jax_kernel_autodiff_pmap_multi_output,
1664
+ devices=None,
362
1665
  )
363
1666
 
364
1667
  except Exception as e: