warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
 
3
18
  import numpy as np
@@ -446,8 +461,8 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
446
461
  node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
447
462
 
448
463
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
449
- cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
450
- cy = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
464
+ cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
465
+ cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
451
466
 
452
467
  w = cx * cy
453
468
 
@@ -460,7 +475,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
460
475
 
461
476
  w = float(1.0)
462
477
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
463
- w *= wp.select(node_i == 0, coords[0], 1.0 - coords[0])
478
+ w *= wp.where(node_i == 0, 1.0 - coords[0], coords[0])
464
479
  else:
465
480
  for k in range(ORDER_PLUS_ONE):
466
481
  if k != node_i:
@@ -469,7 +484,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
469
484
  w *= LAGRANGE_SCALE[node_i]
470
485
 
471
486
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
472
- w *= wp.select(node_j == 0, coords[1], 1.0 - coords[1])
487
+ w *= wp.where(node_j == 0, 1.0 - coords[1], coords[1])
473
488
  else:
474
489
  for k in range(ORDER_PLUS_ONE):
475
490
  if k != node_j:
@@ -498,11 +513,11 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
498
513
  node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
499
514
 
500
515
  if node_type == SquareSerendipityShapeFunctions.VERTEX:
501
- cx = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
502
- cy = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
516
+ cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
517
+ cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
503
518
 
504
- gx = wp.select(node_i == 0, 1.0, -1.0)
505
- gy = wp.select(node_j == 0, 1.0, -1.0)
519
+ gx = wp.where(node_i == 0, -1.0, 1.0)
520
+ gy = wp.where(node_j == 0, -1.0, 1.0)
506
521
 
507
522
  if ORDER == 2:
508
523
  w = cx + cy - 2.0 + LOBATTO_COORDS[1]
@@ -522,7 +537,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
522
537
  return wp.vec2(grad_x, grad_y) * DEGREE_3_CIRCLE_SCALE
523
538
 
524
539
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
525
- prefix_x = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
540
+ prefix_x = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
526
541
  else:
527
542
  prefix_x = LAGRANGE_SCALE[node_j]
528
543
  for k in range(ORDER_PLUS_ONE):
@@ -530,7 +545,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
530
545
  prefix_x *= coords[1] - LOBATTO_COORDS[k]
531
546
 
532
547
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
533
- prefix_y = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
548
+ prefix_y = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
534
549
  else:
535
550
  prefix_y = LAGRANGE_SCALE[node_i]
536
551
  for k in range(ORDER_PLUS_ONE):
@@ -538,7 +553,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
538
553
  prefix_y *= coords[0] - LOBATTO_COORDS[k]
539
554
 
540
555
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
541
- grad_y = wp.select(node_j == 0, 1.0, -1.0) * prefix_y
556
+ grad_y = wp.where(node_j == 0, -1.0, 1.0) * prefix_y
542
557
  else:
543
558
  prefix_y *= LAGRANGE_SCALE[node_j]
544
559
  grad_y = float(0.0)
@@ -549,7 +564,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
549
564
  prefix_y *= delta_y
550
565
 
551
566
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
552
- grad_x = wp.select(node_i == 0, 1.0, -1.0) * prefix_x
567
+ grad_x = wp.where(node_i == 0, -1.0, 1.0) * prefix_x
553
568
  else:
554
569
  prefix_x *= LAGRANGE_SCALE[node_i]
555
570
  grad_x = float(0.0)
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
 
3
18
  import warp as wp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import numpy as np
2
17
 
3
18
  import warp as wp
@@ -181,7 +196,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
181
196
  def trace_node_quadrature_weight(node_index_in_element: int):
182
197
  node_type, type_index = self.node_type_and_type_index(node_index_in_element)
183
198
 
184
- return wp.select(node_type == TrianglePolynomialShapeFunctions.VERTEX, EDGE_WEIGHT, VERTEX_WEIGHT)
199
+ return wp.where(node_type == TrianglePolynomialShapeFunctions.VERTEX, VERTEX_WEIGHT, EDGE_WEIGHT)
185
200
 
186
201
  return trace_node_quadrature_weight
187
202
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import warp as wp
2
17
  from warp.fem import cache
3
18
  from warp.fem.geometry import Tetmesh
@@ -229,10 +244,10 @@ class TetmeshSpaceTopology(SpaceTopology):
229
244
  edge = type_index // INTERIOR_NODES_PER_EDGE
230
245
  c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
231
246
 
232
- return wp.select(
247
+ return wp.where(
233
248
  geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
234
- 1.0,
235
249
  -1.0,
250
+ 1.0,
236
251
  )
237
252
 
238
253
  if wp.static(INTERIOR_NODES_PER_FACE > 0):
@@ -242,7 +257,7 @@ class TetmeshSpaceTopology(SpaceTopology):
242
257
  global_face_index = topo_arg.tet_face_indices[element_index][face]
243
258
  inner = topo_arg.face_tet_indices[global_face_index][0]
244
259
 
245
- return wp.select(inner == element_index, -1.0, 1.0)
260
+ return wp.where(inner == element_index, 1.0, -1.0)
246
261
 
247
262
  return 1.0
248
263
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Optional, Tuple, Type
2
17
 
3
18
  import warp as wp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import warp as wp
2
17
  from warp.fem import cache
3
18
  from warp.fem.geometry import Trimesh
@@ -160,11 +175,11 @@ class TrimeshSpaceTopology(SpaceTopology):
160
175
  edge = type_index // INTERIOR_NODES_PER_SIDE
161
176
 
162
177
  global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
163
- return wp.select(
178
+ return wp.where(
164
179
  topo_arg.edge_vertex_indices[global_edge_index][0]
165
180
  == geo_arg.topology.tri_vertex_indices[element_index][edge],
166
- -1.0,
167
181
  1.0,
182
+ -1.0,
168
183
  )
169
184
 
170
185
  return 1.0
warp/fem/types.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from enum import Enum
2
17
 
3
18
  import warp as wp
warp/fem/utils.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Tuple, Union
2
17
 
3
18
  import numpy as np
@@ -41,13 +56,11 @@ def compress_node_indices(
41
56
  sorted_node_indices = sorted_node_indices_temp.array
42
57
  sorted_array_indices = sorted_array_indices_temp.array
43
58
 
44
- wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
45
-
46
59
  indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
47
60
  wp.launch(
48
- kernel=_iota_kernel,
61
+ kernel=_prepare_node_sort_kernel,
49
62
  dim=index_count,
50
- inputs=[sorted_array_indices, indices_per_element],
63
+ inputs=[node_indices.flatten(), sorted_node_indices, sorted_array_indices, indices_per_element],
51
64
  )
52
65
 
53
66
  # Sort indices
@@ -154,8 +167,16 @@ def masked_indices(
154
167
 
155
168
 
156
169
  @wp.kernel
157
- def _iota_kernel(indices: wp.array(dtype=int), divisor: int):
158
- indices[wp.tid()] = wp.tid() // divisor
170
+ def _prepare_node_sort_kernel(
171
+ node_indices: wp.array(dtype=int),
172
+ sort_keys: wp.array(dtype=int),
173
+ sort_values: wp.array(dtype=int),
174
+ divisor: int,
175
+ ):
176
+ i = wp.tid()
177
+ node = node_indices[i]
178
+ sort_keys[i] = wp.where(node >= 0, node, NULL_NODE_INDEX)
179
+ sort_values[i] = i // divisor
159
180
 
160
181
 
161
182
  @wp.kernel
warp/jax.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import warp
9
17
 
@@ -50,6 +58,19 @@ def device_from_jax(jax_device) -> warp.context.Device:
50
58
  raise RuntimeError(f"Unsupported Jax device platform '{jax_device.platform}'")
51
59
 
52
60
 
61
+ def get_jax_device():
62
+ """Get the current Jax device."""
63
+ import jax
64
+
65
+ # TODO: is there a simpler way of getting the Jax "current" device?
66
+ # check if jax.default_device() context manager is active
67
+ device = jax.config.jax_default_device
68
+ # if default device is not set, use first device
69
+ if device is None:
70
+ device = jax.local_devices()[0]
71
+ return device
72
+
73
+
53
74
  def dtype_to_jax(warp_dtype):
54
75
  """Return the Jax dtype corresponding to a Warp dtype.
55
76
 
@@ -148,7 +169,7 @@ def to_jax(warp_array):
148
169
  """
149
170
  import jax.dlpack
150
171
 
151
- return jax.dlpack.from_dlpack(warp.to_dlpack(warp_array))
172
+ return jax.dlpack.from_dlpack(warp_array)
152
173
 
153
174
 
154
175
  def from_jax(jax_array, dtype=None) -> warp.array:
@@ -0,0 +1,16 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .custom_call import jax_kernel
@@ -1,16 +1,23 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import ctypes
9
17
 
10
- import jax
11
-
12
18
  import warp as wp
13
19
  from warp.context import type_str
20
+ from warp.jax import get_jax_device
14
21
  from warp.types import array_t, launch_bounds_t, strides_from_shape
15
22
 
16
23
  _jax_warp_p = None
@@ -21,35 +28,33 @@ _registered_kernels = [None]
21
28
  _registered_kernel_to_id = {}
22
29
 
23
30
 
24
- def jax_kernel(wp_kernel, launch_dims=None):
31
+ def jax_kernel(kernel, launch_dims=None):
25
32
  """Create a Jax primitive from a Warp kernel.
26
33
 
27
34
  NOTE: This is an experimental feature under development.
28
35
 
29
36
  Args:
30
- wp_kernel: The Warp kernel to be wrapped.
37
+ kernel: The Warp kernel to be wrapped.
31
38
  launch_dims: Optional. Specify the kernel launch dimensions. If None,
32
39
  dimensions are inferred from the shape of the first argument.
33
40
  This option when set will specify the output dimensions.
34
41
 
35
- Current limitations:
36
- - All kernel arguments must be arrays.
37
- - If launch_dims is not provided, kernel launch dimensions are inferred from the shape of the first argument.
38
- - Input arguments are followed by output arguments in the Warp kernel definition.
39
- - There must be at least one input argument and at least one output argument.
40
- - All arrays must be contiguous.
41
- - Only the CUDA backend is supported.
42
+ Limitations:
43
+ - All kernel arguments must be contiguous arrays.
44
+ - Input arguments are followed by output arguments in the Warp kernel definition.
45
+ - There must be at least one input argument and at least one output argument.
46
+ - Only the CUDA backend is supported.
42
47
  """
43
48
 
44
49
  if _jax_warp_p is None:
45
50
  # Create and register the primitive
46
51
  _create_jax_warp_primitive()
47
- if wp_kernel not in _registered_kernel_to_id:
52
+ if kernel not in _registered_kernel_to_id:
48
53
  id = len(_registered_kernels)
49
- _registered_kernels.append(wp_kernel)
50
- _registered_kernel_to_id[wp_kernel] = id
54
+ _registered_kernels.append(kernel)
55
+ _registered_kernel_to_id[kernel] = id
51
56
  else:
52
- id = _registered_kernel_to_id[wp_kernel]
57
+ id = _registered_kernel_to_id[kernel]
53
58
 
54
59
  def bind(*args):
55
60
  return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)
@@ -94,7 +99,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
94
99
  kernel_params[i + 1] = arg_ptr
95
100
 
96
101
  # Get current device.
97
- device = wp.device_from_jax(_get_jax_device())
102
+ device = wp.device_from_jax(get_jax_device())
98
103
 
99
104
  # Get kernel hooks.
100
105
  # Note: module was loaded during jit lowering.
@@ -107,16 +112,6 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
107
112
  )
108
113
 
109
114
 
110
- # TODO: is there a simpler way of getting the Jax "current" device?
111
- def _get_jax_device():
112
- # check if jax.default_device() context manager is active
113
- device = jax.config.jax_default_device
114
- # if default device is not set, use first device
115
- if device is None:
116
- device = jax.local_devices()[0]
117
- return device
118
-
119
-
120
115
  def _create_jax_warp_primitive():
121
116
  from functools import reduce
122
117
 
@@ -280,7 +275,7 @@ def _create_jax_warp_primitive():
280
275
  # TODO This may not be necessary, but it is perhaps better not to be
281
276
  # mucking with kernel loading while already running the workload.
282
277
  module = wp_kernel.module
283
- device = wp.device_from_jax(_get_jax_device())
278
+ device = wp.device_from_jax(get_jax_device())
284
279
  if not module.load(device):
285
280
  raise Exception("Could not load kernel on device")
286
281