warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/tests/test_bool.py CHANGED
@@ -83,6 +83,54 @@ def test_bool_constant(test, device):
83
83
  test.assertTrue(compile_constant_value.numpy()[0])
84
84
 
85
85
 
86
+ def test_bool_constant_vec(test, device):
87
+
88
+ vec3bool = wp.vec(length=3, dtype=wp.bool)
89
+ bool_selector_vec = wp.constant(vec3bool([True, False, True]))
90
+
91
+ @wp.kernel
92
+ def sum_from_bool_vec(sum_array: wp.array(dtype=wp.int32)):
93
+ i = wp.tid()
94
+
95
+ if bool_selector_vec[0]:
96
+ sum_array[i] = sum_array[i] + 1
97
+ if bool_selector_vec[1]:
98
+ sum_array[i] = sum_array[i] + 2
99
+ if bool_selector_vec[2]:
100
+ sum_array[i] = sum_array[i] + 4
101
+
102
+ result_array = wp.zeros(10, dtype=wp.int32, device=device)
103
+
104
+ wp.launch(sum_from_bool_vec, result_array.shape, inputs=[result_array], device=device)
105
+
106
+ assert_np_equal(result_array.numpy(), np.full(result_array.shape, 5))
107
+
108
+
109
+ def test_bool_constant_mat(test, device):
110
+
111
+ mat22bool = wp.mat((2, 2), dtype=wp.bool)
112
+ bool_selector_mat = wp.constant(mat22bool([True, False, False, True]))
113
+
114
+ @wp.kernel
115
+ def sum_from_bool_mat(sum_array: wp.array(dtype=wp.int32)):
116
+ i = wp.tid()
117
+
118
+ if bool_selector_mat[0, 0]:
119
+ sum_array[i] = sum_array[i] + 1
120
+ if bool_selector_mat[0, 1]:
121
+ sum_array[i] = sum_array[i] + 2
122
+ if bool_selector_mat[1, 0]:
123
+ sum_array[i] = sum_array[i] + 4
124
+ if bool_selector_mat[1, 1]:
125
+ sum_array[i] = sum_array[i] + 8
126
+
127
+ result_array = wp.zeros(10, dtype=wp.int32, device=device)
128
+
129
+ wp.launch(sum_from_bool_mat, result_array.shape, inputs=[result_array], device=device)
130
+
131
+ assert_np_equal(result_array.numpy(), np.full(result_array.shape, 9))
132
+
133
+
86
134
  devices = get_test_devices()
87
135
 
88
136
 
@@ -92,6 +140,8 @@ class TestBool(unittest.TestCase):
92
140
 
93
141
  add_function_test(TestBool, "test_bool_identity_ops", test_bool_identity_ops, devices=devices)
94
142
  add_function_test(TestBool, "test_bool_constant", test_bool_constant, devices=devices)
143
+ add_function_test(TestBool, "test_bool_constant_vec", test_bool_constant_vec, devices=devices)
144
+ add_function_test(TestBool, "test_bool_constant_mat", test_bool_constant_mat, devices=devices)
95
145
 
96
146
 
97
147
  if __name__ == "__main__":
warp/tests/test_dlpack.py CHANGED
@@ -14,9 +14,19 @@ import numpy as np
14
14
  import warp as wp
15
15
  from warp.tests.unittest_utils import *
16
16
 
17
+ N = 1024 * 1024
18
+
17
19
  wp.init()
18
20
 
19
21
 
22
+ def _jax_version():
23
+ try:
24
+ import jax
25
+ return jax.__version_info__
26
+ except ImportError:
27
+ return (0, 0, 0)
28
+
29
+
20
30
  @wp.kernel
21
31
  def inc(a: wp.array(dtype=float)):
22
32
  tid = wp.tid()
@@ -24,7 +34,7 @@ def inc(a: wp.array(dtype=float)):
24
34
 
25
35
 
26
36
  def test_dlpack_warp_to_warp(test, device):
27
- a1 = wp.array(data=np.arange(10, dtype=np.float32), device=device)
37
+ a1 = wp.array(data=np.arange(N, dtype=np.float32), device=device)
28
38
 
29
39
  a2 = wp.from_dlpack(wp.to_dlpack(a1))
30
40
 
@@ -44,7 +54,7 @@ def test_dlpack_warp_to_warp(test, device):
44
54
  def test_dlpack_dtypes_and_shapes(test, device):
45
55
  # automatically determine scalar dtype
46
56
  def wrap_scalar_tensor_implicit(dtype):
47
- a1 = wp.zeros(10, dtype=dtype, device=device)
57
+ a1 = wp.zeros(N, dtype=dtype, device=device)
48
58
  a2 = wp.from_dlpack(wp.to_dlpack(a1))
49
59
 
50
60
  test.assertEqual(a1.ptr, a2.ptr)
@@ -55,7 +65,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
55
65
 
56
66
  # explicitly specify scalar dtype
57
67
  def wrap_scalar_tensor_explicit(dtype, target_dtype):
58
- a1 = wp.zeros(10, dtype=dtype, device=device)
68
+ a1 = wp.zeros(N, dtype=dtype, device=device)
59
69
  a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
60
70
 
61
71
  test.assertEqual(a1.ptr, a2.ptr)
@@ -70,7 +80,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
70
80
  scalar_type = vec_dtype._wp_scalar_type_
71
81
  scalar_size = ctypes.sizeof(vec_dtype._type_)
72
82
 
73
- a1 = wp.zeros(10, dtype=vec_dtype, device=device)
83
+ a1 = wp.zeros(N, dtype=vec_dtype, device=device)
74
84
  a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
75
85
 
76
86
  test.assertEqual(a1.ptr, a2.ptr)
@@ -86,7 +96,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
86
96
  scalar_type = vec_dtype._wp_scalar_type_
87
97
  scalar_size = ctypes.sizeof(vec_dtype._type_)
88
98
 
89
- a1 = wp.zeros((10, vec_dtype._length_), dtype=scalar_type, device=device)
99
+ a1 = wp.zeros((N, vec_dtype._length_), dtype=scalar_type, device=device)
90
100
  a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
91
101
 
92
102
  test.assertEqual(a1.ptr, a2.ptr)
@@ -102,7 +112,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
102
112
  scalar_type = mat_dtype._wp_scalar_type_
103
113
  scalar_size = ctypes.sizeof(mat_dtype._type_)
104
114
 
105
- a1 = wp.zeros(10, dtype=mat_dtype, device=device)
115
+ a1 = wp.zeros(N, dtype=mat_dtype, device=device)
106
116
  a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
107
117
 
108
118
  test.assertEqual(a1.ptr, a2.ptr)
@@ -118,7 +128,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
118
128
  scalar_type = mat_dtype._wp_scalar_type_
119
129
  scalar_size = ctypes.sizeof(mat_dtype._type_)
120
130
 
121
- a1 = wp.zeros((10, *mat_dtype._shape_), dtype=scalar_type, device=device)
131
+ a1 = wp.zeros((N, *mat_dtype._shape_), dtype=scalar_type, device=device)
122
132
  a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
123
133
 
124
134
  test.assertEqual(a1.ptr, a2.ptr)
@@ -182,7 +192,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
182
192
  def test_dlpack_warp_to_torch(test, device):
183
193
  import torch.utils.dlpack
184
194
 
185
- a = wp.array(data=np.arange(10, dtype=np.float32), device=device)
195
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
186
196
 
187
197
  t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
188
198
 
@@ -205,11 +215,40 @@ def test_dlpack_warp_to_torch(test, device):
205
215
  assert_np_equal(a.numpy(), t.cpu().numpy())
206
216
 
207
217
 
218
+ def test_dlpack_warp_to_torch_v2(test, device):
219
+ # same as original test, but uses newer __dlpack__() method
220
+
221
+ import torch.utils.dlpack
222
+
223
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
224
+
225
+ # pass the array directly
226
+ t = torch.utils.dlpack.from_dlpack(a)
227
+
228
+ item_size = wp.types.type_size_in_bytes(a.dtype)
229
+
230
+ test.assertEqual(a.ptr, t.data_ptr())
231
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
232
+ test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
233
+ test.assertEqual(a.shape, tuple(t.shape))
234
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
235
+
236
+ assert_np_equal(a.numpy(), t.cpu().numpy())
237
+
238
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
239
+
240
+ assert_np_equal(a.numpy(), t.cpu().numpy())
241
+
242
+ t += 1
243
+
244
+ assert_np_equal(a.numpy(), t.cpu().numpy())
245
+
246
+
208
247
  def test_dlpack_torch_to_warp(test, device):
209
248
  import torch
210
249
  import torch.utils.dlpack
211
250
 
212
- t = torch.arange(10, dtype=torch.float32, device=wp.device_to_torch(device))
251
+ t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
213
252
 
214
253
  a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
215
254
 
@@ -232,11 +271,40 @@ def test_dlpack_torch_to_warp(test, device):
232
271
  assert_np_equal(a.numpy(), t.cpu().numpy())
233
272
 
234
273
 
274
+ def test_dlpack_torch_to_warp_v2(test, device):
275
+ # same as original test, but uses newer __dlpack__() method
276
+
277
+ import torch
278
+
279
+ t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
280
+
281
+ # pass tensor directly
282
+ a = wp.from_dlpack(t)
283
+
284
+ item_size = wp.types.type_size_in_bytes(a.dtype)
285
+
286
+ test.assertEqual(a.ptr, t.data_ptr())
287
+ test.assertEqual(a.device, wp.device_from_torch(t.device))
288
+ test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
289
+ test.assertEqual(a.shape, tuple(t.shape))
290
+ test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
291
+
292
+ assert_np_equal(a.numpy(), t.cpu().numpy())
293
+
294
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
295
+
296
+ assert_np_equal(a.numpy(), t.cpu().numpy())
297
+
298
+ t += 1
299
+
300
+ assert_np_equal(a.numpy(), t.cpu().numpy())
301
+
302
+
235
303
  def test_dlpack_warp_to_jax(test, device):
236
304
  import jax
237
305
  import jax.dlpack
238
306
 
239
- a = wp.array(data=np.arange(10, dtype=np.float32), device=device)
307
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
240
308
 
241
309
  # use generic dlpack conversion
242
310
  j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a))
@@ -266,12 +334,49 @@ def test_dlpack_warp_to_jax(test, device):
266
334
  assert_np_equal(a.numpy(), np.asarray(j2))
267
335
 
268
336
 
337
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
338
+ def test_dlpack_warp_to_jax_v2(test, device):
339
+ # same as original test, but uses newer __dlpack__() method
340
+
341
+ import jax
342
+ import jax.dlpack
343
+
344
+ a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
345
+
346
+ # pass warp array directly
347
+ j1 = jax.dlpack.from_dlpack(a)
348
+
349
+ # use jax wrapper
350
+ j2 = wp.to_jax(a)
351
+
352
+ test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
353
+ test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
354
+ test.assertEqual(a.device, wp.device_from_jax(j1.device()))
355
+ test.assertEqual(a.device, wp.device_from_jax(j2.device()))
356
+ test.assertEqual(a.shape, j1.shape)
357
+ test.assertEqual(a.shape, j2.shape)
358
+
359
+ assert_np_equal(a.numpy(), np.asarray(j1))
360
+ assert_np_equal(a.numpy(), np.asarray(j2))
361
+
362
+ wp.launch(inc, dim=a.size, inputs=[a], device=device)
363
+ wp.synchronize_device(device)
364
+
365
+ # HACK? Run a no-op operation so that Jax flags the arrays as dirty
366
+ # and gets the latest values, which were modified by Warp.
367
+ j1 += 0
368
+ j2 += 0
369
+
370
+ assert_np_equal(a.numpy(), np.asarray(j1))
371
+ assert_np_equal(a.numpy(), np.asarray(j2))
372
+
373
+
269
374
  def test_dlpack_jax_to_warp(test, device):
270
375
  import jax
271
376
  import jax.dlpack
272
377
 
273
378
  with jax.default_device(wp.device_to_jax(device)):
274
- j = jax.numpy.arange(10, dtype=jax.numpy.float32)
379
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
275
380
 
276
381
  # use generic dlpack conversion
277
382
  a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j))
@@ -300,6 +405,42 @@ def test_dlpack_jax_to_warp(test, device):
300
405
  assert_np_equal(a2.numpy(), np.asarray(j))
301
406
 
302
407
 
408
+ @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
409
+ def test_dlpack_jax_to_warp_v2(test, device):
410
+ # same as original test, but uses newer __dlpack__() method
411
+
412
+ import jax
413
+
414
+ with jax.default_device(wp.device_to_jax(device)):
415
+ j = jax.numpy.arange(N, dtype=jax.numpy.float32)
416
+
417
+ # pass jax array directly
418
+ a1 = wp.from_dlpack(j)
419
+
420
+ # use jax wrapper
421
+ a2 = wp.from_jax(j)
422
+
423
+ test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
424
+ test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
425
+ test.assertEqual(a1.device, wp.device_from_jax(j.device()))
426
+ test.assertEqual(a2.device, wp.device_from_jax(j.device()))
427
+ test.assertEqual(a1.shape, j.shape)
428
+ test.assertEqual(a2.shape, j.shape)
429
+
430
+ assert_np_equal(a1.numpy(), np.asarray(j))
431
+ assert_np_equal(a2.numpy(), np.asarray(j))
432
+
433
+ wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
434
+ wp.synchronize_device(device)
435
+
436
+ # HACK? Run a no-op operation so that Jax flags the array as dirty
437
+ # and gets the latest values, which were modified by Warp.
438
+ j += 0
439
+
440
+ assert_np_equal(a1.numpy(), np.asarray(j))
441
+ assert_np_equal(a2.numpy(), np.asarray(j))
442
+
443
+
303
444
  class TestDLPack(unittest.TestCase):
304
445
  pass
305
446
 
@@ -330,9 +471,15 @@ try:
330
471
  add_function_test(
331
472
  TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
332
473
  )
474
+ add_function_test(
475
+ TestDLPack, "test_dlpack_warp_to_torch_v2", test_dlpack_warp_to_torch_v2, devices=torch_compatible_devices
476
+ )
333
477
  add_function_test(
334
478
  TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
335
479
  )
480
+ add_function_test(
481
+ TestDLPack, "test_dlpack_torch_to_warp_v2", test_dlpack_torch_to_warp_v2, devices=torch_compatible_devices
482
+ )
336
483
 
337
484
  except Exception as e:
338
485
  print(f"Skipping Torch DLPack tests due to exception: {e}")
@@ -363,9 +510,15 @@ try:
363
510
  add_function_test(
364
511
  TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
365
512
  )
513
+ add_function_test(
514
+ TestDLPack, "test_dlpack_warp_to_jax_v2", test_dlpack_warp_to_jax_v2, devices=jax_compatible_devices
515
+ )
366
516
  add_function_test(
367
517
  TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
368
518
  )
519
+ add_function_test(
520
+ TestDLPack, "test_dlpack_jax_to_warp_v2", test_dlpack_jax_to_warp_v2, devices=jax_compatible_devices
521
+ )
369
522
 
370
523
  except Exception as e:
371
524
  print(f"Skipping Jax DLPack tests due to exception: {e}")