warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -28,7 +28,7 @@ def test_tile_shared_mem_size(test, device):
28
28
 
29
29
  BLOCK_DIM = 256
30
30
 
31
- @wp.kernel
31
+ @wp.kernel(module="unique")
32
32
  def compute(out: wp.array2d(dtype=float)):
33
33
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
34
34
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -64,7 +64,7 @@ def test_tile_shared_mem_large(test, device):
64
64
  BLOCK_DIM = 256
65
65
 
66
66
  # we disable backward kernel gen since 128k is not supported on most architectures
67
- @wp.kernel(enable_backward=False)
67
+ @wp.kernel(enable_backward=False, module="unique")
68
68
  def compute(out: wp.array2d(dtype=float)):
69
69
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
70
70
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -100,7 +100,7 @@ def test_tile_shared_mem_graph(test, device):
100
100
 
101
101
  BLOCK_DIM = 256
102
102
 
103
- @wp.kernel
103
+ @wp.kernel(module="unique")
104
104
  def compute(out: wp.array2d(dtype=float)):
105
105
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
106
106
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -110,7 +110,7 @@ def test_tile_shared_mem_graph(test, device):
110
110
 
111
111
  out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
112
112
 
113
- wp.load_module(device=device)
113
+ compute.module.load(device)
114
114
 
115
115
  wp.capture_begin(device, force_module_load=False)
116
116
  wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
@@ -157,7 +157,7 @@ def test_tile_shared_mem_func(test, device):
157
157
 
158
158
  return a + b
159
159
 
160
- @wp.kernel
160
+ @wp.kernel(module="unique")
161
161
  def compute(out: wp.array2d(dtype=float)):
162
162
  s = add_tile_small()
163
163
  b = add_tile_big()
@@ -197,7 +197,7 @@ def test_tile_shared_non_aligned(test, device):
197
197
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 3.0
198
198
  return a + b
199
199
 
200
- @wp.kernel
200
+ @wp.kernel(module="unique")
201
201
  def compute(out: wp.array2d(dtype=float)):
202
202
  # This test the logic in the stack allocator, which should increment and
203
203
  # decrement the stack pointer each time foo() is called
@@ -224,6 +224,121 @@ def test_tile_shared_non_aligned(test, device):
224
224
  assert hooks.backward_smem_bytes == expected_required_shared * 2
225
225
 
226
226
 
227
+ def test_tile_shared_vec_accumulation(test, device):
228
+ BLOCK_DIM = 256
229
+
230
+ @wp.kernel(module="unique")
231
+ def compute(indices: wp.array(dtype=int), vecs: wp.array(dtype=wp.vec3), output: wp.array2d(dtype=float)):
232
+ i, j = wp.tid()
233
+
234
+ idx_tile = wp.tile_load(indices, shape=BLOCK_DIM, offset=i * BLOCK_DIM)
235
+ idx = idx_tile[j]
236
+
237
+ s = wp.tile_zeros(shape=(1, 3), dtype=float)
238
+
239
+ s[0, 0] += vecs[idx].x
240
+ s[0, 1] += vecs[idx].y
241
+ s[0, 2] += vecs[idx].z
242
+
243
+ wp.tile_store(output, s, offset=(i, 0))
244
+
245
+ N = BLOCK_DIM * 3
246
+
247
+ basis_vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
248
+ vecs = wp.array(basis_vecs, dtype=wp.vec3, requires_grad=True, device=device)
249
+
250
+ rng = np.random.default_rng(42)
251
+ indices_np = rng.integers(0, 3, size=N)
252
+
253
+ indices = wp.array(indices_np, dtype=int, requires_grad=True, device=device)
254
+
255
+ output = wp.zeros(shape=(3, 3), dtype=float, requires_grad=True, device=device)
256
+
257
+ tape = wp.Tape()
258
+ with tape:
259
+ wp.launch_tiled(compute, dim=3, inputs=[indices, vecs, output], block_dim=BLOCK_DIM, device=device)
260
+
261
+ output.grad = wp.ones_like(output)
262
+
263
+ tape.backward()
264
+
265
+ n0 = np.count_nonzero(indices_np == 0)
266
+ n1 = np.count_nonzero(indices_np == 1)
267
+ n2 = np.count_nonzero(indices_np == 2)
268
+ true_grads = np.array([[n0, n0, n0], [n1, n1, n1], [n2, n2, n2]])
269
+
270
+ indices_np = indices_np.reshape((3, BLOCK_DIM))
271
+
272
+ def compute_row(idx):
273
+ n0 = np.count_nonzero(indices_np[idx, :] == 0)
274
+ n1 = np.count_nonzero(indices_np[idx, :] == 1)
275
+ n2 = np.count_nonzero(indices_np[idx, :] == 2)
276
+ return np.array([1, 0, 0]) * n0 + np.array([0, 1, 0]) * n1 + np.array([0, 0, 1]) * n2
277
+
278
+ row_0 = compute_row(0)
279
+ row_1 = compute_row(1)
280
+ row_2 = compute_row(2)
281
+
282
+ true_vecs = np.stack([row_0, row_1, row_2])
283
+
284
+ assert_np_equal(output.numpy(), true_vecs)
285
+ assert_np_equal(vecs.grad.numpy(), true_grads)
286
+
287
+
288
+ def test_tile_shared_simple_reduction_add(test, device):
289
+ BLOCK_DIM = 256
290
+
291
+ @wp.kernel(module="unique")
292
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
293
+ i, j = wp.tid()
294
+
295
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
296
+
297
+ k = BLOCK_DIM // 2
298
+ while k > 0:
299
+ if j < k:
300
+ t[j] += t[j + k]
301
+ k //= 2
302
+
303
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
304
+
305
+ N = BLOCK_DIM * 4
306
+ x_np = np.arange(N, dtype=np.float32)
307
+ x = wp.array(x_np, dtype=float, device=device)
308
+ y = wp.zeros(4, dtype=float, device=device)
309
+
310
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
311
+
312
+ assert_np_equal(np.sum(y.numpy()), np.sum(x_np))
313
+
314
+
315
+ def test_tile_shared_simple_reduction_sub(test, device):
316
+ BLOCK_DIM = 256
317
+
318
+ @wp.kernel(module="unique")
319
+ def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
320
+ i, j = wp.tid()
321
+
322
+ t = wp.tile_load(x, shape=BLOCK_DIM, offset=BLOCK_DIM * i)
323
+
324
+ k = BLOCK_DIM // 2
325
+ while k > 0:
326
+ if j < k:
327
+ t[j] -= t[j + k]
328
+ k //= 2
329
+
330
+ wp.tile_store(y, wp.tile_view(t, offset=(0,), shape=(1,)), i)
331
+
332
+ N = BLOCK_DIM * 4
333
+ x_np = np.arange(N, dtype=np.float32)
334
+ x = wp.array(x_np, dtype=float, device=device)
335
+ y = wp.zeros(4, dtype=float, device=device)
336
+
337
+ wp.launch_tiled(compute, dim=4, inputs=[x], outputs=[y], block_dim=BLOCK_DIM, device=device)
338
+
339
+ assert_np_equal(np.sum(y.numpy()), 0.0)
340
+
341
+
227
342
  devices = get_cuda_test_devices()
228
343
 
229
344
 
@@ -240,7 +355,21 @@ add_function_test(
240
355
  add_function_test(TestTileSharedMemory, "test_tile_shared_mem_graph", test_tile_shared_mem_graph, devices=devices)
241
356
  add_function_test(TestTileSharedMemory, "test_tile_shared_mem_func", test_tile_shared_mem_func, devices=devices)
242
357
  add_function_test(TestTileSharedMemory, "test_tile_shared_non_aligned", test_tile_shared_non_aligned, devices=devices)
243
-
358
+ add_function_test(
359
+ TestTileSharedMemory, "test_tile_shared_vec_accumulation", test_tile_shared_vec_accumulation, devices=devices
360
+ )
361
+ add_function_test(
362
+ TestTileSharedMemory,
363
+ "test_tile_shared_simple_reduction_add",
364
+ test_tile_shared_simple_reduction_add,
365
+ devices=devices,
366
+ )
367
+ add_function_test(
368
+ TestTileSharedMemory,
369
+ "test_tile_shared_simple_reduction_sub",
370
+ test_tile_shared_simple_reduction_sub,
371
+ devices=devices,
372
+ )
244
373
 
245
374
  if __name__ == "__main__":
246
375
  wp.clear_kernel_cache()
@@ -0,0 +1,121 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ def create_sort_kernel(KEY_TYPE, MAX_SORT_LENGTH):
25
+ @wp.kernel
26
+ def tile_sort_kernel(
27
+ input_keys: wp.array(dtype=KEY_TYPE),
28
+ input_values: wp.array(dtype=wp.int32),
29
+ output_keys: wp.array(dtype=KEY_TYPE),
30
+ output_values: wp.array(dtype=wp.int32),
31
+ ):
32
+ # Load input into shared memory
33
+ keys = wp.tile_load(input_keys, shape=MAX_SORT_LENGTH, storage="shared")
34
+ values = wp.tile_load(input_values, shape=MAX_SORT_LENGTH, storage="shared")
35
+
36
+ # Perform in-place sorting
37
+ wp.tile_sort(keys, values)
38
+
39
+ # Store sorted shared memory into output arrays
40
+ wp.tile_store(output_keys, keys)
41
+ wp.tile_store(output_values, values)
42
+
43
+ return tile_sort_kernel
44
+
45
+
46
+ def test_tile_sort(test, device):
47
+ # Forward-declare kernels for more efficient compilation
48
+ kernels = {}
49
+ for dtype in [int, float]:
50
+ for i in range(0, 11):
51
+ length = 2**i + 1
52
+ kernels[(dtype, length)] = create_sort_kernel(dtype, length)
53
+
54
+ for (dtype, length), kernel in kernels.items():
55
+ for j in range(5, 10):
56
+ TILE_DIM = 2**j
57
+
58
+ rng = np.random.default_rng(42) # Create a random generator instance
59
+
60
+ if dtype == int:
61
+ np_keys = rng.choice(1000000000, size=length, replace=False)
62
+ else: # dtype == float
63
+ np_keys = rng.uniform(0, 1000000000, size=length).astype(dtype)
64
+
65
+ np_values = np.arange(length)
66
+
67
+ # Generate random keys and iota indexer
68
+ input_keys = wp.array(np_keys, dtype=dtype, device=device)
69
+ input_values = wp.array(np_values, dtype=int, device=device)
70
+ output_keys = wp.zeros_like(input_keys, device=device)
71
+ output_values = wp.zeros_like(input_values, device=device)
72
+
73
+ # Execute sorting kernel
74
+ wp.launch_tiled(
75
+ kernel,
76
+ dim=1,
77
+ inputs=[input_keys, input_values, output_keys, output_values],
78
+ block_dim=TILE_DIM,
79
+ device=device,
80
+ )
81
+ wp.synchronize()
82
+
83
+ # Sort using NumPy for validation
84
+ sorted_indices = np.argsort(np_keys)
85
+ np_sorted_keys = np_keys[sorted_indices]
86
+ np_sorted_values = np_values[sorted_indices]
87
+
88
+ if dtype == int:
89
+ keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
90
+ else: # dtype == float
91
+ keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
92
+
93
+ values_match = np.array_equal(output_values.numpy(), np_sorted_values)
94
+
95
+ if not keys_match or not values_match:
96
+ print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
97
+ print("")
98
+ print(output_keys.numpy())
99
+ print(np_sorted_keys)
100
+ print("")
101
+ print(output_values.numpy())
102
+ print(np_sorted_values)
103
+ print("")
104
+
105
+ # Validate results
106
+ test.assertTrue(keys_match, f"Key sorting mismatch for dtype={dtype}!")
107
+ test.assertTrue(values_match, f"Value sorting mismatch for dtype={dtype}!")
108
+
109
+
110
+ devices = get_test_devices()
111
+
112
+
113
+ class TestTileSort(unittest.TestCase):
114
+ pass
115
+
116
+
117
+ add_function_test(TestTileSort, "test_tile_sort", test_tile_sort, devices=devices)
118
+
119
+ if __name__ == "__main__":
120
+ wp.clear_kernel_cache()
121
+ unittest.main(verbosity=2, failfast=True)
@@ -113,17 +113,18 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
113
113
  from warp.tests.interop.test_dlpack import TestDLPack
114
114
  from warp.tests.interop.test_jax import TestJax
115
115
  from warp.tests.interop.test_torch import TestTorch
116
+ from warp.tests.sim.test_cloth import TestCloth
116
117
  from warp.tests.sim.test_collision import TestCollision
117
118
  from warp.tests.sim.test_coloring import TestColoring
118
119
  from warp.tests.sim.test_model import TestModel
119
120
  from warp.tests.sim.test_sim_grad import TestSimGradients
120
121
  from warp.tests.sim.test_sim_kinematics import TestSimKinematics
121
- from warp.tests.sim.test_vbd import TestVbd
122
122
  from warp.tests.test_adam import TestAdam
123
123
  from warp.tests.test_arithmetic import TestArithmetic
124
124
  from warp.tests.test_array import TestArray
125
125
  from warp.tests.test_array_reduce import TestArrayReduce
126
126
  from warp.tests.test_atomic import TestAtomic
127
+ from warp.tests.test_atomic_cas import TestAtomicCAS
127
128
  from warp.tests.test_bool import TestBool
128
129
  from warp.tests.test_builtins_resolution import TestBuiltinsResolution
129
130
  from warp.tests.test_closest_point_edge_edge import TestClosestPointEdgeEdgeMethods
@@ -166,7 +167,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
166
167
  from warp.tests.test_mat_lite import TestMatLite
167
168
  from warp.tests.test_mat_scalar_ops import TestMatScalarOps
168
169
  from warp.tests.test_math import TestMath
169
- from warp.tests.test_mlp import TestMLP
170
170
  from warp.tests.test_module_hashing import TestModuleHashing
171
171
  from warp.tests.test_modules_lite import TestModuleLite
172
172
  from warp.tests.test_noise import TestNoise
@@ -193,13 +193,18 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
193
193
  from warp.tests.test_types import TestTypes
194
194
  from warp.tests.test_utils import TestUtils
195
195
  from warp.tests.test_vec import TestVec
196
+ from warp.tests.test_vec_constructors import TestVecConstructors
196
197
  from warp.tests.test_vec_lite import TestVecLite
197
198
  from warp.tests.test_vec_scalar_ops import TestVecScalarOps
198
199
  from warp.tests.test_verify_fp import TestVerifyFP
199
200
  from warp.tests.tile.test_tile import TestTile
201
+ from warp.tests.tile.test_tile_load import TestTileLoad
200
202
  from warp.tests.tile.test_tile_mathdx import TestTileMathDx
203
+ from warp.tests.tile.test_tile_matmul import TestTileMatmul
201
204
  from warp.tests.tile.test_tile_reduce import TestTileReduce
202
205
  from warp.tests.tile.test_tile_shared_memory import TestTileSharedMemory
206
+ from warp.tests.tile.test_tile_sort import TestTileSort
207
+ from warp.tests.tile.test_tile_view import TestTileView
203
208
 
204
209
  test_classes = [
205
210
  TestAdam,
@@ -208,10 +213,12 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
208
213
  TestArrayReduce,
209
214
  TestAsync,
210
215
  TestAtomic,
216
+ TestAtomicCAS,
211
217
  TestBool,
212
218
  TestBuiltinsResolution,
213
219
  TestBvh,
214
220
  TestClosestPointEdgeEdgeMethods,
221
+ TestCloth,
215
222
  TestCodeGen,
216
223
  TestCodeGenInstancing,
217
224
  TestCollision,
@@ -262,7 +269,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
262
269
  TestMeshQueryAABBMethods,
263
270
  TestMeshQueryPoint,
264
271
  TestMeshQueryRay,
265
- TestMLP,
266
272
  TestModel,
267
273
  TestModuleHashing,
268
274
  TestModuleLite,
@@ -292,16 +298,20 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
292
298
  TestStruct,
293
299
  TestTape,
294
300
  TestTile,
301
+ TestTileLoad,
295
302
  TestTileMathDx,
303
+ TestTileMatmul,
296
304
  TestTileReduce,
297
305
  TestTileSharedMemory,
306
+ TestTileSort,
307
+ TestTileView,
298
308
  TestTorch,
299
309
  TestTransientModule,
300
310
  TestTriangleClosestPoint,
301
311
  TestTypes,
302
312
  TestUtils,
303
- TestVbd,
304
313
  TestVec,
314
+ TestVecConstructors,
305
315
  TestVecLite,
306
316
  TestVecScalarOps,
307
317
  TestVerifyFP,
@@ -350,7 +360,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
350
360
  from warp.tests.test_lvalue import TestLValue
351
361
  from warp.tests.test_mat_lite import TestMatLite
352
362
  from warp.tests.test_math import TestMath
353
- from warp.tests.test_mlp import TestMLP
354
363
  from warp.tests.test_module_hashing import TestModuleHashing
355
364
  from warp.tests.test_modules_lite import TestModuleLite
356
365
  from warp.tests.test_noise import TestNoise
@@ -397,7 +406,6 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
397
406
  TestMeshQueryAABBMethods,
398
407
  TestMeshQueryPoint,
399
408
  TestMeshQueryRay,
400
- TestMLP,
401
409
  TestModuleHashing,
402
410
  TestModuleLite,
403
411
  TestNoise,