warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
|
@@ -28,7 +28,7 @@ def test_tile_shared_mem_size(test, device):
|
|
|
28
28
|
|
|
29
29
|
BLOCK_DIM = 256
|
|
30
30
|
|
|
31
|
-
@wp.kernel
|
|
31
|
+
@wp.kernel(module="unique")
|
|
32
32
|
def compute(out: wp.array2d(dtype=float)):
|
|
33
33
|
a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
|
|
34
34
|
b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
|
|
@@ -64,7 +64,7 @@ def test_tile_shared_mem_large(test, device):
|
|
|
64
64
|
BLOCK_DIM = 256
|
|
65
65
|
|
|
66
66
|
# we disable backward kernel gen since 128k is not supported on most architectures
|
|
67
|
-
@wp.kernel(enable_backward=False)
|
|
67
|
+
@wp.kernel(enable_backward=False, module="unique")
|
|
68
68
|
def compute(out: wp.array2d(dtype=float)):
|
|
69
69
|
a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
|
|
70
70
|
b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
|
|
@@ -100,7 +100,7 @@ def test_tile_shared_mem_graph(test, device):
|
|
|
100
100
|
|
|
101
101
|
BLOCK_DIM = 256
|
|
102
102
|
|
|
103
|
-
@wp.kernel
|
|
103
|
+
@wp.kernel(module="unique")
|
|
104
104
|
def compute(out: wp.array2d(dtype=float)):
|
|
105
105
|
a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
|
|
106
106
|
b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
|
|
@@ -110,7 +110,7 @@ def test_tile_shared_mem_graph(test, device):
|
|
|
110
110
|
|
|
111
111
|
out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
|
|
112
112
|
|
|
113
|
-
|
|
113
|
+
compute.module.load(device)
|
|
114
114
|
|
|
115
115
|
wp.capture_begin(device, force_module_load=False)
|
|
116
116
|
wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
|
|
@@ -157,7 +157,7 @@ def test_tile_shared_mem_func(test, device):
|
|
|
157
157
|
|
|
158
158
|
return a + b
|
|
159
159
|
|
|
160
|
-
@wp.kernel
|
|
160
|
+
@wp.kernel(module="unique")
|
|
161
161
|
def compute(out: wp.array2d(dtype=float)):
|
|
162
162
|
s = add_tile_small()
|
|
163
163
|
b = add_tile_big()
|
|
@@ -197,7 +197,7 @@ def test_tile_shared_non_aligned(test, device):
|
|
|
197
197
|
b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 3.0
|
|
198
198
|
return a + b
|
|
199
199
|
|
|
200
|
-
@wp.kernel
|
|
200
|
+
@wp.kernel(module="unique")
|
|
201
201
|
def compute(out: wp.array2d(dtype=float)):
|
|
202
202
|
# This test the logic in the stack allocator, which should increment and
|
|
203
203
|
# decrement the stack pointer each time foo() is called
|
|
@@ -224,6 +224,121 @@ def test_tile_shared_non_aligned(test, device):
|
|
|
224
224
|
assert hooks.backward_smem_bytes == expected_required_shared * 2
|
|
225
225
|
|
|
226
226
|
|
|
227
|
+
def test_tile_shared_vec_accumulation(test, device):
|
|
228
|
+
BLOCK_DIM = 256
|
|
229
|
+
|
|
230
|
+
@wp.kernel(module="unique")
|
|
231
|
+
def compute(indices: wp.array(dtype=int), vecs: wp.array(dtype=wp.vec3), output: wp.array2d(dtype=float)):
|
|
232
|
+
i, j = wp.tid()
|
|
233
|
+
|
|
234
|
+
idx_tile = wp.tile_load(indices, shape=BLOCK_DIM, offset=i * BLOCK_DIM)
|
|
235
|
+
idx = idx_tile[j]
|
|
236
|
+
|
|
237
|
+
s = wp.tile_zeros(shape=(1, 3), dtype=float)
|
|
238
|
+
|
|
239
|
+
s[0, 0] += vecs[idx].x
|
|
240
|
+
s[0, 1] += vecs[idx].y
|
|
241
|
+
s[0, 2] += vecs[idx].z
|
|
242
|
+
|
|
243
|
+
wp.tile_store(output, s, offset=(i, 0))
|
|
244
|
+
|
|
245
|
+
N = BLOCK_DIM * 3
|
|
246
|
+
|
|
247
|
+
basis_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
|
|
248
|
+
vecs = wp.array(basis_vecs, dtype=wp.vec3, requires_grad=True, device=device)
|
|
249
|
+
|
|
250
|
+
rng = np.random.default_rng(42)
|
|
251
|
+
indices_np = rng.integers(0, 3, size=N)
|
|
252
|
+
|
|
253
|
+
indices = wp.array(indices_np, dtype=int, requires_grad=True, device=device)
|
|
254
|
+
|
|
255
|
+
output = wp.zeros(shape=(3, 3), dtype=float, requires_grad=True, device=device)
|
|
256
|
+
|
|
257
|
+
tape = wp.Tape()
|
|
258
|
+
with tape:
|
|
259
|
+
wp.launch_tiled(compute, dim=3, inputs=[indices, vecs, output], block_dim=BLOCK_DIM, device=device)
|
|
260
|
+
|
|
261
|
+
output.grad = wp.ones_like(output)
|
|
262
|
+
|
|
263
|
+
tape.backward()
|
|
264
|
+
|
|
265
|
+
n0 = np.count_nonzero(indices_np == 0)
|
|
266
|
+
n1 = np.count_nonzero(indices_np == 1)
|
|
267
|
+
n2 = np.count_nonzero(indices_np == 2)
|
|
268
|
+
true_grads = np.array([[n0, n0, n0], [n1, n1, n1], [n2, n2, n2]])
|
|
269
|
+
|
|
270
|
+
indices_np = indices_np.reshape((3, BLOCK_DIM))
|
|
271
|
+
|
|
272
|
+
def compute_row(idx):
|
|
273
|
+
n0 = np.count_nonzero(indices_np[idx, :] == 0)
|
|
274
|
+
n1 = np.count_nonzero(indices_np[idx, :] == 1)
|
|
275
|
+
n2 = np.count_nonzero(indices_np[idx, :] == 2)
|
|
276
|
+
return np.array([1, 0, 0]) * n0 + np.array([0, 1, 0]) * n1 + np.array([0, 0, 1]) * n2
|
|
277
|
+
|
|
278
|
+
row_0 = compute_row(0)
|
|
279
|
+
row_1 = compute_row(1)
|
|
280
|
+
row_2 = compute_row(2)
|
|
281
|
+
|
|
282
|
+
true_vecs = np.stack([row_0, row_1, row_2])
|
|
283
|
+
|
|
284
|
+
assert_np_equal(output.numpy(), true_vecs)
|
|
285
|
+
assert_np_equal(vecs.grad.numpy(), true_grads)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def test_tile_shared_simple_reduction_add(test, device):
|
|
289
|
+
BLOCK_DIM = 256
|
|
290
|
+
|
|
291
|
+
@wp.kernel(module="unique")
|
|
292
|
+
def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
293
|
+
i, j = wp.tid()
|
|
294
|
+
|
|
295
|
+
t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
|
|
296
|
+
|
|
297
|
+
k = BLOCK_DIM // 2
|
|
298
|
+
while k > 0:
|
|
299
|
+
if j < k:
|
|
300
|
+
t[j] += t[j + k]
|
|
301
|
+
k //= 2
|
|
302
|
+
|
|
303
|
+
wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
|
|
304
|
+
|
|
305
|
+
N = BLOCK_DIM * 4
|
|
306
|
+
x_np = np.arange(N, dtype=np.float32)
|
|
307
|
+
x = wp.array(x_np, dtype=float, device=device)
|
|
308
|
+
y = wp.zeros(4, dtype=float, device=device)
|
|
309
|
+
|
|
310
|
+
wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
|
|
311
|
+
|
|
312
|
+
assert_np_equal(np.sum(y.numpy()), np.sum(x_np))
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def test_tile_shared_simple_reduction_sub(test, device):
|
|
316
|
+
BLOCK_DIM = 256
|
|
317
|
+
|
|
318
|
+
@wp.kernel(module="unique")
|
|
319
|
+
def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
320
|
+
i, j = wp.tid()
|
|
321
|
+
|
|
322
|
+
t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
|
|
323
|
+
|
|
324
|
+
k = BLOCK_DIM // 2
|
|
325
|
+
while k > 0:
|
|
326
|
+
if j < k:
|
|
327
|
+
t[j] -= t[j + k]
|
|
328
|
+
k //= 2
|
|
329
|
+
|
|
330
|
+
wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
|
|
331
|
+
|
|
332
|
+
N = BLOCK_DIM * 4
|
|
333
|
+
x_np = np.arange(N, dtype=np.float32)
|
|
334
|
+
x = wp.array(x_np, dtype=float, device=device)
|
|
335
|
+
y = wp.zeros(4, dtype=float, device=device)
|
|
336
|
+
|
|
337
|
+
wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
|
|
338
|
+
|
|
339
|
+
assert_np_equal(np.sum(y.numpy()), 0.0)
|
|
340
|
+
|
|
341
|
+
|
|
227
342
|
devices = get_cuda_test_devices()
|
|
228
343
|
|
|
229
344
|
|
|
@@ -240,7 +355,21 @@ add_function_test(
|
|
|
240
355
|
add_function_test(TestTileSharedMemory, "test_tile_shared_mem_graph", test_tile_shared_mem_graph, devices=devices)
|
|
241
356
|
add_function_test(TestTileSharedMemory, "test_tile_shared_mem_func", test_tile_shared_mem_func, devices=devices)
|
|
242
357
|
add_function_test(TestTileSharedMemory, "test_tile_shared_non_aligned", test_tile_shared_non_aligned, devices=devices)
|
|
243
|
-
|
|
358
|
+
add_function_test(
|
|
359
|
+
TestTileSharedMemory, "test_tile_shared_vec_accumulation", test_tile_shared_vec_accumulation, devices=devices
|
|
360
|
+
)
|
|
361
|
+
add_function_test(
|
|
362
|
+
TestTileSharedMemory,
|
|
363
|
+
"test_tile_shared_simple_reduction_add",
|
|
364
|
+
test_tile_shared_simple_reduction_add,
|
|
365
|
+
devices=devices,
|
|
366
|
+
)
|
|
367
|
+
add_function_test(
|
|
368
|
+
TestTileSharedMemory,
|
|
369
|
+
"test_tile_shared_simple_reduction_sub",
|
|
370
|
+
test_tile_shared_simple_reduction_sub,
|
|
371
|
+
devices=devices,
|
|
372
|
+
)
|
|
244
373
|
|
|
245
374
|
if __name__ == "__main__":
|
|
246
375
|
wp.clear_kernel_cache()
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
from warp.tests.unittest_utils import *
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_sort_kernel(KEY_TYPE, MAX_SORT_LENGTH):
|
|
25
|
+
@wp.kernel
|
|
26
|
+
def tile_sort_kernel(
|
|
27
|
+
input_keys: wp.array(dtype=KEY_TYPE),
|
|
28
|
+
input_values: wp.array(dtype=wp.int32),
|
|
29
|
+
output_keys: wp.array(dtype=KEY_TYPE),
|
|
30
|
+
output_values: wp.array(dtype=wp.int32),
|
|
31
|
+
):
|
|
32
|
+
# Load input into shared memory
|
|
33
|
+
keys = wp.tile_load(input_keys, shape=MAX_SORT_LENGTH, storage="shared")
|
|
34
|
+
values = wp.tile_load(input_values, shape=MAX_SORT_LENGTH, storage="shared")
|
|
35
|
+
|
|
36
|
+
# Perform in-place sorting
|
|
37
|
+
wp.tile_sort(keys, values)
|
|
38
|
+
|
|
39
|
+
# Store sorted shared memory into output arrays
|
|
40
|
+
wp.tile_store(output_keys, keys)
|
|
41
|
+
wp.tile_store(output_values, values)
|
|
42
|
+
|
|
43
|
+
return tile_sort_kernel
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_tile_sort(test, device):
|
|
47
|
+
# Forward-declare kernels for more efficient compilation
|
|
48
|
+
kernels = {}
|
|
49
|
+
for dtype in [int, float]:
|
|
50
|
+
for i in range(0, 11):
|
|
51
|
+
length = 2**i + 1
|
|
52
|
+
kernels[(dtype, length)] = create_sort_kernel(dtype, length)
|
|
53
|
+
|
|
54
|
+
for (dtype, length), kernel in kernels.items():
|
|
55
|
+
for j in range(5, 10):
|
|
56
|
+
TILE_DIM = 2**j
|
|
57
|
+
|
|
58
|
+
rng = np.random.default_rng(42) # Create a random generator instance
|
|
59
|
+
|
|
60
|
+
if dtype == int:
|
|
61
|
+
np_keys = rng.choice(1000000000, size=length, replace=False)
|
|
62
|
+
else: # dtype == float
|
|
63
|
+
np_keys = rng.uniform(0, 1000000000, size=length).astype(dtype)
|
|
64
|
+
|
|
65
|
+
np_values = np.arange(length)
|
|
66
|
+
|
|
67
|
+
# Generate random keys and iota indexer
|
|
68
|
+
input_keys = wp.array(np_keys, dtype=dtype, device=device)
|
|
69
|
+
input_values = wp.array(np_values, dtype=int, device=device)
|
|
70
|
+
output_keys = wp.zeros_like(input_keys, device=device)
|
|
71
|
+
output_values = wp.zeros_like(input_values, device=device)
|
|
72
|
+
|
|
73
|
+
# Execute sorting kernel
|
|
74
|
+
wp.launch_tiled(
|
|
75
|
+
kernel,
|
|
76
|
+
dim=1,
|
|
77
|
+
inputs=[input_keys, input_values, output_keys, output_values],
|
|
78
|
+
block_dim=TILE_DIM,
|
|
79
|
+
device=device,
|
|
80
|
+
)
|
|
81
|
+
wp.synchronize()
|
|
82
|
+
|
|
83
|
+
# Sort using NumPy for validation
|
|
84
|
+
sorted_indices = np.argsort(np_keys)
|
|
85
|
+
np_sorted_keys = np_keys[sorted_indices]
|
|
86
|
+
np_sorted_values = np_values[sorted_indices]
|
|
87
|
+
|
|
88
|
+
if dtype == int:
|
|
89
|
+
keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
|
|
90
|
+
else: # dtype == float
|
|
91
|
+
keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
|
|
92
|
+
|
|
93
|
+
values_match = np.array_equal(output_values.numpy(), np_sorted_values)
|
|
94
|
+
|
|
95
|
+
if not keys_match or not values_match:
|
|
96
|
+
print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
|
|
97
|
+
print("")
|
|
98
|
+
print(output_keys.numpy())
|
|
99
|
+
print(np_sorted_keys)
|
|
100
|
+
print("")
|
|
101
|
+
print(output_values.numpy())
|
|
102
|
+
print(np_sorted_values)
|
|
103
|
+
print("")
|
|
104
|
+
|
|
105
|
+
# Validate results
|
|
106
|
+
test.assertTrue(keys_match, f"Key sorting mismatch for dtype={dtype}!")
|
|
107
|
+
test.assertTrue(values_match, f"Value sorting mismatch for dtype={dtype}!")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
devices = get_test_devices()
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class TestTileSort(unittest.TestCase):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
add_function_test(TestTileSort, "test_tile_sort", test_tile_sort, devices=devices)
|
|
118
|
+
|
|
119
|
+
if __name__ == "__main__":
|
|
120
|
+
wp.clear_kernel_cache()
|
|
121
|
+
unittest.main(verbosity=2, failfast=True)
|
warp/tests/unittest_suites.py
CHANGED
|
@@ -113,17 +113,18 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
113
113
|
from warp.tests.interop.test_dlpack import TestDLPack
|
|
114
114
|
from warp.tests.interop.test_jax import TestJax
|
|
115
115
|
from warp.tests.interop.test_torch import TestTorch
|
|
116
|
+
from warp.tests.sim.test_cloth import TestCloth
|
|
116
117
|
from warp.tests.sim.test_collision import TestCollision
|
|
117
118
|
from warp.tests.sim.test_coloring import TestColoring
|
|
118
119
|
from warp.tests.sim.test_model import TestModel
|
|
119
120
|
from warp.tests.sim.test_sim_grad import TestSimGradients
|
|
120
121
|
from warp.tests.sim.test_sim_kinematics import TestSimKinematics
|
|
121
|
-
from warp.tests.sim.test_vbd import TestVbd
|
|
122
122
|
from warp.tests.test_adam import TestAdam
|
|
123
123
|
from warp.tests.test_arithmetic import TestArithmetic
|
|
124
124
|
from warp.tests.test_array import TestArray
|
|
125
125
|
from warp.tests.test_array_reduce import TestArrayReduce
|
|
126
126
|
from warp.tests.test_atomic import TestAtomic
|
|
127
|
+
from warp.tests.test_atomic_cas import TestAtomicCAS
|
|
127
128
|
from warp.tests.test_bool import TestBool
|
|
128
129
|
from warp.tests.test_builtins_resolution import TestBuiltinsResolution
|
|
129
130
|
from warp.tests.test_closest_point_edge_edge import TestClosestPointEdgeEdgeMethods
|
|
@@ -166,7 +167,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
166
167
|
from warp.tests.test_mat_lite import TestMatLite
|
|
167
168
|
from warp.tests.test_mat_scalar_ops import TestMatScalarOps
|
|
168
169
|
from warp.tests.test_math import TestMath
|
|
169
|
-
from warp.tests.test_mlp import TestMLP
|
|
170
170
|
from warp.tests.test_module_hashing import TestModuleHashing
|
|
171
171
|
from warp.tests.test_modules_lite import TestModuleLite
|
|
172
172
|
from warp.tests.test_noise import TestNoise
|
|
@@ -193,13 +193,18 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
193
193
|
from warp.tests.test_types import TestTypes
|
|
194
194
|
from warp.tests.test_utils import TestUtils
|
|
195
195
|
from warp.tests.test_vec import TestVec
|
|
196
|
+
from warp.tests.test_vec_constructors import TestVecConstructors
|
|
196
197
|
from warp.tests.test_vec_lite import TestVecLite
|
|
197
198
|
from warp.tests.test_vec_scalar_ops import TestVecScalarOps
|
|
198
199
|
from warp.tests.test_verify_fp import TestVerifyFP
|
|
199
200
|
from warp.tests.tile.test_tile import TestTile
|
|
201
|
+
from warp.tests.tile.test_tile_load import TestTileLoad
|
|
200
202
|
from warp.tests.tile.test_tile_mathdx import TestTileMathDx
|
|
203
|
+
from warp.tests.tile.test_tile_matmul import TestTileMatmul
|
|
201
204
|
from warp.tests.tile.test_tile_reduce import TestTileReduce
|
|
202
205
|
from warp.tests.tile.test_tile_shared_memory import TestTileSharedMemory
|
|
206
|
+
from warp.tests.tile.test_tile_sort import TestTileSort
|
|
207
|
+
from warp.tests.tile.test_tile_view import TestTileView
|
|
203
208
|
|
|
204
209
|
test_classes = [
|
|
205
210
|
TestAdam,
|
|
@@ -208,10 +213,12 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
208
213
|
TestArrayReduce,
|
|
209
214
|
TestAsync,
|
|
210
215
|
TestAtomic,
|
|
216
|
+
TestAtomicCAS,
|
|
211
217
|
TestBool,
|
|
212
218
|
TestBuiltinsResolution,
|
|
213
219
|
TestBvh,
|
|
214
220
|
TestClosestPointEdgeEdgeMethods,
|
|
221
|
+
TestCloth,
|
|
215
222
|
TestCodeGen,
|
|
216
223
|
TestCodeGenInstancing,
|
|
217
224
|
TestCollision,
|
|
@@ -262,7 +269,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
262
269
|
TestMeshQueryAABBMethods,
|
|
263
270
|
TestMeshQueryPoint,
|
|
264
271
|
TestMeshQueryRay,
|
|
265
|
-
TestMLP,
|
|
266
272
|
TestModel,
|
|
267
273
|
TestModuleHashing,
|
|
268
274
|
TestModuleLite,
|
|
@@ -292,16 +298,20 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
292
298
|
TestStruct,
|
|
293
299
|
TestTape,
|
|
294
300
|
TestTile,
|
|
301
|
+
TestTileLoad,
|
|
295
302
|
TestTileMathDx,
|
|
303
|
+
TestTileMatmul,
|
|
296
304
|
TestTileReduce,
|
|
297
305
|
TestTileSharedMemory,
|
|
306
|
+
TestTileSort,
|
|
307
|
+
TestTileView,
|
|
298
308
|
TestTorch,
|
|
299
309
|
TestTransientModule,
|
|
300
310
|
TestTriangleClosestPoint,
|
|
301
311
|
TestTypes,
|
|
302
312
|
TestUtils,
|
|
303
|
-
TestVbd,
|
|
304
313
|
TestVec,
|
|
314
|
+
TestVecConstructors,
|
|
305
315
|
TestVecLite,
|
|
306
316
|
TestVecScalarOps,
|
|
307
317
|
TestVerifyFP,
|
|
@@ -350,7 +360,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
|
|
|
350
360
|
from warp.tests.test_lvalue import TestLValue
|
|
351
361
|
from warp.tests.test_mat_lite import TestMatLite
|
|
352
362
|
from warp.tests.test_math import TestMath
|
|
353
|
-
from warp.tests.test_mlp import TestMLP
|
|
354
363
|
from warp.tests.test_module_hashing import TestModuleHashing
|
|
355
364
|
from warp.tests.test_modules_lite import TestModuleLite
|
|
356
365
|
from warp.tests.test_noise import TestNoise
|
|
@@ -397,7 +406,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
|
|
|
397
406
|
TestMeshQueryAABBMethods,
|
|
398
407
|
TestMeshQueryPoint,
|
|
399
408
|
TestMeshQueryRay,
|
|
400
|
-
TestMLP,
|
|
401
409
|
TestModuleHashing,
|
|
402
410
|
TestModuleLite,
|
|
403
411
|
TestNoise,
|