warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/tests/test_grad_customs.py
CHANGED
|
@@ -22,6 +22,7 @@ wp.init()
|
|
|
22
22
|
def reversible_increment(
|
|
23
23
|
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
24
24
|
):
|
|
25
|
+
"""This is a docstring"""
|
|
25
26
|
next_index = wp.atomic_add(counter, counter_index, value)
|
|
26
27
|
thread_values[tid] = next_index
|
|
27
28
|
return next_index
|
|
@@ -31,6 +32,7 @@ def reversible_increment(
|
|
|
31
32
|
def replay_reversible_increment(
|
|
32
33
|
counter: wp.array(dtype=int), counter_index: int, value: int, thread_values: wp.array(dtype=int), tid: int
|
|
33
34
|
):
|
|
35
|
+
"""This is a docstring"""
|
|
34
36
|
return thread_values[tid]
|
|
35
37
|
|
|
36
38
|
|
|
@@ -58,34 +60,39 @@ def test_custom_replay_grad(test, device):
|
|
|
58
60
|
run_atomic_add, dim=num_threads, inputs=[inputs, counter, thread_ids], outputs=[outputs], device=device
|
|
59
61
|
)
|
|
60
62
|
|
|
61
|
-
tape.backward(grads={outputs: wp.
|
|
63
|
+
tape.backward(grads={outputs: wp.ones(num_threads, dtype=wp.float32, device=device)})
|
|
62
64
|
assert_np_equal(inputs.grad.numpy(), 2.0 * inputs.numpy(), tol=1e-4)
|
|
63
65
|
|
|
64
66
|
|
|
65
67
|
@wp.func
|
|
66
68
|
def overload_fn(x: float, y: float):
|
|
69
|
+
"""This is a docstring"""
|
|
67
70
|
return x * 3.0 + y / 3.0, y**2.5
|
|
68
71
|
|
|
69
72
|
|
|
70
73
|
@wp.func_grad(overload_fn)
|
|
71
74
|
def overload_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
|
|
75
|
+
"""This is a docstring"""
|
|
72
76
|
wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
|
|
73
77
|
wp.adjoint[y] += y * adj_ret1 * 3.0
|
|
74
78
|
|
|
75
79
|
|
|
76
80
|
@wp.struct
|
|
77
81
|
class MyStruct:
|
|
82
|
+
"""This is a docstring"""
|
|
78
83
|
scalar: float
|
|
79
84
|
vec: wp.vec3
|
|
80
85
|
|
|
81
86
|
|
|
82
87
|
@wp.func
|
|
83
88
|
def overload_fn(x: MyStruct):
|
|
89
|
+
"""This is a docstring"""
|
|
84
90
|
return x.vec[0] * x.vec[1] * x.vec[2] * 4.0, wp.length(x.vec), x.scalar**0.5
|
|
85
91
|
|
|
86
92
|
|
|
87
93
|
@wp.func_grad(overload_fn)
|
|
88
94
|
def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: float):
|
|
95
|
+
"""This is a docstring"""
|
|
89
96
|
wp.adjoint[x.scalar] += x.scalar * adj_ret0 * 10.0
|
|
90
97
|
wp.adjoint[x.vec][0] += adj_ret0 * x.vec[1] * x.vec[2] * 20.0
|
|
91
98
|
wp.adjoint[x.vec][1] += adj_ret1 * x.vec[0] * x.vec[2] * 30.0
|
|
@@ -96,6 +103,7 @@ def overload_fn_grad(x: MyStruct, adj_ret0: float, adj_ret1: float, adj_ret2: fl
|
|
|
96
103
|
def run_overload_float_fn(
|
|
97
104
|
xs: wp.array(dtype=float), ys: wp.array(dtype=float), output0: wp.array(dtype=float), output1: wp.array(dtype=float)
|
|
98
105
|
):
|
|
106
|
+
"""This is a docstring"""
|
|
99
107
|
i = wp.tid()
|
|
100
108
|
out0, out1 = overload_fn(xs[i], ys[i])
|
|
101
109
|
output0[i] = out0
|
|
@@ -111,17 +119,19 @@ def run_overload_struct_fn(xs: wp.array(dtype=MyStruct), output: wp.array(dtype=
|
|
|
111
119
|
|
|
112
120
|
def test_custom_overload_grad(test, device):
|
|
113
121
|
dim = 3
|
|
114
|
-
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True)
|
|
115
|
-
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True)
|
|
116
|
-
out0_float = wp.zeros(dim)
|
|
117
|
-
out1_float = wp.zeros(dim)
|
|
122
|
+
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
|
|
123
|
+
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
|
|
124
|
+
out0_float = wp.zeros(dim, device=device)
|
|
125
|
+
out1_float = wp.zeros(dim, device=device)
|
|
118
126
|
tape = wp.Tape()
|
|
119
127
|
with tape:
|
|
120
|
-
wp.launch(
|
|
128
|
+
wp.launch(
|
|
129
|
+
run_overload_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
|
|
130
|
+
)
|
|
121
131
|
tape.backward(
|
|
122
132
|
grads={
|
|
123
|
-
out0_float: wp.
|
|
124
|
-
out1_float: wp.
|
|
133
|
+
out0_float: wp.ones(dim, dtype=wp.float32, device=device),
|
|
134
|
+
out1_float: wp.ones(dim, dtype=wp.float32, device=device),
|
|
125
135
|
}
|
|
126
136
|
)
|
|
127
137
|
assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
|
|
@@ -136,12 +146,12 @@ def test_custom_overload_grad(test, device):
|
|
|
136
146
|
x2 = MyStruct()
|
|
137
147
|
x2.vec = wp.vec3(8.0, 9.0, 10.0)
|
|
138
148
|
x2.scalar = 19.0
|
|
139
|
-
xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True)
|
|
140
|
-
out_struct = wp.zeros(dim)
|
|
149
|
+
xs_struct = wp.array([x0, x1, x2], dtype=MyStruct, requires_grad=True, device=device)
|
|
150
|
+
out_struct = wp.zeros(dim, device=device)
|
|
141
151
|
tape = wp.Tape()
|
|
142
152
|
with tape:
|
|
143
|
-
wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct])
|
|
144
|
-
tape.backward(grads={out_struct: wp.
|
|
153
|
+
wp.launch(run_overload_struct_fn, dim=dim, inputs=[xs_struct], outputs=[out_struct], device=device)
|
|
154
|
+
tape.backward(grads={out_struct: wp.ones(dim, dtype=wp.float32, device=device)})
|
|
145
155
|
xs_struct_np = xs_struct.numpy()
|
|
146
156
|
struct_grads = xs_struct.grad.numpy()
|
|
147
157
|
# fmt: off
|
|
@@ -160,6 +170,153 @@ def test_custom_overload_grad(test, device):
|
|
|
160
170
|
# fmt: on
|
|
161
171
|
|
|
162
172
|
|
|
173
|
+
def test_custom_import_grad(test, device):
|
|
174
|
+
from warp.tests.aux_test_grad_customs import aux_custom_fn
|
|
175
|
+
|
|
176
|
+
@wp.kernel
|
|
177
|
+
def run_defined_float_fn(
|
|
178
|
+
xs: wp.array(dtype=float),
|
|
179
|
+
ys: wp.array(dtype=float),
|
|
180
|
+
output0: wp.array(dtype=float),
|
|
181
|
+
output1: wp.array(dtype=float),
|
|
182
|
+
):
|
|
183
|
+
i = wp.tid()
|
|
184
|
+
out0, out1 = aux_custom_fn(xs[i], ys[i])
|
|
185
|
+
output0[i] = out0
|
|
186
|
+
output1[i] = out1
|
|
187
|
+
|
|
188
|
+
dim = 3
|
|
189
|
+
xs_float = wp.array(np.arange(1.0, dim + 1.0), dtype=wp.float32, requires_grad=True, device=device)
|
|
190
|
+
ys_float = wp.array(np.arange(10.0, dim + 10.0), dtype=wp.float32, requires_grad=True, device=device)
|
|
191
|
+
out0_float = wp.zeros(dim, device=device)
|
|
192
|
+
out1_float = wp.zeros(dim, device=device)
|
|
193
|
+
tape = wp.Tape()
|
|
194
|
+
with tape:
|
|
195
|
+
wp.launch(
|
|
196
|
+
run_defined_float_fn, dim=dim, inputs=[xs_float, ys_float], outputs=[out0_float, out1_float], device=device
|
|
197
|
+
)
|
|
198
|
+
tape.backward(
|
|
199
|
+
grads={
|
|
200
|
+
out0_float: wp.ones(dim, dtype=wp.float32, device=device),
|
|
201
|
+
out1_float: wp.ones(dim, dtype=wp.float32, device=device),
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
assert_np_equal(xs_float.grad.numpy(), xs_float.numpy() * 42.0 + ys_float.numpy() * 10.0)
|
|
205
|
+
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@wp.func
|
|
209
|
+
def sigmoid(x: float):
|
|
210
|
+
return 1.0 / (1.0 + wp.exp(-x))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@wp.func_grad(sigmoid)
|
|
214
|
+
def adj_sigmoid(x: float, adj: float):
|
|
215
|
+
# unused function to test that we don't run into infinite recursion when calling
|
|
216
|
+
# the forward function from within the gradient function
|
|
217
|
+
wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@wp.func
|
|
221
|
+
def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
222
|
+
# test function that does not return anything
|
|
223
|
+
ys[i] = sigmoid(xs[i])
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@wp.func_grad(sigmoid_no_return)
|
|
227
|
+
def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
228
|
+
wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
@wp.kernel
|
|
232
|
+
def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
|
|
233
|
+
i = wp.tid()
|
|
234
|
+
sigmoid_no_return(i, xs, ys)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def test_custom_grad_no_return(test, device):
|
|
238
|
+
xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True)
|
|
239
|
+
ys = wp.zeros_like(xs)
|
|
240
|
+
ys.grad.fill_(1.0)
|
|
241
|
+
|
|
242
|
+
tape = wp.Tape()
|
|
243
|
+
with tape:
|
|
244
|
+
wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys])
|
|
245
|
+
tape.backward()
|
|
246
|
+
|
|
247
|
+
sigmoids = ys.numpy()
|
|
248
|
+
grad = xs.grad.numpy()
|
|
249
|
+
assert_np_equal(grad, sigmoids * (1.0 - sigmoids))
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def test_wrapped_docstring(test, device):
|
|
253
|
+
assert "This is a docstring" in reversible_increment.__doc__
|
|
254
|
+
assert "This is a docstring" in replay_reversible_increment.__doc__
|
|
255
|
+
assert "This is a docstring" in overload_fn.__doc__
|
|
256
|
+
assert "This is a docstring" in overload_fn_grad.__doc__
|
|
257
|
+
assert "This is a docstring" in run_overload_float_fn.__doc__
|
|
258
|
+
assert "This is a docstring" in MyStruct.__doc__
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@wp.func
|
|
262
|
+
def dense_gemm(
|
|
263
|
+
m: int,
|
|
264
|
+
n: int,
|
|
265
|
+
p: int,
|
|
266
|
+
transpose_A: bool,
|
|
267
|
+
transpose_B: bool,
|
|
268
|
+
add_to_C: bool,
|
|
269
|
+
A: wp.array(dtype=float),
|
|
270
|
+
B: wp.array(dtype=float),
|
|
271
|
+
# outputs
|
|
272
|
+
C: wp.array(dtype=float),
|
|
273
|
+
):
|
|
274
|
+
# this function doesn't get called but it is an important test for code generation
|
|
275
|
+
# multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
|
|
276
|
+
for i in range(m):
|
|
277
|
+
for j in range(n):
|
|
278
|
+
sum = float(0.0)
|
|
279
|
+
for k in range(p):
|
|
280
|
+
if transpose_A:
|
|
281
|
+
a_i = k * m + i
|
|
282
|
+
else:
|
|
283
|
+
a_i = i * p + k
|
|
284
|
+
if transpose_B:
|
|
285
|
+
b_j = j * p + k
|
|
286
|
+
else:
|
|
287
|
+
b_j = k * n + j
|
|
288
|
+
sum += A[a_i] * B[b_j]
|
|
289
|
+
|
|
290
|
+
if add_to_C:
|
|
291
|
+
C[i * n + j] += sum
|
|
292
|
+
else:
|
|
293
|
+
C[i * n + j] = sum
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@wp.func_grad(dense_gemm)
|
|
297
|
+
def adj_dense_gemm(
|
|
298
|
+
m: int,
|
|
299
|
+
n: int,
|
|
300
|
+
p: int,
|
|
301
|
+
transpose_A: bool,
|
|
302
|
+
transpose_B: bool,
|
|
303
|
+
add_to_C: bool,
|
|
304
|
+
A: wp.array(dtype=float),
|
|
305
|
+
B: wp.array(dtype=float),
|
|
306
|
+
# outputs
|
|
307
|
+
C: wp.array(dtype=float),
|
|
308
|
+
):
|
|
309
|
+
# code generation would break here if we didn't defer building the custom grad
|
|
310
|
+
# function until after the forward functions + kernels of the module have been built
|
|
311
|
+
add_to_C = True
|
|
312
|
+
if transpose_A:
|
|
313
|
+
dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
|
|
314
|
+
dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
|
|
315
|
+
else:
|
|
316
|
+
dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
|
|
317
|
+
dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
|
|
318
|
+
|
|
319
|
+
|
|
163
320
|
devices = get_test_devices()
|
|
164
321
|
|
|
165
322
|
|
|
@@ -169,6 +326,9 @@ class TestGradCustoms(unittest.TestCase):
|
|
|
169
326
|
|
|
170
327
|
add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
|
|
171
328
|
add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
|
|
329
|
+
add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
|
|
330
|
+
add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)
|
|
331
|
+
add_function_test(TestGradCustoms, "test_wrapped_docstring", test_wrapped_docstring, devices=devices)
|
|
172
332
|
|
|
173
333
|
|
|
174
334
|
if __name__ == "__main__":
|
warp/tests/test_jax.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import os
|
|
10
|
+
import unittest
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import warp as wp
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
15
|
+
|
|
16
|
+
wp.init()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# basic kernel with one input and output
|
|
20
|
+
@wp.kernel
|
|
21
|
+
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
|
|
22
|
+
tid = wp.tid()
|
|
23
|
+
output[tid] = 3.0 * input[tid]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# generic kernel with one scalar input and output
|
|
27
|
+
@wp.kernel
|
|
28
|
+
def triple_kernel_scalar(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
29
|
+
tid = wp.tid()
|
|
30
|
+
output[tid] = input.dtype(3) * input[tid]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# generic kernel with one vector/matrix input and output
|
|
34
|
+
@wp.kernel
|
|
35
|
+
def triple_kernel_vecmat(input: wp.array(dtype=Any), output: wp.array(dtype=Any)):
|
|
36
|
+
tid = wp.tid()
|
|
37
|
+
output[tid] = input.dtype.dtype(3) * input[tid]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# kernel with multiple inputs and outputs
|
|
41
|
+
@wp.kernel
|
|
42
|
+
def multiarg_kernel(
|
|
43
|
+
# inputs
|
|
44
|
+
a: wp.array(dtype=float),
|
|
45
|
+
b: wp.array(dtype=float),
|
|
46
|
+
c: wp.array(dtype=float),
|
|
47
|
+
# outputs
|
|
48
|
+
ab: wp.array(dtype=float),
|
|
49
|
+
bc: wp.array(dtype=float),
|
|
50
|
+
):
|
|
51
|
+
tid = wp.tid()
|
|
52
|
+
ab[tid] = a[tid] + b[tid]
|
|
53
|
+
bc[tid] = b[tid] + c[tid]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# various types for testing
|
|
57
|
+
scalar_types = wp.types.scalar_types
|
|
58
|
+
vector_types = []
|
|
59
|
+
matrix_types = []
|
|
60
|
+
for dim in [2, 3, 4]:
|
|
61
|
+
for T in scalar_types:
|
|
62
|
+
vector_types.append(wp.vec(dim, T))
|
|
63
|
+
matrix_types.append(wp.mat((dim, dim), T))
|
|
64
|
+
|
|
65
|
+
# explicitly overload generic kernels to avoid module reloading during tests
|
|
66
|
+
for T in scalar_types:
|
|
67
|
+
wp.overload(triple_kernel_scalar, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
68
|
+
for T in [*vector_types, *matrix_types]:
|
|
69
|
+
wp.overload(triple_kernel_vecmat, [wp.array(dtype=T), wp.array(dtype=T)])
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _jax_version():
|
|
73
|
+
try:
|
|
74
|
+
import jax
|
|
75
|
+
return jax.__version_info__
|
|
76
|
+
except ImportError:
|
|
77
|
+
return (0, 0, 0)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
81
|
+
def test_jax_kernel_basic(test, device):
|
|
82
|
+
import jax.numpy as jp
|
|
83
|
+
from warp.jax_experimental import jax_kernel
|
|
84
|
+
|
|
85
|
+
n = 64
|
|
86
|
+
|
|
87
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
88
|
+
|
|
89
|
+
@jax.jit
|
|
90
|
+
def f():
|
|
91
|
+
x = jp.arange(n, dtype=jp.float32)
|
|
92
|
+
return jax_triple(x)
|
|
93
|
+
|
|
94
|
+
# run on the given device
|
|
95
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
96
|
+
y = f()
|
|
97
|
+
|
|
98
|
+
result = np.asarray(y)
|
|
99
|
+
expected = 3 * np.arange(n, dtype=np.float32)
|
|
100
|
+
|
|
101
|
+
assert_np_equal(result, expected)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
105
|
+
def test_jax_kernel_scalar(test, device):
|
|
106
|
+
import jax.numpy as jp
|
|
107
|
+
from warp.jax_experimental import jax_kernel
|
|
108
|
+
|
|
109
|
+
n = 64
|
|
110
|
+
|
|
111
|
+
for T in scalar_types:
|
|
112
|
+
|
|
113
|
+
jp_dtype = wp.jax.dtype_to_jax(T)
|
|
114
|
+
np_dtype = wp.types.warp_type_to_np_dtype[T]
|
|
115
|
+
|
|
116
|
+
with test.subTest(msg=T.__name__):
|
|
117
|
+
|
|
118
|
+
# get the concrete overload
|
|
119
|
+
kernel_instance = triple_kernel_scalar.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
120
|
+
|
|
121
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
122
|
+
|
|
123
|
+
@jax.jit
|
|
124
|
+
def f():
|
|
125
|
+
x = jp.arange(n, dtype=jp_dtype)
|
|
126
|
+
return jax_triple(x)
|
|
127
|
+
|
|
128
|
+
# run on the given device
|
|
129
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
130
|
+
y = f()
|
|
131
|
+
|
|
132
|
+
result = np.asarray(y)
|
|
133
|
+
expected = 3 * np.arange(n, dtype=np_dtype)
|
|
134
|
+
|
|
135
|
+
assert_np_equal(result, expected)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
139
|
+
def test_jax_kernel_vecmat(test, device):
|
|
140
|
+
import jax.numpy as jp
|
|
141
|
+
from warp.jax_experimental import jax_kernel
|
|
142
|
+
|
|
143
|
+
for T in [*vector_types, *matrix_types]:
|
|
144
|
+
|
|
145
|
+
jp_dtype = wp.jax.dtype_to_jax(T._wp_scalar_type_)
|
|
146
|
+
np_dtype = wp.types.warp_type_to_np_dtype[T._wp_scalar_type_]
|
|
147
|
+
|
|
148
|
+
n = 64 // T._length_
|
|
149
|
+
scalar_shape = (n, *T._shape_)
|
|
150
|
+
scalar_len = n * T._length_
|
|
151
|
+
|
|
152
|
+
with test.subTest(msg=T.__name__):
|
|
153
|
+
|
|
154
|
+
# get the concrete overload
|
|
155
|
+
kernel_instance = triple_kernel_vecmat.get_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
156
|
+
|
|
157
|
+
jax_triple = jax_kernel(kernel_instance)
|
|
158
|
+
|
|
159
|
+
@jax.jit
|
|
160
|
+
def f():
|
|
161
|
+
x = jp.arange(scalar_len, dtype=jp_dtype).reshape(scalar_shape)
|
|
162
|
+
return jax_triple(x)
|
|
163
|
+
|
|
164
|
+
# run on the given device
|
|
165
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
166
|
+
y = f()
|
|
167
|
+
|
|
168
|
+
result = np.asarray(y)
|
|
169
|
+
expected = 3 * np.arange(scalar_len, dtype=np_dtype).reshape(scalar_shape)
|
|
170
|
+
|
|
171
|
+
assert_np_equal(result, expected)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
|
|
175
|
+
def test_jax_kernel_multiarg(test, device):
|
|
176
|
+
import jax.numpy as jp
|
|
177
|
+
from warp.jax_experimental import jax_kernel
|
|
178
|
+
|
|
179
|
+
n = 64
|
|
180
|
+
|
|
181
|
+
jax_multiarg = jax_kernel(multiarg_kernel)
|
|
182
|
+
|
|
183
|
+
@jax.jit
|
|
184
|
+
def f():
|
|
185
|
+
a = jp.full(n, 1, dtype=jp.float32)
|
|
186
|
+
b = jp.full(n, 2, dtype=jp.float32)
|
|
187
|
+
c = jp.full(n, 3, dtype=jp.float32)
|
|
188
|
+
return jax_multiarg(a, b, c)
|
|
189
|
+
|
|
190
|
+
# run on the given device
|
|
191
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
192
|
+
x, y = f()
|
|
193
|
+
|
|
194
|
+
result_x, result_y = np.asarray(x), np.asarray(y)
|
|
195
|
+
expected_x = np.full(n, 3, dtype=np.float32)
|
|
196
|
+
expected_y = np.full(n, 5, dtype=np.float32)
|
|
197
|
+
|
|
198
|
+
assert_np_equal(result_x, expected_x)
|
|
199
|
+
assert_np_equal(result_y, expected_y)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class TestJax(unittest.TestCase):
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# try adding Jax tests if Jax is installed correctly
|
|
207
|
+
try:
|
|
208
|
+
# prevent Jax from gobbling up GPU memory
|
|
209
|
+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
210
|
+
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
211
|
+
|
|
212
|
+
import jax
|
|
213
|
+
import jax.dlpack
|
|
214
|
+
|
|
215
|
+
# NOTE: we must enable 64-bit types in Jax to test the full gamut of types
|
|
216
|
+
jax.config.update("jax_enable_x64", True)
|
|
217
|
+
|
|
218
|
+
# check which Warp devices work with Jax
|
|
219
|
+
# CUDA devices may fail if Jax cannot find a CUDA Toolkit
|
|
220
|
+
test_devices = get_test_devices()
|
|
221
|
+
jax_compatible_devices = []
|
|
222
|
+
jax_compatible_cuda_devices = []
|
|
223
|
+
for d in test_devices:
|
|
224
|
+
try:
|
|
225
|
+
with jax.default_device(wp.device_to_jax(d)):
|
|
226
|
+
j = jax.numpy.arange(10, dtype=jax.numpy.float32)
|
|
227
|
+
j += 1
|
|
228
|
+
jax_compatible_devices.append(d)
|
|
229
|
+
if d.is_cuda:
|
|
230
|
+
jax_compatible_cuda_devices.append(d)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
print(f"Skipping Jax DLPack tests on device '{d}' due to exception: {e}")
|
|
233
|
+
|
|
234
|
+
if jax_compatible_cuda_devices:
|
|
235
|
+
add_function_test(
|
|
236
|
+
TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices
|
|
237
|
+
)
|
|
238
|
+
add_function_test(
|
|
239
|
+
TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
|
|
240
|
+
)
|
|
241
|
+
add_function_test(
|
|
242
|
+
TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
|
|
243
|
+
)
|
|
244
|
+
add_function_test(
|
|
245
|
+
TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
print(f"Skipping Jax tests due to exception: {e}")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
if __name__ == "__main__":
|
|
253
|
+
wp.build.clear_kernel_cache()
|
|
254
|
+
unittest.main(verbosity=2)
|
warp/tests/test_large.py
CHANGED
|
@@ -81,8 +81,8 @@ def test_large_arrays_slow(test, device):
|
|
|
81
81
|
# without changes to support how frequently a test may be run
|
|
82
82
|
total_elements = 2**31 + 8
|
|
83
83
|
|
|
84
|
-
#
|
|
85
|
-
for total_dims in range(
|
|
84
|
+
# 2-D to 4-D arrays: test zero_, fill_, then zero_ for scalar data types:
|
|
85
|
+
for total_dims in range(2, 5):
|
|
86
86
|
dim_x = math.ceil(total_elements ** (1 / total_dims))
|
|
87
87
|
shape_tuple = tuple([dim_x] * total_dims)
|
|
88
88
|
|
|
@@ -99,21 +99,42 @@ def test_large_arrays_slow(test, device):
|
|
|
99
99
|
|
|
100
100
|
def test_large_arrays_fast(test, device):
|
|
101
101
|
# A truncated version of test_large_arrays_slow meant to catch basic errors
|
|
102
|
-
|
|
102
|
+
|
|
103
|
+
# Make is so that a (dim_x, dim_x) array has more than 2**31 elements
|
|
104
|
+
dim_x = math.ceil(math.sqrt(2**31))
|
|
103
105
|
|
|
104
106
|
nptype = np.dtype(np.int8)
|
|
105
107
|
wptype = wp.types.np_dtype_to_warp_type[nptype]
|
|
106
108
|
|
|
107
|
-
a1 = wp.zeros((
|
|
108
|
-
assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
|
|
109
|
-
|
|
109
|
+
a1 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
|
|
110
110
|
a1.fill_(127)
|
|
111
|
+
|
|
111
112
|
assert_np_equal(a1.numpy(), 127 * np.ones_like(a1.numpy()))
|
|
112
113
|
|
|
113
114
|
a1.zero_()
|
|
114
115
|
assert_np_equal(a1.numpy(), np.zeros_like(a1.numpy()))
|
|
115
116
|
|
|
116
117
|
|
|
118
|
+
def test_large_array_excessive_zeros(test, device):
|
|
119
|
+
# Tests the allocation of an array with length exceeding 2**31-1 in a dimension
|
|
120
|
+
|
|
121
|
+
with test.assertRaisesRegex(
|
|
122
|
+
ValueError, "Array shapes must not exceed the maximum representable value of a signed 32-bit integer"
|
|
123
|
+
):
|
|
124
|
+
_ = wp.zeros((2**31), dtype=int, device=device)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def test_large_array_excessive_numpy(test, device):
|
|
128
|
+
# Tests the allocation of an array from a numpy array with length exceeding 2**31-1 in a dimension
|
|
129
|
+
|
|
130
|
+
large_np_array = np.empty((2**31), dtype=int)
|
|
131
|
+
|
|
132
|
+
with test.assertRaisesRegex(
|
|
133
|
+
ValueError, "Array shapes must not exceed the maximum representable value of a signed 32-bit integer"
|
|
134
|
+
):
|
|
135
|
+
_ = wp.array(large_np_array, device=device)
|
|
136
|
+
|
|
137
|
+
|
|
117
138
|
devices = get_test_devices()
|
|
118
139
|
|
|
119
140
|
|
|
@@ -134,6 +155,8 @@ add_function_test(
|
|
|
134
155
|
)
|
|
135
156
|
|
|
136
157
|
add_function_test(TestLarge, "test_large_arrays_fast", test_large_arrays_fast, devices=devices)
|
|
158
|
+
add_function_test(TestLarge, "test_large_array_excessive_zeros", test_large_array_excessive_zeros, devices=devices)
|
|
159
|
+
add_function_test(TestLarge, "test_large_array_excessive_numpy", test_large_array_excessive_numpy, devices=devices)
|
|
137
160
|
|
|
138
161
|
|
|
139
162
|
if __name__ == "__main__":
|
warp/tests/test_launch.py
CHANGED
|
@@ -301,7 +301,30 @@ def test_launch_tuple_args(test, device):
|
|
|
301
301
|
outputs=(out,),
|
|
302
302
|
device=device,
|
|
303
303
|
)
|
|
304
|
+
assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
|
|
304
305
|
|
|
306
|
+
wp.launch(
|
|
307
|
+
kernel_mul,
|
|
308
|
+
dim=len(values),
|
|
309
|
+
inputs=(
|
|
310
|
+
values,
|
|
311
|
+
coeff,
|
|
312
|
+
out,
|
|
313
|
+
),
|
|
314
|
+
device=device,
|
|
315
|
+
)
|
|
316
|
+
assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
|
|
317
|
+
|
|
318
|
+
wp.launch(
|
|
319
|
+
kernel_mul,
|
|
320
|
+
dim=len(values),
|
|
321
|
+
outputs=(
|
|
322
|
+
values,
|
|
323
|
+
coeff,
|
|
324
|
+
out,
|
|
325
|
+
),
|
|
326
|
+
device=device,
|
|
327
|
+
)
|
|
305
328
|
assert_np_equal(out.numpy(), np.array((0, 3, 6, 9)))
|
|
306
329
|
|
|
307
330
|
|
|
@@ -323,6 +346,8 @@ add_function_test(TestLaunch, "test_launch_cmd_set_ctype", test_launch_cmd_set_c
|
|
|
323
346
|
add_function_test(TestLaunch, "test_launch_cmd_set_dim", test_launch_cmd_set_dim, devices=devices)
|
|
324
347
|
add_function_test(TestLaunch, "test_launch_cmd_empty", test_launch_cmd_empty, devices=devices)
|
|
325
348
|
|
|
349
|
+
add_function_test(TestLaunch, "test_launch_tuple_args", test_launch_tuple_args, devices=devices)
|
|
350
|
+
|
|
326
351
|
|
|
327
352
|
if __name__ == "__main__":
|
|
328
353
|
wp.build.clear_kernel_cache()
|
|
@@ -7,9 +7,10 @@ import unittest
|
|
|
7
7
|
from warp.optim.linear import preconditioner, cg, bicgstab, gmres
|
|
8
8
|
from warp.tests.unittest_utils import *
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
wp.init()
|
|
12
11
|
|
|
12
|
+
from warp.context import runtime # noqa: E402
|
|
13
|
+
|
|
13
14
|
|
|
14
15
|
def _check_linear_solve(test, A, b, func, *args, **kwargs):
|
|
15
16
|
# test from zero
|
|
@@ -75,6 +76,15 @@ def _make_indefinite_system(n: int, seed: int, dtype, device, spd=False):
|
|
|
75
76
|
return wp.array(A, dtype=dtype, device=device), wp.array(b, dtype=dtype, device=device)
|
|
76
77
|
|
|
77
78
|
|
|
79
|
+
def _make_identity_system(n: int, seed: int, dtype, device):
|
|
80
|
+
rng = np.random.default_rng(seed)
|
|
81
|
+
|
|
82
|
+
A = np.eye(n)
|
|
83
|
+
b = rng.uniform(low=-1.0, high=1.0, size=(n,))
|
|
84
|
+
|
|
85
|
+
return wp.array(A, dtype=dtype, device=device), wp.array(b, dtype=dtype, device=device)
|
|
86
|
+
|
|
87
|
+
|
|
78
88
|
def test_cg(test, device):
|
|
79
89
|
A, b = _make_spd_system(n=64, seed=123, device=device, dtype=wp.float64)
|
|
80
90
|
M = preconditioner(A, "diag")
|
|
@@ -88,6 +98,9 @@ def test_cg(test, device):
|
|
|
88
98
|
_check_linear_solve(test, A, b, cg, maxiter=1000)
|
|
89
99
|
_check_linear_solve(test, A, b, cg, M=M, maxiter=1000)
|
|
90
100
|
|
|
101
|
+
A, b = _make_identity_system(n=5, seed=321, device=device, dtype=wp.float32)
|
|
102
|
+
_check_linear_solve(test, A, b, cg, maxiter=30)
|
|
103
|
+
|
|
91
104
|
|
|
92
105
|
def test_bicgstab(test, device):
|
|
93
106
|
A, b = _make_nonsymmetric_system(n=64, seed=123, device=device, dtype=wp.float64)
|
|
@@ -111,6 +124,9 @@ def test_bicgstab(test, device):
|
|
|
111
124
|
_check_linear_solve(test, A, b, bicgstab, M=M, maxiter=1000)
|
|
112
125
|
_check_linear_solve(test, A, b, bicgstab, M=M, maxiter=1000, is_left_preconditioner=True)
|
|
113
126
|
|
|
127
|
+
A, b = _make_identity_system(n=5, seed=321, device=device, dtype=wp.float32)
|
|
128
|
+
_check_linear_solve(test, A, b, bicgstab, maxiter=30)
|
|
129
|
+
|
|
114
130
|
|
|
115
131
|
def test_gmres(test, device):
|
|
116
132
|
A, b = _make_nonsymmetric_system(n=64, seed=456, device=device, dtype=wp.float64)
|
|
@@ -127,6 +143,9 @@ def test_gmres(test, device):
|
|
|
127
143
|
_check_linear_solve(test, A, b, gmres, M=M, maxiter=1000, tol=1.0e-5)
|
|
128
144
|
_check_linear_solve(test, A, b, gmres, M=M, maxiter=1000, tol=1.0e-5, is_left_preconditioner=True)
|
|
129
145
|
|
|
146
|
+
A, b = _make_identity_system(n=5, seed=123, device=device, dtype=wp.float32)
|
|
147
|
+
_check_linear_solve(test, A, b, gmres, maxiter=120)
|
|
148
|
+
|
|
130
149
|
|
|
131
150
|
class TestLinearSolvers(unittest.TestCase):
|
|
132
151
|
pass
|
|
@@ -134,8 +153,6 @@ class TestLinearSolvers(unittest.TestCase):
|
|
|
134
153
|
|
|
135
154
|
devices = get_test_devices()
|
|
136
155
|
|
|
137
|
-
from warp.context import runtime
|
|
138
|
-
|
|
139
156
|
if not runtime.core.is_cutlass_enabled():
|
|
140
157
|
devices = [d for d in devices if not d.is_cuda]
|
|
141
158
|
print("Skipping CUDA linear solver tests because CUTLASS is not supported in this build")
|