warp-lang 0.15.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (80) hide show
  1. warp/__init__.py +1 -0
  2. warp/codegen.py +7 -3
  3. warp/config.py +2 -1
  4. warp/constants.py +3 -0
  5. warp/context.py +44 -21
  6. warp/examples/assets/bunny.usd +0 -0
  7. warp/examples/assets/cartpole.urdf +110 -0
  8. warp/examples/assets/crazyflie.usd +0 -0
  9. warp/examples/assets/cube.usda +42 -0
  10. warp/examples/assets/nv_ant.xml +92 -0
  11. warp/examples/assets/nv_humanoid.xml +183 -0
  12. warp/examples/assets/quadruped.urdf +268 -0
  13. warp/examples/assets/rocks.nvdb +0 -0
  14. warp/examples/assets/rocks.usd +0 -0
  15. warp/examples/assets/sphere.usda +56 -0
  16. warp/examples/assets/torus.usda +105 -0
  17. warp/examples/core/example_dem.py +6 -6
  18. warp/examples/core/example_fluid.py +3 -3
  19. warp/examples/core/example_graph_capture.py +3 -6
  20. warp/examples/optim/example_bounce.py +9 -8
  21. warp/examples/optim/example_cloth_throw.py +12 -8
  22. warp/examples/optim/example_diffray.py +10 -12
  23. warp/examples/optim/example_drone.py +31 -14
  24. warp/examples/optim/example_spring_cage.py +10 -15
  25. warp/examples/optim/example_trajectory.py +7 -24
  26. warp/examples/sim/example_cartpole.py +3 -9
  27. warp/examples/sim/example_cloth.py +10 -10
  28. warp/examples/sim/example_granular.py +3 -3
  29. warp/examples/sim/example_granular_collision_sdf.py +9 -4
  30. warp/examples/sim/example_jacobian_ik.py +0 -10
  31. warp/examples/sim/example_particle_chain.py +4 -4
  32. warp/examples/sim/example_quadruped.py +15 -11
  33. warp/examples/sim/example_rigid_chain.py +13 -8
  34. warp/examples/sim/example_rigid_contact.py +4 -4
  35. warp/examples/sim/example_rigid_force.py +7 -7
  36. warp/examples/sim/example_rigid_soft_contact.py +4 -4
  37. warp/examples/sim/example_soft_body.py +3 -3
  38. warp/jax.py +45 -0
  39. warp/jax_experimental.py +339 -0
  40. warp/render/render_opengl.py +188 -95
  41. warp/render/render_usd.py +34 -10
  42. warp/sim/__init__.py +13 -4
  43. warp/sim/articulation.py +4 -5
  44. warp/sim/collide.py +320 -175
  45. warp/sim/import_mjcf.py +25 -30
  46. warp/sim/import_urdf.py +94 -63
  47. warp/sim/import_usd.py +51 -36
  48. warp/sim/inertia.py +3 -2
  49. warp/sim/integrator.py +233 -0
  50. warp/sim/integrator_euler.py +447 -469
  51. warp/sim/integrator_featherstone.py +1991 -0
  52. warp/sim/integrator_xpbd.py +1420 -640
  53. warp/sim/model.py +741 -487
  54. warp/sim/particles.py +2 -1
  55. warp/sim/render.py +18 -2
  56. warp/sim/utils.py +222 -11
  57. warp/stubs.py +1 -0
  58. warp/tape.py +6 -9
  59. warp/tests/test_examples.py +87 -20
  60. warp/tests/test_grad_customs.py +122 -0
  61. warp/tests/test_jax.py +254 -0
  62. warp/tests/test_options.py +13 -53
  63. warp/tests/test_quat.py +23 -0
  64. warp/tests/test_snippet.py +2 -0
  65. warp/tests/test_utils.py +31 -26
  66. warp/tests/test_verify_fp.py +65 -0
  67. warp/tests/unittest_suites.py +4 -0
  68. warp/utils.py +50 -1
  69. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/METADATA +1 -1
  70. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +73 -64
  71. warp/examples/env/__init__.py +0 -0
  72. warp/examples/env/env_ant.py +0 -61
  73. warp/examples/env/env_cartpole.py +0 -63
  74. warp/examples/env/env_humanoid.py +0 -65
  75. warp/examples/env/env_usd.py +0 -97
  76. warp/examples/env/environment.py +0 -526
  77. warp/sim/optimizer.py +0 -138
  78. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/LICENSE.md +0 -0
  79. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  80. {warp_lang-0.15.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ wp.init()
22
22
  def reversible_increment(
23
23
  counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
24
24
  ):
25
+ """This is a docstring"""
25
26
  next_index = wp.atomic_add(counter, counter_index, value)
26
27
  thread_values[tid] = next_index
27
28
  return next_index
@@ -31,6 +32,7 @@ def reversible_increment(
31
32
  def replay_reversible_increment(
32
33
  counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
33
34
  ):
35
+ """This is a docstring"""
34
36
  return thread_values[tid]
35
37
 
36
38
 
@@ -64,28 +66,33 @@ def test_custom_replay_grad(test, device):
64
66
 
65
67
  @wp.func
66
68
  def overload_fn(x: float, y: float):
69
+ """This is a docstring"""
67
70
  return x * 3.0 + y / 3.0, y**2.5
68
71
 
69
72
 
70
73
  @wp.func_grad(overload_fn)
71
74
  def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
75
+ """This is a docstring"""
72
76
  wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
73
77
  wp.adjoint[y] += y * adj_ret1 * 3.0
74
78
 
75
79
 
76
80
  @wp.struct
77
81
  class MyStruct:
82
+ """This is a docstring"""
78
83
  scalar: float
79
84
  vec: wp.vec3
80
85
 
81
86
 
82
87
  @wp.func
83
88
  def overload_fn(x: MyStruct):
89
+ """This is a docstring"""
84
90
  return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
85
91
 
86
92
 
87
93
  @wp.func_grad(overload_fn)
88
94
  def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
95
+ """This is a docstring"""
89
96
  wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
90
97
  wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
91
98
  wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
@@ -96,6 +103,7 @@ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: fl
96
103
  def run_overload_float_fn(
97
104
  xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
98
105
  ):
106
+ """This is a docstring"""
99
107
  i = wp.tid()
100
108
  out0, out1 = overload_fn(xs[i], ys[i])
101
109
  output0[i] = out0
@@ -197,6 +205,118 @@ def test_custom_import_grad(test, device):
197
205
  assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
198
206
 
199
207
 
208
+ @wp.func
209
+ def sigmoid(x: float):
210
+ return 1.0 / (1.0 + wp.exp(-x))
211
+
212
+
213
+ @wp.func_grad(sigmoid)
214
+ def adj_sigmoid(x: float, adj: float):
215
+ # unused function to test that we don't run into infinite recursion when calling
216
+ # the forward function from within the gradient function
217
+ wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))
218
+
219
+
220
+ @wp.func
221
+ def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
222
+ # test function that does not return anything
223
+ ys[i] = sigmoid(xs[i])
224
+
225
+
226
+ @wp.func_grad(sigmoid_no_return)
227
+ def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
228
+ wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])
229
+
230
+
231
+ @wp.kernel
232
+ def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
233
+ i = wp.tid()
234
+ sigmoid_no_return(i, xs, ys)
235
+
236
+
237
+ def test_custom_grad_no_return(test, device):
238
+ xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True)
239
+ ys = wp.zeros_like(xs)
240
+ ys.grad.fill_(1.0)
241
+
242
+ tape = wp.Tape()
243
+ with tape:
244
+ wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys])
245
+ tape.backward()
246
+
247
+ sigmoids = ys.numpy()
248
+ grad = xs.grad.numpy()
249
+ assert_np_equal(grad, sigmoids * (1.0 - sigmoids))
250
+
251
+
252
+ def test_wrapped_docstring(test, device):
253
+ assert "This is a docstring" in reversible_increment.__doc__
254
+ assert "This is a docstring" in replay_reversible_increment.__doc__
255
+ assert "This is a docstring" in overload_fn.__doc__
256
+ assert "This is a docstring" in overload_fn_grad.__doc__
257
+ assert "This is a docstring" in run_overload_float_fn.__doc__
258
+ assert "This is a docstring" in MyStruct.__doc__
259
+
260
+
261
+ @wp.func
262
+ def dense_gemm(
263
+ m: int,
264
+ n: int,
265
+ p: int,
266
+ transpose_A: bool,
267
+ transpose_B: bool,
268
+ add_to_C: bool,
269
+ A: wp.array(dtype=float),
270
+ B: wp.array(dtype=float),
271
+ # outputs
272
+ C: wp.array(dtype=float),
273
+ ):
274
+ # this function doesn't get called but it is an important test for code generation
275
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
276
+ for i in range(m):
277
+ for j in range(n):
278
+ sum = float(0.0)
279
+ for k in range(p):
280
+ if transpose_A:
281
+ a_i = k * m + i
282
+ else:
283
+ a_i = i * p + k
284
+ if transpose_B:
285
+ b_j = j * p + k
286
+ else:
287
+ b_j = k * n + j
288
+ sum += A[a_i] * B[b_j]
289
+
290
+ if add_to_C:
291
+ C[i * n + j] += sum
292
+ else:
293
+ C[i * n + j] = sum
294
+
295
+
296
+ @wp.func_grad(dense_gemm)
297
+ def adj_dense_gemm(
298
+ m: int,
299
+ n: int,
300
+ p: int,
301
+ transpose_A: bool,
302
+ transpose_B: bool,
303
+ add_to_C: bool,
304
+ A: wp.array(dtype=float),
305
+ B: wp.array(dtype=float),
306
+ # outputs
307
+ C: wp.array(dtype=float),
308
+ ):
309
+ # code generation would break here if we didn't defer building the custom grad
310
+ # function until after the forward functions + kernels of the module have been built
311
+ add_to_C = True
312
+ if transpose_A:
313
+ dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
314
+ dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
315
+ else:
316
+ dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
317
+ dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
318
+
319
+
200
320
  devices = get_test_devices()
201
321
 
202
322
 
@@ -207,6 +327,8 @@ class TestGradCustoms(unittest.TestCase):
207
327
  add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
208
328
  add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
209
329
  add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
330
+ add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)
331
+ add_function_test(TestGradCustoms, "test_wrapped_docstring", test_wrapped_docstring, devices=devices)
210
332
 
211
333
 
212
334
  if __name__ == "__main__":
warp/tests/test_jax.py ADDED
@@ -0,0 +1,254 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import numpy as np
9
+ import os
10
+ import unittest
11
+ from typing import Any
12
+
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
15
+
16
+ wp.init()
17
+
18
+
19
+ # basic kernel with one input and output
20
+ @wp.kernel
21
+ def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
22
+ tid = wp.tid()
23
+ output[tid] = 3.0 * input[tid]
24
+
25
+
26
+ # generic kernel with one scalar input and output
27
+ @wp.kernel
28
+ def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
29
+ tid = wp.tid()
30
+ output[tid] = input.dtype(3) * input[tid]
31
+
32
+
33
+ # generic kernel with one vector/matrix input and output
34
+ @wp.kernel
35
+ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
36
+ tid = wp.tid()
37
+ output[tid] = input.dtype.dtype(3) * input[tid]
38
+
39
+
40
+ # kernel with multiple inputs and outputs
41
+ @wp.kernel
42
+ def multiarg_kernel(
43
+ # inputs
44
+ a: wp.array(dtype=float),
45
+ b: wp.array(dtype=float),
46
+ c: wp.array(dtype=float),
47
+ # outputs
48
+ ab: wp.array(dtype=float),
49
+ bc: wp.array(dtype=float),
50
+ ):
51
+ tid = wp.tid()
52
+ ab[tid] = a[tid] + b[tid]
53
+ bc[tid] = b[tid] + c[tid]
54
+
55
+
56
+ # various types for testing
57
+ scalar_types = wp.types.scalar_types
58
+ vector_types = []
59
+ matrix_types = []
60
+ for dim in [2, 3, 4]:
61
+ for T in scalar_types:
62
+ vector_types.append(wp.vec(dim, T))
63
+ matrix_types.append(wp.mat((dim, dim), T))
64
+
65
+ # explicitly overload generic kernels to avoid module reloading during tests
66
+ for T in scalar_types:
67
+ wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
68
+ for T in [*vector_types, *matrix_types]:
69
+ wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
70
+
71
+
72
+ def _jax_version():
73
+ try:
74
+ import jax
75
+ return jax.__version_info__
76
+ except ImportError:
77
+ return (0, 0, 0)
78
+
79
+
80
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
81
+ def test_jax_kernel_basic(test, device):
82
+ import jax.numpy as jp
83
+ from warp.jax_experimental import jax_kernel
84
+
85
+ n = 64
86
+
87
+ jax_triple = jax_kernel(triple_kernel)
88
+
89
+ @jax.jit
90
+ def f():
91
+ x = jp.arange(n, dtype=jp.float32)
92
+ return jax_triple(x)
93
+
94
+ # run on the given device
95
+ with jax.default_device(wp.device_to_jax(device)):
96
+ y = f()
97
+
98
+ result = np.asarray(y)
99
+ expected = 3 * np.arange(n, dtype=np.float32)
100
+
101
+ assert_np_equal(result, expected)
102
+
103
+
104
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
105
+ def test_jax_kernel_scalar(test, device):
106
+ import jax.numpy as jp
107
+ from warp.jax_experimental import jax_kernel
108
+
109
+ n = 64
110
+
111
+ for T in scalar_types:
112
+
113
+ jp_dtype = wp.jax.dtype_to_jax(T)
114
+ np_dtype = wp.types.warp_type_to_np_dtype[T]
115
+
116
+ with test.subTest(msg=T.__name__):
117
+
118
+ # get the concrete overload
119
+ kernel_instance = triple_kernel_scalar.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
120
+
121
+ jax_triple = jax_kernel(kernel_instance)
122
+
123
+ @jax.jit
124
+ def f():
125
+ x = jp.arange(n, dtype=jp_dtype)
126
+ return jax_triple(x)
127
+
128
+ # run on the given device
129
+ with jax.default_device(wp.device_to_jax(device)):
130
+ y = f()
131
+
132
+ result = np.asarray(y)
133
+ expected = 3 * np.arange(n, dtype=np_dtype)
134
+
135
+ assert_np_equal(result, expected)
136
+
137
+
138
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
139
+ def test_jax_kernel_vecmat(test, device):
140
+ import jax.numpy as jp
141
+ from warp.jax_experimental import jax_kernel
142
+
143
+ for T in [*vector_types, *matrix_types]:
144
+
145
+ jp_dtype = wp.jax.dtype_to_jax(T._wp_scalar_type_)
146
+ np_dtype = wp.types.warp_type_to_np_dtype[T._wp_scalar_type_]
147
+
148
+ n = 64 // T._length_
149
+ scalar_shape = (n, *T._shape_)
150
+ scalar_len = n * T._length_
151
+
152
+ with test.subTest(msg=T.__name__):
153
+
154
+ # get the concrete overload
155
+ kernel_instance = triple_kernel_vecmat.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
156
+
157
+ jax_triple = jax_kernel(kernel_instance)
158
+
159
+ @jax.jit
160
+ def f():
161
+ x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
162
+ return jax_triple(x)
163
+
164
+ # run on the given device
165
+ with jax.default_device(wp.device_to_jax(device)):
166
+ y = f()
167
+
168
+ result = np.asarray(y)
169
+ expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
170
+
171
+ assert_np_equal(result, expected)
172
+
173
+
174
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
175
+ def test_jax_kernel_multiarg(test, device):
176
+ import jax.numpy as jp
177
+ from warp.jax_experimental import jax_kernel
178
+
179
+ n = 64
180
+
181
+ jax_multiarg = jax_kernel(multiarg_kernel)
182
+
183
+ @jax.jit
184
+ def f():
185
+ a = jp.full(n, 1, dtype=jp.float32)
186
+ b = jp.full(n, 2, dtype=jp.float32)
187
+ c = jp.full(n, 3, dtype=jp.float32)
188
+ return jax_multiarg(a, b, c)
189
+
190
+ # run on the given device
191
+ with jax.default_device(wp.device_to_jax(device)):
192
+ x, y = f()
193
+
194
+ result_x, result_y = np.asarray(x), np.asarray(y)
195
+ expected_x = np.full(n, 3, dtype=np.float32)
196
+ expected_y = np.full(n, 5, dtype=np.float32)
197
+
198
+ assert_np_equal(result_x, expected_x)
199
+ assert_np_equal(result_y, expected_y)
200
+
201
+
202
+ class TestJax(unittest.TestCase):
203
+ pass
204
+
205
+
206
+ # try adding Jax tests if Jax is installed correctly
207
+ try:
208
+ # prevent Jax from gobbling up GPU memory
209
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
210
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
211
+
212
+ import jax
213
+ import jax.dlpack
214
+
215
+ # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
216
+ jax.config.update("jax_enable_x64", True)
217
+
218
+ # check which Warp devices work with Jax
219
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
220
+ test_devices = get_test_devices()
221
+ jax_compatible_devices = []
222
+ jax_compatible_cuda_devices = []
223
+ for d in test_devices:
224
+ try:
225
+ with jax.default_device(wp.device_to_jax(d)):
226
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
227
+ j += 1
228
+ jax_compatible_devices.append(d)
229
+ if d.is_cuda:
230
+ jax_compatible_cuda_devices.append(d)
231
+ except Exception as e:
232
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
233
+
234
+ if jax_compatible_cuda_devices:
235
+ add_function_test(
236
+ TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices
237
+ )
238
+ add_function_test(
239
+ TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
240
+ )
241
+ add_function_test(
242
+ TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
243
+ )
244
+ add_function_test(
245
+ TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
246
+ )
247
+
248
+ except Exception as e:
249
+ print(f"Skipping Jax tests due to exception: {e}")
250
+
251
+
252
+ if __name__ == "__main__":
253
+ wp.build.clear_kernel_cache()
254
+ unittest.main(verbosity=2)
@@ -6,6 +6,8 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import unittest
9
+ import contextlib
10
+ import io
9
11
 
10
12
  import warp as wp
11
13
  from warp.tests.unittest_utils import *
@@ -49,7 +51,12 @@ def test_options_1(test, device):
49
51
  with tape:
50
52
  wp.launch(scale, dim=1, inputs=[x, y], device=device)
51
53
 
52
- tape.backward(y)
54
+ with contextlib.redirect_stdout(io.StringIO()) as f:
55
+ tape.backward(y)
56
+
57
+ expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale.key} is defined in a module with the option 'enable_backward=False' set.\n"
58
+
59
+ assert f.getvalue() == expected
53
60
  assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
54
61
 
55
62
 
@@ -91,58 +98,13 @@ def test_options_4(test, device):
91
98
  with tape:
92
99
  wp.launch(scale_2, dim=1, inputs=[x, y], device=device)
93
100
 
94
- tape.backward(y)
95
- assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
96
-
97
-
98
- @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
99
- def test_options_5(test, device):
100
- wp.set_module_options({"enable_backward": True})
101
-
102
- @wp.kernel
103
- def loss_kernel(y: wp.array(dtype=float), loss: wp.array(dtype=float)):
104
- tid = wp.tid()
105
- wp.atomic_add(loss, 0, y[tid])
106
-
107
- A = wp.array(np.ones((2, 2), dtype=float), dtype=float, requires_grad=True, device=device)
108
- x = wp.array([[1.0], [2.0]], dtype=float, requires_grad=True, device=device)
109
- b = wp.zeros_like(x)
110
- y = wp.zeros_like(x)
111
- loss = wp.zeros(1, requires_grad=True, device=device)
112
-
113
- tape = wp.Tape()
114
-
115
- with tape:
116
- wp.matmul(A, x, b, y)
117
- wp.launch(loss_kernel, dim=2, inputs=[y.flatten(), loss], device=device)
118
-
119
- tape.backward(loss)
120
- assert_np_equal(x.grad.numpy(), np.array([[2.0], [2.0]]))
101
+ with contextlib.redirect_stdout(io.StringIO()) as f:
102
+ tape.backward(y)
121
103
 
104
+ expected = f"Warp UserWarning: Running the tape backwards may produce incorrect gradients because recorded kernel {scale_2.key} is configured with the option 'enable_backward=False'.\n"
122
105
 
123
- @unittest.skipUnless(runtime.core.is_cutlass_enabled(), "Warp was not built with CUTLASS support")
124
- def test_options_6(test, device):
125
- wp.set_module_options({"enable_backward": False})
126
-
127
- @wp.kernel
128
- def loss_kernel(y: wp.array(dtype=float), loss: wp.array(dtype=float)):
129
- tid = wp.tid()
130
- wp.atomic_add(loss, 0, y[tid])
131
-
132
- A = wp.array(np.ones((2, 2), dtype=float), dtype=float, requires_grad=True, device=device)
133
- x = wp.array([[1.0], [2.0]], dtype=float, requires_grad=True, device=device)
134
- b = wp.zeros_like(x)
135
- y = wp.zeros_like(x)
136
- loss = wp.zeros(1, requires_grad=True, device=device)
137
-
138
- tape = wp.Tape()
139
-
140
- with tape:
141
- wp.matmul(A, x, b, y)
142
- wp.launch(loss_kernel, dim=2, inputs=[y.flatten(), loss], device=device)
143
-
144
- tape.backward(loss)
145
- assert_np_equal(x.grad.numpy(), np.array([[0.0], [0.0]]))
106
+ assert f.getvalue() == expected
107
+ assert_np_equal(tape.gradients[x].numpy(), np.array(0.0))
146
108
 
147
109
 
148
110
  devices = get_test_devices()
@@ -156,8 +118,6 @@ add_function_test(TestOptions, "test_options_1", test_options_1, devices=devices
156
118
  add_function_test(TestOptions, "test_options_2", test_options_2, devices=devices)
157
119
  add_function_test(TestOptions, "test_options_3", test_options_3, devices=devices)
158
120
  add_function_test(TestOptions, "test_options_4", test_options_4, devices=devices)
159
- add_function_test(TestOptions, "test_options_5", test_options_5, devices=devices)
160
- add_function_test(TestOptions, "test_options_6", test_options_6, devices=devices)
161
121
 
162
122
 
163
123
  if __name__ == "__main__":
warp/tests/test_quat.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
11
 
12
12
  import warp as wp
13
13
  from warp.tests.unittest_utils import *
14
+ import warp.sim
14
15
 
15
16
  wp.init()
16
17
 
@@ -1871,6 +1872,21 @@ def test_quat_identity(test, device, dtype, register_kernels=False):
1871
1872
  assert_np_equal(output.numpy(), expected)
1872
1873
 
1873
1874
 
1875
+ ############################################################
1876
+
1877
+
1878
+ def test_quat_euler_conversion(test, device, dtype, register_kernels=False):
1879
+ rng = np.random.default_rng(123)
1880
+ N = 3
1881
+
1882
+ rpy_arr = rng.uniform(low=-np.pi, high=np.pi, size=(N, 3))
1883
+
1884
+ quats_from_euler = [list(wp.sim.quat_from_euler(wp.vec3(*rpy), 0, 1, 2)) for rpy in rpy_arr]
1885
+ quats_from_rpy = [list(wp.quat_rpy(rpy[0], rpy[1], rpy[2])) for rpy in rpy_arr]
1886
+
1887
+ assert_np_equal(np.array(quats_from_euler), np.array(quats_from_rpy), tol=1e-4)
1888
+
1889
+
1874
1890
  def test_anon_type_instance(test, device, dtype, register_kernels=False):
1875
1891
  rng = np.random.default_rng(123)
1876
1892
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
@@ -2053,6 +2069,13 @@ for dtype in np_float_types:
2053
2069
  add_function_test_register_kernel(
2054
2070
  TestQuat, f"test_quat_to_matrix_{dtype.__name__}", test_quat_to_matrix, devices=devices, dtype=dtype
2055
2071
  )
2072
+ add_function_test_register_kernel(
2073
+ TestQuat,
2074
+ f"test_quat_euler_conversion_{dtype.__name__}",
2075
+ test_quat_euler_conversion,
2076
+ devices=devices,
2077
+ dtype=dtype,
2078
+ )
2056
2079
  add_function_test(
2057
2080
  TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2058
2081
  )
@@ -86,6 +86,7 @@ def test_shared_memory(test, device):
86
86
 
87
87
  @wp.func_native(snippet)
88
88
  def reverse(d: wp.array(dtype=int), N: int, tid: int):
89
+ """Reverse the array d in place using shared memory."""
89
90
  return
90
91
 
91
92
  @wp.kernel
@@ -100,6 +101,7 @@ def test_shared_memory(test, device):
100
101
  wp.launch(kernel=reverse_kernel, dim=N, inputs=[x, N], device=device)
101
102
 
102
103
  assert_np_equal(x.numpy(), y)
104
+ assert reverse.__doc__ == "Reverse the array d in place using shared memory."
103
105
 
104
106
 
105
107
  def test_cpu_snippet(test, device):
warp/tests/test_utils.py CHANGED
@@ -267,55 +267,60 @@ class TestUtils(unittest.TestCase):
267
267
  def test_warn(self):
268
268
  # Multiple warnings get printed out each time.
269
269
  with contextlib.redirect_stdout(io.StringIO()) as f:
270
- frame_info = inspect.getframeinfo(inspect.currentframe())
271
270
  wp.utils.warn("hello, world!")
272
271
  wp.utils.warn("hello, world!")
273
272
 
274
273
  expected = (
275
- "{}:{}: {}\n"
276
- "{}:{}: {}\n"
277
- ).format(
278
- frame_info.filename,
279
- frame_info.lineno + 1,
280
- "UserWarning: hello, world!\n wp.utils.warn(\"hello, world!\")",
281
- frame_info.filename,
282
- frame_info.lineno + 2,
283
- "UserWarning: hello, world!\n wp.utils.warn(\"hello, world!\")",
274
+ "Warp UserWarning: hello, world!\n"
275
+ "Warp UserWarning: hello, world!\n"
284
276
  )
277
+
285
278
  self.assertEqual(f.getvalue(), expected)
286
279
 
280
+ # Test verbose warnings
281
+ saved_verbosity = wp.config.verbose_warnings
282
+ try:
283
+ wp.config.verbose_warnings = True
284
+ with contextlib.redirect_stdout(io.StringIO()) as f:
285
+ frame_info = inspect.getframeinfo(inspect.currentframe())
286
+ wp.utils.warn("hello, world!")
287
+ wp.utils.warn("hello, world!")
288
+
289
+ expected = (
290
+ f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 1})\n"
291
+ " wp.utils.warn(\"hello, world!\")\n"
292
+ f"Warp UserWarning: hello, world! ({frame_info.filename}:{frame_info.lineno + 2})\n"
293
+ " wp.utils.warn(\"hello, world!\")\n"
294
+ )
295
+
296
+ self.assertEqual(f.getvalue(), expected)
297
+
298
+ finally:
299
+ # make sure to restore warning verbosity
300
+ wp.config.verbose_warnings = saved_verbosity
301
+
302
+
287
303
  # Multiple similar deprecation warnings get printed out only once.
288
304
  with contextlib.redirect_stdout(io.StringIO()) as f:
289
- frame_info = inspect.getframeinfo(inspect.currentframe())
290
305
  wp.utils.warn("hello, world!", category=DeprecationWarning)
291
306
  wp.utils.warn("hello, world!", category=DeprecationWarning)
292
307
 
293
308
  expected = (
294
- "{}:{}: {}\n"
295
- ).format(
296
- frame_info.filename,
297
- frame_info.lineno + 1,
298
- "DeprecationWarning: hello, world!\n wp.utils.warn(\"hello, world!\", category=DeprecationWarning)",
309
+ "Warp DeprecationWarning: hello, world!\n"
299
310
  )
311
+
300
312
  self.assertEqual(f.getvalue(), expected)
301
313
 
302
314
  # Multiple different deprecation warnings get printed out each time.
303
315
  with contextlib.redirect_stdout(io.StringIO()) as f:
304
- frame_info = inspect.getframeinfo(inspect.currentframe())
305
316
  wp.utils.warn("foo", category=DeprecationWarning)
306
317
  wp.utils.warn("bar", category=DeprecationWarning)
307
318
 
308
319
  expected = (
309
- "{}:{}: {}\n"
310
- "{}:{}: {}\n"
311
- ).format(
312
- frame_info.filename,
313
- frame_info.lineno + 1,
314
- "DeprecationWarning: foo\n wp.utils.warn(\"foo\", category=DeprecationWarning)",
315
- frame_info.filename,
316
- frame_info.lineno + 2,
317
- "DeprecationWarning: bar\n wp.utils.warn(\"bar\", category=DeprecationWarning)",
320
+ "Warp DeprecationWarning: foo\n"
321
+ "Warp DeprecationWarning: bar\n"
318
322
  )
323
+
319
324
  self.assertEqual(f.getvalue(), expected)
320
325
 
321
326
  def test_transform_expand(self):