warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -106,6 +106,39 @@ def test_tile_reduce_min(test, device):
106
106
  test.assertAlmostEqual(min_wp[i], min_np, places=4)
107
107
 
108
108
 
109
+ @wp.kernel
110
+ def tile_argmin_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=int)):
111
+ # output tile index
112
+ i = wp.tid()
113
+
114
+ a = wp.tile_load(input[i], shape=TILE_DIM)
115
+ m = wp.tile_argmin(a)
116
+
117
+ wp.tile_store(output, m, offset=i)
118
+
119
+
120
+ def test_tile_reduce_argmin(test, device):
121
+ batch_count = 56
122
+
123
+ N = TILE_DIM
124
+
125
+ rng = np.random.default_rng(42)
126
+ input = rng.random((batch_count, N), dtype=np.float32)
127
+
128
+ input_wp = wp.array(input, requires_grad=True, device=device)
129
+ output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
130
+
131
+ with wp.Tape() as tape:
132
+ wp.launch_tiled(
133
+ tile_argmin_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
134
+ )
135
+
136
+ argmin_wp = output_wp.numpy()
137
+ for i in range(batch_count):
138
+ argmin_np = np.argmin(input[i])
139
+ test.assertAlmostEqual(argmin_wp[i], argmin_np, places=4)
140
+
141
+
109
142
  @wp.kernel
110
143
  def tile_max_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
111
144
  # output tile index
@@ -139,6 +172,39 @@ def test_tile_reduce_max(test, device):
139
172
  test.assertAlmostEqual(max_wp[i], max_np, places=4)
140
173
 
141
174
 
175
+ @wp.kernel
176
+ def tile_argmax_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=int)):
177
+ # output tile index
178
+ i = wp.tid()
179
+
180
+ a = wp.tile_load(input[i], shape=TILE_DIM)
181
+ m = wp.tile_argmax(a)
182
+
183
+ wp.tile_store(output, m, offset=i)
184
+
185
+
186
+ def test_tile_reduce_argmax(test, device):
187
+ batch_count = 56
188
+
189
+ N = TILE_DIM
190
+
191
+ rng = np.random.default_rng(42)
192
+ input = rng.random((batch_count, N), dtype=np.float32)
193
+
194
+ input_wp = wp.array(input, requires_grad=True, device=device)
195
+ output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
196
+
197
+ with wp.Tape() as tape:
198
+ wp.launch_tiled(
199
+ tile_argmax_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
200
+ )
201
+
202
+ argmax_wp = output_wp.numpy()
203
+ for i in range(batch_count):
204
+ argmax_np = np.argmax(input[i])
205
+ test.assertAlmostEqual(argmax_wp[i], argmax_np, places=4)
206
+
207
+
142
208
  @wp.kernel
143
209
  def tile_reduce_custom_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
144
210
  # output tile index
@@ -176,6 +242,79 @@ def test_tile_reduce_custom(test, device):
176
242
  test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
177
243
 
178
244
 
245
+ def create_tile_scan_inclusive_kernel(tile_dim: int):
246
+ @wp.kernel
247
+ def tile_scan_inclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
248
+ i = wp.tid()
249
+ t = wp.tile_load(input[i], shape=tile_dim)
250
+ t = wp.tile_scan_inclusive(t)
251
+ wp.tile_store(output[i], t)
252
+
253
+ return tile_scan_inclusive_kernel
254
+
255
+
256
+ def test_tile_scan_inclusive(test, device):
257
+ batch_count = 56
258
+ N = 1234
259
+
260
+ rng = np.random.default_rng(42)
261
+ input = rng.random((batch_count, N), dtype=np.float32)
262
+
263
+ input_wp = wp.array2d(input, requires_grad=True, device=device)
264
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
265
+
266
+ with wp.Tape() as tape:
267
+ wp.launch_tiled(
268
+ create_tile_scan_inclusive_kernel(N),
269
+ dim=[batch_count],
270
+ inputs=[input_wp, output_wp],
271
+ block_dim=TILE_DIM,
272
+ device=device,
273
+ )
274
+
275
+ scan_wp = output_wp.numpy()
276
+ for i in range(batch_count):
277
+ scan_np = np.cumsum(input[i])
278
+ np.testing.assert_allclose(scan_wp[i], scan_np, rtol=1e-5, atol=1e-6)
279
+
280
+
281
+ def create_tile_scan_exclusive_kernel(tile_dim: int):
282
+ @wp.kernel
283
+ def tile_scan_exclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
284
+ i = wp.tid()
285
+ t = wp.tile_load(input[i], shape=tile_dim)
286
+ t = wp.tile_scan_exclusive(t)
287
+ wp.tile_store(output[i], t)
288
+
289
+ return tile_scan_exclusive_kernel
290
+
291
+
292
+ def test_tile_scan_exclusive(test, device):
293
+ batch_count = 56
294
+ N = 1234
295
+
296
+ rng = np.random.default_rng(42)
297
+ input = rng.random((batch_count, N), dtype=np.float32)
298
+
299
+ input_wp = wp.array2d(input, requires_grad=True, device=device)
300
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
301
+
302
+ with wp.Tape() as tape:
303
+ wp.launch_tiled(
304
+ create_tile_scan_exclusive_kernel(N),
305
+ dim=[batch_count],
306
+ inputs=[input_wp, output_wp],
307
+ block_dim=TILE_DIM,
308
+ device=device,
309
+ )
310
+
311
+ scan_wp = output_wp.numpy()
312
+ for i in range(batch_count):
313
+ scan_np = np.zeros(N, dtype=np.float32)
314
+ scan_np[1:] = np.cumsum(input[i][:-1])
315
+ np.testing.assert_allclose(scan_wp[i], scan_np, rtol=1e-5, atol=1e-6)
316
+
317
+
179
318
  @wp.struct
180
319
  class KeyValue:
181
320
  key: wp.int32
@@ -423,7 +562,75 @@ def test_tile_arange(test, device):
423
562
  assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
424
563
 
425
564
 
565
+ @wp.kernel
566
+ def tile_strided_loop_kernel(arr: wp.array(dtype=float), max_val: wp.array(dtype=float)):
567
+ tid, lane = wp.tid()
568
+
569
+ num_threads = wp.block_dim()
570
+
571
+ thread_max = wp.float32(-wp.inf)
572
+
573
+ length = arr.shape[0]
574
+ upper = ((length + num_threads - 1) // num_threads) * num_threads
575
+ for el_id in range(lane, upper, num_threads):
576
+ if el_id < length:
577
+ val = arr[el_id]
578
+ else:
579
+ val = wp.float32(-wp.inf)
580
+
581
+ t = wp.tile(val)
582
+ local_max = wp.tile_max(t)
583
+
584
+ thread_max = wp.max(thread_max, local_max[0])
585
+
586
+ if lane == 0:
587
+ max_val[0] = thread_max
588
+
589
+
590
+ def test_tile_strided_loop(test, device):
591
+ N = 5 # Length of array
592
+
593
+ rng = np.random.default_rng(42)
594
+ input = rng.random(N, dtype=np.float32)
595
+
596
+ input_wp = wp.array(input, device=device)
597
+ output_wp = wp.zeros(1, dtype=wp.float32, device=device)
598
+
599
+ wp.launch_tiled(
600
+ tile_strided_loop_kernel,
601
+ dim=[1],
602
+ inputs=[input_wp, output_wp],
603
+ device=device,
604
+ block_dim=128,
605
+ )
606
+
607
+ max_wp = output_wp.numpy()
608
+ max_np = np.max(input)
609
+ test.assertAlmostEqual(max_wp[0], max_np, places=4)
610
+
611
+
612
+ @wp.kernel
613
+ def test_tile_reduce_matrix_kernel(y: wp.array(dtype=wp.mat33)):
614
+ i = wp.tid()
615
+ I = wp.identity(3, dtype=wp.float32)
616
+ m = wp.float32(i) * I
617
+
618
+ t = wp.tile(m, preserve_type=True)
619
+ sum = wp.tile_reduce(wp.add, t)
620
+
621
+ wp.tile_store(y, sum)
622
+
623
+
624
+ def test_tile_reduce_matrix(test, device):
625
+ y = wp.zeros(shape=1, dtype=wp.mat33, device=device)
626
+
627
+ wp.launch(test_tile_reduce_matrix_kernel, dim=TILE_DIM, inputs=[], outputs=[y], block_dim=TILE_DIM, device=device)
628
+
629
+ assert_np_equal(y.numpy().squeeze(), np.eye(3, dtype=np.float32) * 2016.0)
630
+
631
+
426
632
  devices = get_test_devices()
633
+ cuda_devices = get_cuda_test_devices()
427
634
 
428
635
 
429
636
  class TestTileReduce(unittest.TestCase):
@@ -433,6 +640,8 @@ class TestTileReduce(unittest.TestCase):
433
640
  add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum, devices=devices)
434
641
  add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
435
642
  add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
643
+ add_function_test(TestTileReduce, "test_tile_reduce_argmin", test_tile_reduce_argmin, devices=devices)
644
+ add_function_test(TestTileReduce, "test_tile_reduce_argmax", test_tile_reduce_argmax, devices=devices)
436
645
  add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
437
646
  add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
438
647
  add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
@@ -441,6 +650,10 @@ add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devi
441
650
  add_function_test(TestTileReduce, "test_tile_arange", test_tile_arange, devices=devices)
442
651
  add_function_test(TestTileReduce, "test_tile_untile_scalar", test_tile_untile_scalar, devices=devices)
443
652
  add_function_test(TestTileReduce, "test_tile_untile_vector", test_tile_untile_vector, devices=devices)
653
+ add_function_test(TestTileReduce, "test_tile_strided_loop", test_tile_strided_loop, devices=devices)
654
+ add_function_test(TestTileReduce, "test_tile_scan_inclusive", test_tile_scan_inclusive, devices=devices)
655
+ add_function_test(TestTileReduce, "test_tile_scan_exclusive", test_tile_scan_exclusive, devices=devices)
656
+ add_function_test(TestTileReduce, "test_tile_reduce_matrix", test_tile_reduce_matrix, devices=cuda_devices)
444
657
 
445
658
  if __name__ == "__main__":
446
659
  wp.clear_kernel_cache()
@@ -224,6 +224,121 @@ def test_tile_shared_non_aligned(test, device):
224
224
  assert hooks.backward_smem_bytes == expected_required_shared * 2
225
225
 
226
226
 
227
+ def test_tile_shared_vec_accumulation(test, device):
228
+ BLOCK_DIM = 64
229
+
230
+ @wp.kernel
231
+ def compute(indices: wp.array(dtype=int), vecs: wp.array(dtype=wp.vec3), output: wp.array2d(dtype=float)):
232
+ i, j = wp.tid()
233
+
234
+ idx_tile = wp.tile_load(indices, shape=BLOCK_DIM, offset=i * BLOCK_DIM)
235
+ idx = idx_tile[j]
236
+
237
+ s = wp.tile_zeros(shape=(1, 3), dtype=float)
238
+
239
+ s[0, 0] += vecs[idx].x
240
+ s[0, 1] += vecs[idx].y
241
+ s[0, 2] += vecs[idx].z
242
+
243
+ wp.tile_store(output, s, offset=(i, 0))
244
+
245
+ N = BLOCK_DIM * 3
246
+
247
+ basis_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
248
+ vecs = wp.array(basis_vecs, dtype=wp.vec3, requires_grad=True, device=device)
249
+
250
+ rng = np.random.default_rng(42)
251
+ indices_np = rng.integers(0, 3, size=N)
252
+
253
+ indices = wp.array(indices_np, dtype=int, requires_grad=True, device=device)
254
+
255
+ output = wp.zeros(shape=(3, 3), dtype=float, requires_grad=True, device=device)
256
+
257
+ tape = wp.Tape()
258
+ with tape:
259
+ wp.launch_tiled(compute, dim=3, inputs=[indices, vecs, output], block_dim=BLOCK_DIM, device=device)
260
+
261
+ output.grad = wp.ones_like(output)
262
+
263
+ tape.backward()
264
+
265
+ n0 = np.count_nonzero(indices_np == 0)
266
+ n1 = np.count_nonzero(indices_np == 1)
267
+ n2 = np.count_nonzero(indices_np == 2)
268
+ true_grads = np.array([[n0, n0, n0], [n1, n1, n1], [n2, n2, n2]])
269
+
270
+ indices_np = indices_np.reshape((3, BLOCK_DIM))
271
+
272
+ def compute_row(idx):
273
+ n0 = np.count_nonzero(indices_np[idx, :] == 0)
274
+ n1 = np.count_nonzero(indices_np[idx, :] == 1)
275
+ n2 = np.count_nonzero(indices_np[idx, :] == 2)
276
+ return np.array([1, 0, 0]) * n0 + np.array([0, 1, 0]) * n1 + np.array([0, 0, 1]) * n2
277
+
278
+ row_0 = compute_row(0)
279
+ row_1 = compute_row(1)
280
+ row_2 = compute_row(2)
281
+
282
+ true_vecs = np.stack([row_0, row_1, row_2])
283
+
284
+ assert_np_equal(output.numpy(), true_vecs)
285
+ assert_np_equal(vecs.grad.numpy(), true_grads)
286
+
287
+
288
+ def test_tile_shared_simple_reduction_add(test, device):
289
+ BLOCK_DIM = 64
290
+
291
+ @wp.kernel
292
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
293
+ i, j = wp.tid()
294
+
295
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
296
+
297
+ k = BLOCK_DIM // 2
298
+ while k > 0:
299
+ if j < k:
300
+ t[j] += t[j + k]
301
+ k //= 2
302
+
303
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
304
+
305
+ N = BLOCK_DIM * 4
306
+ x_np = np.arange(N, dtype=np.float32)
307
+ x = wp.array(x_np, dtype=float, device=device)
308
+ y = wp.zeros(4, dtype=float, device=device)
309
+
310
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
311
+
312
+ assert_np_equal(np.sum(y.numpy()), np.sum(x_np))
313
+
314
+
315
+ def test_tile_shared_simple_reduction_sub(test, device):
316
+ BLOCK_DIM = 64
317
+
318
+ @wp.kernel
319
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
320
+ i, j = wp.tid()
321
+
322
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
323
+
324
+ k = BLOCK_DIM // 2
325
+ while k > 0:
326
+ if j < k:
327
+ t[j] -= t[j + k]
328
+ k //= 2
329
+
330
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
331
+
332
+ N = BLOCK_DIM * 4
333
+ x_np = np.arange(N, dtype=np.float32)
334
+ x = wp.array(x_np, dtype=float, device=device)
335
+ y = wp.zeros(4, dtype=float, device=device)
336
+
337
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
338
+
339
+ assert_np_equal(np.sum(y.numpy()), 0.0)
340
+
341
+
227
342
  devices = get_cuda_test_devices()
228
343
 
229
344
 
@@ -240,7 +355,21 @@ add_function_test(
240
355
  add_function_test(TestTileSharedMemory, "test_tile_shared_mem_graph", test_tile_shared_mem_graph, devices=devices)
241
356
  add_function_test(TestTileSharedMemory, "test_tile_shared_mem_func", test_tile_shared_mem_func, devices=devices)
242
357
  add_function_test(TestTileSharedMemory, "test_tile_shared_non_aligned", test_tile_shared_non_aligned, devices=devices)
243
-
358
+ add_function_test(
359
+ TestTileSharedMemory, "test_tile_shared_vec_accumulation", test_tile_shared_vec_accumulation, devices=devices
360
+ )
361
+ add_function_test(
362
+ TestTileSharedMemory,
363
+ "test_tile_shared_simple_reduction_add",
364
+ test_tile_shared_simple_reduction_add,
365
+ devices=devices,
366
+ )
367
+ add_function_test(
368
+ TestTileSharedMemory,
369
+ "test_tile_shared_simple_reduction_sub",
370
+ test_tile_shared_simple_reduction_sub,
371
+ devices=devices,
372
+ )
244
373
 
245
374
  if __name__ == "__main__":
246
375
  wp.clear_kernel_cache()
@@ -0,0 +1,117 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ def create_sort_kernel(KEY_TYPE, MAX_SORT_LENGTH):
25
+ @wp.kernel
26
+ def tile_sort_kernel(
27
+ input_keys: wp.array(dtype=KEY_TYPE),
28
+ input_values: wp.array(dtype=wp.int32),
29
+ output_keys: wp.array(dtype=KEY_TYPE),
30
+ output_values: wp.array(dtype=wp.int32),
31
+ ):
32
+ # Load input into shared memory
33
+ keys = wp.tile_load(input_keys, shape=MAX_SORT_LENGTH, storage="shared")
34
+ values = wp.tile_load(input_values, shape=MAX_SORT_LENGTH, storage="shared")
35
+
36
+ # Perform in-place sorting
37
+ wp.tile_sort(keys, values)
38
+
39
+ # Store sorted shared memory into output arrays
40
+ wp.tile_store(output_keys, keys)
41
+ wp.tile_store(output_values, values)
42
+
43
+ return tile_sort_kernel
44
+
45
+
46
+ def test_tile_sort(test, device):
47
+ for dtype in [int, float]: # Loop over int and float keys
48
+ for j in range(5, 10):
49
+ TILE_DIM = 2**j
50
+ for i in range(0, 11): # Start from 1 to avoid zero-length cases
51
+ length = 2**i + 1
52
+
53
+ rng = np.random.default_rng(42) # Create a random generator instance
54
+
55
+ if dtype == int:
56
+ np_keys = rng.choice(1000000000, size=length, replace=False)
57
+ else: # dtype == float
58
+ np_keys = rng.uniform(0, 1000000000, size=length)
59
+
60
+ np_values = np.arange(length)
61
+
62
+ # Generate random keys and iota indexer
63
+ input_keys = wp.array(np_keys, dtype=dtype, device=device)
64
+ input_values = wp.array(np_values, dtype=int, device=device)
65
+ output_keys = wp.zeros_like(input_keys, device=device)
66
+ output_values = wp.zeros_like(input_values, device=device)
67
+
68
+ # Execute sorting kernel
69
+ kernel = create_sort_kernel(dtype, length)
70
+ wp.launch_tiled(
71
+ kernel,
72
+ dim=1,
73
+ inputs=[input_keys, input_values, output_keys, output_values],
74
+ block_dim=TILE_DIM,
75
+ device=device,
76
+ )
77
+ wp.synchronize()
78
+
79
+ # Sort using NumPy for validation
80
+ sorted_indices = np.argsort(np_keys)
81
+ np_sorted_keys = np_keys[sorted_indices]
82
+ np_sorted_values = np_values[sorted_indices]
83
+
84
+ if dtype == int:
85
+ keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
86
+ else: # dtype == float
87
+ keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
88
+
89
+ values_match = np.array_equal(output_values.numpy(), np_sorted_values)
90
+
91
+ if not keys_match or not values_match:
92
+ print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
93
+ print("")
94
+ print(output_keys.numpy())
95
+ print(np_sorted_keys)
96
+ print("")
97
+ print(output_values.numpy())
98
+ print(np_sorted_values)
99
+ print("")
100
+
101
+ # Validate results
102
+ assert keys_match, f"Key sorting mismatch for dtype={dtype}!"
103
+ assert values_match, f"Value sorting mismatch for dtype={dtype}!"
104
+
105
+
106
+ devices = get_test_devices()
107
+
108
+
109
+ class TestTileSort(unittest.TestCase):
110
+ pass
111
+
112
+
113
+ add_function_test(TestTileSort, "test_tile_sort", test_tile_sort, devices=devices)
114
+
115
+ if __name__ == "__main__":
116
+ wp.clear_kernel_cache()
117
+ unittest.main(verbosity=2, failfast=True)
@@ -113,17 +113,18 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
113
113
  from warp.tests.interop.test_dlpack import TestDLPack
114
114
  from warp.tests.interop.test_jax import TestJax
115
115
  from warp.tests.interop.test_torch import TestTorch
116
+ from warp.tests.sim.test_cloth import TestCloth
116
117
  from warp.tests.sim.test_collision import TestCollision
117
118
  from warp.tests.sim.test_coloring import TestColoring
118
119
  from warp.tests.sim.test_model import TestModel
119
120
  from warp.tests.sim.test_sim_grad import TestSimGradients
120
121
  from warp.tests.sim.test_sim_kinematics import TestSimKinematics
121
- from warp.tests.sim.test_vbd import TestVbd
122
122
  from warp.tests.test_adam import TestAdam
123
123
  from warp.tests.test_arithmetic import TestArithmetic
124
124
  from warp.tests.test_array import TestArray
125
125
  from warp.tests.test_array_reduce import TestArrayReduce
126
126
  from warp.tests.test_atomic import TestAtomic
127
+ from warp.tests.test_atomic_cas import TestAtomicCAS
127
128
  from warp.tests.test_bool import TestBool
128
129
  from warp.tests.test_builtins_resolution import TestBuiltinsResolution
129
130
  from warp.tests.test_closest_point_edge_edge import TestClosestPointEdgeEdgeMethods
@@ -166,7 +167,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
166
167
  from warp.tests.test_mat_lite import TestMatLite
167
168
  from warp.tests.test_mat_scalar_ops import TestMatScalarOps
168
169
  from warp.tests.test_math import TestMath
169
- from warp.tests.test_mlp import TestMLP
170
170
  from warp.tests.test_module_hashing import TestModuleHashing
171
171
  from warp.tests.test_modules_lite import TestModuleLite
172
172
  from warp.tests.test_noise import TestNoise
@@ -208,10 +208,12 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
208
208
  TestArrayReduce,
209
209
  TestAsync,
210
210
  TestAtomic,
211
+ TestAtomicCAS,
211
212
  TestBool,
212
213
  TestBuiltinsResolution,
213
214
  TestBvh,
214
215
  TestClosestPointEdgeEdgeMethods,
216
+ TestCloth,
215
217
  TestCodeGen,
216
218
  TestCodeGenInstancing,
217
219
  TestCollision,
@@ -262,7 +264,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
262
264
  TestMeshQueryAABBMethods,
263
265
  TestMeshQueryPoint,
264
266
  TestMeshQueryRay,
265
- TestMLP,
266
267
  TestModel,
267
268
  TestModuleHashing,
268
269
  TestModuleLite,
@@ -300,7 +301,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
300
301
  TestTriangleClosestPoint,
301
302
  TestTypes,
302
303
  TestUtils,
303
- TestVbd,
304
304
  TestVec,
305
305
  TestVecLite,
306
306
  TestVecScalarOps,
@@ -350,7 +350,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
350
350
  from warp.tests.test_lvalue import TestLValue
351
351
  from warp.tests.test_mat_lite import TestMatLite
352
352
  from warp.tests.test_math import TestMath
353
- from warp.tests.test_mlp import TestMLP
354
353
  from warp.tests.test_module_hashing import TestModuleHashing
355
354
  from warp.tests.test_modules_lite import TestModuleLite
356
355
  from warp.tests.test_noise import TestNoise
@@ -397,7 +396,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
397
396
  TestMeshQueryAABBMethods,
398
397
  TestMeshQueryPoint,
399
398
  TestMeshQueryRay,
400
- TestMLP,
401
399
  TestModuleHashing,
402
400
  TestModuleLite,
403
401
  TestNoise,