warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/tests/test_bool.py
CHANGED
|
@@ -83,6 +83,54 @@ def test_bool_constant(test, device):
|
|
|
83
83
|
test.assertTrue(compile_constant_value.numpy()[0])
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
def test_bool_constant_vec(test, device):
|
|
87
|
+
|
|
88
|
+
vec3bool = wp.vec(length=3, dtype=wp.bool)
|
|
89
|
+
bool_selector_vec = wp.constant(vec3bool([True, False, True]))
|
|
90
|
+
|
|
91
|
+
@wp.kernel
|
|
92
|
+
def sum_from_bool_vec(sum_array: wp.array(dtype=wp.int32)):
|
|
93
|
+
i = wp.tid()
|
|
94
|
+
|
|
95
|
+
if bool_selector_vec[0]:
|
|
96
|
+
sum_array[i] = sum_array[i] + 1
|
|
97
|
+
if bool_selector_vec[1]:
|
|
98
|
+
sum_array[i] = sum_array[i] + 2
|
|
99
|
+
if bool_selector_vec[2]:
|
|
100
|
+
sum_array[i] = sum_array[i] + 4
|
|
101
|
+
|
|
102
|
+
result_array = wp.zeros(10, dtype=wp.int32, device=device)
|
|
103
|
+
|
|
104
|
+
wp.launch(sum_from_bool_vec, result_array.shape, inputs=[result_array], device=device)
|
|
105
|
+
|
|
106
|
+
assert_np_equal(result_array.numpy(), np.full(result_array.shape, 5))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_bool_constant_mat(test, device):
|
|
110
|
+
|
|
111
|
+
mat22bool = wp.mat((2, 2), dtype=wp.bool)
|
|
112
|
+
bool_selector_mat = wp.constant(mat22bool([True, False, False, True]))
|
|
113
|
+
|
|
114
|
+
@wp.kernel
|
|
115
|
+
def sum_from_bool_mat(sum_array: wp.array(dtype=wp.int32)):
|
|
116
|
+
i = wp.tid()
|
|
117
|
+
|
|
118
|
+
if bool_selector_mat[0, 0]:
|
|
119
|
+
sum_array[i] = sum_array[i] + 1
|
|
120
|
+
if bool_selector_mat[0, 1]:
|
|
121
|
+
sum_array[i] = sum_array[i] + 2
|
|
122
|
+
if bool_selector_mat[1, 0]:
|
|
123
|
+
sum_array[i] = sum_array[i] + 4
|
|
124
|
+
if bool_selector_mat[1, 1]:
|
|
125
|
+
sum_array[i] = sum_array[i] + 8
|
|
126
|
+
|
|
127
|
+
result_array = wp.zeros(10, dtype=wp.int32, device=device)
|
|
128
|
+
|
|
129
|
+
wp.launch(sum_from_bool_mat, result_array.shape, inputs=[result_array], device=device)
|
|
130
|
+
|
|
131
|
+
assert_np_equal(result_array.numpy(), np.full(result_array.shape, 9))
|
|
132
|
+
|
|
133
|
+
|
|
86
134
|
devices = get_test_devices()
|
|
87
135
|
|
|
88
136
|
|
|
@@ -92,6 +140,8 @@ class TestBool(unittest.TestCase):
|
|
|
92
140
|
|
|
93
141
|
add_function_test(TestBool, "test_bool_identity_ops", test_bool_identity_ops, devices=devices)
|
|
94
142
|
add_function_test(TestBool, "test_bool_constant", test_bool_constant, devices=devices)
|
|
143
|
+
add_function_test(TestBool, "test_bool_constant_vec", test_bool_constant_vec, devices=devices)
|
|
144
|
+
add_function_test(TestBool, "test_bool_constant_mat", test_bool_constant_mat, devices=devices)
|
|
95
145
|
|
|
96
146
|
|
|
97
147
|
if __name__ == "__main__":
|
warp/tests/test_dlpack.py
CHANGED
|
@@ -14,9 +14,19 @@ import numpy as np
|
|
|
14
14
|
import warp as wp
|
|
15
15
|
from warp.tests.unittest_utils import *
|
|
16
16
|
|
|
17
|
+
N = 1024 * 1024
|
|
18
|
+
|
|
17
19
|
wp.init()
|
|
18
20
|
|
|
19
21
|
|
|
22
|
+
def _jax_version():
|
|
23
|
+
try:
|
|
24
|
+
import jax
|
|
25
|
+
return jax.__version_info__
|
|
26
|
+
except ImportError:
|
|
27
|
+
return (0, 0, 0)
|
|
28
|
+
|
|
29
|
+
|
|
20
30
|
@wp.kernel
|
|
21
31
|
def inc(a: wp.array(dtype=float)):
|
|
22
32
|
tid = wp.tid()
|
|
@@ -24,7 +34,7 @@ def inc(a: wp.array(dtype=float)):
|
|
|
24
34
|
|
|
25
35
|
|
|
26
36
|
def test_dlpack_warp_to_warp(test, device):
|
|
27
|
-
a1 = wp.array(data=np.arange(
|
|
37
|
+
a1 = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
28
38
|
|
|
29
39
|
a2 = wp.from_dlpack(wp.to_dlpack(a1))
|
|
30
40
|
|
|
@@ -44,7 +54,7 @@ def test_dlpack_warp_to_warp(test, device):
|
|
|
44
54
|
def test_dlpack_dtypes_and_shapes(test, device):
|
|
45
55
|
# automatically determine scalar dtype
|
|
46
56
|
def wrap_scalar_tensor_implicit(dtype):
|
|
47
|
-
a1 = wp.zeros(
|
|
57
|
+
a1 = wp.zeros(N, dtype=dtype, device=device)
|
|
48
58
|
a2 = wp.from_dlpack(wp.to_dlpack(a1))
|
|
49
59
|
|
|
50
60
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -55,7 +65,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
55
65
|
|
|
56
66
|
# explicitly specify scalar dtype
|
|
57
67
|
def wrap_scalar_tensor_explicit(dtype, target_dtype):
|
|
58
|
-
a1 = wp.zeros(
|
|
68
|
+
a1 = wp.zeros(N, dtype=dtype, device=device)
|
|
59
69
|
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=target_dtype)
|
|
60
70
|
|
|
61
71
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -70,7 +80,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
70
80
|
scalar_type = vec_dtype._wp_scalar_type_
|
|
71
81
|
scalar_size = ctypes.sizeof(vec_dtype._type_)
|
|
72
82
|
|
|
73
|
-
a1 = wp.zeros(
|
|
83
|
+
a1 = wp.zeros(N, dtype=vec_dtype, device=device)
|
|
74
84
|
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
|
|
75
85
|
|
|
76
86
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -86,7 +96,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
86
96
|
scalar_type = vec_dtype._wp_scalar_type_
|
|
87
97
|
scalar_size = ctypes.sizeof(vec_dtype._type_)
|
|
88
98
|
|
|
89
|
-
a1 = wp.zeros((
|
|
99
|
+
a1 = wp.zeros((N, vec_dtype._length_), dtype=scalar_type, device=device)
|
|
90
100
|
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=vec_dtype)
|
|
91
101
|
|
|
92
102
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -102,7 +112,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
102
112
|
scalar_type = mat_dtype._wp_scalar_type_
|
|
103
113
|
scalar_size = ctypes.sizeof(mat_dtype._type_)
|
|
104
114
|
|
|
105
|
-
a1 = wp.zeros(
|
|
115
|
+
a1 = wp.zeros(N, dtype=mat_dtype, device=device)
|
|
106
116
|
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=scalar_type)
|
|
107
117
|
|
|
108
118
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -118,7 +128,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
118
128
|
scalar_type = mat_dtype._wp_scalar_type_
|
|
119
129
|
scalar_size = ctypes.sizeof(mat_dtype._type_)
|
|
120
130
|
|
|
121
|
-
a1 = wp.zeros((
|
|
131
|
+
a1 = wp.zeros((N, *mat_dtype._shape_), dtype=scalar_type, device=device)
|
|
122
132
|
a2 = wp.from_dlpack(wp.to_dlpack(a1), dtype=mat_dtype)
|
|
123
133
|
|
|
124
134
|
test.assertEqual(a1.ptr, a2.ptr)
|
|
@@ -182,7 +192,7 @@ def test_dlpack_dtypes_and_shapes(test, device):
|
|
|
182
192
|
def test_dlpack_warp_to_torch(test, device):
|
|
183
193
|
import torch.utils.dlpack
|
|
184
194
|
|
|
185
|
-
a = wp.array(data=np.arange(
|
|
195
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
186
196
|
|
|
187
197
|
t = torch.utils.dlpack.from_dlpack(wp.to_dlpack(a))
|
|
188
198
|
|
|
@@ -205,11 +215,40 @@ def test_dlpack_warp_to_torch(test, device):
|
|
|
205
215
|
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
206
216
|
|
|
207
217
|
|
|
218
|
+
def test_dlpack_warp_to_torch_v2(test, device):
|
|
219
|
+
# same as original test, but uses newer __dlpack__() method
|
|
220
|
+
|
|
221
|
+
import torch.utils.dlpack
|
|
222
|
+
|
|
223
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
224
|
+
|
|
225
|
+
# pass the array directly
|
|
226
|
+
t = torch.utils.dlpack.from_dlpack(a)
|
|
227
|
+
|
|
228
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
229
|
+
|
|
230
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
231
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
232
|
+
test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
|
|
233
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
234
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
235
|
+
|
|
236
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
237
|
+
|
|
238
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
239
|
+
|
|
240
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
241
|
+
|
|
242
|
+
t += 1
|
|
243
|
+
|
|
244
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
245
|
+
|
|
246
|
+
|
|
208
247
|
def test_dlpack_torch_to_warp(test, device):
|
|
209
248
|
import torch
|
|
210
249
|
import torch.utils.dlpack
|
|
211
250
|
|
|
212
|
-
t = torch.arange(
|
|
251
|
+
t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
|
|
213
252
|
|
|
214
253
|
a = wp.from_dlpack(torch.utils.dlpack.to_dlpack(t))
|
|
215
254
|
|
|
@@ -232,11 +271,40 @@ def test_dlpack_torch_to_warp(test, device):
|
|
|
232
271
|
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
233
272
|
|
|
234
273
|
|
|
274
|
+
def test_dlpack_torch_to_warp_v2(test, device):
|
|
275
|
+
# same as original test, but uses newer __dlpack__() method
|
|
276
|
+
|
|
277
|
+
import torch
|
|
278
|
+
|
|
279
|
+
t = torch.arange(N, dtype=torch.float32, device=wp.device_to_torch(device))
|
|
280
|
+
|
|
281
|
+
# pass tensor directly
|
|
282
|
+
a = wp.from_dlpack(t)
|
|
283
|
+
|
|
284
|
+
item_size = wp.types.type_size_in_bytes(a.dtype)
|
|
285
|
+
|
|
286
|
+
test.assertEqual(a.ptr, t.data_ptr())
|
|
287
|
+
test.assertEqual(a.device, wp.device_from_torch(t.device))
|
|
288
|
+
test.assertEqual(a.dtype, wp.torch.dtype_from_torch(t.dtype))
|
|
289
|
+
test.assertEqual(a.shape, tuple(t.shape))
|
|
290
|
+
test.assertEqual(a.strides, tuple(s * item_size for s in t.stride()))
|
|
291
|
+
|
|
292
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
293
|
+
|
|
294
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
295
|
+
|
|
296
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
297
|
+
|
|
298
|
+
t += 1
|
|
299
|
+
|
|
300
|
+
assert_np_equal(a.numpy(), t.cpu().numpy())
|
|
301
|
+
|
|
302
|
+
|
|
235
303
|
def test_dlpack_warp_to_jax(test, device):
|
|
236
304
|
import jax
|
|
237
305
|
import jax.dlpack
|
|
238
306
|
|
|
239
|
-
a = wp.array(data=np.arange(
|
|
307
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
240
308
|
|
|
241
309
|
# use generic dlpack conversion
|
|
242
310
|
j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a))
|
|
@@ -266,12 +334,49 @@ def test_dlpack_warp_to_jax(test, device):
|
|
|
266
334
|
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
267
335
|
|
|
268
336
|
|
|
337
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
|
|
338
|
+
def test_dlpack_warp_to_jax_v2(test, device):
|
|
339
|
+
# same as original test, but uses newer __dlpack__() method
|
|
340
|
+
|
|
341
|
+
import jax
|
|
342
|
+
import jax.dlpack
|
|
343
|
+
|
|
344
|
+
a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
|
|
345
|
+
|
|
346
|
+
# pass warp array directly
|
|
347
|
+
j1 = jax.dlpack.from_dlpack(a)
|
|
348
|
+
|
|
349
|
+
# use jax wrapper
|
|
350
|
+
j2 = wp.to_jax(a)
|
|
351
|
+
|
|
352
|
+
test.assertEqual(a.ptr, j1.unsafe_buffer_pointer())
|
|
353
|
+
test.assertEqual(a.ptr, j2.unsafe_buffer_pointer())
|
|
354
|
+
test.assertEqual(a.device, wp.device_from_jax(j1.device()))
|
|
355
|
+
test.assertEqual(a.device, wp.device_from_jax(j2.device()))
|
|
356
|
+
test.assertEqual(a.shape, j1.shape)
|
|
357
|
+
test.assertEqual(a.shape, j2.shape)
|
|
358
|
+
|
|
359
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
360
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
361
|
+
|
|
362
|
+
wp.launch(inc, dim=a.size, inputs=[a], device=device)
|
|
363
|
+
wp.synchronize_device(device)
|
|
364
|
+
|
|
365
|
+
# HACK? Run a no-op operation so that Jax flags the arrays as dirty
|
|
366
|
+
# and gets the latest values, which were modified by Warp.
|
|
367
|
+
j1 += 0
|
|
368
|
+
j2 += 0
|
|
369
|
+
|
|
370
|
+
assert_np_equal(a.numpy(), np.asarray(j1))
|
|
371
|
+
assert_np_equal(a.numpy(), np.asarray(j2))
|
|
372
|
+
|
|
373
|
+
|
|
269
374
|
def test_dlpack_jax_to_warp(test, device):
|
|
270
375
|
import jax
|
|
271
376
|
import jax.dlpack
|
|
272
377
|
|
|
273
378
|
with jax.default_device(wp.device_to_jax(device)):
|
|
274
|
-
j = jax.numpy.arange(
|
|
379
|
+
j = jax.numpy.arange(N, dtype=jax.numpy.float32)
|
|
275
380
|
|
|
276
381
|
# use generic dlpack conversion
|
|
277
382
|
a1 = wp.from_dlpack(jax.dlpack.to_dlpack(j))
|
|
@@ -300,6 +405,42 @@ def test_dlpack_jax_to_warp(test, device):
|
|
|
300
405
|
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
301
406
|
|
|
302
407
|
|
|
408
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
|
|
409
|
+
def test_dlpack_jax_to_warp_v2(test, device):
|
|
410
|
+
# same as original test, but uses newer __dlpack__() method
|
|
411
|
+
|
|
412
|
+
import jax
|
|
413
|
+
|
|
414
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
415
|
+
j = jax.numpy.arange(N, dtype=jax.numpy.float32)
|
|
416
|
+
|
|
417
|
+
# pass jax array directly
|
|
418
|
+
a1 = wp.from_dlpack(j)
|
|
419
|
+
|
|
420
|
+
# use jax wrapper
|
|
421
|
+
a2 = wp.from_jax(j)
|
|
422
|
+
|
|
423
|
+
test.assertEqual(a1.ptr, j.unsafe_buffer_pointer())
|
|
424
|
+
test.assertEqual(a2.ptr, j.unsafe_buffer_pointer())
|
|
425
|
+
test.assertEqual(a1.device, wp.device_from_jax(j.device()))
|
|
426
|
+
test.assertEqual(a2.device, wp.device_from_jax(j.device()))
|
|
427
|
+
test.assertEqual(a1.shape, j.shape)
|
|
428
|
+
test.assertEqual(a2.shape, j.shape)
|
|
429
|
+
|
|
430
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
431
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
432
|
+
|
|
433
|
+
wp.launch(inc, dim=a1.size, inputs=[a1], device=device)
|
|
434
|
+
wp.synchronize_device(device)
|
|
435
|
+
|
|
436
|
+
# HACK? Run a no-op operation so that Jax flags the array as dirty
|
|
437
|
+
# and gets the latest values, which were modified by Warp.
|
|
438
|
+
j += 0
|
|
439
|
+
|
|
440
|
+
assert_np_equal(a1.numpy(), np.asarray(j))
|
|
441
|
+
assert_np_equal(a2.numpy(), np.asarray(j))
|
|
442
|
+
|
|
443
|
+
|
|
303
444
|
class TestDLPack(unittest.TestCase):
|
|
304
445
|
pass
|
|
305
446
|
|
|
@@ -330,9 +471,15 @@ try:
|
|
|
330
471
|
add_function_test(
|
|
331
472
|
TestDLPack, "test_dlpack_warp_to_torch", test_dlpack_warp_to_torch, devices=torch_compatible_devices
|
|
332
473
|
)
|
|
474
|
+
add_function_test(
|
|
475
|
+
TestDLPack, "test_dlpack_warp_to_torch_v2", test_dlpack_warp_to_torch_v2, devices=torch_compatible_devices
|
|
476
|
+
)
|
|
333
477
|
add_function_test(
|
|
334
478
|
TestDLPack, "test_dlpack_torch_to_warp", test_dlpack_torch_to_warp, devices=torch_compatible_devices
|
|
335
479
|
)
|
|
480
|
+
add_function_test(
|
|
481
|
+
TestDLPack, "test_dlpack_torch_to_warp_v2", test_dlpack_torch_to_warp_v2, devices=torch_compatible_devices
|
|
482
|
+
)
|
|
336
483
|
|
|
337
484
|
except Exception as e:
|
|
338
485
|
print(f"Skipping Torch DLPack tests due to exception: {e}")
|
|
@@ -363,9 +510,15 @@ try:
|
|
|
363
510
|
add_function_test(
|
|
364
511
|
TestDLPack, "test_dlpack_warp_to_jax", test_dlpack_warp_to_jax, devices=jax_compatible_devices
|
|
365
512
|
)
|
|
513
|
+
add_function_test(
|
|
514
|
+
TestDLPack, "test_dlpack_warp_to_jax_v2", test_dlpack_warp_to_jax_v2, devices=jax_compatible_devices
|
|
515
|
+
)
|
|
366
516
|
add_function_test(
|
|
367
517
|
TestDLPack, "test_dlpack_jax_to_warp", test_dlpack_jax_to_warp, devices=jax_compatible_devices
|
|
368
518
|
)
|
|
519
|
+
add_function_test(
|
|
520
|
+
TestDLPack, "test_dlpack_jax_to_warp_v2", test_dlpack_jax_to_warp_v2, devices=jax_compatible_devices
|
|
521
|
+
)
|
|
369
522
|
|
|
370
523
|
except Exception as e:
|
|
371
524
|
print(f"Skipping Jax DLPack tests due to exception: {e}")
|