warp-lang 0.15.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +1 -0
- warp/codegen.py +7 -3
- warp/config.py +2 -1
- warp/constants.py +3 -0
- warp/context.py +44 -21
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/core/example_dem.py +6 -6
- warp/examples/core/example_fluid.py +3 -3
- warp/examples/core/example_graph_capture.py +3 -6
- warp/examples/optim/example_bounce.py +9 -8
- warp/examples/optim/example_cloth_throw.py +12 -8
- warp/examples/optim/example_diffray.py +10 -12
- warp/examples/optim/example_drone.py +31 -14
- warp/examples/optim/example_spring_cage.py +10 -15
- warp/examples/optim/example_trajectory.py +7 -24
- warp/examples/sim/example_cartpole.py +3 -9
- warp/examples/sim/example_cloth.py +10 -10
- warp/examples/sim/example_granular.py +3 -3
- warp/examples/sim/example_granular_collision_sdf.py +9 -4
- warp/examples/sim/example_jacobian_ik.py +0 -10
- warp/examples/sim/example_particle_chain.py +4 -4
- warp/examples/sim/example_quadruped.py +15 -11
- warp/examples/sim/example_rigid_chain.py +13 -8
- warp/examples/sim/example_rigid_contact.py +4 -4
- warp/examples/sim/example_rigid_force.py +7 -7
- warp/examples/sim/example_rigid_soft_contact.py +4 -4
- warp/examples/sim/example_soft_body.py +3 -3
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/render/render_opengl.py +188 -95
- warp/render/render_usd.py +34 -10
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +741 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +18 -2
- warp/sim/utils.py +222 -11
- warp/stubs.py +1 -0
- warp/tape.py +6 -9
- warp/tests/test_examples.py +87 -20
- warp/tests/test_grad_customs.py +122 -0
- warp/tests/test_jax.py +254 -0
- warp/tests/test_options.py +13 -53
- warp/tests/test_quat.py +23 -0
- warp/tests/test_snippet.py +2 -0
- warp/tests/test_utils.py +31 -26
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/unittest_suites.py +4 -0
- warp/utils.py +50 -1
- {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/METADATA +1 -1
- {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +73 -64
- warp/examples/env/__init__.py +0 -0
- warp/examples/env/env_ant.py +0 -61
- warp/examples/env/env_cartpole.py +0 -63
- warp/examples/env/env_humanoid.py +0 -65
- warp/examples/env/env_usd.py +0 -97
- warp/examples/env/environment.py +0 -526
- warp/sim/optimizer.py +0 -138
- {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/tests/test_grad_customs.py
CHANGED
|
@@ -22,6 +22,7 @@ wp.init()
|
|
|
22
22
|
def reversible_increment(
|
|
23
23
|
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
24
24
|
):
|
|
25
|
+
"""This is a docstring"""
|
|
25
26
|
next_index = wp.atomic_add(counter, counter_index, value)
|
|
26
27
|
thread_values[tid] = next_index
|
|
27
28
|
return next_index
|
|
@@ -31,6 +32,7 @@ def reversible_increment(
|
|
|
31
32
|
def replay_reversible_increment(
|
|
32
33
|
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
33
34
|
):
|
|
35
|
+
"""This is a docstring"""
|
|
34
36
|
return thread_values[tid]
|
|
35
37
|
|
|
36
38
|
|
|
@@ -64,28 +66,33 @@ def test_custom_replay_grad(test, device):
|
|
|
64
66
|
|
|
65
67
|
@wp.func
|
|
66
68
|
def overload_fn(x: float, y: float):
|
|
69
|
+
"""This is a docstring"""
|
|
67
70
|
return x * 3.0 + y / 3.0, y**2.5
|
|
68
71
|
|
|
69
72
|
|
|
70
73
|
@wp.func_grad(overload_fn)
|
|
71
74
|
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
|
|
75
|
+
"""This is a docstring"""
|
|
72
76
|
wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
|
|
73
77
|
wp.adjoint[y] += y * adj_ret1 * 3.0
|
|
74
78
|
|
|
75
79
|
|
|
76
80
|
@wp.struct
|
|
77
81
|
class MyStruct:
|
|
82
|
+
"""This is a docstring"""
|
|
78
83
|
scalar: float
|
|
79
84
|
vec: wp.vec3
|
|
80
85
|
|
|
81
86
|
|
|
82
87
|
@wp.func
|
|
83
88
|
def overload_fn(x: MyStruct):
|
|
89
|
+
"""This is a docstring"""
|
|
84
90
|
return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
|
|
85
91
|
|
|
86
92
|
|
|
87
93
|
@wp.func_grad(overload_fn)
|
|
88
94
|
def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
|
|
95
|
+
"""This is a docstring"""
|
|
89
96
|
wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
|
|
90
97
|
wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
|
|
91
98
|
wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
|
|
@@ -96,6 +103,7 @@ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: fl
|
|
|
96
103
|
def run_overload_float_fn(
|
|
97
104
|
xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
|
|
98
105
|
):
|
|
106
|
+
"""This is a docstring"""
|
|
99
107
|
i = wp.tid()
|
|
100
108
|
out0, out1 = overload_fn(xs[i], ys[i])
|
|
101
109
|
output0[i] = out0
|
|
@@ -197,6 +205,118 @@ def test_custom_import_grad(test, device):
|
|
|
197
205
|
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
|
|
198
206
|
|
|
199
207
|
|
|
208
|
+
@wp.func
|
|
209
|
+
def sigmoid(x: float):
|
|
210
|
+
return 1.0 / (1.0 + wp.exp(-x))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@wp.func_grad(sigmoid)
|
|
214
|
+
def adj_sigmoid(x: float, adj: float):
|
|
215
|
+
# unused function to test that we don't run into infinite recursion when calling
|
|
216
|
+
# the forward function from within the gradient function
|
|
217
|
+
wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@wp.func
|
|
221
|
+
def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
222
|
+
# test function that does not return anything
|
|
223
|
+
ys[i] = sigmoid(xs[i])
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@wp.func_grad(sigmoid_no_return)
|
|
227
|
+
def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
228
|
+
wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@wp.kernel
|
|
232
|
+
def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
233
|
+
i = wp.tid()
|
|
234
|
+
sigmoid_no_return(i, xs, ys)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_custom_grad_no_return(test, device):
|
|
238
|
+
xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True)
|
|
239
|
+
ys = wp.zeros_like(xs)
|
|
240
|
+
ys.grad.fill_(1.0)
|
|
241
|
+
|
|
242
|
+
tape = wp.Tape()
|
|
243
|
+
with tape:
|
|
244
|
+
wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys])
|
|
245
|
+
tape.backward()
|
|
246
|
+
|
|
247
|
+
sigmoids = ys.numpy()
|
|
248
|
+
grad = xs.grad.numpy()
|
|
249
|
+
assert_np_equal(grad, sigmoids * (1.0 - sigmoids))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def test_wrapped_docstring(test, device):
|
|
253
|
+
assert "This is a docstring" in reversible_increment.__doc__
|
|
254
|
+
assert "This is a docstring" in replay_reversible_increment.__doc__
|
|
255
|
+
assert "This is a docstring" in overload_fn.__doc__
|
|
256
|
+
assert "This is a docstring" in overload_fn_grad.__doc__
|
|
257
|
+
assert "This is a docstring" in run_overload_float_fn.__doc__
|
|
258
|
+
assert "This is a docstring" in MyStruct.__doc__
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@wp.func
|
|
262
|
+
def dense_gemm(
|
|
263
|
+
m: int,
|
|
264
|
+
n: int,
|
|
265
|
+
p: int,
|
|
266
|
+
transpose_A: bool,
|
|
267
|
+
transpose_B: bool,
|
|
268
|
+
add_to_C: bool,
|
|
269
|
+
A: wp.array(dtype=float),
|
|
270
|
+
B: wp.array(dtype=float),
|
|
271
|
+
# outputs
|
|
272
|
+
C: wp.array(dtype=float),
|
|
273
|
+
):
|
|
274
|
+
# this function doesn't get called but it is an important test for code generation
|
|
275
|
+
# multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
|
|
276
|
+
for i in range(m):
|
|
277
|
+
for j in range(n):
|
|
278
|
+
sum = float(0.0)
|
|
279
|
+
for k in range(p):
|
|
280
|
+
if transpose_A:
|
|
281
|
+
a_i = k * m + i
|
|
282
|
+
else:
|
|
283
|
+
a_i = i * p + k
|
|
284
|
+
if transpose_B:
|
|
285
|
+
b_j = j * p + k
|
|
286
|
+
else:
|
|
287
|
+
b_j = k * n + j
|
|
288
|
+
sum += A[a_i] * B[b_j]
|
|
289
|
+
|
|
290
|
+
if add_to_C:
|
|
291
|
+
C[i * n + j] += sum
|
|
292
|
+
else:
|
|
293
|
+
C[i * n + j] = sum
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@wp.func_grad(dense_gemm)
|
|
297
|
+
def adj_dense_gemm(
|
|
298
|
+
m: int,
|
|
299
|
+
n: int,
|
|
300
|
+
p: int,
|
|
301
|
+
transpose_A: bool,
|
|
302
|
+
transpose_B: bool,
|
|
303
|
+
add_to_C: bool,
|
|
304
|
+
A: wp.array(dtype=float),
|
|
305
|
+
B: wp.array(dtype=float),
|
|
306
|
+
# outputs
|
|
307
|
+
C: wp.array(dtype=float),
|
|
308
|
+
):
|
|
309
|
+
# code generation would break here if we didn't defer building the custom grad
|
|
310
|
+
# function until after the forward functions + kernels of the module have been built
|
|
311
|
+
add_to_C = True
|
|
312
|
+
if transpose_A:
|
|
313
|
+
dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
|
|
314
|
+
dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
|
|
315
|
+
else:
|
|
316
|
+
dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
|
|
317
|
+
dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
|
|
318
|
+
|
|
319
|
+
|
|
200
320
|
devices = get_test_devices()
|
|
201
321
|
|
|
202
322
|
|
|
@@ -207,6 +327,8 @@ class TestGradCustoms(unittest.TestCase):
|
|
|
207
327
|
add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
|
|
208
328
|
add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
|
|
209
329
|
add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
|
|
330
|
+
add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)
|
|
331
|
+
add_function_test(TestGradCustoms, "test_wrapped_docstring", test_wrapped_docstring, devices=devices)
|
|
210
332
|
|
|
211
333
|
|
|
212
334
|
if __name__ == "__main__":
|
warp/tests/test_jax.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import os
|
|
10
|
+
import unittest
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import warp as wp
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
15
|
+
|
|
16
|
+
wp.init()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# basic kernel with one input and output
|
|
20
|
+
@wp.kernel
|
|
21
|
+
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
22
|
+
tid = wp.tid()
|
|
23
|
+
output[tid] = 3.0 * input[tid]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# generic kernel with one scalar input and output
|
|
27
|
+
@wp.kernel
|
|
28
|
+
def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
29
|
+
tid = wp.tid()
|
|
30
|
+
output[tid] = input.dtype(3) * input[tid]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# generic kernel with one vector/matrix input and output
|
|
34
|
+
@wp.kernel
|
|
35
|
+
def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
36
|
+
tid = wp.tid()
|
|
37
|
+
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# kernel with multiple inputs and outputs
|
|
41
|
+
@wp.kernel
|
|
42
|
+
def multiarg_kernel(
|
|
43
|
+
# inputs
|
|
44
|
+
a: wp.array(dtype=float),
|
|
45
|
+
b: wp.array(dtype=float),
|
|
46
|
+
c: wp.array(dtype=float),
|
|
47
|
+
# outputs
|
|
48
|
+
ab: wp.array(dtype=float),
|
|
49
|
+
bc: wp.array(dtype=float),
|
|
50
|
+
):
|
|
51
|
+
tid = wp.tid()
|
|
52
|
+
ab[tid] = a[tid] + b[tid]
|
|
53
|
+
bc[tid] = b[tid] + c[tid]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# various types for testing
|
|
57
|
+
scalar_types = wp.types.scalar_types
|
|
58
|
+
vector_types = []
|
|
59
|
+
matrix_types = []
|
|
60
|
+
for dim in [2, 3, 4]:
|
|
61
|
+
for T in scalar_types:
|
|
62
|
+
vector_types.append(wp.vec(dim, T))
|
|
63
|
+
matrix_types.append(wp.mat((dim, dim), T))
|
|
64
|
+
|
|
65
|
+
# explicitly overload generic kernels to avoid module reloading during tests
|
|
66
|
+
for T in scalar_types:
|
|
67
|
+
wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
68
|
+
for T in [*vector_types, *matrix_types]:
|
|
69
|
+
wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _jax_version():
|
|
73
|
+
try:
|
|
74
|
+
import jax
|
|
75
|
+
return jax.__version_info__
|
|
76
|
+
except ImportError:
|
|
77
|
+
return (0, 0, 0)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
81
|
+
def test_jax_kernel_basic(test, device):
|
|
82
|
+
import jax.numpy as jp
|
|
83
|
+
from warp.jax_experimental import jax_kernel
|
|
84
|
+
|
|
85
|
+
n = 64
|
|
86
|
+
|
|
87
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
88
|
+
|
|
89
|
+
@jax.jit
|
|
90
|
+
def f():
|
|
91
|
+
x = jp.arange(n, dtype=jp.float32)
|
|
92
|
+
return jax_triple(x)
|
|
93
|
+
|
|
94
|
+
# run on the given device
|
|
95
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
96
|
+
y = f()
|
|
97
|
+
|
|
98
|
+
result = np.asarray(y)
|
|
99
|
+
expected = 3 * np.arange(n, dtype=np.float32)
|
|
100
|
+
|
|
101
|
+
assert_np_equal(result, expected)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
105
|
+
def test_jax_kernel_scalar(test, device):
|
|
106
|
+
import jax.numpy as jp
|
|
107
|
+
from warp.jax_experimental import jax_kernel
|
|
108
|
+
|
|
109
|
+
n = 64
|
|
110
|
+
|
|
111
|
+
for T in scalar_types:
|
|
112
|
+
|
|
113
|
+
jp_dtype = wp.jax.dtype_to_jax(T)
|
|
114
|
+
np_dtype = wp.types.warp_type_to_np_dtype[T]
|
|
115
|
+
|
|
116
|
+
with test.subTest(msg=T.__name__):
|
|
117
|
+
|
|
118
|
+
# get the concrete overload
|
|
119
|
+
kernel_instance = triple_kernel_scalar.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
120
|
+
|
|
121
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
122
|
+
|
|
123
|
+
@jax.jit
|
|
124
|
+
def f():
|
|
125
|
+
x = jp.arange(n, dtype=jp_dtype)
|
|
126
|
+
return jax_triple(x)
|
|
127
|
+
|
|
128
|
+
# run on the given device
|
|
129
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
130
|
+
y = f()
|
|
131
|
+
|
|
132
|
+
result = np.asarray(y)
|
|
133
|
+
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
134
|
+
|
|
135
|
+
assert_np_equal(result, expected)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
139
|
+
def test_jax_kernel_vecmat(test, device):
|
|
140
|
+
import jax.numpy as jp
|
|
141
|
+
from warp.jax_experimental import jax_kernel
|
|
142
|
+
|
|
143
|
+
for T in [*vector_types, *matrix_types]:
|
|
144
|
+
|
|
145
|
+
jp_dtype = wp.jax.dtype_to_jax(T._wp_scalar_type_)
|
|
146
|
+
np_dtype = wp.types.warp_type_to_np_dtype[T._wp_scalar_type_]
|
|
147
|
+
|
|
148
|
+
n = 64 // T._length_
|
|
149
|
+
scalar_shape = (n, *T._shape_)
|
|
150
|
+
scalar_len = n * T._length_
|
|
151
|
+
|
|
152
|
+
with test.subTest(msg=T.__name__):
|
|
153
|
+
|
|
154
|
+
# get the concrete overload
|
|
155
|
+
kernel_instance = triple_kernel_vecmat.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
156
|
+
|
|
157
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
158
|
+
|
|
159
|
+
@jax.jit
|
|
160
|
+
def f():
|
|
161
|
+
x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
|
|
162
|
+
return jax_triple(x)
|
|
163
|
+
|
|
164
|
+
# run on the given device
|
|
165
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
166
|
+
y = f()
|
|
167
|
+
|
|
168
|
+
result = np.asarray(y)
|
|
169
|
+
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
170
|
+
|
|
171
|
+
assert_np_equal(result, expected)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
175
|
+
def test_jax_kernel_multiarg(test, device):
|
|
176
|
+
import jax.numpy as jp
|
|
177
|
+
from warp.jax_experimental import jax_kernel
|
|
178
|
+
|
|
179
|
+
n = 64
|
|
180
|
+
|
|
181
|
+
jax_multiarg = jax_kernel(multiarg_kernel)
|
|
182
|
+
|
|
183
|
+
@jax.jit
|
|
184
|
+
def f():
|
|
185
|
+
a = jp.full(n, 1, dtype=jp.float32)
|
|
186
|
+
b = jp.full(n, 2, dtype=jp.float32)
|
|
187
|
+
c = jp.full(n, 3, dtype=jp.float32)
|
|
188
|
+
return jax_multiarg(a, b, c)
|
|
189
|
+
|
|
190
|
+
# run on the given device
|
|
191
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
192
|
+
x, y = f()
|
|
193
|
+
|
|
194
|
+
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
195
|
+
expected_x = np.full(n, 3, dtype=np.float32)
|
|
196
|
+
expected_y = np.full(n, 5, dtype=np.float32)
|
|
197
|
+
|
|
198
|
+
assert_np_equal(result_x, expected_x)
|
|
199
|
+
assert_np_equal(result_y, expected_y)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class TestJax(unittest.TestCase):
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# try adding Jax tests if Jax is installed correctly
|
|
207
|
+
try:
|
|
208
|
+
# prevent Jax from gobbling up GPU memory
|
|
209
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
210
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
211
|
+
|
|
212
|
+
import jax
|
|
213
|
+
import jax.dlpack
|
|
214
|
+
|
|
215
|
+
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
216
|
+
jax.config.update("jax_enable_x64", True)
|
|
217
|
+
|
|
218
|
+
# check which Warp devices work with Jax
|
|
219
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
220
|
+
test_devices = get_test_devices()
|
|
221
|
+
jax_compatible_devices = []
|
|
222
|
+
jax_compatible_cuda_devices = []
|
|
223
|
+
for d in test_devices:
|
|
224
|
+
try:
|
|
225
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
226
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
227
|
+
j += 1
|
|
228
|
+
jax_compatible_devices.append(d)
|
|
229
|
+
if d.is_cuda:
|
|
230
|
+
jax_compatible_cuda_devices.append(d)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
233
|
+
|
|
234
|
+
if jax_compatible_cuda_devices:
|
|
235
|
+
add_function_test(
|
|
236
|
+
TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices
|
|
237
|
+
)
|
|
238
|
+
add_function_test(
|
|
239
|
+
TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
|
|
240
|
+
)
|
|
241
|
+
add_function_test(
|
|
242
|
+
TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
|
|
243
|
+
)
|
|
244
|
+
add_function_test(
|
|
245
|
+
TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
print(f"Skipping Jax tests due to exception: {e}")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
if __name__ == "__main__":
|
|
253
|
+
wp.build.clear_kernel_cache()
|
|
254
|
+
unittest.main(verbosity=2)
|
warp/tests/test_options.py
CHANGED
|
@@ -6,6 +6,8 @@
|
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
8
|
import unittest
|
|
9
|
+
import contextlib
|
|
10
|
+
import io
|
|
9
11
|
|
|
10
12
|
import warp as wp
|
|
11
13
|
from warp.tests.unittest_utils import *
|
|
@@ -49,7 +51,12 @@ def test_options_1(test, device):
|
|
|
49
51
|
with tape:
|
|
50
52
|
wp.launch(scale, dim=1, inputs=[x, y], device=device)
|
|
51
53
|
|
|
52
|
-
|
|
54
|
+
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
55
|
+
tape.backward(y)
|
|
56
|
+
|
|
57
|
+
expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale.key} is defined in a module with the option 'enable_backward=False' set.\n"
|
|
58
|
+
|
|
59
|
+
assert f.getvalue() == expected
|
|
53
60
|
assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
|
|
54
61
|
|
|
55
62
|
|
|
@@ -91,58 +98,13 @@ def test_options_4(test, device):
|
|
|
91
98
|
with tape:
|
|
92
99
|
wp.launch(scale_2, dim=1, inputs=[x, y], device=device)
|
|
93
100
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
|
|
99
|
-
def test_options_5(test, device):
|
|
100
|
-
wp.set_module_options({"enable_backward": True})
|
|
101
|
-
|
|
102
|
-
@wp.kernel
|
|
103
|
-
def loss_kernel(y: wp.array(dtype=float), loss: wp.array(dtype=float)):
|
|
104
|
-
tid = wp.tid()
|
|
105
|
-
wp.atomic_add(loss, 0, y[tid])
|
|
106
|
-
|
|
107
|
-
A = wp.array(np.ones((2, 2), dtype=float), dtype=float, requires_grad=True, device=device)
|
|
108
|
-
x = wp.array([[1.0], [2.0]], dtype=float, requires_grad=True, device=device)
|
|
109
|
-
b = wp.zeros_like(x)
|
|
110
|
-
y = wp.zeros_like(x)
|
|
111
|
-
loss = wp.zeros(1, requires_grad=True, device=device)
|
|
112
|
-
|
|
113
|
-
tape = wp.Tape()
|
|
114
|
-
|
|
115
|
-
with tape:
|
|
116
|
-
wp.matmul(A, x, b, y)
|
|
117
|
-
wp.launch(loss_kernel, dim=2, inputs=[y.flatten(), loss], device=device)
|
|
118
|
-
|
|
119
|
-
tape.backward(loss)
|
|
120
|
-
assert_np_equal(x.grad.numpy(), np.array([[2.0], [2.0]]))
|
|
101
|
+
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
102
|
+
tape.backward(y)
|
|
121
103
|
|
|
104
|
+
expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale_2.key} is configured with the option 'enable_backward=False'.\n"
|
|
122
105
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
wp.set_module_options({"enable_backward": False})
|
|
126
|
-
|
|
127
|
-
@wp.kernel
|
|
128
|
-
def loss_kernel(y: wp.array(dtype=float), loss: wp.array(dtype=float)):
|
|
129
|
-
tid = wp.tid()
|
|
130
|
-
wp.atomic_add(loss, 0, y[tid])
|
|
131
|
-
|
|
132
|
-
A = wp.array(np.ones((2, 2), dtype=float), dtype=float, requires_grad=True, device=device)
|
|
133
|
-
x = wp.array([[1.0], [2.0]], dtype=float, requires_grad=True, device=device)
|
|
134
|
-
b = wp.zeros_like(x)
|
|
135
|
-
y = wp.zeros_like(x)
|
|
136
|
-
loss = wp.zeros(1, requires_grad=True, device=device)
|
|
137
|
-
|
|
138
|
-
tape = wp.Tape()
|
|
139
|
-
|
|
140
|
-
with tape:
|
|
141
|
-
wp.matmul(A, x, b, y)
|
|
142
|
-
wp.launch(loss_kernel, dim=2, inputs=[y.flatten(), loss], device=device)
|
|
143
|
-
|
|
144
|
-
tape.backward(loss)
|
|
145
|
-
assert_np_equal(x.grad.numpy(), np.array([[0.0], [0.0]]))
|
|
106
|
+
assert f.getvalue() == expected
|
|
107
|
+
assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
|
|
146
108
|
|
|
147
109
|
|
|
148
110
|
devices = get_test_devices()
|
|
@@ -156,8 +118,6 @@ add_function_test(TestOptions, "test_options_1", test_options_1, devices=devices
|
|
|
156
118
|
add_function_test(TestOptions, "test_options_2", test_options_2, devices=devices)
|
|
157
119
|
add_function_test(TestOptions, "test_options_3", test_options_3, devices=devices)
|
|
158
120
|
add_function_test(TestOptions, "test_options_4", test_options_4, devices=devices)
|
|
159
|
-
add_function_test(TestOptions, "test_options_5", test_options_5, devices=devices)
|
|
160
|
-
add_function_test(TestOptions, "test_options_6", test_options_6, devices=devices)
|
|
161
121
|
|
|
162
122
|
|
|
163
123
|
if __name__ == "__main__":
|
warp/tests/test_quat.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
|
11
11
|
|
|
12
12
|
import warp as wp
|
|
13
13
|
from warp.tests.unittest_utils import *
|
|
14
|
+
import warp.sim
|
|
14
15
|
|
|
15
16
|
wp.init()
|
|
16
17
|
|
|
@@ -1871,6 +1872,21 @@ def test_quat_identity(test, device, dtype, register_kernels=False):
|
|
|
1871
1872
|
assert_np_equal(output.numpy(), expected)
|
|
1872
1873
|
|
|
1873
1874
|
|
|
1875
|
+
############################################################
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
|
|
1879
|
+
rng = np.random.default_rng(123)
|
|
1880
|
+
N = 3
|
|
1881
|
+
|
|
1882
|
+
rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
|
|
1883
|
+
|
|
1884
|
+
quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
|
|
1885
|
+
quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
|
|
1886
|
+
|
|
1887
|
+
assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
|
|
1888
|
+
|
|
1889
|
+
|
|
1874
1890
|
def test_anon_type_instance(test, device, dtype, register_kernels=False):
|
|
1875
1891
|
rng = np.random.default_rng(123)
|
|
1876
1892
|
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
@@ -2053,6 +2069,13 @@ for dtype in np_float_types:
|
|
|
2053
2069
|
add_function_test_register_kernel(
|
|
2054
2070
|
TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
|
|
2055
2071
|
)
|
|
2072
|
+
add_function_test_register_kernel(
|
|
2073
|
+
TestQuat,
|
|
2074
|
+
f"test_quat_euler_conversion_{dtype.__name__}",
|
|
2075
|
+
test_quat_euler_conversion,
|
|
2076
|
+
devices=devices,
|
|
2077
|
+
dtype=dtype,
|
|
2078
|
+
)
|
|
2056
2079
|
add_function_test(
|
|
2057
2080
|
TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
|
|
2058
2081
|
)
|
warp/tests/test_snippet.py
CHANGED
|
@@ -86,6 +86,7 @@ def test_shared_memory(test, device):
|
|
|
86
86
|
|
|
87
87
|
@wp.func_native(snippet)
|
|
88
88
|
def reverse(d: wp.array(dtype=int), N: int, tid: int):
|
|
89
|
+
"""Reverse the array d in place using shared memory."""
|
|
89
90
|
return
|
|
90
91
|
|
|
91
92
|
@wp.kernel
|
|
@@ -100,6 +101,7 @@ def test_shared_memory(test, device):
|
|
|
100
101
|
wp.launch(kernel=reverse_kernel, dim=N, inputs=[x, N], device=device)
|
|
101
102
|
|
|
102
103
|
assert_np_equal(x.numpy(), y)
|
|
104
|
+
assert reverse.__doc__ == "Reverse the array d in place using shared memory."
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
def test_cpu_snippet(test, device):
|
warp/tests/test_utils.py
CHANGED
|
@@ -267,55 +267,60 @@ class TestUtils(unittest.TestCase):
|
|
|
267
267
|
def test_warn(self):
|
|
268
268
|
# Multiple warnings get printed out each time.
|
|
269
269
|
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
270
|
-
frame_info = inspect.getframeinfo(inspect.currentframe())
|
|
271
270
|
wp.utils.warn("hello, world!")
|
|
272
271
|
wp.utils.warn("hello, world!")
|
|
273
272
|
|
|
274
273
|
expected = (
|
|
275
|
-
"
|
|
276
|
-
"
|
|
277
|
-
).format(
|
|
278
|
-
frame_info.filename,
|
|
279
|
-
frame_info.lineno + 1,
|
|
280
|
-
"UserWarning: hello, world!\n wp.utils.warn(\"hello, world!\")",
|
|
281
|
-
frame_info.filename,
|
|
282
|
-
frame_info.lineno + 2,
|
|
283
|
-
"UserWarning: hello, world!\n wp.utils.warn(\"hello, world!\")",
|
|
274
|
+
"Warp UserWarning: hello, world!\n"
|
|
275
|
+
"Warp UserWarning: hello, world!\n"
|
|
284
276
|
)
|
|
277
|
+
|
|
285
278
|
self.assertEqual(f.getvalue(), expected)
|
|
286
279
|
|
|
280
|
+
# Test verbose warnings
|
|
281
|
+
saved_verbosity = wp.config.verbose_warnings
|
|
282
|
+
try:
|
|
283
|
+
wp.config.verbose_warnings = True
|
|
284
|
+
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
285
|
+
frame_info = inspect.getframeinfo(inspect.currentframe())
|
|
286
|
+
wp.utils.warn("hello, world!")
|
|
287
|
+
wp.utils.warn("hello, world!")
|
|
288
|
+
|
|
289
|
+
expected = (
|
|
290
|
+
f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 1})\n"
|
|
291
|
+
" wp.utils.warn(\"hello, world!\")\n"
|
|
292
|
+
f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 2})\n"
|
|
293
|
+
" wp.utils.warn(\"hello, world!\")\n"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self.assertEqual(f.getvalue(), expected)
|
|
297
|
+
|
|
298
|
+
finally:
|
|
299
|
+
# make sure to restore warning verbosity
|
|
300
|
+
wp.config.verbose_warnings = saved_verbosity
|
|
301
|
+
|
|
302
|
+
|
|
287
303
|
# Multiple similar deprecation warnings get printed out only once.
|
|
288
304
|
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
289
|
-
frame_info = inspect.getframeinfo(inspect.currentframe())
|
|
290
305
|
wp.utils.warn("hello, world!", category=DeprecationWarning)
|
|
291
306
|
wp.utils.warn("hello, world!", category=DeprecationWarning)
|
|
292
307
|
|
|
293
308
|
expected = (
|
|
294
|
-
"
|
|
295
|
-
).format(
|
|
296
|
-
frame_info.filename,
|
|
297
|
-
frame_info.lineno + 1,
|
|
298
|
-
"DeprecationWarning: hello, world!\n wp.utils.warn(\"hello, world!\", category=DeprecationWarning)",
|
|
309
|
+
"Warp DeprecationWarning: hello, world!\n"
|
|
299
310
|
)
|
|
311
|
+
|
|
300
312
|
self.assertEqual(f.getvalue(), expected)
|
|
301
313
|
|
|
302
314
|
# Multiple different deprecation warnings get printed out each time.
|
|
303
315
|
with contextlib.redirect_stdout(io.StringIO()) as f:
|
|
304
|
-
frame_info = inspect.getframeinfo(inspect.currentframe())
|
|
305
316
|
wp.utils.warn("foo", category=DeprecationWarning)
|
|
306
317
|
wp.utils.warn("bar", category=DeprecationWarning)
|
|
307
318
|
|
|
308
319
|
expected = (
|
|
309
|
-
"
|
|
310
|
-
"
|
|
311
|
-
).format(
|
|
312
|
-
frame_info.filename,
|
|
313
|
-
frame_info.lineno + 1,
|
|
314
|
-
"DeprecationWarning: foo\n wp.utils.warn(\"foo\", category=DeprecationWarning)",
|
|
315
|
-
frame_info.filename,
|
|
316
|
-
frame_info.lineno + 2,
|
|
317
|
-
"DeprecationWarning: bar\n wp.utils.warn(\"bar\", category=DeprecationWarning)",
|
|
320
|
+
"Warp DeprecationWarning: foo\n"
|
|
321
|
+
"Warp DeprecationWarning: bar\n"
|
|
318
322
|
)
|
|
323
|
+
|
|
319
324
|
self.assertEqual(f.getvalue(), expected)
|
|
320
325
|
|
|
321
326
|
def test_transform_expand(self):
|