warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,6 @@ import numpy as np
20
20
  import warp as wp
21
21
  from warp.tests.unittest_utils import *
22
22
 
23
- wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
24
-
25
23
  TILE_M = wp.constant(8)
26
24
  TILE_N = wp.constant(4)
27
25
  TILE_K = wp.constant(8)
@@ -216,7 +214,6 @@ def test_tile_binary_map(test, device):
216
214
  assert_np_equal(B_wp.grad.numpy(), B_grad)
217
215
 
218
216
 
219
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
220
217
  def test_tile_grouped_gemm(test, device):
221
218
  @wp.kernel
222
219
  def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
@@ -256,60 +253,62 @@ def test_tile_grouped_gemm(test, device):
256
253
  assert_np_equal(C_wp.numpy(), C, 1e-6)
257
254
 
258
255
 
259
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
260
- def test_tile_gemm(test, device):
261
- @wp.kernel
262
- def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
263
- # output tile index
264
- i, j = wp.tid()
256
+ def test_tile_gemm(dtype):
257
+ def test(test, device):
258
+ @wp.kernel
259
+ def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
260
+ # output tile index
261
+ i, j = wp.tid()
265
262
 
266
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
263
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
267
264
 
268
- M = A.shape[0]
269
- N = B.shape[1]
270
- K = A.shape[1]
265
+ M = A.shape[0]
266
+ N = B.shape[1]
267
+ K = A.shape[1]
271
268
 
272
- count = int(K / TILE_K)
269
+ count = int(K / TILE_K)
273
270
 
274
- for k in range(0, count):
275
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
276
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
271
+ for k in range(0, count):
272
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
273
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
277
274
 
278
- # sum += a*b
279
- wp.tile_matmul(a, b, sum)
275
+ # sum += a*b
276
+ wp.tile_matmul(a, b, sum)
280
277
 
281
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
278
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
282
279
 
283
- M = TILE_M * 7
284
- K = TILE_K * 6
285
- N = TILE_N * 5
280
+ M = TILE_M * 7
281
+ K = TILE_K * 6
282
+ N = TILE_N * 5
286
283
 
287
- rng = np.random.default_rng(42)
288
- A = rng.random((M, K), dtype=np.float32)
289
- B = rng.random((K, N), dtype=np.float32)
290
- C = np.zeros((M, N), dtype=np.float32)
284
+ rng = np.random.default_rng(42)
285
+ A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
286
+ B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
+ C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
291
288
 
292
- A_wp = wp.array(A, requires_grad=True, device=device)
293
- B_wp = wp.array(B, requires_grad=True, device=device)
294
- C_wp = wp.array(C, requires_grad=True, device=device)
289
+ A_wp = wp.array(A, requires_grad=True, device=device)
290
+ B_wp = wp.array(B, requires_grad=True, device=device)
291
+ C_wp = wp.array(C, requires_grad=True, device=device)
295
292
 
296
- with wp.Tape() as tape:
297
- wp.launch_tiled(
298
- tile_gemm,
299
- dim=(int(M / TILE_M), int(N / TILE_N)),
300
- inputs=[A_wp, B_wp, C_wp],
301
- block_dim=TILE_DIM,
302
- device=device,
303
- )
293
+ with wp.Tape() as tape:
294
+ wp.launch_tiled(
295
+ tile_gemm,
296
+ dim=(int(M / TILE_M), int(N / TILE_N)),
297
+ inputs=[A_wp, B_wp, C_wp],
298
+ block_dim=TILE_DIM,
299
+ device=device,
300
+ )
304
301
 
305
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-5)
302
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
306
303
 
307
- adj_C = np.ones_like(C)
304
+ adj_C = np.ones_like(C)
308
305
 
309
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
306
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
310
307
 
311
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-5)
312
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-5)
308
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
309
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
310
+
311
+ return test
313
312
 
314
313
 
315
314
  @wp.kernel
@@ -550,7 +549,6 @@ def test_tile_transpose(test, device):
550
549
  assert_np_equal(output.numpy(), input.numpy().T)
551
550
 
552
551
 
553
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
554
552
  def test_tile_transpose_matmul(test, device):
555
553
  @wp.kernel
556
554
  def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
@@ -572,9 +570,36 @@ def test_tile_transpose_matmul(test, device):
572
570
 
573
571
 
574
572
  @wp.kernel
575
- def test_tile_broadcast_add_kernel(
573
+ def test_tile_broadcast_add_1d_kernel(
574
+ input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
575
+ ):
576
+ a = wp.tile_load(input_a, shape=(10,))
577
+ b = wp.tile_load(input_b, shape=(1,))
578
+
579
+ c = wp.tile_broadcast(b, shape=(10,))
580
+ d = a + c
581
+
582
+ wp.tile_store(output, d)
583
+
584
+
585
+ def test_tile_broadcast_add_1d(test, device):
586
+ N = 10
587
+
588
+ # implicit 1-dim ([1], 1)
589
+ a = wp.array(np.arange(0, N, dtype=np.float32), device=device)
590
+ b = wp.array(np.ones(1, dtype=np.float32), device=device)
591
+ out = wp.zeros((N,), dtype=float, device=device)
592
+
593
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
594
+
595
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
596
+
597
+
598
+ @wp.kernel
599
+ def test_tile_broadcast_add_2d_kernel(
576
600
  input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
577
601
  ):
602
+ # implicit 1-dim ([1], 10)
578
603
  a = wp.tile_load(input_a, shape=(10, 10))
579
604
  b = wp.tile_load(input_b, shape=10)
580
605
 
@@ -584,7 +609,7 @@ def test_tile_broadcast_add_kernel(
584
609
  wp.tile_store(output, d)
585
610
 
586
611
 
587
- def test_tile_broadcast_add(test, device):
612
+ def test_tile_broadcast_add_2d(test, device):
588
613
  M = 10
589
614
  N = 10
590
615
 
@@ -592,7 +617,62 @@ def test_tile_broadcast_add(test, device):
592
617
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
593
618
  out = wp.zeros((M, N), dtype=float, device=device)
594
619
 
595
- wp.launch_tiled(test_tile_broadcast_add_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
620
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
621
+
622
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
623
+
624
+
625
+ @wp.kernel
626
+ def test_tile_broadcast_add_3d_kernel(
627
+ input_a: wp.array3d(dtype=float), input_b: wp.array3d(dtype=float), output: wp.array3d(dtype=float)
628
+ ):
629
+ a = wp.tile_load(input_a, shape=(4, 10, 12))
630
+ b = wp.tile_load(input_b, shape=(4, 10, 1))
631
+
632
+ c = wp.tile_broadcast(b, shape=(4, 10, 12))
633
+ d = a + c
634
+
635
+ wp.tile_store(output, d)
636
+
637
+
638
+ def test_tile_broadcast_add_3d(test, device):
639
+ M = 4
640
+ N = 10
641
+ O = 12
642
+
643
+ # explicit 1-dim (M, N, 1) to (M, N, O)
644
+ a = wp.array(np.ones((M, N, O), dtype=np.float32), device=device)
645
+ b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
646
+ out = wp.zeros((M, N, O), dtype=float, device=device)
647
+
648
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
649
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
650
+
651
+
652
+ @wp.kernel
653
+ def test_tile_broadcast_add_4d_kernel(
654
+ input_a: wp.array4d(dtype=float), input_b: wp.array4d(dtype=float), output: wp.array4d(dtype=float)
655
+ ):
656
+ a = wp.tile_load(input_a, shape=(4, 10, 5, 6))
657
+ b = wp.tile_load(input_b, shape=(4, 1, 5, 1))
658
+ c = wp.tile_broadcast(b, shape=(4, 10, 5, 6))
659
+ d = a + c
660
+
661
+ wp.tile_store(output, d)
662
+
663
+
664
+ def test_tile_broadcast_add_4d(test, device):
665
+ M = 4
666
+ N = 10
667
+ O = 5
668
+ P = 6
669
+
670
+ # explicit 1-dims (M, 1, O, 1) to (M, N, O, P)
671
+ a = wp.array(np.ones((M, N, O, P), dtype=np.float32), device=device)
672
+ b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
673
+ out = wp.zeros((M, N, O, P), dtype=float, device=device)
674
+
675
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
596
676
 
597
677
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
598
678
 
@@ -665,7 +745,7 @@ def test_tile_print(test, device):
665
745
  wp.synchronize()
666
746
 
667
747
 
668
- devices = get_cuda_test_devices()
748
+ devices = get_test_devices()
669
749
 
670
750
 
671
751
  class TestTile(unittest.TestCase):
@@ -677,15 +757,20 @@ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devi
677
757
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
678
758
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
679
759
  add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
680
- add_function_test(TestTile, "test_tile_gemm", test_tile_gemm, devices=devices)
760
+ add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
761
+ add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
762
+ add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
681
763
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
682
764
  add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
683
765
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
684
- add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
766
+ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
685
767
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
686
768
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
687
769
  add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
688
- add_function_test(TestTile, "test_tile_broadcast_add", test_tile_broadcast_add, devices=devices)
770
+ add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
771
+ add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
772
+ add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
773
+ add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
689
774
  add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
690
775
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
691
776
  add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
@@ -184,6 +184,96 @@ def test_tile_load_unaligned(test, device):
184
184
  assert_np_equal(input.grad.numpy(), expected_grad)
185
185
 
186
186
 
187
+ @wp.kernel
188
+ def tile_load_aligned_small_kernel(
189
+ input: wp.array2d(dtype=float),
190
+ output: wp.array2d(dtype=float),
191
+ ):
192
+ t = wp.tile_load(input, shape=(3, 3), offset=(0, 0), storage="shared")
193
+ wp.tile_store(output, t, offset=(0, 0))
194
+
195
+
196
+ # regression test for tiles that are smaller than sizeof(float4) in that last
197
+ # dimension but are aligned to float4. Did trigger the fast float4 path by accident.
198
+ def test_tile_load_aligned_small(test, device):
199
+ rng = np.random.default_rng(42)
200
+
201
+ shape = [TILE_M, TILE_N]
202
+
203
+ input = wp.array(rng.random(shape), dtype=float, requires_grad=True, device=device)
204
+ output = wp.zeros(shape, dtype=float, device=device)
205
+
206
+ wp.launch_tiled(
207
+ tile_load_aligned_small_kernel,
208
+ dim=[1],
209
+ inputs=[input, output],
210
+ block_dim=TILE_DIM,
211
+ device=device,
212
+ )
213
+
214
+ # zeros except for the 3x3 tile at 0, 0
215
+ assert_np_equal(output.numpy()[3:, :], np.zeros((TILE_M - 3, TILE_N)))
216
+ assert_np_equal(output.numpy()[:, 3:], np.zeros((TILE_M, TILE_N - 3)))
217
+
218
+ # check output elements
219
+ assert_np_equal(output.numpy()[:3, :3], input.numpy()[:3, :3])
220
+
221
+
222
+ TILE_WIDTH = 5
223
+ TILE_OFFSET_X = 0
224
+ TILE_OFFSET_Y = 8
225
+
226
+
227
+ @wp.kernel
228
+ def test_tile_load_aligned_offset_unaligned_size_kernel(
229
+ input: wp.array2d(dtype=float),
230
+ output: wp.array2d(dtype=float),
231
+ ):
232
+ # Load a 5x5 tile from the input array starting at offset (0,8)
233
+ # and store it in shared memory
234
+ tile = wp.tile_load(input, shape=(TILE_WIDTH, TILE_WIDTH), offset=(TILE_OFFSET_X, TILE_OFFSET_Y), storage="shared")
235
+
236
+ # Store the loaded tile back to the output array at the same offset
237
+ wp.tile_store(output, tile, offset=(TILE_OFFSET_X, TILE_OFFSET_Y))
238
+
239
+
240
+ def test_tile_load_aligned_offset_unaligned_size(test, device):
241
+ """Test loading a tile with aligned offset but unaligned size."""
242
+
243
+ rng = np.random.default_rng(42)
244
+ array_shape = [TILE_N, TILE_M]
245
+
246
+ input_array = wp.array(rng.random(array_shape), dtype=float, requires_grad=True, device=device)
247
+ output_array = wp.zeros(array_shape, dtype=float, device=device)
248
+
249
+ wp.launch_tiled(
250
+ test_tile_load_aligned_offset_unaligned_size_kernel,
251
+ dim=[1],
252
+ inputs=[input_array, output_array],
253
+ block_dim=TILE_DIM,
254
+ device=device,
255
+ )
256
+
257
+ # Region before the tile offset should be zeros
258
+ assert_np_equal(output_array.numpy()[:TILE_WIDTH, :TILE_OFFSET_Y], np.zeros((TILE_WIDTH, TILE_OFFSET_Y)))
259
+
260
+ # Region where the tile was loaded/stored should match input
261
+ assert_np_equal(
262
+ output_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y : TILE_OFFSET_Y + TILE_WIDTH],
263
+ input_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y : TILE_OFFSET_Y + TILE_WIDTH],
264
+ )
265
+
266
+ # Region after the tile should be zeros
267
+ remaining_width = TILE_M - (TILE_OFFSET_Y + TILE_WIDTH)
268
+ assert_np_equal(
269
+ output_array.numpy()[:TILE_WIDTH, TILE_OFFSET_Y + TILE_WIDTH :], np.zeros((TILE_WIDTH, remaining_width))
270
+ )
271
+
272
+ # Rows below the tile should all be zeros
273
+ remaining_height = TILE_N - TILE_WIDTH
274
+ assert_np_equal(output_array.numpy()[TILE_WIDTH:, :], np.zeros((remaining_height, TILE_M)))
275
+
276
+
187
277
  # ----------------------------------------------------------------------------------------
188
278
 
189
279
  TILE_SIZE = 4
@@ -376,7 +466,7 @@ def test_tile_load_fortran(test, device):
376
466
  assert_array_equal(B_wp.grad, A_wp.grad)
377
467
 
378
468
 
379
- devices = get_cuda_test_devices()
469
+ devices = get_test_devices()
380
470
 
381
471
 
382
472
  class TestTileLoad(unittest.TestCase):
@@ -388,6 +478,13 @@ add_function_test(TestTileLoad, "test_tile_load_2d", test_tile_load(tile_load_2d
388
478
  add_function_test(TestTileLoad, "test_tile_load_3d", test_tile_load(tile_load_3d_kernel, 3), devices=devices)
389
479
  add_function_test(TestTileLoad, "test_tile_load_4d", test_tile_load(tile_load_4d_kernel, 4), devices=devices)
390
480
  add_function_test(TestTileLoad, "test_tile_load_unaligned", test_tile_load_unaligned, devices=devices)
481
+ add_function_test(TestTileLoad, "test_tile_load_aligned_small", test_tile_load_aligned_small, devices=devices)
482
+ add_function_test(
483
+ TestTileLoad,
484
+ "test_tile_load_aligned_offset_unaligned_size",
485
+ test_tile_load_aligned_offset_unaligned_size,
486
+ devices=devices,
487
+ )
391
488
 
392
489
  add_function_test(TestTileLoad, "test_tile_extract_1d", test_tile_extract(tile_extract_1d_kernel, 1), devices=devices)
393
490
  add_function_test(TestTileLoad, "test_tile_extract_2d", test_tile_extract(tile_extract_2d_kernel, 2), devices=devices)
@@ -92,6 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
92
92
  wp.tile_store(gy, xy)
93
93
 
94
94
 
95
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
95
96
  def test_tile_math_fft(test, device, wp_dtype):
96
97
  np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
97
98
  np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
@@ -172,31 +173,33 @@ def test_tile_math_cholesky(test, device):
172
173
  # TODO: implement and test backward pass
173
174
 
174
175
 
175
- devices = get_cuda_test_devices()
176
+ all_devices = get_test_devices()
177
+ cuda_devices = get_cuda_test_devices()
176
178
 
177
179
 
178
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
179
180
  class TestTileMathDx(unittest.TestCase):
180
181
  pass
181
182
 
182
183
 
183
184
  # check_output=False so we can enable libmathdx's logging without failing the tests
184
- add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
185
185
  add_function_test(
186
- TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=devices, check_output=False
186
+ TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=all_devices, check_output=False
187
+ )
188
+ add_function_test(
189
+ TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=all_devices, check_output=False
187
190
  )
188
191
  add_function_test(
189
192
  TestTileMathDx,
190
193
  "test_tile_math_fft_vec2f",
191
194
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
192
- devices=devices,
195
+ devices=cuda_devices,
193
196
  check_output=False,
194
197
  )
195
198
  add_function_test(
196
199
  TestTileMathDx,
197
200
  "test_tile_math_fft_vec2d",
198
201
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
199
- devices=devices,
202
+ devices=cuda_devices,
200
203
  check_output=False,
201
204
  )
202
205
 
@@ -22,11 +22,6 @@ import warp.examples
22
22
  import warp.optim
23
23
  from warp.tests.unittest_utils import *
24
24
 
25
- wp.init()
26
-
27
- # needs to be constant for the whole module
28
- NUM_THREADS = 32
29
-
30
25
 
31
26
  def create_layer(rng, dim_in, dim_hid, dtype=float):
32
27
  w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
@@ -45,10 +40,12 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
45
40
  return a
46
41
 
47
42
 
48
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
49
43
  def test_multi_layer_nn(test, device):
50
44
  import torch as tc
51
45
 
46
+ if device.is_cuda and not wp.context.runtime.core.is_mathdx_enabled():
47
+ test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
48
+
52
49
  NUM_FREQ = wp.constant(8)
53
50
 
54
51
  DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
@@ -60,7 +57,13 @@ def test_multi_layer_nn(test, device):
60
57
 
61
58
  BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
62
59
 
60
+ if device.is_cpu:
61
+ NUM_THREADS = 1
62
+ else:
63
+ NUM_THREADS = 32
64
+
63
65
  dtype = wp.float16
66
+ npdtype = wp.types.warp_type_to_np_dtype[dtype]
64
67
 
65
68
  @wp.func
66
69
  def relu(x: dtype):
@@ -74,7 +77,7 @@ def test_multi_layer_nn(test, device):
74
77
  def zero(loss: wp.array(dtype=float)):
75
78
  loss[0] = 0.0
76
79
 
77
- @wp.kernel
80
+ @wp.kernel(module="unique")
78
81
  def compute(
79
82
  batches: wp.array(dtype=int),
80
83
  input: wp.array2d(dtype=dtype),
@@ -170,7 +173,9 @@ def test_multi_layer_nn(test, device):
170
173
  input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
171
174
  output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
172
175
 
173
- reference_np = np.load(os.path.join(os.path.dirname(__file__), "assets/pixel.npy"), allow_pickle=True) / 255.0
176
+ reference_np = (
177
+ np.load(os.path.join(os.path.dirname(__file__), "..", "assets", "pixel.npy"), allow_pickle=True) / 255.0
178
+ )
174
179
  reference = wp.array(reference_np, dtype=float)
175
180
 
176
181
  assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
@@ -232,7 +237,7 @@ def test_multi_layer_nn(test, device):
232
237
  z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
233
238
 
234
239
  # test numpy forward
235
- assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
240
+ assert_np_equal(output.numpy()[:, indices].astype(npdtype), z_np, tol=1.0e-2)
236
241
 
237
242
  # torch
238
243
  input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
@@ -260,7 +265,9 @@ def test_multi_layer_nn(test, device):
260
265
  l_tc.backward()
261
266
 
262
267
  # test torch
263
- assert_np_equal(z_tc.cpu().detach().numpy(), output.numpy()[:, indices], tol=1.0e-2)
268
+ assert_np_equal(
269
+ z_tc.cpu().detach().numpy(), output.numpy()[:, indices].astype(npdtype), tol=1.0e-2
270
+ )
264
271
  assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
265
272
  assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
266
273
  assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
@@ -277,7 +284,6 @@ def test_multi_layer_nn(test, device):
277
284
  test.assertLess(loss.numpy()[0], 0.002)
278
285
 
279
286
 
280
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
281
287
  def test_single_layer_nn(test, device):
282
288
  import torch as tc
283
289
 
@@ -287,11 +293,16 @@ def test_single_layer_nn(test, device):
287
293
 
288
294
  NUM_BLOCKS = 56
289
295
 
296
+ if device.is_cpu:
297
+ NUM_THREADS = 1
298
+ else:
299
+ NUM_THREADS = 32
300
+
290
301
  @wp.func
291
302
  def relu(x: float):
292
303
  return wp.max(x, 0.0)
293
304
 
294
- @wp.kernel
305
+ @wp.kernel(module="unique")
295
306
  def compute(
296
307
  input: wp.array2d(dtype=float),
297
308
  weights: wp.array2d(dtype=float),
@@ -353,7 +364,6 @@ try:
353
364
  import torch
354
365
 
355
366
  # check which Warp devices work with Torch
356
- # CUDA devices may fail if Torch was not compiled with CUDA support
357
367
  torch_compatible_devices = []
358
368
  torch_compatible_cuda_devices = []
359
369
 
@@ -372,7 +382,7 @@ try:
372
382
  "test_single_layer_nn",
373
383
  test_single_layer_nn,
374
384
  check_output=False,
375
- devices=torch_compatible_cuda_devices,
385
+ devices=torch_compatible_devices,
376
386
  )
377
387
  add_function_test(
378
388
  TestTileMLP,
@@ -388,4 +398,5 @@ except Exception as e:
388
398
 
389
399
  if __name__ == "__main__":
390
400
  wp.clear_kernel_cache()
401
+ wp.clear_lto_cache()
391
402
  unittest.main(verbosity=2, failfast=True)
@@ -176,6 +176,64 @@ def test_tile_reduce_custom(test, device):
176
176
  test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
177
177
 
178
178
 
179
+ @wp.struct
180
+ class KeyValue:
181
+ key: wp.int32
182
+ value: wp.float32
183
+
184
+
185
+ @wp.func
186
+ def kv_max(a: KeyValue, b: KeyValue) -> KeyValue:
187
+ return wp.where(a.value < b.value, b, a)
188
+
189
+
190
+ @wp.kernel
191
+ def initialize_key_value(values: wp.array2d(dtype=wp.float32), keyvalues: wp.array2d(dtype=KeyValue)):
192
+ batch, idx = wp.tid()
193
+ keyvalues[batch, idx] = KeyValue(idx, values[batch, idx])
194
+
195
+
196
+ @wp.kernel(enable_backward=False)
197
+ def tile_reduce_custom_struct_kernel(values: wp.array2d(dtype=KeyValue), res: wp.array(dtype=KeyValue)):
198
+ # output tile index
199
+ i = wp.tid()
200
+
201
+ t = wp.tile_load(values, shape=(1, TILE_DIM), offset=(i, 0))
202
+
203
+ max_el = wp.tile_reduce(kv_max, t)
204
+ wp.tile_store(res, max_el, offset=i)
205
+
206
+
207
+ def test_tile_reduce_custom_struct(test, device):
208
+ batch_count = 56
209
+
210
+ N = TILE_DIM
211
+
212
+ rng = np.random.default_rng(42)
213
+ input = rng.random((batch_count, N), dtype=np.float32)
214
+
215
+ input_wp = wp.array(input, dtype=wp.float32, device=device)
216
+ keyvalues_wp = wp.empty(input_wp.shape, dtype=KeyValue, device=device)
217
+
218
+ wp.launch(initialize_key_value, dim=[batch_count, N], inputs=[input_wp], outputs=[keyvalues_wp], device=device)
219
+
220
+ output_wp = wp.empty(batch_count, dtype=KeyValue, device=device)
221
+
222
+ wp.launch_tiled(
223
+ tile_reduce_custom_struct_kernel,
224
+ dim=[batch_count],
225
+ inputs=[keyvalues_wp],
226
+ outputs=[output_wp],
227
+ block_dim=TILE_DIM,
228
+ device=device,
229
+ )
230
+
231
+ prod_wp = np.array([k for k, v in output_wp.numpy()])
232
+ expected = np.argmax(input, axis=1)
233
+
234
+ assert_np_equal(prod_wp, expected)
235
+
236
+
179
237
  @wp.kernel
180
238
  def tile_grouped_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
181
239
  # output tile index
@@ -365,7 +423,7 @@ def test_tile_arange(test, device):
365
423
  assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
366
424
 
367
425
 
368
- devices = get_cuda_test_devices()
426
+ devices = get_test_devices()
369
427
 
370
428
 
371
429
  class TestTileReduce(unittest.TestCase):
@@ -376,6 +434,7 @@ add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum,
376
434
  add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
377
435
  add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
378
436
  add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
437
+ add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
379
438
  add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
380
439
  add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
381
440
  add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
@@ -155,7 +155,7 @@ def test_tile_view_offset(test, device):
155
155
  assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
156
156
 
157
157
 
158
- devices = get_cuda_test_devices()
158
+ devices = get_test_devices()
159
159
 
160
160
 
161
161
  class TestTileView(unittest.TestCase):
@@ -23,6 +23,7 @@ def run_suite() -> bool:
23
23
  """Run a test suite"""
24
24
 
25
25
  # force rebuild of all kernels
26
+ wp.clear_lto_cache()
26
27
  wp.clear_kernel_cache()
27
28
  print("Cleared Warp kernel cache")
28
29