warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -20,8 +20,6 @@ import numpy as np
|
|
|
20
20
|
import warp as wp
|
|
21
21
|
from warp.tests.unittest_utils import *
|
|
22
22
|
|
|
23
|
-
wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
|
|
24
|
-
|
|
25
23
|
TILE_M = wp.constant(8)
|
|
26
24
|
TILE_N = wp.constant(4)
|
|
27
25
|
TILE_K = wp.constant(8)
|
|
@@ -216,7 +214,6 @@ def test_tile_binary_map(test, device):
|
|
|
216
214
|
assert_np_equal(B_wp.grad.numpy(), B_grad)
|
|
217
215
|
|
|
218
216
|
|
|
219
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
220
217
|
def test_tile_grouped_gemm(test, device):
|
|
221
218
|
@wp.kernel
|
|
222
219
|
def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
|
|
@@ -256,60 +253,62 @@ def test_tile_grouped_gemm(test, device):
|
|
|
256
253
|
assert_np_equal(C_wp.numpy(), C, 1e-6)
|
|
257
254
|
|
|
258
255
|
|
|
259
|
-
|
|
260
|
-
def
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
256
|
+
def test_tile_gemm(dtype):
|
|
257
|
+
def test(test, device):
|
|
258
|
+
@wp.kernel
|
|
259
|
+
def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
|
|
260
|
+
# output tile index
|
|
261
|
+
i, j = wp.tid()
|
|
265
262
|
|
|
266
|
-
|
|
263
|
+
sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
|
|
267
264
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
265
|
+
M = A.shape[0]
|
|
266
|
+
N = B.shape[1]
|
|
267
|
+
K = A.shape[1]
|
|
271
268
|
|
|
272
|
-
|
|
269
|
+
count = int(K / TILE_K)
|
|
273
270
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
271
|
+
for k in range(0, count):
|
|
272
|
+
a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
|
|
273
|
+
b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
|
|
277
274
|
|
|
278
|
-
|
|
279
|
-
|
|
275
|
+
# sum += a*b
|
|
276
|
+
wp.tile_matmul(a, b, sum)
|
|
280
277
|
|
|
281
|
-
|
|
278
|
+
wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
|
|
282
279
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
280
|
+
M = TILE_M * 7
|
|
281
|
+
K = TILE_K * 6
|
|
282
|
+
N = TILE_N * 5
|
|
286
283
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
284
|
+
rng = np.random.default_rng(42)
|
|
285
|
+
A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
286
|
+
B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
287
|
+
C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
|
|
291
288
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
289
|
+
A_wp = wp.array(A, requires_grad=True, device=device)
|
|
290
|
+
B_wp = wp.array(B, requires_grad=True, device=device)
|
|
291
|
+
C_wp = wp.array(C, requires_grad=True, device=device)
|
|
295
292
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
293
|
+
with wp.Tape() as tape:
|
|
294
|
+
wp.launch_tiled(
|
|
295
|
+
tile_gemm,
|
|
296
|
+
dim=(int(M / TILE_M), int(N / TILE_N)),
|
|
297
|
+
inputs=[A_wp, B_wp, C_wp],
|
|
298
|
+
block_dim=TILE_DIM,
|
|
299
|
+
device=device,
|
|
300
|
+
)
|
|
304
301
|
|
|
305
|
-
|
|
302
|
+
assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
|
|
306
303
|
|
|
307
|
-
|
|
304
|
+
adj_C = np.ones_like(C)
|
|
308
305
|
|
|
309
|
-
|
|
306
|
+
tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
|
|
310
307
|
|
|
311
|
-
|
|
312
|
-
|
|
308
|
+
assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
|
|
309
|
+
assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
|
|
310
|
+
|
|
311
|
+
return test
|
|
313
312
|
|
|
314
313
|
|
|
315
314
|
@wp.kernel
|
|
@@ -550,7 +549,6 @@ def test_tile_transpose(test, device):
|
|
|
550
549
|
assert_np_equal(output.numpy(), input.numpy().T)
|
|
551
550
|
|
|
552
551
|
|
|
553
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
554
552
|
def test_tile_transpose_matmul(test, device):
|
|
555
553
|
@wp.kernel
|
|
556
554
|
def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
|
|
@@ -572,9 +570,36 @@ def test_tile_transpose_matmul(test, device):
|
|
|
572
570
|
|
|
573
571
|
|
|
574
572
|
@wp.kernel
|
|
575
|
-
def
|
|
573
|
+
def test_tile_broadcast_add_1d_kernel(
|
|
574
|
+
input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
|
|
575
|
+
):
|
|
576
|
+
a = wp.tile_load(input_a, shape=(10,))
|
|
577
|
+
b = wp.tile_load(input_b, shape=(1,))
|
|
578
|
+
|
|
579
|
+
c = wp.tile_broadcast(b, shape=(10,))
|
|
580
|
+
d = a + c
|
|
581
|
+
|
|
582
|
+
wp.tile_store(output, d)
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def test_tile_broadcast_add_1d(test, device):
|
|
586
|
+
N = 10
|
|
587
|
+
|
|
588
|
+
# implicit 1-dim ([1], 1)
|
|
589
|
+
a = wp.array(np.arange(0, N, dtype=np.float32), device=device)
|
|
590
|
+
b = wp.array(np.ones(1, dtype=np.float32), device=device)
|
|
591
|
+
out = wp.zeros((N,), dtype=float, device=device)
|
|
592
|
+
|
|
593
|
+
wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
594
|
+
|
|
595
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
@wp.kernel
|
|
599
|
+
def test_tile_broadcast_add_2d_kernel(
|
|
576
600
|
input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
|
|
577
601
|
):
|
|
602
|
+
# implicit 1-dim ([1], 10)
|
|
578
603
|
a = wp.tile_load(input_a, shape=(10, 10))
|
|
579
604
|
b = wp.tile_load(input_b, shape=10)
|
|
580
605
|
|
|
@@ -584,7 +609,7 @@ def test_tile_broadcast_add_kernel(
|
|
|
584
609
|
wp.tile_store(output, d)
|
|
585
610
|
|
|
586
611
|
|
|
587
|
-
def
|
|
612
|
+
def test_tile_broadcast_add_2d(test, device):
|
|
588
613
|
M = 10
|
|
589
614
|
N = 10
|
|
590
615
|
|
|
@@ -592,7 +617,62 @@ def test_tile_broadcast_add(test, device):
|
|
|
592
617
|
b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
|
|
593
618
|
out = wp.zeros((M, N), dtype=float, device=device)
|
|
594
619
|
|
|
595
|
-
wp.launch_tiled(
|
|
620
|
+
wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
621
|
+
|
|
622
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@wp.kernel
|
|
626
|
+
def test_tile_broadcast_add_3d_kernel(
|
|
627
|
+
input_a: wp.array3d(dtype=float), input_b: wp.array3d(dtype=float), output: wp.array3d(dtype=float)
|
|
628
|
+
):
|
|
629
|
+
a = wp.tile_load(input_a, shape=(4, 10, 12))
|
|
630
|
+
b = wp.tile_load(input_b, shape=(4, 10, 1))
|
|
631
|
+
|
|
632
|
+
c = wp.tile_broadcast(b, shape=(4, 10, 12))
|
|
633
|
+
d = a + c
|
|
634
|
+
|
|
635
|
+
wp.tile_store(output, d)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def test_tile_broadcast_add_3d(test, device):
|
|
639
|
+
M = 4
|
|
640
|
+
N = 10
|
|
641
|
+
O = 12
|
|
642
|
+
|
|
643
|
+
# explicit 1-dim (M, N, 1) to (M, N, O)
|
|
644
|
+
a = wp.array(np.ones((M, N, O), dtype=np.float32), device=device)
|
|
645
|
+
b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
|
|
646
|
+
out = wp.zeros((M, N, O), dtype=float, device=device)
|
|
647
|
+
|
|
648
|
+
wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
649
|
+
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
@wp.kernel
|
|
653
|
+
def test_tile_broadcast_add_4d_kernel(
|
|
654
|
+
input_a: wp.array4d(dtype=float), input_b: wp.array4d(dtype=float), output: wp.array4d(dtype=float)
|
|
655
|
+
):
|
|
656
|
+
a = wp.tile_load(input_a, shape=(4, 10, 5, 6))
|
|
657
|
+
b = wp.tile_load(input_b, shape=(4, 1, 5, 1))
|
|
658
|
+
c = wp.tile_broadcast(b, shape=(4, 10, 5, 6))
|
|
659
|
+
d = a + c
|
|
660
|
+
|
|
661
|
+
wp.tile_store(output, d)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def test_tile_broadcast_add_4d(test, device):
|
|
665
|
+
M = 4
|
|
666
|
+
N = 10
|
|
667
|
+
O = 5
|
|
668
|
+
P = 6
|
|
669
|
+
|
|
670
|
+
# explicit 1-dims (M, 1, O, 1) to (M, N, O, P)
|
|
671
|
+
a = wp.array(np.ones((M, N, O, P), dtype=np.float32), device=device)
|
|
672
|
+
b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
|
|
673
|
+
out = wp.zeros((M, N, O, P), dtype=float, device=device)
|
|
674
|
+
|
|
675
|
+
wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
|
|
596
676
|
|
|
597
677
|
assert_np_equal(out.numpy(), a.numpy() + b.numpy())
|
|
598
678
|
|
|
@@ -665,7 +745,7 @@ def test_tile_print(test, device):
|
|
|
665
745
|
wp.synchronize()
|
|
666
746
|
|
|
667
747
|
|
|
668
|
-
devices =
|
|
748
|
+
devices = get_test_devices()
|
|
669
749
|
|
|
670
750
|
|
|
671
751
|
class TestTile(unittest.TestCase):
|
|
@@ -677,15 +757,20 @@ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devi
|
|
|
677
757
|
add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
|
|
678
758
|
add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
|
|
679
759
|
add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
|
|
680
|
-
add_function_test(TestTile, "
|
|
760
|
+
add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
|
|
761
|
+
add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
|
|
762
|
+
add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
|
|
681
763
|
add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
|
|
682
764
|
add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
|
|
683
765
|
add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
|
|
684
|
-
add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
|
|
766
|
+
add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
|
|
685
767
|
add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
|
|
686
768
|
add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
|
|
687
769
|
add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
|
|
688
|
-
add_function_test(TestTile, "
|
|
770
|
+
add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
|
|
771
|
+
add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
|
|
772
|
+
add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
|
|
773
|
+
add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
|
|
689
774
|
add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
|
|
690
775
|
add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
|
|
691
776
|
add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
|
|
@@ -184,6 +184,96 @@ def test_tile_load_unaligned(test, device):
|
|
|
184
184
|
assert_np_equal(input.grad.numpy(), expected_grad)
|
|
185
185
|
|
|
186
186
|
|
|
187
|
+
@wp.kernel
|
|
188
|
+
def tile_load_aligned_small_kernel(
|
|
189
|
+
input: wp.array2d(dtype=float),
|
|
190
|
+
output: wp.array2d(dtype=float),
|
|
191
|
+
):
|
|
192
|
+
t = wp.tile_load(input, shape=(3, 3), offset=(0, 0), storage="shared")
|
|
193
|
+
wp.tile_store(output, t, offset=(0, 0))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# regression test for tiles that are smaller than sizeof(float4) in that last
|
|
197
|
+
# dimension but are aligned to float4. Did trigger the fast float4 path by accident.
|
|
198
|
+
def test_tile_load_aligned_small(test, device):
|
|
199
|
+
rng = np.random.default_rng(42)
|
|
200
|
+
|
|
201
|
+
shape = [TILE_M, TILE_N]
|
|
202
|
+
|
|
203
|
+
input = wp.array(rng.random(shape), dtype=float, requires_grad=True, device=device)
|
|
204
|
+
output = wp.zeros(shape, dtype=float, device=device)
|
|
205
|
+
|
|
206
|
+
wp.launch_tiled(
|
|
207
|
+
tile_load_aligned_small_kernel,
|
|
208
|
+
dim=[1],
|
|
209
|
+
inputs=[input, output],
|
|
210
|
+
block_dim=TILE_DIM,
|
|
211
|
+
device=device,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# zeros except for the 3x3 tile at 0, 0
|
|
215
|
+
assert_np_equal(output.numpy()[3:, :], np.zeros((TILE_M - 3, TILE_N)))
|
|
216
|
+
assert_np_equal(output.numpy()[:, 3:], np.zeros((TILE_M, TILE_N - 3)))
|
|
217
|
+
|
|
218
|
+
# check output elements
|
|
219
|
+
assert_np_equal(output.numpy()[:3, :3], input.numpy()[:3, :3])
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
TILE_WIDTH = 5
|
|
223
|
+
TILE_OFFSET_X = 0
|
|
224
|
+
TILE_OFFSET_Y = 8
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@wp.kernel
|
|
228
|
+
def test_tile_load_aligned_offset_unaligned_size_kernel(
|
|
229
|
+
input: wp.array2d(dtype=float),
|
|
230
|
+
output: wp.array2d(dtype=float),
|
|
231
|
+
):
|
|
232
|
+
# Load a 5x5 tile from the input array starting at offset (0,8)
|
|
233
|
+
# and store it in shared memory
|
|
234
|
+
tile = wp.tile_load(input, shape=(TILE_WIDTH, TILE_WIDTH), offset=(TILE_OFFSET_X, TILE_OFFSET_Y), storage="shared")
|
|
235
|
+
|
|
236
|
+
# Store the loaded tile back to the output array at the same offset
|
|
237
|
+
wp.tile_store(output, tile, offset=(TILE_OFFSET_X, TILE_OFFSET_Y))
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_tile_load_aligned_offset_unaligned_size(test, device):
|
|
241
|
+
"""Test loading a tile with aligned offset but unaligned size."""
|
|
242
|
+
|
|
243
|
+
rng = np.random.default_rng(42)
|
|
244
|
+
array_shape = [TILE_N, TILE_M]
|
|
245
|
+
|
|
246
|
+
input_array = wp.array(rng.random(array_shape), dtype=float, requires_grad=True, device=device)
|
|
247
|
+
output_array = wp.zeros(array_shape, dtype=float, device=device)
|
|
248
|
+
|
|
249
|
+
wp.launch_tiled(
|
|
250
|
+
test_tile_load_aligned_offset_unaligned_size_kernel,
|
|
251
|
+
dim=[1],
|
|
252
|
+
inputs=[input_array, output_array],
|
|
253
|
+
block_dim=TILE_DIM,
|
|
254
|
+
device=device,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Region before the tile offset should be zeros
|
|
258
|
+
assert_np_equal(output_array.numpy()[:TILE_WIDTH, :TILE_OFFSET_Y], np.zeros((TILE_WIDTH, TILE_OFFSET_Y)))
|
|
259
|
+
|
|
260
|
+
# Region where the tile was loaded/stored should match input
|
|
261
|
+
assert_np_equal(
|
|
262
|
+
output_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y : TILE_OFFSET_Y + TILE_WIDTH],
|
|
263
|
+
input_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y : TILE_OFFSET_Y + TILE_WIDTH],
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Region after the tile should be zeros
|
|
267
|
+
remaining_width = TILE_M - (TILE_OFFSET_Y + TILE_WIDTH)
|
|
268
|
+
assert_np_equal(
|
|
269
|
+
output_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y + TILE_WIDTH :], np.zeros((TILE_WIDTH, remaining_width))
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Rows below the tile should all be zeros
|
|
273
|
+
remaining_height = TILE_N - TILE_WIDTH
|
|
274
|
+
assert_np_equal(output_array.numpy()[TILE_WIDTH:, :], np.zeros((remaining_height, TILE_M)))
|
|
275
|
+
|
|
276
|
+
|
|
187
277
|
# ----------------------------------------------------------------------------------------
|
|
188
278
|
|
|
189
279
|
TILE_SIZE = 4
|
|
@@ -376,7 +466,7 @@ def test_tile_load_fortran(test, device):
|
|
|
376
466
|
assert_array_equal(B_wp.grad, A_wp.grad)
|
|
377
467
|
|
|
378
468
|
|
|
379
|
-
devices =
|
|
469
|
+
devices = get_test_devices()
|
|
380
470
|
|
|
381
471
|
|
|
382
472
|
class TestTileLoad(unittest.TestCase):
|
|
@@ -388,6 +478,13 @@ add_function_test(TestTileLoad, "test_tile_load_2d", test_tile_load(tile_load_2d
|
|
|
388
478
|
add_function_test(TestTileLoad, "test_tile_load_3d", test_tile_load(tile_load_3d_kernel, 3), devices=devices)
|
|
389
479
|
add_function_test(TestTileLoad, "test_tile_load_4d", test_tile_load(tile_load_4d_kernel, 4), devices=devices)
|
|
390
480
|
add_function_test(TestTileLoad, "test_tile_load_unaligned", test_tile_load_unaligned, devices=devices)
|
|
481
|
+
add_function_test(TestTileLoad, "test_tile_load_aligned_small", test_tile_load_aligned_small, devices=devices)
|
|
482
|
+
add_function_test(
|
|
483
|
+
TestTileLoad,
|
|
484
|
+
"test_tile_load_aligned_offset_unaligned_size",
|
|
485
|
+
test_tile_load_aligned_offset_unaligned_size,
|
|
486
|
+
devices=devices,
|
|
487
|
+
)
|
|
391
488
|
|
|
392
489
|
add_function_test(TestTileLoad, "test_tile_extract_1d", test_tile_extract(tile_extract_1d_kernel, 1), devices=devices)
|
|
393
490
|
add_function_test(TestTileLoad, "test_tile_extract_2d", test_tile_extract(tile_extract_2d_kernel, 2), devices=devices)
|
|
@@ -92,6 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
|
|
|
92
92
|
wp.tile_store(gy, xy)
|
|
93
93
|
|
|
94
94
|
|
|
95
|
+
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
95
96
|
def test_tile_math_fft(test, device, wp_dtype):
|
|
96
97
|
np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
|
|
97
98
|
np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
|
|
@@ -172,31 +173,33 @@ def test_tile_math_cholesky(test, device):
|
|
|
172
173
|
# TODO: implement and test backward pass
|
|
173
174
|
|
|
174
175
|
|
|
175
|
-
|
|
176
|
+
all_devices = get_test_devices()
|
|
177
|
+
cuda_devices = get_cuda_test_devices()
|
|
176
178
|
|
|
177
179
|
|
|
178
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
179
180
|
class TestTileMathDx(unittest.TestCase):
|
|
180
181
|
pass
|
|
181
182
|
|
|
182
183
|
|
|
183
184
|
# check_output=False so we can enable libmathdx's logging without failing the tests
|
|
184
|
-
add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
|
|
185
185
|
add_function_test(
|
|
186
|
-
TestTileMathDx, "
|
|
186
|
+
TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=all_devices, check_output=False
|
|
187
|
+
)
|
|
188
|
+
add_function_test(
|
|
189
|
+
TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=all_devices, check_output=False
|
|
187
190
|
)
|
|
188
191
|
add_function_test(
|
|
189
192
|
TestTileMathDx,
|
|
190
193
|
"test_tile_math_fft_vec2f",
|
|
191
194
|
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
|
|
192
|
-
devices=
|
|
195
|
+
devices=cuda_devices,
|
|
193
196
|
check_output=False,
|
|
194
197
|
)
|
|
195
198
|
add_function_test(
|
|
196
199
|
TestTileMathDx,
|
|
197
200
|
"test_tile_math_fft_vec2d",
|
|
198
201
|
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
|
|
199
|
-
devices=
|
|
202
|
+
devices=cuda_devices,
|
|
200
203
|
check_output=False,
|
|
201
204
|
)
|
|
202
205
|
|
|
@@ -22,11 +22,6 @@ import warp.examples
|
|
|
22
22
|
import warp.optim
|
|
23
23
|
from warp.tests.unittest_utils import *
|
|
24
24
|
|
|
25
|
-
wp.init()
|
|
26
|
-
|
|
27
|
-
# needs to be constant for the whole module
|
|
28
|
-
NUM_THREADS = 32
|
|
29
|
-
|
|
30
25
|
|
|
31
26
|
def create_layer(rng, dim_in, dim_hid, dtype=float):
|
|
32
27
|
w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
@@ -45,10 +40,12 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
|
|
|
45
40
|
return a
|
|
46
41
|
|
|
47
42
|
|
|
48
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
49
43
|
def test_multi_layer_nn(test, device):
|
|
50
44
|
import torch as tc
|
|
51
45
|
|
|
46
|
+
if device.is_cuda and not wp.context.runtime.core.is_mathdx_enabled():
|
|
47
|
+
test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
|
|
48
|
+
|
|
52
49
|
NUM_FREQ = wp.constant(8)
|
|
53
50
|
|
|
54
51
|
DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
|
|
@@ -60,7 +57,13 @@ def test_multi_layer_nn(test, device):
|
|
|
60
57
|
|
|
61
58
|
BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
|
|
62
59
|
|
|
60
|
+
if device.is_cpu:
|
|
61
|
+
NUM_THREADS = 1
|
|
62
|
+
else:
|
|
63
|
+
NUM_THREADS = 32
|
|
64
|
+
|
|
63
65
|
dtype = wp.float16
|
|
66
|
+
npdtype = wp.types.warp_type_to_np_dtype[dtype]
|
|
64
67
|
|
|
65
68
|
@wp.func
|
|
66
69
|
def relu(x: dtype):
|
|
@@ -74,7 +77,7 @@ def test_multi_layer_nn(test, device):
|
|
|
74
77
|
def zero(loss: wp.array(dtype=float)):
|
|
75
78
|
loss[0] = 0.0
|
|
76
79
|
|
|
77
|
-
@wp.kernel
|
|
80
|
+
@wp.kernel(module="unique")
|
|
78
81
|
def compute(
|
|
79
82
|
batches: wp.array(dtype=int),
|
|
80
83
|
input: wp.array2d(dtype=dtype),
|
|
@@ -170,7 +173,9 @@ def test_multi_layer_nn(test, device):
|
|
|
170
173
|
input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
|
|
171
174
|
output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
|
|
172
175
|
|
|
173
|
-
reference_np =
|
|
176
|
+
reference_np = (
|
|
177
|
+
np.load(os.path.join(os.path.dirname(__file__), "..", "assets", "pixel.npy"), allow_pickle=True) / 255.0
|
|
178
|
+
)
|
|
174
179
|
reference = wp.array(reference_np, dtype=float)
|
|
175
180
|
|
|
176
181
|
assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
|
|
@@ -232,7 +237,7 @@ def test_multi_layer_nn(test, device):
|
|
|
232
237
|
z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
|
|
233
238
|
|
|
234
239
|
# test numpy forward
|
|
235
|
-
assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
|
|
240
|
+
assert_np_equal(output.numpy()[:, indices].astype(npdtype), z_np, tol=1.0e-2)
|
|
236
241
|
|
|
237
242
|
# torch
|
|
238
243
|
input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
|
|
@@ -260,7 +265,9 @@ def test_multi_layer_nn(test, device):
|
|
|
260
265
|
l_tc.backward()
|
|
261
266
|
|
|
262
267
|
# test torch
|
|
263
|
-
assert_np_equal(
|
|
268
|
+
assert_np_equal(
|
|
269
|
+
z_tc.cpu().detach().numpy(), output.numpy()[:, indices].astype(npdtype), tol=1.0e-2
|
|
270
|
+
)
|
|
264
271
|
assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
265
272
|
assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
266
273
|
assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
@@ -277,7 +284,6 @@ def test_multi_layer_nn(test, device):
|
|
|
277
284
|
test.assertLess(loss.numpy()[0], 0.002)
|
|
278
285
|
|
|
279
286
|
|
|
280
|
-
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
281
287
|
def test_single_layer_nn(test, device):
|
|
282
288
|
import torch as tc
|
|
283
289
|
|
|
@@ -287,11 +293,16 @@ def test_single_layer_nn(test, device):
|
|
|
287
293
|
|
|
288
294
|
NUM_BLOCKS = 56
|
|
289
295
|
|
|
296
|
+
if device.is_cpu:
|
|
297
|
+
NUM_THREADS = 1
|
|
298
|
+
else:
|
|
299
|
+
NUM_THREADS = 32
|
|
300
|
+
|
|
290
301
|
@wp.func
|
|
291
302
|
def relu(x: float):
|
|
292
303
|
return wp.max(x, 0.0)
|
|
293
304
|
|
|
294
|
-
@wp.kernel
|
|
305
|
+
@wp.kernel(module="unique")
|
|
295
306
|
def compute(
|
|
296
307
|
input: wp.array2d(dtype=float),
|
|
297
308
|
weights: wp.array2d(dtype=float),
|
|
@@ -353,7 +364,6 @@ try:
|
|
|
353
364
|
import torch
|
|
354
365
|
|
|
355
366
|
# check which Warp devices work with Torch
|
|
356
|
-
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
357
367
|
torch_compatible_devices = []
|
|
358
368
|
torch_compatible_cuda_devices = []
|
|
359
369
|
|
|
@@ -372,7 +382,7 @@ try:
|
|
|
372
382
|
"test_single_layer_nn",
|
|
373
383
|
test_single_layer_nn,
|
|
374
384
|
check_output=False,
|
|
375
|
-
devices=
|
|
385
|
+
devices=torch_compatible_devices,
|
|
376
386
|
)
|
|
377
387
|
add_function_test(
|
|
378
388
|
TestTileMLP,
|
|
@@ -388,4 +398,5 @@ except Exception as e:
|
|
|
388
398
|
|
|
389
399
|
if __name__ == "__main__":
|
|
390
400
|
wp.clear_kernel_cache()
|
|
401
|
+
wp.clear_lto_cache()
|
|
391
402
|
unittest.main(verbosity=2, failfast=True)
|
|
@@ -176,6 +176,64 @@ def test_tile_reduce_custom(test, device):
|
|
|
176
176
|
test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
|
|
177
177
|
|
|
178
178
|
|
|
179
|
+
@wp.struct
|
|
180
|
+
class KeyValue:
|
|
181
|
+
key: wp.int32
|
|
182
|
+
value: wp.float32
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@wp.func
|
|
186
|
+
def kv_max(a: KeyValue, b: KeyValue) -> KeyValue:
|
|
187
|
+
return wp.where(a.value < b.value, b, a)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@wp.kernel
|
|
191
|
+
def initialize_key_value(values: wp.array2d(dtype=wp.float32), keyvalues: wp.array2d(dtype=KeyValue)):
|
|
192
|
+
batch, idx = wp.tid()
|
|
193
|
+
keyvalues[batch, idx] = KeyValue(idx, values[batch, idx])
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@wp.kernel(enable_backward=False)
|
|
197
|
+
def tile_reduce_custom_struct_kernel(values: wp.array2d(dtype=KeyValue), res: wp.array(dtype=KeyValue)):
|
|
198
|
+
# output tile index
|
|
199
|
+
i = wp.tid()
|
|
200
|
+
|
|
201
|
+
t = wp.tile_load(values, shape=(1, TILE_DIM), offset=(i, 0))
|
|
202
|
+
|
|
203
|
+
max_el = wp.tile_reduce(kv_max, t)
|
|
204
|
+
wp.tile_store(res, max_el, offset=i)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def test_tile_reduce_custom_struct(test, device):
|
|
208
|
+
batch_count = 56
|
|
209
|
+
|
|
210
|
+
N = TILE_DIM
|
|
211
|
+
|
|
212
|
+
rng = np.random.default_rng(42)
|
|
213
|
+
input = rng.random((batch_count, N), dtype=np.float32)
|
|
214
|
+
|
|
215
|
+
input_wp = wp.array(input, dtype=wp.float32, device=device)
|
|
216
|
+
keyvalues_wp = wp.empty(input_wp.shape, dtype=KeyValue, device=device)
|
|
217
|
+
|
|
218
|
+
wp.launch(initialize_key_value, dim=[batch_count, N], inputs=[input_wp], outputs=[keyvalues_wp], device=device)
|
|
219
|
+
|
|
220
|
+
output_wp = wp.empty(batch_count, dtype=KeyValue, device=device)
|
|
221
|
+
|
|
222
|
+
wp.launch_tiled(
|
|
223
|
+
tile_reduce_custom_struct_kernel,
|
|
224
|
+
dim=[batch_count],
|
|
225
|
+
inputs=[keyvalues_wp],
|
|
226
|
+
outputs=[output_wp],
|
|
227
|
+
block_dim=TILE_DIM,
|
|
228
|
+
device=device,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
prod_wp = np.array([k for k, v in output_wp.numpy()])
|
|
232
|
+
expected = np.argmax(input, axis=1)
|
|
233
|
+
|
|
234
|
+
assert_np_equal(prod_wp, expected)
|
|
235
|
+
|
|
236
|
+
|
|
179
237
|
@wp.kernel
|
|
180
238
|
def tile_grouped_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
|
|
181
239
|
# output tile index
|
|
@@ -365,7 +423,7 @@ def test_tile_arange(test, device):
|
|
|
365
423
|
assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
|
|
366
424
|
|
|
367
425
|
|
|
368
|
-
devices =
|
|
426
|
+
devices = get_test_devices()
|
|
369
427
|
|
|
370
428
|
|
|
371
429
|
class TestTileReduce(unittest.TestCase):
|
|
@@ -376,6 +434,7 @@ add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum,
|
|
|
376
434
|
add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
|
|
377
435
|
add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
|
|
378
436
|
add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
|
|
437
|
+
add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
|
|
379
438
|
add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
|
|
380
439
|
add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
|
|
381
440
|
add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
|