warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ wp.init()
22
22
  def reversible_increment(
23
23
  counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
24
24
  ):
25
+ """This is a docstring"""
25
26
  next_index = wp.atomic_add(counter, counter_index, value)
26
27
  thread_values[tid] = next_index
27
28
  return next_index
@@ -31,6 +32,7 @@ def reversible_increment(
31
32
  def replay_reversible_increment(
32
33
  counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
33
34
  ):
35
+ """This is a docstring"""
34
36
  return thread_values[tid]
35
37
 
36
38
 
@@ -58,34 +60,39 @@ def test_custom_replay_grad(test, device):
58
60
  run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
59
61
  )
60
62
 
61
- tape.backward(grads={outputs: wp.array(np.ones(num_threads, dtype=np.float32), device=device)})
63
+ tape.backward(grads={outputs: wp.ones(num_threads, dtype=wp.float32, device=device)})
62
64
  assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
63
65
 
64
66
 
65
67
  @wp.func
66
68
  def overload_fn(x: float, y: float):
69
+ """This is a docstring"""
67
70
  return x * 3.0 + y / 3.0, y**2.5
68
71
 
69
72
 
70
73
  @wp.func_grad(overload_fn)
71
74
  def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
75
+ """This is a docstring"""
72
76
  wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
73
77
  wp.adjoint[y] += y * adj_ret1 * 3.0
74
78
 
75
79
 
76
80
  @wp.struct
77
81
  class MyStruct:
82
+ """This is a docstring"""
78
83
  scalar: float
79
84
  vec: wp.vec3
80
85
 
81
86
 
82
87
  @wp.func
83
88
  def overload_fn(x: MyStruct):
89
+ """This is a docstring"""
84
90
  return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
85
91
 
86
92
 
87
93
  @wp.func_grad(overload_fn)
88
94
  def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
95
+ """This is a docstring"""
89
96
  wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
90
97
  wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
91
98
  wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
@@ -96,6 +103,7 @@ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: fl
96
103
  def run_overload_float_fn(
97
104
  xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
98
105
  ):
106
+ """This is a docstring"""
99
107
  i = wp.tid()
100
108
  out0, out1 = overload_fn(xs[i], ys[i])
101
109
  output0[i] = out0
@@ -111,17 +119,19 @@ def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=
111
119
 
112
120
  def test_custom_overload_grad(test, device):
113
121
  dim = 3
114
- xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
115
- ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
116
- out0_float = wp.zeros(dim)
117
- out1_float = wp.zeros(dim)
122
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
123
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
124
+ out0_float = wp.zeros(dim, device=device)
125
+ out1_float = wp.zeros(dim, device=device)
118
126
  tape = wp.Tape()
119
127
  with tape:
120
- wp.launch(run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float])
128
+ wp.launch(
129
+ run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
130
+ )
121
131
  tape.backward(
122
132
  grads={
123
- out0_float: wp.array(np.ones(dim), dtype=wp.float32),
124
- out1_float: wp.array(np.ones(dim), dtype=wp.float32),
133
+ out0_float: wp.ones(dim, dtype=wp.float32, device=device),
134
+ out1_float: wp.ones(dim, dtype=wp.float32, device=device),
125
135
  }
126
136
  )
127
137
  assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
@@ -136,12 +146,12 @@ def test_custom_overload_grad(test, device):
136
146
  x2 = MyStruct()
137
147
  x2.vec = wp.vec3(8.0, 9.0, 10.0)
138
148
  x2.scalar = 19.0
139
- xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
140
- out_struct = wp.zeros(dim)
149
+ xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True, device=device)
150
+ out_struct = wp.zeros(dim, device=device)
141
151
  tape = wp.Tape()
142
152
  with tape:
143
- wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
144
- tape.backward(grads={out_struct: wp.array(np.ones(dim), dtype=wp.float32)})
153
+ wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct], device=device)
154
+ tape.backward(grads={out_struct: wp.ones(dim, dtype=wp.float32, device=device)})
145
155
  xs_struct_np = xs_struct.numpy()
146
156
  struct_grads = xs_struct.grad.numpy()
147
157
  # fmt: off
@@ -160,6 +170,153 @@ def test_custom_overload_grad(test, device):
160
170
  # fmt: on
161
171
 
162
172
 
173
+ def test_custom_import_grad(test, device):
174
+ from warp.tests.aux_test_grad_customs import aux_custom_fn
175
+
176
+ @wp.kernel
177
+ def run_defined_float_fn(
178
+ xs: wp.array(dtype=float),
179
+ ys: wp.array(dtype=float),
180
+ output0: wp.array(dtype=float),
181
+ output1: wp.array(dtype=float),
182
+ ):
183
+ i = wp.tid()
184
+ out0, out1 = aux_custom_fn(xs[i], ys[i])
185
+ output0[i] = out0
186
+ output1[i] = out1
187
+
188
+ dim = 3
189
+ xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
190
+ ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
191
+ out0_float = wp.zeros(dim, device=device)
192
+ out1_float = wp.zeros(dim, device=device)
193
+ tape = wp.Tape()
194
+ with tape:
195
+ wp.launch(
196
+ run_defined_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
197
+ )
198
+ tape.backward(
199
+ grads={
200
+ out0_float: wp.ones(dim, dtype=wp.float32, device=device),
201
+ out1_float: wp.ones(dim, dtype=wp.float32, device=device),
202
+ }
203
+ )
204
+ assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
205
+ assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
206
+
207
+
208
+ @wp.func
209
+ def sigmoid(x: float):
210
+ return 1.0 / (1.0 + wp.exp(-x))
211
+
212
+
213
+ @wp.func_grad(sigmoid)
214
+ def adj_sigmoid(x: float, adj: float):
215
+ # unused function to test that we don't run into infinite recursion when calling
216
+ # the forward function from within the gradient function
217
+ wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))
218
+
219
+
220
+ @wp.func
221
+ def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
222
+ # test function that does not return anything
223
+ ys[i] = sigmoid(xs[i])
224
+
225
+
226
+ @wp.func_grad(sigmoid_no_return)
227
+ def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
228
+ wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])
229
+
230
+
231
+ @wp.kernel
232
+ def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
233
+ i = wp.tid()
234
+ sigmoid_no_return(i, xs, ys)
235
+
236
+
237
+ def test_custom_grad_no_return(test, device):
238
+ xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True)
239
+ ys = wp.zeros_like(xs)
240
+ ys.grad.fill_(1.0)
241
+
242
+ tape = wp.Tape()
243
+ with tape:
244
+ wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys])
245
+ tape.backward()
246
+
247
+ sigmoids = ys.numpy()
248
+ grad = xs.grad.numpy()
249
+ assert_np_equal(grad, sigmoids * (1.0 - sigmoids))
250
+
251
+
252
+ def test_wrapped_docstring(test, device):
253
+ assert "This is a docstring" in reversible_increment.__doc__
254
+ assert "This is a docstring" in replay_reversible_increment.__doc__
255
+ assert "This is a docstring" in overload_fn.__doc__
256
+ assert "This is a docstring" in overload_fn_grad.__doc__
257
+ assert "This is a docstring" in run_overload_float_fn.__doc__
258
+ assert "This is a docstring" in MyStruct.__doc__
259
+
260
+
261
+ @wp.func
262
+ def dense_gemm(
263
+ m: int,
264
+ n: int,
265
+ p: int,
266
+ transpose_A: bool,
267
+ transpose_B: bool,
268
+ add_to_C: bool,
269
+ A: wp.array(dtype=float),
270
+ B: wp.array(dtype=float),
271
+ # outputs
272
+ C: wp.array(dtype=float),
273
+ ):
274
+ # this function doesn't get called but it is an important test for code generation
275
+ # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
276
+ for i in range(m):
277
+ for j in range(n):
278
+ sum = float(0.0)
279
+ for k in range(p):
280
+ if transpose_A:
281
+ a_i = k * m + i
282
+ else:
283
+ a_i = i * p + k
284
+ if transpose_B:
285
+ b_j = j * p + k
286
+ else:
287
+ b_j = k * n + j
288
+ sum += A[a_i] * B[b_j]
289
+
290
+ if add_to_C:
291
+ C[i * n + j] += sum
292
+ else:
293
+ C[i * n + j] = sum
294
+
295
+
296
+ @wp.func_grad(dense_gemm)
297
+ def adj_dense_gemm(
298
+ m: int,
299
+ n: int,
300
+ p: int,
301
+ transpose_A: bool,
302
+ transpose_B: bool,
303
+ add_to_C: bool,
304
+ A: wp.array(dtype=float),
305
+ B: wp.array(dtype=float),
306
+ # outputs
307
+ C: wp.array(dtype=float),
308
+ ):
309
+ # code generation would break here if we didn't defer building the custom grad
310
+ # function until after the forward functions + kernels of the module have been built
311
+ add_to_C = True
312
+ if transpose_A:
313
+ dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
314
+ dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
315
+ else:
316
+ dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
317
+ dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
318
+
319
+
163
320
  devices = get_test_devices()
164
321
 
165
322
 
@@ -169,6 +326,9 @@ class TestGradCustoms(unittest.TestCase):
169
326
 
170
327
  add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
171
328
  add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
329
+ add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
330
+ add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)
331
+ add_function_test(TestGradCustoms, "test_wrapped_docstring", test_wrapped_docstring, devices=devices)
172
332
 
173
333
 
174
334
  if __name__ == "__main__":
warp/tests/test_jax.py ADDED
@@ -0,0 +1,254 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import numpy as np
9
+ import os
10
+ import unittest
11
+ from typing import Any
12
+
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
15
+
16
+ wp.init()
17
+
18
+
19
+ # basic kernel with one input and output
20
+ @wp.kernel
21
+ def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
22
+ tid = wp.tid()
23
+ output[tid] = 3.0 * input[tid]
24
+
25
+
26
+ # generic kernel with one scalar input and output
27
+ @wp.kernel
28
+ def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
29
+ tid = wp.tid()
30
+ output[tid] = input.dtype(3) * input[tid]
31
+
32
+
33
+ # generic kernel with one vector/matrix input and output
34
+ @wp.kernel
35
+ def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
36
+ tid = wp.tid()
37
+ output[tid] = input.dtype.dtype(3) * input[tid]
38
+
39
+
40
+ # kernel with multiple inputs and outputs
41
+ @wp.kernel
42
+ def multiarg_kernel(
43
+ # inputs
44
+ a: wp.array(dtype=float),
45
+ b: wp.array(dtype=float),
46
+ c: wp.array(dtype=float),
47
+ # outputs
48
+ ab: wp.array(dtype=float),
49
+ bc: wp.array(dtype=float),
50
+ ):
51
+ tid = wp.tid()
52
+ ab[tid] = a[tid] + b[tid]
53
+ bc[tid] = b[tid] + c[tid]
54
+
55
+
56
+ # various types for testing
57
+ scalar_types = wp.types.scalar_types
58
+ vector_types = []
59
+ matrix_types = []
60
+ for dim in [2, 3, 4]:
61
+ for T in scalar_types:
62
+ vector_types.append(wp.vec(dim, T))
63
+ matrix_types.append(wp.mat((dim, dim), T))
64
+
65
+ # explicitly overload generic kernels to avoid module reloading during tests
66
+ for T in scalar_types:
67
+ wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
68
+ for T in [*vector_types, *matrix_types]:
69
+ wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
70
+
71
+
72
+ def _jax_version():
73
+ try:
74
+ import jax
75
+ return jax.__version_info__
76
+ except ImportError:
77
+ return (0, 0, 0)
78
+
79
+
80
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
81
+ def test_jax_kernel_basic(test, device):
82
+ import jax.numpy as jp
83
+ from warp.jax_experimental import jax_kernel
84
+
85
+ n = 64
86
+
87
+ jax_triple = jax_kernel(triple_kernel)
88
+
89
+ @jax.jit
90
+ def f():
91
+ x = jp.arange(n, dtype=jp.float32)
92
+ return jax_triple(x)
93
+
94
+ # run on the given device
95
+ with jax.default_device(wp.device_to_jax(device)):
96
+ y = f()
97
+
98
+ result = np.asarray(y)
99
+ expected = 3 * np.arange(n, dtype=np.float32)
100
+
101
+ assert_np_equal(result, expected)
102
+
103
+
104
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
105
+ def test_jax_kernel_scalar(test, device):
106
+ import jax.numpy as jp
107
+ from warp.jax_experimental import jax_kernel
108
+
109
+ n = 64
110
+
111
+ for T in scalar_types:
112
+
113
+ jp_dtype = wp.jax.dtype_to_jax(T)
114
+ np_dtype = wp.types.warp_type_to_np_dtype[T]
115
+
116
+ with test.subTest(msg=T.__name__):
117
+
118
+ # get the concrete overload
119
+ kernel_instance = triple_kernel_scalar.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
120
+
121
+ jax_triple = jax_kernel(kernel_instance)
122
+
123
+ @jax.jit
124
+ def f():
125
+ x = jp.arange(n, dtype=jp_dtype)
126
+ return jax_triple(x)
127
+
128
+ # run on the given device
129
+ with jax.default_device(wp.device_to_jax(device)):
130
+ y = f()
131
+
132
+ result = np.asarray(y)
133
+ expected = 3 * np.arange(n, dtype=np_dtype)
134
+
135
+ assert_np_equal(result, expected)
136
+
137
+
138
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
139
+ def test_jax_kernel_vecmat(test, device):
140
+ import jax.numpy as jp
141
+ from warp.jax_experimental import jax_kernel
142
+
143
+ for T in [*vector_types, *matrix_types]:
144
+
145
+ jp_dtype = wp.jax.dtype_to_jax(T._wp_scalar_type_)
146
+ np_dtype = wp.types.warp_type_to_np_dtype[T._wp_scalar_type_]
147
+
148
+ n = 64 // T._length_
149
+ scalar_shape = (n, *T._shape_)
150
+ scalar_len = n * T._length_
151
+
152
+ with test.subTest(msg=T.__name__):
153
+
154
+ # get the concrete overload
155
+ kernel_instance = triple_kernel_vecmat.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
156
+
157
+ jax_triple = jax_kernel(kernel_instance)
158
+
159
+ @jax.jit
160
+ def f():
161
+ x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
162
+ return jax_triple(x)
163
+
164
+ # run on the given device
165
+ with jax.default_device(wp.device_to_jax(device)):
166
+ y = f()
167
+
168
+ result = np.asarray(y)
169
+ expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
170
+
171
+ assert_np_equal(result, expected)
172
+
173
+
174
+ @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
175
+ def test_jax_kernel_multiarg(test, device):
176
+ import jax.numpy as jp
177
+ from warp.jax_experimental import jax_kernel
178
+
179
+ n = 64
180
+
181
+ jax_multiarg = jax_kernel(multiarg_kernel)
182
+
183
+ @jax.jit
184
+ def f():
185
+ a = jp.full(n, 1, dtype=jp.float32)
186
+ b = jp.full(n, 2, dtype=jp.float32)
187
+ c = jp.full(n, 3, dtype=jp.float32)
188
+ return jax_multiarg(a, b, c)
189
+
190
+ # run on the given device
191
+ with jax.default_device(wp.device_to_jax(device)):
192
+ x, y = f()
193
+
194
+ result_x, result_y = np.asarray(x), np.asarray(y)
195
+ expected_x = np.full(n, 3, dtype=np.float32)
196
+ expected_y = np.full(n, 5, dtype=np.float32)
197
+
198
+ assert_np_equal(result_x, expected_x)
199
+ assert_np_equal(result_y, expected_y)
200
+
201
+
202
+ class TestJax(unittest.TestCase):
203
+ pass
204
+
205
+
206
+ # try adding Jax tests if Jax is installed correctly
207
+ try:
208
+ # prevent Jax from gobbling up GPU memory
209
+ os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
210
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
211
+
212
+ import jax
213
+ import jax.dlpack
214
+
215
+ # NOTE: we must enable 64-bit types in Jax to test the full gamut of types
216
+ jax.config.update("jax_enable_x64", True)
217
+
218
+ # check which Warp devices work with Jax
219
+ # CUDA devices may fail if Jax cannot find a CUDA Toolkit
220
+ test_devices = get_test_devices()
221
+ jax_compatible_devices = []
222
+ jax_compatible_cuda_devices = []
223
+ for d in test_devices:
224
+ try:
225
+ with jax.default_device(wp.device_to_jax(d)):
226
+ j = jax.numpy.arange(10, dtype=jax.numpy.float32)
227
+ j += 1
228
+ jax_compatible_devices.append(d)
229
+ if d.is_cuda:
230
+ jax_compatible_cuda_devices.append(d)
231
+ except Exception as e:
232
+ print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
233
+
234
+ if jax_compatible_cuda_devices:
235
+ add_function_test(
236
+ TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices
237
+ )
238
+ add_function_test(
239
+ TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
240
+ )
241
+ add_function_test(
242
+ TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
243
+ )
244
+ add_function_test(
245
+ TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
246
+ )
247
+
248
+ except Exception as e:
249
+ print(f"Skipping Jax tests due to exception: {e}")
250
+
251
+
252
+ if __name__ == "__main__":
253
+ wp.build.clear_kernel_cache()
254
+ unittest.main(verbosity=2)
warp/tests/test_large.py CHANGED
@@ -81,8 +81,8 @@ def test_large_arrays_slow(test, device):
81
81
  # without changes to support how frequently a test may be run
82
82
  total_elements = 2**31 + 8
83
83
 
84
- # 1-D to 4-D arrays: test zero_, fill_, then zero_ for scalar data types:
85
- for total_dims in range(1, 5):
84
+ # 2-D to 4-D arrays: test zero_, fill_, then zero_ for scalar data types:
85
+ for total_dims in range(2, 5):
86
86
  dim_x = math.ceil(total_elements ** (1 / total_dims))
87
87
  shape_tuple = tuple([dim_x] * total_dims)
88
88
 
@@ -99,21 +99,42 @@ def test_large_arrays_slow(test, device):
99
99
 
100
100
  def test_large_arrays_fast(test, device):
101
101
  # A truncated version of test_large_arrays_slow meant to catch basic errors
102
- total_elements = 2**31 + 8
102
+
103
+ # Make is so that a (dim_x, dim_x) array has more than 2**31 elements
104
+ dim_x = math.ceil(math.sqrt(2**31))
103
105
 
104
106
  nptype = np.dtype(np.int8)
105
107
  wptype = wp.types.np_dtype_to_warp_type[nptype]
106
108
 
107
- a1 = wp.zeros((total_elements,), dtype=wptype, device=device)
108
- assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
109
-
109
+ a1 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
110
110
  a1.fill_(127)
111
+
111
112
  assert_np_equal(a1.numpy(), 127 * np.ones_like(a1.numpy()))
112
113
 
113
114
  a1.zero_()
114
115
  assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
115
116
 
116
117
 
118
+ def test_large_array_excessive_zeros(test, device):
119
+ # Tests the allocation of an array with length exceeding 2**31-1 in a dimension
120
+
121
+ with test.assertRaisesRegex(
122
+ ValueError, "Array shapes must not exceed the maximum representable value of a signed 32-bit integer"
123
+ ):
124
+ _ = wp.zeros((2**31), dtype=int, device=device)
125
+
126
+
127
+ def test_large_array_excessive_numpy(test, device):
128
+ # Tests the allocation of an array from a numpy array with length exceeding 2**31-1 in a dimension
129
+
130
+ large_np_array = np.empty((2**31), dtype=int)
131
+
132
+ with test.assertRaisesRegex(
133
+ ValueError, "Array shapes must not exceed the maximum representable value of a signed 32-bit integer"
134
+ ):
135
+ _ = wp.array(large_np_array, device=device)
136
+
137
+
117
138
  devices = get_test_devices()
118
139
 
119
140
 
@@ -134,6 +155,8 @@ add_function_test(
134
155
  )
135
156
 
136
157
  add_function_test(TestLarge, "test_large_arrays_fast", test_large_arrays_fast, devices=devices)
158
+ add_function_test(TestLarge, "test_large_array_excessive_zeros", test_large_array_excessive_zeros, devices=devices)
159
+ add_function_test(TestLarge, "test_large_array_excessive_numpy", test_large_array_excessive_numpy, devices=devices)
137
160
 
138
161
 
139
162
  if __name__ == "__main__":
warp/tests/test_launch.py CHANGED
@@ -301,7 +301,30 @@ def test_launch_tuple_args(test, device):
301
301
  outputs=(out,),
302
302
  device=device,
303
303
  )
304
+ assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
304
305
 
306
+ wp.launch(
307
+ kernel_mul,
308
+ dim=len(values),
309
+ inputs=(
310
+ values,
311
+ coeff,
312
+ out,
313
+ ),
314
+ device=device,
315
+ )
316
+ assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
317
+
318
+ wp.launch(
319
+ kernel_mul,
320
+ dim=len(values),
321
+ outputs=(
322
+ values,
323
+ coeff,
324
+ out,
325
+ ),
326
+ device=device,
327
+ )
305
328
  assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
306
329
 
307
330
 
@@ -323,6 +346,8 @@ add_function_test(TestLaunch, "test_launch_cmd_set_ctype", test_launch_cmd_set_c
323
346
  add_function_test(TestLaunch, "test_launch_cmd_set_dim", test_launch_cmd_set_dim, devices=devices)
324
347
  add_function_test(TestLaunch, "test_launch_cmd_empty", test_launch_cmd_empty, devices=devices)
325
348
 
349
+ add_function_test(TestLaunch, "test_launch_tuple_args", test_launch_tuple_args, devices=devices)
350
+
326
351
 
327
352
  if __name__ == "__main__":
328
353
  wp.build.clear_kernel_cache()
@@ -7,9 +7,10 @@ import unittest
7
7
  from warp.optim.linear import preconditioner, cg, bicgstab, gmres
8
8
  from warp.tests.unittest_utils import *
9
9
 
10
-
11
10
  wp.init()
12
11
 
12
+ from warp.context import runtime # noqa: E402
13
+
13
14
 
14
15
  def _check_linear_solve(test, A, b, func, *args, **kwargs):
15
16
  # test from zero
@@ -75,6 +76,15 @@ def _make_indefinite_system(n: int, seed: int, dtype, device, spd=False):
75
76
  return wp.array(A, dtype=dtype, device=device), wp.array(b, dtype=dtype, device=device)
76
77
 
77
78
 
79
+ def _make_identity_system(n: int, seed: int, dtype, device):
80
+ rng = np.random.default_rng(seed)
81
+
82
+ A = np.eye(n)
83
+ b = rng.uniform(low=-1.0, high=1.0, size=(n,))
84
+
85
+ return wp.array(A, dtype=dtype, device=device), wp.array(b, dtype=dtype, device=device)
86
+
87
+
78
88
  def test_cg(test, device):
79
89
  A, b = _make_spd_system(n=64, seed=123, device=device, dtype=wp.float64)
80
90
  M = preconditioner(A, "diag")
@@ -88,6 +98,9 @@ def test_cg(test, device):
88
98
  _check_linear_solve(test, A, b, cg, maxiter=1000)
89
99
  _check_linear_solve(test, A, b, cg, M=M, maxiter=1000)
90
100
 
101
+ A, b = _make_identity_system(n=5, seed=321, device=device, dtype=wp.float32)
102
+ _check_linear_solve(test, A, b, cg, maxiter=30)
103
+
91
104
 
92
105
  def test_bicgstab(test, device):
93
106
  A, b = _make_nonsymmetric_system(n=64, seed=123, device=device, dtype=wp.float64)
@@ -111,6 +124,9 @@ def test_bicgstab(test, device):
111
124
  _check_linear_solve(test, A, b, bicgstab, M=M, maxiter=1000)
112
125
  _check_linear_solve(test, A, b, bicgstab, M=M, maxiter=1000, is_left_preconditioner=True)
113
126
 
127
+ A, b = _make_identity_system(n=5, seed=321, device=device, dtype=wp.float32)
128
+ _check_linear_solve(test, A, b, bicgstab, maxiter=30)
129
+
114
130
 
115
131
  def test_gmres(test, device):
116
132
  A, b = _make_nonsymmetric_system(n=64, seed=456, device=device, dtype=wp.float64)
@@ -127,6 +143,9 @@ def test_gmres(test, device):
127
143
  _check_linear_solve(test, A, b, gmres, M=M, maxiter=1000, tol=1.0e-5)
128
144
  _check_linear_solve(test, A, b, gmres, M=M, maxiter=1000, tol=1.0e-5, is_left_preconditioner=True)
129
145
 
146
+ A, b = _make_identity_system(n=5, seed=123, device=device, dtype=wp.float32)
147
+ _check_linear_solve(test, A, b, gmres, maxiter=120)
148
+
130
149
 
131
150
  class TestLinearSolvers(unittest.TestCase):
132
151
  pass
@@ -134,8 +153,6 @@ class TestLinearSolvers(unittest.TestCase):
134
153
 
135
154
  devices = get_test_devices()
136
155
 
137
- from warp.context import runtime
138
-
139
156
  if not runtime.core.is_cutlass_enabled():
140
157
  devices = [d for d in devices if not d.is_cuda]
141
158
  print("Skipping CUDA linear solver tests because CUTLASS is not supported in this build")