warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/tests/test_array.py
CHANGED
|
@@ -412,7 +412,7 @@ def test_slicing(test, device):
|
|
|
412
412
|
assert_array_equal(wp_arr[:5], wp.array(np_arr[:5], dtype=int, device=device))
|
|
413
413
|
assert_array_equal(wp_arr[1:5], wp.array(np_arr[1:5], dtype=int, device=device))
|
|
414
414
|
assert_array_equal(wp_arr[-9:-5:1], wp.array(np_arr[-9:-5:1], dtype=int, device=device))
|
|
415
|
-
assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device))
|
|
415
|
+
assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device))
|
|
416
416
|
|
|
417
417
|
|
|
418
418
|
def test_view(test, device):
|
|
@@ -2370,6 +2370,257 @@ def test_array_from_cai(test, device):
|
|
|
2370
2370
|
assert_np_equal(arr_warp.numpy(), np.array([[2, 1, 1], [1, 0, 0], [1, 0, 0]]))
|
|
2371
2371
|
|
|
2372
2372
|
|
|
2373
|
+
def test_array_from_data(test, device):
|
|
2374
|
+
with wp.ScopedDevice(device):
|
|
2375
|
+
# =========================================
|
|
2376
|
+
# scalars, reshaping
|
|
2377
|
+
|
|
2378
|
+
data = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
2379
|
+
src = wp.array(data)
|
|
2380
|
+
|
|
2381
|
+
assert src.device == device
|
|
2382
|
+
|
|
2383
|
+
dtypes = [Any, wp.float32]
|
|
2384
|
+
shapes = [None, (3, 4), (12,), (3, 2, 2)]
|
|
2385
|
+
|
|
2386
|
+
for dtype in dtypes:
|
|
2387
|
+
for shape in shapes:
|
|
2388
|
+
with test.subTest(msg=f"scalar, dtype={dtype}, shape={shape}"):
|
|
2389
|
+
dst = wp.array(src, dtype=dtype, shape=shape)
|
|
2390
|
+
assert dst.device == src.device
|
|
2391
|
+
if dtype is Any:
|
|
2392
|
+
assert dst.dtype == src.dtype
|
|
2393
|
+
else:
|
|
2394
|
+
assert dst.dtype == dtype
|
|
2395
|
+
if shape is None:
|
|
2396
|
+
assert dst.shape == src.shape
|
|
2397
|
+
assert_np_equal(dst.numpy(), data)
|
|
2398
|
+
else:
|
|
2399
|
+
assert dst.shape == shape
|
|
2400
|
+
assert_np_equal(dst.numpy(), data.reshape(shape))
|
|
2401
|
+
|
|
2402
|
+
# =========================================
|
|
2403
|
+
# vectors, reshaping
|
|
2404
|
+
|
|
2405
|
+
with test.subTest(msg="vector, single"):
|
|
2406
|
+
data = np.arange(3, dtype=np.float32)
|
|
2407
|
+
src = wp.array(data)
|
|
2408
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2409
|
+
assert dst.dtype == wp.vec3
|
|
2410
|
+
assert dst.shape == (1,)
|
|
2411
|
+
assert_np_equal(dst.numpy(), data.reshape((1, 3)))
|
|
2412
|
+
|
|
2413
|
+
with test.subTest(msg="vector, multiple in 1d"):
|
|
2414
|
+
data = np.arange(12, dtype=np.float32)
|
|
2415
|
+
src = wp.array(data)
|
|
2416
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2417
|
+
assert dst.dtype == wp.vec3
|
|
2418
|
+
assert dst.shape == (4,)
|
|
2419
|
+
assert_np_equal(dst.numpy(), data.reshape((4, 3)))
|
|
2420
|
+
|
|
2421
|
+
with test.subTest(msg="vector, singles in 2d"):
|
|
2422
|
+
data = np.arange(12, dtype=np.float32).reshape((4, 3))
|
|
2423
|
+
src = wp.array(data)
|
|
2424
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2425
|
+
assert dst.dtype == wp.vec3
|
|
2426
|
+
assert dst.shape == (4,)
|
|
2427
|
+
assert_np_equal(dst.numpy(), data.reshape((4, 3)))
|
|
2428
|
+
|
|
2429
|
+
with test.subTest(msg="vector, multiples in 2d"):
|
|
2430
|
+
data = np.arange(24, dtype=np.float32).reshape((4, 6))
|
|
2431
|
+
src = wp.array(data)
|
|
2432
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2433
|
+
assert dst.dtype == wp.vec3
|
|
2434
|
+
assert dst.shape == (4, 2)
|
|
2435
|
+
assert_np_equal(dst.numpy(), data.reshape((4, 2, 3)))
|
|
2436
|
+
|
|
2437
|
+
with test.subTest(msg="vector, singles in 2d, reshape"):
|
|
2438
|
+
data = np.arange(12, dtype=np.float32).reshape((4, 3))
|
|
2439
|
+
src = wp.array(data)
|
|
2440
|
+
dst = wp.array(src, dtype=wp.vec3, shape=(2, 2))
|
|
2441
|
+
assert dst.dtype == wp.vec3
|
|
2442
|
+
assert dst.shape == (2, 2)
|
|
2443
|
+
assert_np_equal(dst.numpy(), data.reshape((2, 2, 3)))
|
|
2444
|
+
|
|
2445
|
+
with test.subTest(msg="vector, multiples in 2d, reshape"):
|
|
2446
|
+
data = np.arange(24, dtype=np.float32).reshape((4, 6))
|
|
2447
|
+
src = wp.array(data)
|
|
2448
|
+
dst = wp.array(src, dtype=wp.vec3, shape=(2, 2, 2))
|
|
2449
|
+
assert dst.dtype == wp.vec3
|
|
2450
|
+
assert dst.shape == (2, 2, 2)
|
|
2451
|
+
assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 3)))
|
|
2452
|
+
|
|
2453
|
+
# =========================================
|
|
2454
|
+
# matrices, reshaping
|
|
2455
|
+
|
|
2456
|
+
with test.subTest(msg="matrix, single in 2d"):
|
|
2457
|
+
# one 2x2 matrix in a 2d array
|
|
2458
|
+
data = np.arange(4, dtype=np.float32).reshape((2, 2))
|
|
2459
|
+
src = wp.array(data)
|
|
2460
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2461
|
+
assert dst.dtype == wp.mat22
|
|
2462
|
+
assert dst.shape == (1,)
|
|
2463
|
+
assert_np_equal(dst.numpy(), data.reshape((1, 2, 2)))
|
|
2464
|
+
|
|
2465
|
+
with test.subTest(msg="matrix, single in 1d"):
|
|
2466
|
+
# 2x2 matrix in a 1d array
|
|
2467
|
+
data = np.arange(4, dtype=np.float32)
|
|
2468
|
+
src = wp.array(data)
|
|
2469
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2470
|
+
assert dst.dtype == wp.mat22
|
|
2471
|
+
assert dst.shape == (1,)
|
|
2472
|
+
assert_np_equal(dst.numpy(), data.reshape((1, 2, 2)))
|
|
2473
|
+
|
|
2474
|
+
with test.subTest(msg="matrix, multiples in 1d"):
|
|
2475
|
+
# 3 2x2 matrices in a 1d array
|
|
2476
|
+
data = np.arange(12, dtype=np.float32)
|
|
2477
|
+
src = wp.array(data)
|
|
2478
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2479
|
+
assert dst.dtype == wp.mat22
|
|
2480
|
+
assert dst.shape == (3,)
|
|
2481
|
+
assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
|
|
2482
|
+
|
|
2483
|
+
with test.subTest(msg="matrix, multiples in 1d, reshape"):
|
|
2484
|
+
# 4 2x2 matrices in a 1d array
|
|
2485
|
+
data = np.arange(16, dtype=np.float32)
|
|
2486
|
+
src = wp.array(data)
|
|
2487
|
+
dst = wp.array(src, dtype=wp.mat22, shape=(4,))
|
|
2488
|
+
assert dst.dtype == wp.mat22
|
|
2489
|
+
assert dst.shape == (4,)
|
|
2490
|
+
assert_np_equal(dst.numpy(), data.reshape((4, 2, 2)))
|
|
2491
|
+
|
|
2492
|
+
with test.subTest(msg="matrix, multiples in 2d"):
|
|
2493
|
+
# 3 2x2 matrices in a 2d array
|
|
2494
|
+
data = np.arange(12, dtype=np.float32).reshape((3, 4))
|
|
2495
|
+
src = wp.array(data)
|
|
2496
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2497
|
+
assert dst.dtype == wp.mat22
|
|
2498
|
+
assert dst.shape == (3,)
|
|
2499
|
+
assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
|
|
2500
|
+
|
|
2501
|
+
with test.subTest(msg="matrix, multiples in 2d, reshape"):
|
|
2502
|
+
# 4 2x2 matrices in a 2d array
|
|
2503
|
+
data = np.arange(16, dtype=np.float32).reshape((4, 4))
|
|
2504
|
+
src = wp.array(data)
|
|
2505
|
+
dst = wp.array(src, dtype=wp.mat22, shape=(2, 2))
|
|
2506
|
+
assert dst.dtype == wp.mat22
|
|
2507
|
+
assert dst.shape == (2, 2)
|
|
2508
|
+
assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 2)))
|
|
2509
|
+
|
|
2510
|
+
with test.subTest(msg="matrix, multiples in 3d"):
|
|
2511
|
+
# 3 2x2 matrices in a 3d array
|
|
2512
|
+
data = np.arange(12, dtype=np.float32).reshape((3, 2, 2))
|
|
2513
|
+
src = wp.array(data)
|
|
2514
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2515
|
+
assert dst.dtype == wp.mat22
|
|
2516
|
+
assert dst.shape == (3,)
|
|
2517
|
+
assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
|
|
2518
|
+
|
|
2519
|
+
with test.subTest(msg="matrix, multiples in 3d, reshape"):
|
|
2520
|
+
# 4 2x2 matrices in a 3d array
|
|
2521
|
+
data = np.arange(16, dtype=np.float32).reshape((4, 2, 2))
|
|
2522
|
+
src = wp.array(data)
|
|
2523
|
+
dst = wp.array(src, dtype=wp.mat22, shape=(2, 2))
|
|
2524
|
+
assert dst.dtype == wp.mat22
|
|
2525
|
+
assert dst.shape == (2, 2)
|
|
2526
|
+
assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 2)))
|
|
2527
|
+
|
|
2528
|
+
# =========================================
|
|
2529
|
+
# vectors and matrices in strided arrays
|
|
2530
|
+
|
|
2531
|
+
with test.subTest(msg="vector, singles in 2d, strided"):
|
|
2532
|
+
# 4 vec3 in strided 2d array
|
|
2533
|
+
data = np.arange(20, dtype=np.float32).reshape((4, 5))
|
|
2534
|
+
src = wp.array(data)[:, 2:] # source with strides
|
|
2535
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2536
|
+
assert dst.dtype == wp.vec3
|
|
2537
|
+
assert dst.shape == (4,)
|
|
2538
|
+
expected = np.array(
|
|
2539
|
+
[
|
|
2540
|
+
[2, 3, 4],
|
|
2541
|
+
[7, 8, 9],
|
|
2542
|
+
[12, 13, 14],
|
|
2543
|
+
[17, 18, 19],
|
|
2544
|
+
],
|
|
2545
|
+
dtype=np.float32,
|
|
2546
|
+
)
|
|
2547
|
+
assert_np_equal(dst.numpy(), expected)
|
|
2548
|
+
|
|
2549
|
+
with test.subTest(msg="vector, multiples in 2d, strided"):
|
|
2550
|
+
# 4 vec3 in strided 2d array
|
|
2551
|
+
data = np.arange(14, dtype=np.float32).reshape((2, 7))
|
|
2552
|
+
src = wp.array(data)[:, 1:] # source with strides
|
|
2553
|
+
dst = wp.array(src, dtype=wp.vec3)
|
|
2554
|
+
assert dst.dtype == wp.vec3
|
|
2555
|
+
assert dst.shape == (2, 2)
|
|
2556
|
+
expected = np.array(
|
|
2557
|
+
[
|
|
2558
|
+
[
|
|
2559
|
+
[1, 2, 3],
|
|
2560
|
+
[4, 5, 6],
|
|
2561
|
+
],
|
|
2562
|
+
[
|
|
2563
|
+
[8, 9, 10],
|
|
2564
|
+
[11, 12, 13],
|
|
2565
|
+
],
|
|
2566
|
+
],
|
|
2567
|
+
dtype=np.float32,
|
|
2568
|
+
)
|
|
2569
|
+
assert_np_equal(dst.numpy(), expected)
|
|
2570
|
+
|
|
2571
|
+
with test.subTest(msg="matrix, multiples in 2d, strided"):
|
|
2572
|
+
# 3 2x2 matrices in a 2d array
|
|
2573
|
+
data = np.arange(15, dtype=np.float32).reshape((3, 5))
|
|
2574
|
+
src = wp.array(data)[:, 1:] # source with strides
|
|
2575
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2576
|
+
assert dst.dtype == wp.mat22
|
|
2577
|
+
assert dst.shape == (3,)
|
|
2578
|
+
expected = np.array(
|
|
2579
|
+
[
|
|
2580
|
+
[
|
|
2581
|
+
[1, 2],
|
|
2582
|
+
[3, 4],
|
|
2583
|
+
],
|
|
2584
|
+
[
|
|
2585
|
+
[6, 7],
|
|
2586
|
+
[8, 9],
|
|
2587
|
+
],
|
|
2588
|
+
[
|
|
2589
|
+
[11, 12],
|
|
2590
|
+
[13, 14],
|
|
2591
|
+
],
|
|
2592
|
+
],
|
|
2593
|
+
dtype=np.float32,
|
|
2594
|
+
)
|
|
2595
|
+
assert_np_equal(dst.numpy(), expected)
|
|
2596
|
+
|
|
2597
|
+
with test.subTest(msg="matrix, multiples in 3d, strided"):
|
|
2598
|
+
# 3 2x2 matrices in a 3d array
|
|
2599
|
+
data = np.arange(18, dtype=np.float32).reshape((3, 3, 2))
|
|
2600
|
+
src = wp.array(data)[:, 1:] # source with strides
|
|
2601
|
+
dst = wp.array(src, dtype=wp.mat22)
|
|
2602
|
+
assert dst.dtype == wp.mat22
|
|
2603
|
+
assert dst.shape == (3,)
|
|
2604
|
+
expected = np.array(
|
|
2605
|
+
[
|
|
2606
|
+
[
|
|
2607
|
+
[2, 3],
|
|
2608
|
+
[4, 5],
|
|
2609
|
+
],
|
|
2610
|
+
[
|
|
2611
|
+
[8, 9],
|
|
2612
|
+
[10, 11],
|
|
2613
|
+
],
|
|
2614
|
+
[
|
|
2615
|
+
[14, 15],
|
|
2616
|
+
[16, 17],
|
|
2617
|
+
],
|
|
2618
|
+
],
|
|
2619
|
+
dtype=np.float32,
|
|
2620
|
+
)
|
|
2621
|
+
assert_np_equal(dst.numpy(), expected)
|
|
2622
|
+
|
|
2623
|
+
|
|
2373
2624
|
@wp.kernel
|
|
2374
2625
|
def inplace_add_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
|
|
2375
2626
|
i = wp.tid()
|
|
@@ -2604,7 +2855,7 @@ def test_array_inplace_non_diff_ops(test, device):
|
|
|
2604
2855
|
wp.launch(inplace_div_1d, N, inputs=[x1, y1], device=device)
|
|
2605
2856
|
assert_np_equal(x1.numpy(), np.full(N, fill_value=2.0, dtype=float))
|
|
2606
2857
|
|
|
2607
|
-
for dtype in wp.types.non_atomic_types
|
|
2858
|
+
for dtype in (*wp.types.non_atomic_types, wp.vec2b, wp.vec2ub, wp.vec2s, wp.vec2us, uint16vec3):
|
|
2608
2859
|
x = wp.full(N, value=0, dtype=dtype, device=device)
|
|
2609
2860
|
y = wp.full(N, value=1, dtype=dtype, device=device)
|
|
2610
2861
|
|
|
@@ -2943,6 +3194,7 @@ add_function_test(TestArray, "test_alloc_strides", test_alloc_strides, devices=d
|
|
|
2943
3194
|
add_function_test(TestArray, "test_casting", test_casting, devices=devices)
|
|
2944
3195
|
add_function_test(TestArray, "test_array_len", test_array_len, devices=devices)
|
|
2945
3196
|
add_function_test(TestArray, "test_cuda_interface_conversion", test_cuda_interface_conversion, devices=devices)
|
|
3197
|
+
add_function_test(TestArray, "test_array_from_data", test_array_from_data, devices=devices)
|
|
2946
3198
|
|
|
2947
3199
|
try:
|
|
2948
3200
|
import torch
|
warp/tests/test_array_reduce.py
CHANGED
|
@@ -28,7 +28,7 @@ def make_test_array_sum(dtype):
|
|
|
28
28
|
def test_array_sum(test, device):
|
|
29
29
|
rng = np.random.default_rng(123)
|
|
30
30
|
|
|
31
|
-
cols = wp.types.
|
|
31
|
+
cols = wp.types.type_size(dtype)
|
|
32
32
|
|
|
33
33
|
values_np = rng.random(size=(N, cols))
|
|
34
34
|
values = wp.array(values_np, device=device, dtype=dtype)
|
|
@@ -77,7 +77,7 @@ def make_test_array_inner(dtype):
|
|
|
77
77
|
def test_array_inner(test, device):
|
|
78
78
|
rng = np.random.default_rng(123)
|
|
79
79
|
|
|
80
|
-
cols = wp.types.
|
|
80
|
+
cols = wp.types.type_size(dtype)
|
|
81
81
|
|
|
82
82
|
a_np = rng.random(size=(N, cols))
|
|
83
83
|
b_np = rng.random(size=(N, cols))
|
warp/tests/test_assert.py
CHANGED
|
@@ -245,6 +245,59 @@ class TestAssertDebug(unittest.TestCase):
|
|
|
245
245
|
self.assertRegex(output, r"Assertion failed: .*assert value == 1.*Array element must be 1")
|
|
246
246
|
|
|
247
247
|
|
|
248
|
+
class TestAssertModeSwitch(unittest.TestCase):
|
|
249
|
+
"""Test that switching from release mode to debug mode rebuilds the module with assertions enabled."""
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def setUpClass(cls):
|
|
253
|
+
cls._saved_mode = wp.config.mode
|
|
254
|
+
cls._saved_mode_module = wp.get_module_options()["mode"]
|
|
255
|
+
cls._saved_cache_kernels = wp.config.cache_kernels
|
|
256
|
+
|
|
257
|
+
# Don't set any mode initially - use whatever the default is
|
|
258
|
+
wp.config.cache_kernels = False
|
|
259
|
+
|
|
260
|
+
@classmethod
|
|
261
|
+
def tearDownClass(cls):
|
|
262
|
+
wp.config.mode = cls._saved_mode
|
|
263
|
+
wp.set_module_options({"mode": cls._saved_mode_module})
|
|
264
|
+
wp.config.cache_kernels = cls._saved_cache_kernels
|
|
265
|
+
|
|
266
|
+
def test_switch_to_debug_mode(self):
|
|
267
|
+
"""Test that switching from release mode to debug mode rebuilds the module with assertions enabled."""
|
|
268
|
+
with wp.ScopedDevice("cpu"):
|
|
269
|
+
# Create an array that will trigger an assertion
|
|
270
|
+
input_array = wp.zeros(1, dtype=int)
|
|
271
|
+
|
|
272
|
+
# In default mode, this should not assert
|
|
273
|
+
capture = StdErrCapture()
|
|
274
|
+
capture.begin()
|
|
275
|
+
wp.launch(expect_ones, input_array.shape, inputs=[input_array])
|
|
276
|
+
output = capture.end()
|
|
277
|
+
|
|
278
|
+
# Should not have any assertion output in release mode
|
|
279
|
+
self.assertEqual(output, "", f"Kernel should not print anything to stderr in release mode, got {output}")
|
|
280
|
+
|
|
281
|
+
# Now switch to debug mode and have it compile a new kernel
|
|
282
|
+
wp.config.mode = "debug"
|
|
283
|
+
|
|
284
|
+
@wp.kernel
|
|
285
|
+
def expect_ones_debug(a: wp.array(dtype=int)):
|
|
286
|
+
i = wp.tid()
|
|
287
|
+
assert a[i] == 1
|
|
288
|
+
|
|
289
|
+
# In debug mode, this should assert
|
|
290
|
+
capture = StdErrCapture()
|
|
291
|
+
capture.begin()
|
|
292
|
+
wp.launch(expect_ones_debug, input_array.shape, inputs=[input_array])
|
|
293
|
+
output = capture.end()
|
|
294
|
+
|
|
295
|
+
# Should have assertion output in debug mode
|
|
296
|
+
# Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
|
|
297
|
+
if output != "" or sys.platform != "win32":
|
|
298
|
+
self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1")
|
|
299
|
+
|
|
300
|
+
|
|
248
301
|
if __name__ == "__main__":
|
|
249
302
|
wp.clear_kernel_cache()
|
|
250
303
|
unittest.main(verbosity=2)
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import unittest
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
import warp as wp
|
|
20
|
+
from warp.tests.unittest_utils import *
|
|
21
|
+
|
|
22
|
+
kernel_cache = {}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def getkernel(func, suffix=""):
|
|
26
|
+
key = func.__name__ + "_" + suffix
|
|
27
|
+
if key not in kernel_cache:
|
|
28
|
+
kernel_cache[key] = wp.Kernel(func=func, key=key)
|
|
29
|
+
return kernel_cache[key]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_atomic_cas(test, device, dtype, register_kernels=False):
|
|
33
|
+
warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
34
|
+
n = 100
|
|
35
|
+
counter = wp.array([0], dtype=warp_type, device=device)
|
|
36
|
+
lock = wp.array([0], dtype=warp_type, device=device)
|
|
37
|
+
|
|
38
|
+
@wp.func
|
|
39
|
+
def spinlock_acquire_1d(lock: wp.array(dtype=warp_type)):
|
|
40
|
+
# Try to acquire the lock by setting it to 1 if it's 0
|
|
41
|
+
while wp.atomic_cas(lock, 0, warp_type(0), warp_type(1)) == 1:
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@wp.func
|
|
45
|
+
def spinlock_release_1d(lock: wp.array(dtype=warp_type)):
|
|
46
|
+
# Release the lock by setting it back to 0
|
|
47
|
+
wp.atomic_exch(lock, 0, warp_type(0))
|
|
48
|
+
|
|
49
|
+
@wp.func
|
|
50
|
+
def volatile_read_1d(ptr: wp.array(dtype=warp_type), index: int):
|
|
51
|
+
value = wp.atomic_exch(ptr, index, warp_type(0))
|
|
52
|
+
wp.atomic_exch(ptr, index, value)
|
|
53
|
+
return value
|
|
54
|
+
|
|
55
|
+
def test_spinlock_counter_1d(counter: wp.array(dtype=warp_type), lock: wp.array(dtype=warp_type)):
|
|
56
|
+
# Try to acquire the lock
|
|
57
|
+
spinlock_acquire_1d(lock)
|
|
58
|
+
|
|
59
|
+
# Critical section - increment counter
|
|
60
|
+
# counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
|
|
61
|
+
|
|
62
|
+
# Work around since warp arrays cannot be marked as volatile
|
|
63
|
+
value = volatile_read_1d(counter, 0)
|
|
64
|
+
counter[0] = value + warp_type(1)
|
|
65
|
+
|
|
66
|
+
# Release the lock
|
|
67
|
+
spinlock_release_1d(lock)
|
|
68
|
+
|
|
69
|
+
kernel = getkernel(test_spinlock_counter_1d, suffix=dtype.__name__)
|
|
70
|
+
|
|
71
|
+
if register_kernels:
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
|
|
75
|
+
|
|
76
|
+
# Verify counter reached n
|
|
77
|
+
counter_np = counter.numpy()
|
|
78
|
+
expected = np.array([n], dtype=dtype)
|
|
79
|
+
|
|
80
|
+
if not np.array_equal(counter_np, expected):
|
|
81
|
+
print(f"Counter mismatch: expected {expected}, got {counter_np}")
|
|
82
|
+
|
|
83
|
+
assert_np_equal(counter_np, expected)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def test_atomic_cas_2d(test, device, dtype, register_kernels=False):
|
|
87
|
+
warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
88
|
+
n = 100
|
|
89
|
+
counter = wp.array([0], dtype=warp_type, device=device)
|
|
90
|
+
lock = wp.zeros(shape=(1, 1), dtype=warp_type, device=device)
|
|
91
|
+
|
|
92
|
+
@wp.func
|
|
93
|
+
def spinlock_acquire_2d(lock: wp.array2d(dtype=warp_type)):
|
|
94
|
+
# Try to acquire the lock by setting it to 1 if it's 0
|
|
95
|
+
while wp.atomic_cas(lock, 0, 0, warp_type(0), warp_type(1)) == 1:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
@wp.func
|
|
99
|
+
def spinlock_release_2d(lock: wp.array2d(dtype=warp_type)):
|
|
100
|
+
# Release the lock by setting it back to 0
|
|
101
|
+
wp.atomic_exch(lock, 0, 0, warp_type(0))
|
|
102
|
+
|
|
103
|
+
@wp.func
|
|
104
|
+
def volatile_read_2d(ptr: wp.array(dtype=warp_type), index: int):
|
|
105
|
+
value = wp.atomic_exch(ptr, index, warp_type(0))
|
|
106
|
+
wp.atomic_exch(ptr, index, value)
|
|
107
|
+
return value
|
|
108
|
+
|
|
109
|
+
def test_spinlock_counter_2d(counter: wp.array(dtype=warp_type), lock: wp.array2d(dtype=warp_type)):
|
|
110
|
+
# Try to acquire the lock
|
|
111
|
+
spinlock_acquire_2d(lock)
|
|
112
|
+
|
|
113
|
+
# Critical section - increment counter
|
|
114
|
+
# counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
|
|
115
|
+
|
|
116
|
+
# Work around since warp arrays cannot be marked as volatile
|
|
117
|
+
value = volatile_read_2d(counter, 0)
|
|
118
|
+
counter[0] = value + warp_type(1)
|
|
119
|
+
|
|
120
|
+
# Release the lock
|
|
121
|
+
spinlock_release_2d(lock)
|
|
122
|
+
|
|
123
|
+
kernel = getkernel(test_spinlock_counter_2d, suffix=dtype.__name__)
|
|
124
|
+
|
|
125
|
+
if register_kernels:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
|
|
129
|
+
|
|
130
|
+
# Verify counter reached n
|
|
131
|
+
counter_np = counter.numpy()
|
|
132
|
+
expected = np.array([n], dtype=dtype)
|
|
133
|
+
|
|
134
|
+
if not np.array_equal(counter_np, expected):
|
|
135
|
+
print(f"Counter mismatch: expected {expected}, got {counter_np}")
|
|
136
|
+
|
|
137
|
+
assert_np_equal(counter_np, expected)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_atomic_cas_3d(test, device, dtype, register_kernels=False):
|
|
141
|
+
warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
142
|
+
n = 100
|
|
143
|
+
counter = wp.array([0], dtype=warp_type, device=device)
|
|
144
|
+
lock = wp.zeros(shape=(1, 1, 1), dtype=warp_type, device=device)
|
|
145
|
+
|
|
146
|
+
@wp.func
|
|
147
|
+
def spinlock_acquire_3d(lock: wp.array3d(dtype=warp_type)):
|
|
148
|
+
# Try to acquire the lock by setting it to 1 if it's 0
|
|
149
|
+
while wp.atomic_cas(lock, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
@wp.func
|
|
153
|
+
def spinlock_release_3d(lock: wp.array3d(dtype=warp_type)):
|
|
154
|
+
# Release the lock by setting it back to 0
|
|
155
|
+
wp.atomic_exch(lock, 0, 0, 0, warp_type(0))
|
|
156
|
+
|
|
157
|
+
@wp.func
|
|
158
|
+
def volatile_read_3d(ptr: wp.array(dtype=warp_type), index: int):
|
|
159
|
+
value = wp.atomic_exch(ptr, index, warp_type(0))
|
|
160
|
+
wp.atomic_exch(ptr, index, value)
|
|
161
|
+
return value
|
|
162
|
+
|
|
163
|
+
def test_spinlock_counter_3d(counter: wp.array(dtype=warp_type), lock: wp.array3d(dtype=warp_type)):
|
|
164
|
+
# Try to acquire the lock
|
|
165
|
+
spinlock_acquire_3d(lock)
|
|
166
|
+
|
|
167
|
+
# Critical section - increment counter
|
|
168
|
+
# counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
|
|
169
|
+
|
|
170
|
+
# Work around since warp arrays cannot be marked as volatile
|
|
171
|
+
value = volatile_read_3d(counter, 0)
|
|
172
|
+
counter[0] = value + warp_type(1)
|
|
173
|
+
|
|
174
|
+
# Release the lock
|
|
175
|
+
spinlock_release_3d(lock)
|
|
176
|
+
|
|
177
|
+
kernel = getkernel(test_spinlock_counter_3d, suffix=dtype.__name__)
|
|
178
|
+
|
|
179
|
+
if register_kernels:
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
|
|
183
|
+
|
|
184
|
+
# Verify counter reached n
|
|
185
|
+
counter_np = counter.numpy()
|
|
186
|
+
expected = np.array([n], dtype=dtype)
|
|
187
|
+
|
|
188
|
+
if not np.array_equal(counter_np, expected):
|
|
189
|
+
print(f"Counter mismatch: expected {expected}, got {counter_np}")
|
|
190
|
+
|
|
191
|
+
assert_np_equal(counter_np, expected)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def create_spinlock_test_4d(dtype):
|
|
195
|
+
@wp.func
|
|
196
|
+
def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=4)):
|
|
197
|
+
# Try to acquire the lock by setting it to 1 if it's 0
|
|
198
|
+
while wp.atomic_cas(lock, 0, 0, 0, 0, dtype(0), dtype(1)) == 1:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
@wp.func
|
|
202
|
+
def spinlock_release(lock: wp.array(dtype=dtype, ndim=4)):
|
|
203
|
+
# Release the lock by setting it back to 0
|
|
204
|
+
wp.atomic_exch(lock, 0, 0, 0, 0, dtype(0))
|
|
205
|
+
|
|
206
|
+
@wp.func
|
|
207
|
+
def volatile_read(ptr: wp.array(dtype=dtype), index: int):
|
|
208
|
+
value = wp.atomic_exch(ptr, index, dtype(0))
|
|
209
|
+
wp.atomic_exch(ptr, index, value)
|
|
210
|
+
return value
|
|
211
|
+
|
|
212
|
+
@wp.kernel
|
|
213
|
+
def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=4)):
|
|
214
|
+
# Try to acquire the lock
|
|
215
|
+
spinlock_acquire(lock)
|
|
216
|
+
|
|
217
|
+
# Critical section - increment counter
|
|
218
|
+
# counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
|
|
219
|
+
|
|
220
|
+
# Work around since warp arrays cannot be marked as volatile
|
|
221
|
+
value = volatile_read(counter, 0)
|
|
222
|
+
counter[0] = value + dtype(1)
|
|
223
|
+
|
|
224
|
+
# Release the lock
|
|
225
|
+
spinlock_release(lock)
|
|
226
|
+
|
|
227
|
+
return test_spinlock_counter
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_atomic_cas_4d(test, device, dtype, register_kernels=False):
|
|
231
|
+
warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
|
|
232
|
+
n = 100
|
|
233
|
+
counter = wp.array([0], dtype=warp_type, device=device)
|
|
234
|
+
lock = wp.zeros(shape=(1, 1, 1, 1), dtype=warp_type, device=device)
|
|
235
|
+
|
|
236
|
+
@wp.func
|
|
237
|
+
def spinlock_acquire_4d(lock: wp.array4d(dtype=warp_type)):
|
|
238
|
+
# Try to acquire the lock by setting it to 1 if it's 0
|
|
239
|
+
while wp.atomic_cas(lock, 0, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
@wp.func
|
|
243
|
+
def spinlock_release_4d(lock: wp.array4d(dtype=warp_type)):
|
|
244
|
+
# Release the lock by setting it back to 0
|
|
245
|
+
wp.atomic_exch(lock, 0, 0, 0, 0, warp_type(0))
|
|
246
|
+
|
|
247
|
+
@wp.func
|
|
248
|
+
def volatile_read_4d(ptr: wp.array(dtype=warp_type), index: int):
|
|
249
|
+
value = wp.atomic_exch(ptr, index, warp_type(0))
|
|
250
|
+
wp.atomic_exch(ptr, index, value)
|
|
251
|
+
return value
|
|
252
|
+
|
|
253
|
+
def test_spinlock_counter_4d(counter: wp.array(dtype=warp_type), lock: wp.array4d(dtype=warp_type)):
|
|
254
|
+
# Try to acquire the lock
|
|
255
|
+
spinlock_acquire_4d(lock)
|
|
256
|
+
|
|
257
|
+
# Critical section - increment counter
|
|
258
|
+
# counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
|
|
259
|
+
|
|
260
|
+
# Work around since warp arrays cannot be marked as volatile
|
|
261
|
+
value = volatile_read_4d(counter, 0)
|
|
262
|
+
counter[0] = value + warp_type(1)
|
|
263
|
+
|
|
264
|
+
# Release the lock
|
|
265
|
+
spinlock_release_4d(lock)
|
|
266
|
+
|
|
267
|
+
kernel = getkernel(test_spinlock_counter_4d, suffix=dtype.__name__)
|
|
268
|
+
|
|
269
|
+
if register_kernels:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
|
|
273
|
+
|
|
274
|
+
# Verify counter reached n
|
|
275
|
+
counter_np = counter.numpy()
|
|
276
|
+
expected = np.array([n], dtype=dtype)
|
|
277
|
+
|
|
278
|
+
if not np.array_equal(counter_np, expected):
|
|
279
|
+
print(f"Counter mismatch: expected {expected}, got {counter_np}")
|
|
280
|
+
|
|
281
|
+
assert_np_equal(counter_np, expected)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
devices = get_test_devices()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class TestAtomicCAS(unittest.TestCase):
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# Test all supported types
|
|
292
|
+
np_test_types = (np.int32, np.uint32, np.int64, np.uint64, np.float32, np.float64)
|
|
293
|
+
|
|
294
|
+
for dtype in np_test_types:
|
|
295
|
+
type_name = dtype.__name__
|
|
296
|
+
add_function_test_register_kernel(
|
|
297
|
+
TestAtomicCAS, f"test_cas_{type_name}", test_atomic_cas, devices=devices, dtype=dtype
|
|
298
|
+
)
|
|
299
|
+
# Add 2D test for each type
|
|
300
|
+
add_function_test_register_kernel(
|
|
301
|
+
TestAtomicCAS, f"test_cas_2d_{type_name}", test_atomic_cas_2d, devices=devices, dtype=dtype
|
|
302
|
+
)
|
|
303
|
+
add_function_test_register_kernel(
|
|
304
|
+
TestAtomicCAS, f"test_cas_3d_{type_name}", test_atomic_cas_3d, devices=devices, dtype=dtype
|
|
305
|
+
)
|
|
306
|
+
add_function_test_register_kernel(
|
|
307
|
+
TestAtomicCAS, f"test_cas_4d_{type_name}", test_atomic_cas_4d, devices=devices, dtype=dtype
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if __name__ == "__main__":
|
|
311
|
+
wp.clear_kernel_cache()
|
|
312
|
+
unittest.main(verbosity=2)
|