warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,12 @@ from typing import Any
21
21
  import numpy as np
22
22
 
23
23
  import warp as wp
24
- from warp.jax import get_jax_device
24
+ from warp._src.jax import get_jax_device
25
25
  from warp.tests.unittest_utils import *
26
26
 
27
+ # default array size for tests
28
+ ARRAY_SIZE = 1024 * 1024
29
+
27
30
 
28
31
  # basic kernel with one input and output
29
32
  @wp.kernel
@@ -46,6 +49,18 @@ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)
46
49
  output[tid] = input.dtype.dtype(3) * input[tid]
47
50
 
48
51
 
52
+ @wp.kernel
53
+ def inc_1d_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
54
+ tid = wp.tid()
55
+ y[tid] = x[tid] + 1.0
56
+
57
+
58
+ @wp.kernel
59
+ def inc_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
60
+ i, j = wp.tid()
61
+ y[i, j] = x[i, j] + 1.0
62
+
63
+
49
64
  # kernel with multiple inputs and outputs
50
65
  @wp.kernel
51
66
  def multiarg_kernel(
@@ -63,7 +78,7 @@ def multiarg_kernel(
63
78
 
64
79
 
65
80
  # various types for testing
66
- scalar_types = wp.types.scalar_types
81
+ scalar_types = wp._src.types.scalar_types
67
82
  vector_types = []
68
83
  matrix_types = []
69
84
  for dim in [2, 3, 4]:
@@ -146,7 +161,7 @@ def test_jax_kernel_basic(test, device, use_ffi=False):
146
161
 
147
162
  jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
148
163
 
149
- n = 64
164
+ n = ARRAY_SIZE
150
165
 
151
166
  @jax.jit
152
167
  def f():
@@ -157,6 +172,8 @@ def test_jax_kernel_basic(test, device, use_ffi=False):
157
172
  with jax.default_device(wp.device_to_jax(device)):
158
173
  y = f()
159
174
 
175
+ wp.synchronize_device(device)
176
+
160
177
  result = np.asarray(y).reshape((n,))
161
178
  expected = 3 * np.arange(n, dtype=np.float32)
162
179
 
@@ -175,6 +192,7 @@ def test_jax_kernel_scalar(test, device, use_ffi=False):
175
192
 
176
193
  kwargs = {"quiet": True}
177
194
 
195
+ # use a smallish size to ensure arange * 3 doesn't overflow
178
196
  n = 64
179
197
 
180
198
  for T in scalar_types:
@@ -196,6 +214,8 @@ def test_jax_kernel_scalar(test, device, use_ffi=False):
196
214
  with jax.default_device(wp.device_to_jax(device)):
197
215
  y = f()
198
216
 
217
+ wp.synchronize_device(device)
218
+
199
219
  result = np.asarray(y).reshape((n,))
200
220
  expected = 3 * np.arange(n, dtype=np_dtype)
201
221
 
@@ -218,6 +238,7 @@ def test_jax_kernel_vecmat(test, device, use_ffi=False):
218
238
  jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
219
239
  np_dtype = wp.dtype_to_numpy(T._wp_scalar_type_)
220
240
 
241
+ # use a smallish size to ensure arange * 3 doesn't overflow
221
242
  n = 64 // T._length_
222
243
  scalar_shape = (n, *T._shape_)
223
244
  scalar_len = n * T._length_
@@ -237,6 +258,8 @@ def test_jax_kernel_vecmat(test, device, use_ffi=False):
237
258
  with jax.default_device(wp.device_to_jax(device)):
238
259
  y = f()
239
260
 
261
+ wp.synchronize_device(device)
262
+
240
263
  result = np.asarray(y).reshape(scalar_shape)
241
264
  expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
242
265
 
@@ -255,7 +278,7 @@ def test_jax_kernel_multiarg(test, device, use_ffi=False):
255
278
 
256
279
  jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
257
280
 
258
- n = 64
281
+ n = ARRAY_SIZE
259
282
 
260
283
  @jax.jit
261
284
  def f():
@@ -268,6 +291,8 @@ def test_jax_kernel_multiarg(test, device, use_ffi=False):
268
291
  with jax.default_device(wp.device_to_jax(device)):
269
292
  x, y = f()
270
293
 
294
+ wp.synchronize_device(device)
295
+
271
296
  result_x, result_y = np.asarray(x), np.asarray(y)
272
297
  expected_x = np.full(n, 3, dtype=np.float32)
273
298
  expected_y = np.full(n, 5, dtype=np.float32)
@@ -292,40 +317,32 @@ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
292
317
  m = 32
293
318
 
294
319
  # Test with 1D launch dims
295
- @wp.kernel
296
- def add_one_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
297
- tid = wp.tid()
298
- y[tid] = x[tid] + 1.0
299
-
300
- jax_add_one = jax_kernel(
301
- add_one_kernel, launch_dims=(n - 2,), **kwargs
320
+ jax_inc_1d = jax_kernel(
321
+ inc_1d_kernel, launch_dims=(n - 2,), **kwargs
302
322
  ) # Intentionally not the same as the first dimension of the input
303
323
 
304
324
  @jax.jit
305
325
  def f_1d():
306
326
  x = jp.arange(n, dtype=jp.float32)
307
- return jax_add_one(x)
327
+ return jax_inc_1d(x)
308
328
 
309
329
  # Test with 2D launch dims
310
- @wp.kernel
311
- def add_one_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
312
- i, j = wp.tid()
313
- y[i, j] = x[i, j] + 1.0
314
-
315
- jax_add_one_2d = jax_kernel(
316
- add_one_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
330
+ jax_inc_2d = jax_kernel(
331
+ inc_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
317
332
  ) # Intentionally not the same as the first dimension of the input
318
333
 
319
334
  @jax.jit
320
335
  def f_2d():
321
336
  x = jp.zeros((n, m), dtype=jp.float32) + 3.0
322
- return jax_add_one_2d(x)
337
+ return jax_inc_2d(x)
323
338
 
324
339
  # run on the given device
325
340
  with jax.default_device(wp.device_to_jax(device)):
326
341
  y_1d = f_1d()
327
342
  y_2d = f_2d()
328
343
 
344
+ wp.synchronize_device(device)
345
+
329
346
  result_1d = np.asarray(y_1d).reshape((n - 2,))
330
347
  expected_1d = np.arange(n - 2, dtype=np.float32) + 1.0
331
348
 
@@ -342,11 +359,17 @@ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
342
359
 
343
360
 
344
361
  @wp.kernel
345
- def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
362
+ def add_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), output: wp.array(dtype=float)):
346
363
  tid = wp.tid()
347
364
  output[tid] = a[tid] + b[tid]
348
365
 
349
366
 
367
+ @wp.kernel
368
+ def axpy_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float), alpha: float, out: wp.array(dtype=float)):
369
+ tid = wp.tid()
370
+ out[tid] = alpha * x[tid] + y[tid]
371
+
372
+
350
373
  @wp.kernel
351
374
  def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
352
375
  tid = wp.tid()
@@ -408,6 +431,39 @@ def in_out_kernel(
408
431
  c[tid] = 2.0 * a[tid]
409
432
 
410
433
 
434
+ @wp.kernel
435
+ def multi_out_kernel(
436
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
437
+ ):
438
+ tid = wp.tid()
439
+ c[tid] = a[tid] + b[tid]
440
+ d[tid] = s * a[tid]
441
+
442
+
443
+ @wp.kernel
444
+ def multi_out_kernel_v2(
445
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
446
+ ):
447
+ tid = wp.tid()
448
+ c[tid] = a[tid] * a[tid]
449
+ d[tid] = a[tid] * b[tid] * s
450
+
451
+
452
+ @wp.kernel
453
+ def multi_out_kernel_v3(
454
+ a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float), d: wp.array(dtype=float)
455
+ ):
456
+ tid = wp.tid()
457
+ c[tid] = a[tid] ** 2.0
458
+ d[tid] = a[tid] * b[tid] * s
459
+
460
+
461
+ @wp.kernel
462
+ def scale_sum_square_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), s: float, c: wp.array(dtype=float)):
463
+ tid = wp.tid()
464
+ c[tid] = (a[tid] * s + b[tid]) ** 2.0
465
+
466
+
411
467
  # The Python function to call.
412
468
  # Note the argument annotations, just like Warp kernels.
413
469
  def scale_func(
@@ -432,6 +488,15 @@ def in_out_func(
432
488
  wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
433
489
 
434
490
 
491
+ def double_func(
492
+ # inputs
493
+ a: wp.array(dtype=float),
494
+ # outputs
495
+ b: wp.array(dtype=float),
496
+ ):
497
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, 2.0], outputs=[b])
498
+
499
+
435
500
  @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
436
501
  def test_ffi_jax_kernel_add(test, device):
437
502
  # two inputs and one output
@@ -443,16 +508,18 @@ def test_ffi_jax_kernel_add(test, device):
443
508
 
444
509
  @jax.jit
445
510
  def f():
446
- n = 10
447
- a = jp.arange(n, dtype=jp.int32)
448
- b = jp.ones(n, dtype=jp.int32)
511
+ n = ARRAY_SIZE
512
+ a = jp.arange(n, dtype=jp.float32)
513
+ b = jp.ones(n, dtype=jp.float32)
449
514
  return jax_add(a, b)
450
515
 
451
516
  with jax.default_device(wp.device_to_jax(device)):
452
517
  (y,) = f()
453
518
 
519
+ wp.synchronize_device(device)
520
+
454
521
  result = np.asarray(y)
455
- expected = np.arange(1, 11, dtype=np.int32)
522
+ expected = np.arange(1, ARRAY_SIZE + 1, dtype=np.float32)
456
523
 
457
524
  assert_np_equal(result, expected)
458
525
 
@@ -465,7 +532,8 @@ def test_ffi_jax_kernel_sincos(test, device):
465
532
  from warp.jax_experimental.ffi import jax_kernel
466
533
 
467
534
  jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
468
- n = 32
535
+
536
+ n = ARRAY_SIZE
469
537
 
470
538
  @jax.jit
471
539
  def f():
@@ -475,6 +543,8 @@ def test_ffi_jax_kernel_sincos(test, device):
475
543
  with jax.default_device(wp.device_to_jax(device)):
476
544
  s, c = f()
477
545
 
546
+ wp.synchronize_device(device)
547
+
478
548
  result_s = np.asarray(s)
479
549
  result_c = np.asarray(c)
480
550
 
@@ -498,6 +568,8 @@ def test_ffi_jax_kernel_diagonal(test, device):
498
568
  # launch dimensions determine output size
499
569
  return jax_diagonal(launch_dims=4)
500
570
 
571
+ wp.synchronize_device(device)
572
+
501
573
  with jax.default_device(wp.device_to_jax(device)):
502
574
  (d,) = f()
503
575
 
@@ -527,12 +599,14 @@ def test_ffi_jax_kernel_in_out(test, device):
527
599
  f = jax.jit(jax_func)
528
600
 
529
601
  with jax.default_device(wp.device_to_jax(device)):
530
- a = jp.ones(10, dtype=jp.float32)
531
- b = jp.arange(10, dtype=jp.float32)
602
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
603
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
532
604
  b, c = f(a, b)
533
605
 
534
- assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
535
- assert_np_equal(c, np.full(10, 2, dtype=np.float32))
606
+ wp.synchronize_device(device)
607
+
608
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
609
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
536
610
 
537
611
 
538
612
  @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
@@ -546,14 +620,16 @@ def test_ffi_jax_kernel_scale_vec_constant(test, device):
546
620
 
547
621
  @jax.jit
548
622
  def f():
549
- a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
623
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
550
624
  s = 2.0
551
625
  return jax_scale_vec(a, s)
552
626
 
553
627
  with jax.default_device(wp.device_to_jax(device)):
554
628
  (b,) = f()
555
629
 
556
- expected = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
630
+ wp.synchronize_device(device)
631
+
632
+ expected = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
557
633
 
558
634
  assert_np_equal(b, expected)
559
635
 
@@ -572,13 +648,15 @@ def test_ffi_jax_kernel_scale_vec_static(test, device):
572
648
  def f(a, s):
573
649
  return jax_scale_vec(a, s)
574
650
 
575
- a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
651
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # array of vec2
576
652
  s = 3.0
577
653
 
578
654
  with jax.default_device(wp.device_to_jax(device)):
579
655
  (b,) = f(a, s)
580
656
 
581
- expected = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
657
+ wp.synchronize_device(device)
658
+
659
+ expected = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
582
660
 
583
661
  assert_np_equal(b, expected)
584
662
 
@@ -605,6 +683,8 @@ def test_ffi_jax_kernel_launch_dims_default(test, device):
605
683
  with jax.default_device(wp.device_to_jax(device)):
606
684
  (result,) = f()
607
685
 
686
+ wp.synchronize_device(device)
687
+
608
688
  expected = np.full((3, 4), 12, dtype=np.float32)
609
689
 
610
690
  test.assertEqual(result.shape, expected.shape)
@@ -641,6 +721,8 @@ def test_ffi_jax_kernel_launch_dims_custom(test, device):
641
721
  with jax.default_device(wp.device_to_jax(device)):
642
722
  result1, result2 = f()
643
723
 
724
+ wp.synchronize_device(device)
725
+
644
726
  expected1 = np.full((3, 4), 12, dtype=np.float32)
645
727
  expected2 = np.full((4, 3), 12, dtype=np.float32)
646
728
 
@@ -662,8 +744,8 @@ def test_ffi_jax_callable_scale_constant(test, device):
662
744
  @jax.jit
663
745
  def f():
664
746
  # inputs
665
- a = jp.arange(10, dtype=jp.float32)
666
- b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
747
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
748
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
667
749
  s = 2.0
668
750
 
669
751
  # output shapes
@@ -676,8 +758,10 @@ def test_ffi_jax_callable_scale_constant(test, device):
676
758
  with jax.default_device(wp.device_to_jax(device)):
677
759
  result1, result2 = f()
678
760
 
679
- expected1 = 2 * np.arange(10, dtype=np.float32)
680
- expected2 = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
761
+ wp.synchronize_device(device)
762
+
763
+ expected1 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32)
764
+ expected2 = 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
681
765
 
682
766
  assert_np_equal(result1, expected1)
683
767
  assert_np_equal(result2, expected2)
@@ -704,13 +788,15 @@ def test_ffi_jax_callable_scale_static(test, device):
704
788
 
705
789
  with jax.default_device(wp.device_to_jax(device)):
706
790
  # inputs
707
- a = jp.arange(10, dtype=jp.float32)
708
- b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
791
+ a = jp.arange(ARRAY_SIZE, dtype=jp.float32)
792
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32).reshape((ARRAY_SIZE // 2, 2)) # wp.vec2
709
793
  s = 3.0
710
794
  result1, result2 = f(a, b, s)
711
795
 
712
- expected1 = 3 * np.arange(10, dtype=np.float32)
713
- expected2 = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
796
+ wp.synchronize_device(device)
797
+
798
+ expected1 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32)
799
+ expected2 = 3 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2))
714
800
 
715
801
  assert_np_equal(result1, expected1)
716
802
  assert_np_equal(result2, expected2)
@@ -728,12 +814,224 @@ def test_ffi_jax_callable_in_out(test, device):
728
814
  f = jax.jit(jax_func)
729
815
 
730
816
  with jax.default_device(wp.device_to_jax(device)):
731
- a = jp.ones(10, dtype=jp.float32)
732
- b = jp.arange(10, dtype=jp.float32)
817
+ a = jp.ones(ARRAY_SIZE, dtype=jp.float32)
818
+ b = jp.arange(ARRAY_SIZE, dtype=jp.float32)
733
819
  b, c = f(a, b)
734
820
 
735
- assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
736
- assert_np_equal(c, np.full(10, 2, dtype=np.float32))
821
+ wp.synchronize_device(device)
822
+
823
+ assert_np_equal(b, np.arange(1, ARRAY_SIZE + 1, dtype=np.float32))
824
+ assert_np_equal(c, np.full(ARRAY_SIZE, 2, dtype=np.float32))
825
+
826
+
827
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
828
+ def test_ffi_jax_callable_graph_cache(test, device):
829
+ # test graph caching limits
830
+ import jax
831
+ import jax.numpy as jp
832
+
833
+ from warp.jax_experimental.ffi import (
834
+ GraphMode,
835
+ clear_jax_callable_graph_cache,
836
+ get_jax_callable_default_graph_cache_max,
837
+ jax_callable,
838
+ set_jax_callable_default_graph_cache_max,
839
+ )
840
+
841
+ # --- test with default cache settings ---
842
+
843
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
844
+ f = jax.jit(jax_double)
845
+ arrays = []
846
+
847
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
848
+
849
+ with jax.default_device(wp.device_to_jax(device)):
850
+ for i in range(10):
851
+ n = 10 + i
852
+ a = jp.arange(n, dtype=jp.float32)
853
+ (b,) = f(a)
854
+
855
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
856
+
857
+ # ensure graph cache is always growing
858
+ test.assertEqual(jax_double.graph_cache_size, i + 1)
859
+
860
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture each time
861
+ arrays.append(a)
862
+
863
+ # --- test clearing one callable's cache ---
864
+
865
+ clear_jax_callable_graph_cache(jax_double)
866
+
867
+ test.assertEqual(jax_double.graph_cache_size, 0)
868
+
869
+ # --- test with a custom cache limit ---
870
+
871
+ graph_cache_max = 5
872
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP, graph_cache_max=graph_cache_max)
873
+ f = jax.jit(jax_double)
874
+ arrays = []
875
+
876
+ test.assertEqual(jax_double.graph_cache_max, graph_cache_max)
877
+
878
+ with jax.default_device(wp.device_to_jax(device)):
879
+ for i in range(10):
880
+ n = 10 + i
881
+ a = jp.arange(n, dtype=jp.float32)
882
+ (b,) = f(a)
883
+
884
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
885
+
886
+ # ensure graph cache size is capped
887
+ test.assertEqual(jax_double.graph_cache_size, min(i + 1, graph_cache_max))
888
+
889
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
890
+ arrays.append(a)
891
+
892
+ # --- test clearing all callables' caches ---
893
+
894
+ clear_jax_callable_graph_cache()
895
+
896
+ with wp.jax_experimental.ffi._FFI_REGISTRY_LOCK:
897
+ for c in wp.jax_experimental.ffi._FFI_CALLABLE_REGISTRY.values():
898
+ test.assertEqual(c.graph_cache_size, 0)
899
+
900
+ # --- test with a custom default cache limit ---
901
+
902
+ saved_max = get_jax_callable_default_graph_cache_max()
903
+ try:
904
+ set_jax_callable_default_graph_cache_max(5)
905
+ jax_double = jax_callable(double_func, graph_mode=GraphMode.WARP)
906
+ f = jax.jit(jax_double)
907
+ arrays = []
908
+
909
+ test.assertEqual(jax_double.graph_cache_max, get_jax_callable_default_graph_cache_max())
910
+
911
+ with jax.default_device(wp.device_to_jax(device)):
912
+ for i in range(10):
913
+ n = 10 + i
914
+ a = jp.arange(n, dtype=jp.float32)
915
+ (b,) = f(a)
916
+
917
+ assert_np_equal(b, 2 * np.arange(n, dtype=np.float32))
918
+
919
+ # ensure graph cache size is capped
920
+ test.assertEqual(
921
+ jax_double.graph_cache_size,
922
+ min(i + 1, get_jax_callable_default_graph_cache_max()),
923
+ )
924
+
925
+ # keep JAX array alive to prevent the memory from being reused, thus forcing a new graph capture
926
+ arrays.append(a)
927
+
928
+ clear_jax_callable_graph_cache()
929
+
930
+ finally:
931
+ set_jax_callable_default_graph_cache_max(saved_max)
932
+
933
+
934
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
935
+ def test_ffi_jax_callable_pmap_mul(test, device):
936
+ import jax
937
+ import jax.numpy as jp
938
+
939
+ from warp.jax_experimental.ffi import jax_callable
940
+
941
+ j = jax_callable(double_func, num_outputs=1)
942
+
943
+ ndev = jax.local_device_count()
944
+ per_device = max(ARRAY_SIZE // ndev, 64)
945
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
946
+
947
+ def per_device_func(v):
948
+ (y,) = j(v)
949
+ return y
950
+
951
+ y = jax.pmap(per_device_func)(x)
952
+
953
+ wp.synchronize()
954
+
955
+ assert_np_equal(np.asarray(y), 2 * np.asarray(x))
956
+
957
+
958
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
959
+ def test_ffi_jax_callable_pmap_multi_output(test, device):
960
+ import jax
961
+ import jax.numpy as jp
962
+
963
+ from warp.jax_experimental.ffi import jax_callable
964
+
965
+ def multi_out_py(
966
+ a: wp.array(dtype=float),
967
+ b: wp.array(dtype=float),
968
+ s: float,
969
+ c: wp.array(dtype=float),
970
+ d: wp.array(dtype=float),
971
+ ):
972
+ wp.launch(multi_out_kernel, dim=a.shape, inputs=[a, b, s], outputs=[c, d])
973
+
974
+ j = jax_callable(multi_out_py, num_outputs=2)
975
+
976
+ ndev = jax.local_device_count()
977
+ per_device = max(ARRAY_SIZE // ndev, 64)
978
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
979
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
980
+ s = 3.0
981
+
982
+ def per_device_func(aa, bb):
983
+ c, d = j(aa, bb, s)
984
+ return c + d # simple combine to exercise both outputs
985
+
986
+ out = jax.pmap(per_device_func)(a, b)
987
+
988
+ wp.synchronize()
989
+
990
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
991
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
992
+ ref = (a_np + b_np) + s * a_np
993
+ assert_np_equal(np.asarray(out), ref)
994
+
995
+
996
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
997
+ def test_ffi_jax_callable_pmap_multi_stage(test, device):
998
+ import jax
999
+ import jax.numpy as jp
1000
+
1001
+ from warp.jax_experimental.ffi import jax_callable
1002
+
1003
+ def multi_stage_py(
1004
+ a: wp.array(dtype=float),
1005
+ b: wp.array(dtype=float),
1006
+ alpha: float,
1007
+ tmp: wp.array(dtype=float),
1008
+ out: wp.array(dtype=float),
1009
+ ):
1010
+ wp.launch(add_kernel, dim=a.shape, inputs=[a, b], outputs=[tmp])
1011
+ wp.launch(axpy_kernel, dim=a.shape, inputs=[tmp, b, alpha], outputs=[out])
1012
+
1013
+ j = jax_callable(multi_stage_py, num_outputs=2)
1014
+
1015
+ ndev = jax.local_device_count()
1016
+ per_device = max(ARRAY_SIZE // ndev, 64)
1017
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1018
+ b = jp.ones((ndev, per_device), dtype=jp.float32)
1019
+ alpha = 2.5
1020
+
1021
+ def per_device_func(aa, bb):
1022
+ tmp, out = j(aa, bb, alpha)
1023
+ return tmp + out
1024
+
1025
+ combined = jax.pmap(per_device_func)(a, b)
1026
+
1027
+ wp.synchronize()
1028
+
1029
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1030
+ b_np = np.ones((ndev, per_device), dtype=np.float32)
1031
+ tmp_ref = a_np + b_np
1032
+ out_ref = alpha * (a_np + b_np) + b_np
1033
+ ref = tmp_ref + out_ref
1034
+ assert_np_equal(np.asarray(combined), ref)
737
1035
 
738
1036
 
739
1037
  @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
@@ -770,7 +1068,7 @@ def test_ffi_callback(test, device):
770
1068
  # register callback
771
1069
  register_ffi_callback("warp_func", warp_func)
772
1070
 
773
- n = 10
1071
+ n = ARRAY_SIZE
774
1072
 
775
1073
  with jax.default_device(wp.device_to_jax(device)):
776
1074
  # inputs
@@ -788,8 +1086,344 @@ def test_ffi_callback(test, device):
788
1086
  # call it
789
1087
  c, d = call(a, b, scale=s)
790
1088
 
791
- assert_np_equal(c, 2 * np.arange(10, dtype=np.float32))
792
- assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2)))
1089
+ wp.synchronize_device(device)
1090
+
1091
+ assert_np_equal(c, 2 * np.arange(ARRAY_SIZE, dtype=np.float32))
1092
+ assert_np_equal(d, 2 * np.arange(ARRAY_SIZE, dtype=np.float32).reshape((ARRAY_SIZE // 2, 2)))
1093
+
1094
+
1095
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1096
+ def test_ffi_jax_kernel_autodiff_simple(test, device):
1097
+ import jax
1098
+ import jax.numpy as jp
1099
+
1100
+ from warp.jax_experimental.ffi import jax_kernel
1101
+
1102
+ jax_func = jax_kernel(
1103
+ scale_sum_square_kernel,
1104
+ num_outputs=1,
1105
+ enable_backward=True,
1106
+ )
1107
+
1108
+ from functools import partial
1109
+
1110
+ @partial(jax.jit, static_argnames=["s"])
1111
+ def loss(a, b, s):
1112
+ out = jax_func(a, b, s)[0]
1113
+ return jp.sum(out)
1114
+
1115
+ n = ARRAY_SIZE
1116
+ a = jp.arange(n, dtype=jp.float32)
1117
+ b = jp.ones(n, dtype=jp.float32)
1118
+ s = 2.0
1119
+
1120
+ with jax.default_device(wp.device_to_jax(device)):
1121
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1122
+
1123
+ wp.synchronize_device(device)
1124
+
1125
+ # reference gradients
1126
+ # d/da sum((a*s + b)^2) = sum(2*(a*s + b) * s)
1127
+ # d/db sum((a*s + b)^2) = sum(2*(a*s + b))
1128
+ a_np = np.arange(n, dtype=np.float32)
1129
+ b_np = np.ones(n, dtype=np.float32)
1130
+ ref_da = 2.0 * (a_np * s + b_np) * s
1131
+ ref_db = 2.0 * (a_np * s + b_np)
1132
+
1133
+ assert_np_equal(np.asarray(da), ref_da)
1134
+ assert_np_equal(np.asarray(db), ref_db)
1135
+
1136
+
1137
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1138
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_simple(test, device):
1139
+ import jax
1140
+ import jax.numpy as jp
1141
+
1142
+ from warp.jax_experimental.ffi import jax_kernel
1143
+
1144
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1145
+
1146
+ def loss(a, b, s):
1147
+ out = jax_func(a, b, s)[0]
1148
+ return jp.sum(out)
1149
+
1150
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1151
+
1152
+ # more typical: jit(grad(...)) with static scalar
1153
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1154
+
1155
+ n = ARRAY_SIZE
1156
+ a = jp.arange(n, dtype=jp.float32)
1157
+ b = jp.ones(n, dtype=jp.float32)
1158
+ s = 2.0
1159
+
1160
+ with jax.default_device(wp.device_to_jax(device)):
1161
+ da, db = jitted_grad(a, b, s)
1162
+
1163
+ wp.synchronize_device(device)
1164
+
1165
+ a_np = np.arange(n, dtype=np.float32)
1166
+ b_np = np.ones(n, dtype=np.float32)
1167
+ ref_da = 2.0 * (a_np * s + b_np) * s
1168
+ ref_db = 2.0 * (a_np * s + b_np)
1169
+
1170
+ assert_np_equal(np.asarray(da), ref_da)
1171
+ assert_np_equal(np.asarray(db), ref_db)
1172
+
1173
+
1174
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1175
+ def test_ffi_jax_kernel_autodiff_multi_output(test, device):
1176
+ import jax
1177
+ import jax.numpy as jp
1178
+
1179
+ from warp.jax_experimental.ffi import jax_kernel
1180
+
1181
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1182
+
1183
+ def caller(fn, a, b, s):
1184
+ c, d = fn(a, b, s)
1185
+ return jp.sum(c + d)
1186
+
1187
+ @jax.jit
1188
+ def grads(a, b, s):
1189
+ # mark s as static in the inner call via partial to avoid hashing
1190
+ def _inner(a, b, s):
1191
+ return caller(jax_func, a, b, s)
1192
+
1193
+ return jax.grad(lambda a, b: _inner(a, b, 2.0), argnums=(0, 1))(a, b)
1194
+
1195
+ n = ARRAY_SIZE
1196
+ a = jp.arange(n, dtype=jp.float32)
1197
+ b = jp.ones(n, dtype=jp.float32)
1198
+ s = 2.0
1199
+
1200
+ with jax.default_device(wp.device_to_jax(device)):
1201
+ da, db = grads(a, b, s)
1202
+
1203
+ wp.synchronize_device(device)
1204
+
1205
+ a_np = np.arange(n, dtype=np.float32)
1206
+ b_np = np.ones(n, dtype=np.float32)
1207
+ # d/da sum(c+d) = 2*a + b*s
1208
+ ref_da = 2.0 * a_np + b_np * s
1209
+ # d/db sum(c+d) = a*s
1210
+ ref_db = a_np * s
1211
+
1212
+ assert_np_equal(np.asarray(da), ref_da)
1213
+ assert_np_equal(np.asarray(db), ref_db)
1214
+
1215
+
1216
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1217
+ def test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output(test, device):
1218
+ import jax
1219
+ import jax.numpy as jp
1220
+
1221
+ from warp.jax_experimental.ffi import jax_kernel
1222
+
1223
+ jax_func = jax_kernel(multi_out_kernel_v3, num_outputs=2, enable_backward=True)
1224
+
1225
+ def loss(a, b, s):
1226
+ c, d = jax_func(a, b, s)
1227
+ return jp.sum(c + d)
1228
+
1229
+ grad_fn = jax.grad(loss, argnums=(0, 1))
1230
+ jitted_grad = jax.jit(lambda a, b, s: grad_fn(a, b, s), static_argnames=("s",))
1231
+
1232
+ n = ARRAY_SIZE
1233
+ a = jp.arange(n, dtype=jp.float32)
1234
+ b = jp.ones(n, dtype=jp.float32)
1235
+ s = 2.0
1236
+
1237
+ with jax.default_device(wp.device_to_jax(device)):
1238
+ da, db = jitted_grad(a, b, s)
1239
+
1240
+ wp.synchronize_device(device)
1241
+
1242
+ a_np = np.arange(n, dtype=np.float32)
1243
+ b_np = np.ones(n, dtype=np.float32)
1244
+ ref_da = 2.0 * a_np + b_np * s
1245
+ ref_db = a_np * s
1246
+
1247
+ assert_np_equal(np.asarray(da), ref_da)
1248
+ assert_np_equal(np.asarray(db), ref_db)
1249
+
1250
+
1251
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1252
+ def test_ffi_jax_kernel_autodiff_2d(test, device):
1253
+ import jax
1254
+ import jax.numpy as jp
1255
+
1256
+ from warp.jax_experimental.ffi import jax_kernel
1257
+
1258
+ jax_func = jax_kernel(inc_2d_kernel, num_outputs=1, enable_backward=True)
1259
+
1260
+ @jax.jit
1261
+ def loss(a):
1262
+ out = jax_func(a)[0]
1263
+ return jp.sum(out)
1264
+
1265
+ n, m = 8, 6
1266
+ a = jp.arange(n * m, dtype=jp.float32).reshape((n, m))
1267
+
1268
+ with jax.default_device(wp.device_to_jax(device)):
1269
+ (da,) = jax.grad(loss, argnums=(0,))(a)
1270
+
1271
+ wp.synchronize_device(device)
1272
+
1273
+ ref = np.ones((n, m), dtype=np.float32)
1274
+ assert_np_equal(np.asarray(da), ref)
1275
+
1276
+
1277
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1278
+ def test_ffi_jax_kernel_autodiff_vec2(test, device):
1279
+ import jax
1280
+ import jax.numpy as jp
1281
+
1282
+ from warp.jax_experimental.ffi import jax_kernel
1283
+
1284
+ jax_func = jax_kernel(scale_vec_kernel, num_outputs=1, enable_backward=True)
1285
+
1286
+ from functools import partial
1287
+
1288
+ @partial(jax.jit, static_argnames=("s",))
1289
+ def loss(a, s):
1290
+ out = jax_func(a, s)[0]
1291
+ return jp.sum(out)
1292
+
1293
+ n = ARRAY_SIZE
1294
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2))
1295
+ s = 3.0
1296
+
1297
+ with jax.default_device(wp.device_to_jax(device)):
1298
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1299
+
1300
+ wp.synchronize_device(device)
1301
+
1302
+ # d/da sum(a*s) = s
1303
+ ref = np.full_like(np.asarray(a), s)
1304
+ assert_np_equal(np.asarray(da), ref)
1305
+
1306
+
1307
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1308
+ def test_ffi_jax_kernel_autodiff_mat22(test, device):
1309
+ import jax
1310
+ import jax.numpy as jp
1311
+
1312
+ from warp.jax_experimental.ffi import jax_kernel
1313
+
1314
+ @wp.kernel
1315
+ def scale_mat_kernel(a: wp.array(dtype=wp.mat22), s: float, out: wp.array(dtype=wp.mat22)):
1316
+ tid = wp.tid()
1317
+ out[tid] = a[tid] * s
1318
+
1319
+ jax_func = jax_kernel(scale_mat_kernel, num_outputs=1, enable_backward=True)
1320
+
1321
+ from functools import partial
1322
+
1323
+ @partial(jax.jit, static_argnames=("s",))
1324
+ def loss(a, s):
1325
+ out = jax_func(a, s)[0]
1326
+ return jp.sum(out)
1327
+
1328
+ n = 12 # must be divisible by 4 for 2x2 matrices
1329
+ a = jp.arange(n, dtype=jp.float32).reshape((n // 4, 2, 2))
1330
+ s = 2.5
1331
+
1332
+ with jax.default_device(wp.device_to_jax(device)):
1333
+ (da,) = jax.grad(loss, argnums=(0,))(a, s)
1334
+
1335
+ wp.synchronize_device(device)
1336
+
1337
+ ref = np.full((n // 4, 2, 2), s, dtype=np.float32)
1338
+ assert_np_equal(np.asarray(da), ref)
1339
+
1340
+
1341
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1342
+ def test_ffi_jax_kernel_autodiff_static_required(test, device):
1343
+ import jax
1344
+ import jax.numpy as jp
1345
+
1346
+ from warp.jax_experimental.ffi import jax_kernel
1347
+
1348
+ # Require explicit static_argnames for scalar s
1349
+ jax_func = jax_kernel(scale_sum_square_kernel, num_outputs=1, enable_backward=True)
1350
+
1351
+ def loss(a, b, s):
1352
+ out = jax_func(a, b, s)[0]
1353
+ return jp.sum(out)
1354
+
1355
+ n = ARRAY_SIZE
1356
+ a = jp.arange(n, dtype=jp.float32)
1357
+ b = jp.ones(n, dtype=jp.float32)
1358
+ s = 1.5
1359
+
1360
+ with jax.default_device(wp.device_to_jax(device)):
1361
+ da, db = jax.grad(loss, argnums=(0, 1))(a, b, s)
1362
+
1363
+ wp.synchronize_device(device)
1364
+
1365
+ a_np = np.arange(n, dtype=np.float32)
1366
+ b_np = np.ones(n, dtype=np.float32)
1367
+ ref_da = 2.0 * (a_np * s + b_np) * s
1368
+ ref_db = 2.0 * (a_np * s + b_np)
1369
+
1370
+ assert_np_equal(np.asarray(da), ref_da)
1371
+ assert_np_equal(np.asarray(db), ref_db)
1372
+
1373
+
1374
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1375
+ def test_ffi_jax_kernel_autodiff_pmap_triple(test, device):
1376
+ import jax
1377
+ import jax.numpy as jp
1378
+
1379
+ from warp.jax_experimental.ffi import jax_kernel
1380
+
1381
+ jax_mul = jax_kernel(triple_kernel, num_outputs=1, enable_backward=True)
1382
+
1383
+ ndev = jax.local_device_count()
1384
+ per_device = ARRAY_SIZE // ndev
1385
+ x = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1386
+
1387
+ def per_device_loss(x):
1388
+ y = jax_mul(x)[0]
1389
+ return jp.sum(y)
1390
+
1391
+ grads = jax.pmap(jax.grad(per_device_loss))(x)
1392
+
1393
+ wp.synchronize()
1394
+
1395
+ assert_np_equal(np.asarray(grads), np.full((ndev, per_device), 3.0, dtype=np.float32))
1396
+
1397
+
1398
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
1399
+ def test_ffi_jax_kernel_autodiff_pmap_multi_output(test, device):
1400
+ import jax
1401
+ import jax.numpy as jp
1402
+
1403
+ from warp.jax_experimental.ffi import jax_kernel
1404
+
1405
+ jax_mo = jax_kernel(multi_out_kernel_v2, num_outputs=2, enable_backward=True)
1406
+
1407
+ ndev = jax.local_device_count()
1408
+ per_device = ARRAY_SIZE // ndev
1409
+ a = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1410
+ b = jp.arange(ndev * per_device, dtype=jp.float32).reshape((ndev, per_device))
1411
+ s = 2.0
1412
+
1413
+ def per_dev_loss(aa, bb):
1414
+ c, d = jax_mo(aa, bb, s)
1415
+ return jp.sum(c + d)
1416
+
1417
+ da, db = jax.pmap(jax.grad(per_dev_loss, argnums=(0, 1)))(a, b)
1418
+
1419
+ wp.synchronize()
1420
+
1421
+ a_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1422
+ b_np = np.arange(ndev * per_device, dtype=np.float32).reshape((ndev, per_device))
1423
+ ref_da = 2.0 * a_np + b_np * s
1424
+ ref_db = a_np * s
1425
+ assert_np_equal(np.asarray(da), ref_da)
1426
+ assert_np_equal(np.asarray(db), ref_db)
793
1427
 
794
1428
 
795
1429
  class TestJax(unittest.TestCase):
@@ -936,10 +1570,99 @@ try:
936
1570
  add_function_test(
937
1571
  TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
938
1572
  )
1573
+ add_function_test(
1574
+ TestJax,
1575
+ "test_ffi_jax_callable_graph_cache",
1576
+ test_ffi_jax_callable_graph_cache,
1577
+ devices=jax_compatible_cuda_devices,
1578
+ )
1579
+
1580
+ # pmap tests
1581
+ add_function_test(
1582
+ TestJax,
1583
+ "test_ffi_jax_callable_pmap_multi_output",
1584
+ test_ffi_jax_callable_pmap_multi_output,
1585
+ devices=None,
1586
+ )
1587
+ add_function_test(
1588
+ TestJax,
1589
+ "test_ffi_jax_callable_pmap_mul",
1590
+ test_ffi_jax_callable_pmap_mul,
1591
+ devices=None,
1592
+ )
1593
+ add_function_test(
1594
+ TestJax,
1595
+ "test_ffi_jax_callable_pmap_multi_stage",
1596
+ test_ffi_jax_callable_pmap_multi_stage,
1597
+ devices=None,
1598
+ )
939
1599
 
940
1600
  # ffi callback tests
941
1601
  add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
942
1602
 
1603
+ # autodiff tests
1604
+ add_function_test(
1605
+ TestJax,
1606
+ "test_ffi_jax_kernel_autodiff_simple",
1607
+ test_ffi_jax_kernel_autodiff_simple,
1608
+ devices=jax_compatible_cuda_devices,
1609
+ )
1610
+ add_function_test(
1611
+ TestJax,
1612
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_simple",
1613
+ test_ffi_jax_kernel_autodiff_jit_of_grad_simple,
1614
+ devices=jax_compatible_cuda_devices,
1615
+ )
1616
+ add_function_test(
1617
+ TestJax,
1618
+ "test_ffi_jax_kernel_autodiff_multi_output",
1619
+ test_ffi_jax_kernel_autodiff_multi_output,
1620
+ devices=jax_compatible_cuda_devices,
1621
+ )
1622
+ add_function_test(
1623
+ TestJax,
1624
+ "test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output",
1625
+ test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output,
1626
+ devices=jax_compatible_cuda_devices,
1627
+ )
1628
+ add_function_test(
1629
+ TestJax,
1630
+ "test_ffi_jax_kernel_autodiff_2d",
1631
+ test_ffi_jax_kernel_autodiff_2d,
1632
+ devices=jax_compatible_cuda_devices,
1633
+ )
1634
+ add_function_test(
1635
+ TestJax,
1636
+ "test_ffi_jax_kernel_autodiff_vec2",
1637
+ test_ffi_jax_kernel_autodiff_vec2,
1638
+ devices=jax_compatible_cuda_devices,
1639
+ )
1640
+ add_function_test(
1641
+ TestJax,
1642
+ "test_ffi_jax_kernel_autodiff_mat22",
1643
+ test_ffi_jax_kernel_autodiff_mat22,
1644
+ devices=jax_compatible_cuda_devices,
1645
+ )
1646
+ add_function_test(
1647
+ TestJax,
1648
+ "test_ffi_jax_kernel_autodiff_static_required",
1649
+ test_ffi_jax_kernel_autodiff_static_required,
1650
+ devices=jax_compatible_cuda_devices,
1651
+ )
1652
+
1653
+ # autodiff with pmap tests
1654
+ add_function_test(
1655
+ TestJax,
1656
+ "test_ffi_jax_kernel_autodiff_pmap_triple",
1657
+ test_ffi_jax_kernel_autodiff_pmap_triple,
1658
+ devices=None,
1659
+ )
1660
+ add_function_test(
1661
+ TestJax,
1662
+ "test_ffi_jax_kernel_autodiff_pmap_multi_output",
1663
+ test_ffi_jax_kernel_autodiff_pmap_multi_output,
1664
+ devices=None,
1665
+ )
943
1666
 
944
1667
  except Exception as e:
945
1668
  print(f"Skipping Jax tests due to exception: {e}")