warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import itertools
17
+
16
18
  import numpy as np
17
19
 
18
20
  import warp as wp
@@ -48,6 +50,63 @@ def color_lattice_grid(num_x, num_y):
48
50
  return color_groups
49
51
 
50
52
 
53
+ def create_lattice_grid(N):
54
+ size = 10
55
+ position = (0, 0)
56
+
57
+ X = np.linspace(-0.5 * size + position[0], 0.5 * size + position[0], N)
58
+ Y = np.linspace(-0.5 * size + position[1], 0.5 * size + position[1], N)
59
+
60
+ X, Y = np.meshgrid(X, Y)
61
+
62
+ Z = []
63
+ for _i in range(N):
64
+ Z.append(np.linspace(0, size, N))
65
+
66
+ Z = np.array(Z)
67
+
68
+ vs = []
69
+ for i, j in itertools.product(range(N), range(N)):
70
+ vs.append(wp.vec3((X[i, j], Y[i, j], Z[i, j])))
71
+
72
+ fs = []
73
+ for i, j in itertools.product(range(0, N - 1), range(0, N - 1)):
74
+ vId = j + i * N
75
+
76
+ if (j + i) % 2:
77
+ fs.extend(
78
+ [
79
+ vId,
80
+ vId + N + 1,
81
+ vId + 1,
82
+ ]
83
+ )
84
+ fs.extend(
85
+ [
86
+ vId,
87
+ vId + N,
88
+ vId + N + 1,
89
+ ]
90
+ )
91
+ else:
92
+ fs.extend(
93
+ [
94
+ vId,
95
+ vId + N,
96
+ vId + 1,
97
+ ]
98
+ )
99
+ fs.extend(
100
+ [
101
+ vId + N,
102
+ vId + N + 1,
103
+ vId + 1,
104
+ ]
105
+ )
106
+
107
+ return vs, fs
108
+
109
+
51
110
  def test_coloring_corner_case(test, device):
52
111
  builder_1 = wp.sim.ModelBuilder()
53
112
  builder_1.color()
@@ -91,7 +150,7 @@ def test_coloring_trimesh(test, device):
91
150
  edge_indices_cpu = wp.array(model.edge_indices.numpy()[:, 2:], dtype=int, device="cpu")
92
151
 
93
152
  # coloring without bending
94
- num_colors_greedy = wp.context.runtime.core.graph_coloring(
153
+ num_colors_greedy = wp.context.runtime.core.wp_graph_coloring(
95
154
  model.particle_count,
96
155
  edge_indices_cpu.__ctype__(),
97
156
  ColoringAlgorithm.GREEDY.value,
@@ -104,7 +163,7 @@ def test_coloring_trimesh(test, device):
104
163
  device="cpu",
105
164
  )
106
165
 
107
- num_colors_mcs = wp.context.runtime.core.graph_coloring(
166
+ num_colors_mcs = wp.context.runtime.core.wp_graph_coloring(
108
167
  model.particle_count,
109
168
  edge_indices_cpu.__ctype__(),
110
169
  ColoringAlgorithm.MCS.value,
@@ -119,13 +178,13 @@ def test_coloring_trimesh(test, device):
119
178
 
120
179
  # coloring with bending
121
180
  edge_indices_cpu_with_bending = construct_trimesh_graph_edges(model.edge_indices, True)
122
- num_colors_greedy = wp.context.runtime.core.graph_coloring(
181
+ num_colors_greedy = wp.context.runtime.core.wp_graph_coloring(
123
182
  model.particle_count,
124
183
  edge_indices_cpu_with_bending.__ctype__(),
125
184
  ColoringAlgorithm.GREEDY.value,
126
185
  particle_colors.__ctype__(),
127
186
  )
128
- wp.context.runtime.core.balance_coloring(
187
+ wp.context.runtime.core.wp_balance_coloring(
129
188
  model.particle_count,
130
189
  edge_indices_cpu_with_bending.__ctype__(),
131
190
  num_colors_greedy,
@@ -139,13 +198,13 @@ def test_coloring_trimesh(test, device):
139
198
  device="cpu",
140
199
  )
141
200
 
142
- num_colors_mcs = wp.context.runtime.core.graph_coloring(
201
+ num_colors_mcs = wp.context.runtime.core.wp_graph_coloring(
143
202
  model.particle_count,
144
203
  edge_indices_cpu_with_bending.__ctype__(),
145
204
  ColoringAlgorithm.MCS.value,
146
205
  particle_colors.__ctype__(),
147
206
  )
148
- max_min_ratio = wp.context.runtime.core.balance_coloring(
207
+ max_min_ratio = wp.context.runtime.core.wp_balance_coloring(
149
208
  model.particle_count,
150
209
  edge_indices_cpu_with_bending.__ctype__(),
151
210
  num_colors_mcs,
@@ -164,6 +223,22 @@ def test_coloring_trimesh(test, device):
164
223
  color_sizes = np.array([c.shape[0] for c in color_categories_balanced], dtype=np.float32)
165
224
  test.assertTrue(np.max(color_sizes) / np.min(color_sizes) <= max_min_ratio)
166
225
 
226
+ # test if the color balance can quit from equilibrium
227
+ builder = wp.sim.ModelBuilder()
228
+
229
+ vs, fs = create_lattice_grid(100)
230
+ builder.add_cloth_mesh(
231
+ pos=wp.vec3(0.0, 0.0, 0.0),
232
+ rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), 0.0),
233
+ scale=1.0,
234
+ vertices=vs,
235
+ indices=fs,
236
+ vel=wp.vec3(0.0, 0.0, 0.0),
237
+ density=0.02,
238
+ )
239
+
240
+ builder.color(include_bending=True)
241
+
167
242
 
168
243
  @unittest.skipUnless(USD_AVAILABLE, "Requires usd-core")
169
244
  def test_combine_coloring(test, device):
@@ -263,7 +338,7 @@ class TestColoring(unittest.TestCase):
263
338
  pass
264
339
 
265
340
 
266
- add_function_test(TestColoring, "test_coloring_trimesh", test_coloring_trimesh, devices=devices)
341
+ add_function_test(TestColoring, "test_coloring_trimesh", test_coloring_trimesh, devices=devices, check_output=False)
267
342
  add_function_test(TestColoring, "test_combine_coloring", test_combine_coloring, devices=devices)
268
343
  add_function_test(TestColoring, "test_coloring_corner_case", test_coloring_corner_case, devices=devices)
269
344
 
warp/tests/test_array.py CHANGED
@@ -2902,10 +2902,8 @@ def test_direct_from_numpy(test, device):
2902
2902
 
2903
2903
 
2904
2904
  @wp.kernel
2905
- def kernel_array_from_ptr(
2906
- ptr: wp.uint64,
2907
- ):
2908
- arr = wp.array(ptr=ptr, shape=(2, 3), dtype=wp.float32)
2905
+ def kernel_array_from_ptr(arr_orig: wp.array2d(dtype=wp.float32)):
2906
+ arr = wp.array(ptr=arr_orig.ptr, shape=(2, 3), dtype=wp.float32)
2909
2907
  arr[0, 0] = 1.0
2910
2908
  arr[0, 1] = 2.0
2911
2909
  arr[0, 2] = 3.0
@@ -2913,7 +2911,56 @@ def kernel_array_from_ptr(
2913
2911
 
2914
2912
  def test_kernel_array_from_ptr(test, device):
2915
2913
  arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2916
- wp.launch(kernel_array_from_ptr, dim=(1,), inputs=(arr.ptr,), device=device)
2914
+ wp.launch(kernel_array_from_ptr, dim=(1,), inputs=(arr,), device=device)
2915
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2916
+
2917
+
2918
+ @wp.struct
2919
+ class MyStruct:
2920
+ a: wp.float32
2921
+ b: wp.float32
2922
+ c: wp.float32
2923
+
2924
+
2925
+ @wp.kernel
2926
+ def kernel_array_from_ptr_struct(arr_orig: wp.array(dtype=MyStruct)):
2927
+ arr = wp.array(ptr=arr_orig.ptr, shape=(2,), dtype=MyStruct)
2928
+ arr[0].a = 1.0
2929
+ arr[0].b = 2.0
2930
+ arr[0].c = 3.0
2931
+ arr[1].a = 4.0
2932
+ arr[1].b = 5.0
2933
+ arr[1].c = 6.0
2934
+
2935
+
2936
+ def test_kernel_array_from_ptr_struct(test, device):
2937
+ arr = wp.zeros(shape=(2,), dtype=MyStruct, device=device)
2938
+ wp.launch(kernel_array_from_ptr_struct, dim=(1,), inputs=(arr,), device=device)
2939
+ arr_np = arr.numpy()
2940
+ expected = np.zeros_like(arr_np)
2941
+ expected[0] = (1.0, 2.0, 3.0)
2942
+ expected[1] = (4.0, 5.0, 6.0)
2943
+ assert_np_equal(arr_np, expected)
2944
+
2945
+
2946
+ @wp.kernel
2947
+ def kernel_array_from_ptr_variable_shape(
2948
+ ptr: wp.uint64,
2949
+ shape_x: int,
2950
+ shape_y: int,
2951
+ ):
2952
+ arr = wp.array(ptr=ptr, shape=(shape_x, shape_y), dtype=wp.float32)
2953
+ arr[0, 0] = 1.0
2954
+ arr[0, 1] = 2.0
2955
+ if shape_y > 2:
2956
+ arr[0, 2] = 3.0
2957
+
2958
+
2959
+ def test_kernel_array_from_ptr_variable_shape(test, device):
2960
+ arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2961
+ wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 2), device=device)
2962
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 0.0), (0.0, 0.0, 0.0))))
2963
+ wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 3), device=device)
2917
2964
  assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2918
2965
 
2919
2966
 
@@ -3185,6 +3232,10 @@ add_function_test(TestArray, "test_array_inplace_diff_ops", test_array_inplace_d
3185
3232
  add_function_test(TestArray, "test_array_inplace_non_diff_ops", test_array_inplace_non_diff_ops, devices=devices)
3186
3233
  add_function_test(TestArray, "test_direct_from_numpy", test_direct_from_numpy, devices=["cpu"])
3187
3234
  add_function_test(TestArray, "test_kernel_array_from_ptr", test_kernel_array_from_ptr, devices=devices)
3235
+ add_function_test(TestArray, "test_kernel_array_from_ptr_struct", test_kernel_array_from_ptr_struct, devices=devices)
3236
+ add_function_test(
3237
+ TestArray, "test_kernel_array_from_ptr_variable_shape", test_kernel_array_from_ptr_variable_shape, devices=devices
3238
+ )
3188
3239
 
3189
3240
  add_function_test(TestArray, "test_array_from_int32_domain", test_array_from_int32_domain, devices=devices)
3190
3241
  add_function_test(TestArray, "test_array_from_int64_domain", test_array_from_int64_domain, devices=devices)
warp/tests/test_assert.py CHANGED
@@ -245,6 +245,59 @@ class TestAssertDebug(unittest.TestCase):
245
245
  self.assertRegex(output, r"Assertion failed: .*assert value == 1.*Array element must be 1")
246
246
 
247
247
 
248
+ class TestAssertModeSwitch(unittest.TestCase):
249
+ """Test that switching from release mode to debug mode rebuilds the module with assertions enabled."""
250
+
251
+ @classmethod
252
+ def setUpClass(cls):
253
+ cls._saved_mode = wp.config.mode
254
+ cls._saved_mode_module = wp.get_module_options()["mode"]
255
+ cls._saved_cache_kernels = wp.config.cache_kernels
256
+
257
+ # Don't set any mode initially - use whatever the default is
258
+ wp.config.cache_kernels = False
259
+
260
+ @classmethod
261
+ def tearDownClass(cls):
262
+ wp.config.mode = cls._saved_mode
263
+ wp.set_module_options({"mode": cls._saved_mode_module})
264
+ wp.config.cache_kernels = cls._saved_cache_kernels
265
+
266
+ def test_switch_to_debug_mode(self):
267
+ """Test that switching from release mode to debug mode rebuilds the module with assertions enabled."""
268
+ with wp.ScopedDevice("cpu"):
269
+ # Create an array that will trigger an assertion
270
+ input_array = wp.zeros(1, dtype=int)
271
+
272
+ # In default mode, this should not assert
273
+ capture = StdErrCapture()
274
+ capture.begin()
275
+ wp.launch(expect_ones, input_array.shape, inputs=[input_array])
276
+ output = capture.end()
277
+
278
+ # Should not have any assertion output in release mode
279
+ self.assertEqual(output, "", f"Kernel should not print anything to stderr in release mode, got {output}")
280
+
281
+ # Now switch to debug mode and have it compile a new kernel
282
+ wp.config.mode = "debug"
283
+
284
+ @wp.kernel
285
+ def expect_ones_debug(a: wp.array(dtype=int)):
286
+ i = wp.tid()
287
+ assert a[i] == 1
288
+
289
+ # In debug mode, this should assert
290
+ capture = StdErrCapture()
291
+ capture.begin()
292
+ wp.launch(expect_ones_debug, input_array.shape, inputs=[input_array])
293
+ output = capture.end()
294
+
295
+ # Should have assertion output in debug mode
296
+ # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
297
+ if output != "" or sys.platform != "win32":
298
+ self.assertRegex(output, r"Assertion failed: .*assert a\[i\] == 1")
299
+
300
+
248
301
  if __name__ == "__main__":
249
302
  wp.clear_kernel_cache()
250
303
  unittest.main(verbosity=2)
@@ -19,54 +19,63 @@ import numpy as np
19
19
  import warp as wp
20
20
  from warp.tests.unittest_utils import *
21
21
 
22
+ kernel_cache = {}
23
+
24
+
25
+ def getkernel(func, suffix=""):
26
+ key = func.__name__ + "_" + suffix
27
+ if key not in kernel_cache:
28
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
29
+ return kernel_cache[key]
30
+
31
+
32
+ def test_atomic_cas(test, device, dtype, register_kernels=False):
33
+ warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
34
+ n = 100
35
+ counter = wp.array([0], dtype=warp_type, device=device)
36
+ lock = wp.array([0], dtype=warp_type, device=device)
22
37
 
23
- def create_spinlock_test(dtype):
24
38
  @wp.func
25
- def spinlock_acquire(lock: wp.array(dtype=dtype)):
39
+ def spinlock_acquire_1d(lock: wp.array(dtype=warp_type)):
26
40
  # Try to acquire the lock by setting it to 1 if it's 0
27
- while wp.atomic_cas(lock, 0, dtype(0), dtype(1)) == 1:
41
+ while wp.atomic_cas(lock, 0, warp_type(0), warp_type(1)) == 1:
28
42
  pass
29
43
 
30
44
  @wp.func
31
- def spinlock_release(lock: wp.array(dtype=dtype)):
45
+ def spinlock_release_1d(lock: wp.array(dtype=warp_type)):
32
46
  # Release the lock by setting it back to 0
33
- wp.atomic_exch(lock, 0, dtype(0))
47
+ wp.atomic_exch(lock, 0, warp_type(0))
34
48
 
35
49
  @wp.func
36
- def volatile_read(ptr: wp.array(dtype=dtype), index: int):
37
- value = wp.atomic_exch(ptr, index, dtype(0))
50
+ def volatile_read_1d(ptr: wp.array(dtype=warp_type), index: int):
51
+ value = wp.atomic_exch(ptr, index, warp_type(0))
38
52
  wp.atomic_exch(ptr, index, value)
39
53
  return value
40
54
 
41
- @wp.kernel
42
- def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype)):
55
+ def test_spinlock_counter_1d(counter: wp.array(dtype=warp_type), lock: wp.array(dtype=warp_type)):
43
56
  # Try to acquire the lock
44
- spinlock_acquire(lock)
57
+ spinlock_acquire_1d(lock)
45
58
 
46
59
  # Critical section - increment counter
47
60
  # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
48
61
 
49
62
  # Work around since warp arrays cannot be marked as volatile
50
- value = volatile_read(counter, 0)
51
- counter[0] = value + dtype(1)
63
+ value = volatile_read_1d(counter, 0)
64
+ counter[0] = value + warp_type(1)
52
65
 
53
66
  # Release the lock
54
- spinlock_release(lock)
55
-
56
- return test_spinlock_counter
67
+ spinlock_release_1d(lock)
57
68
 
69
+ kernel = getkernel(test_spinlock_counter_1d, suffix=dtype.__name__)
58
70
 
59
- def test_atomic_cas(test, device, warp_type, numpy_type):
60
- n = 100
61
- counter = wp.array([0], dtype=warp_type, device=device)
62
- lock = wp.array([0], dtype=warp_type, device=device)
71
+ if register_kernels:
72
+ return
63
73
 
64
- test_spinlock_counter = create_spinlock_test(warp_type)
65
- wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
74
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
66
75
 
67
76
  # Verify counter reached n
68
77
  counter_np = counter.numpy()
69
- expected = np.array([n], dtype=numpy_type)
78
+ expected = np.array([n], dtype=dtype)
70
79
 
71
80
  if not np.array_equal(counter_np, expected):
72
81
  print(f"Counter mismatch: expected {expected}, got {counter_np}")
@@ -74,53 +83,53 @@ def test_atomic_cas(test, device, warp_type, numpy_type):
74
83
  assert_np_equal(counter_np, expected)
75
84
 
76
85
 
77
- def create_spinlock_test_2d(dtype):
86
+ def test_atomic_cas_2d(test, device, dtype, register_kernels=False):
87
+ warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
88
+ n = 100
89
+ counter = wp.array([0], dtype=warp_type, device=device)
90
+ lock = wp.zeros(shape=(1, 1), dtype=warp_type, device=device)
91
+
78
92
  @wp.func
79
- def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=2)):
93
+ def spinlock_acquire_2d(lock: wp.array2d(dtype=warp_type)):
80
94
  # Try to acquire the lock by setting it to 1 if it's 0
81
- while wp.atomic_cas(lock, 0, 0, dtype(0), dtype(1)) == 1:
95
+ while wp.atomic_cas(lock, 0, 0, warp_type(0), warp_type(1)) == 1:
82
96
  pass
83
97
 
84
98
  @wp.func
85
- def spinlock_release(lock: wp.array(dtype=dtype, ndim=2)):
99
+ def spinlock_release_2d(lock: wp.array2d(dtype=warp_type)):
86
100
  # Release the lock by setting it back to 0
87
- wp.atomic_exch(lock, 0, 0, dtype(0))
101
+ wp.atomic_exch(lock, 0, 0, warp_type(0))
88
102
 
89
103
  @wp.func
90
- def volatile_read(ptr: wp.array(dtype=dtype), index: int):
91
- value = wp.atomic_exch(ptr, index, dtype(0))
104
+ def volatile_read_2d(ptr: wp.array(dtype=warp_type), index: int):
105
+ value = wp.atomic_exch(ptr, index, warp_type(0))
92
106
  wp.atomic_exch(ptr, index, value)
93
107
  return value
94
108
 
95
- @wp.kernel
96
- def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=2)):
109
+ def test_spinlock_counter_2d(counter: wp.array(dtype=warp_type), lock: wp.array2d(dtype=warp_type)):
97
110
  # Try to acquire the lock
98
- spinlock_acquire(lock)
111
+ spinlock_acquire_2d(lock)
99
112
 
100
113
  # Critical section - increment counter
101
114
  # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
102
115
 
103
116
  # Work around since warp arrays cannot be marked as volatile
104
- value = volatile_read(counter, 0)
105
- counter[0] = value + dtype(1)
117
+ value = volatile_read_2d(counter, 0)
118
+ counter[0] = value + warp_type(1)
106
119
 
107
120
  # Release the lock
108
- spinlock_release(lock)
109
-
110
- return test_spinlock_counter
121
+ spinlock_release_2d(lock)
111
122
 
123
+ kernel = getkernel(test_spinlock_counter_2d, suffix=dtype.__name__)
112
124
 
113
- def test_atomic_cas_2d(test, device, warp_type, numpy_type):
114
- n = 100
115
- counter = wp.array([0], dtype=warp_type, device=device)
116
- lock = wp.zeros(shape=(1, 1), dtype=warp_type, device=device)
125
+ if register_kernels:
126
+ return
117
127
 
118
- test_spinlock_counter = create_spinlock_test_2d(warp_type)
119
- wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
128
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
120
129
 
121
130
  # Verify counter reached n
122
131
  counter_np = counter.numpy()
123
- expected = np.array([n], dtype=numpy_type)
132
+ expected = np.array([n], dtype=dtype)
124
133
 
125
134
  if not np.array_equal(counter_np, expected):
126
135
  print(f"Counter mismatch: expected {expected}, got {counter_np}")
@@ -128,53 +137,53 @@ def test_atomic_cas_2d(test, device, warp_type, numpy_type):
128
137
  assert_np_equal(counter_np, expected)
129
138
 
130
139
 
131
- def create_spinlock_test_3d(dtype):
140
+ def test_atomic_cas_3d(test, device, dtype, register_kernels=False):
141
+ warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
142
+ n = 100
143
+ counter = wp.array([0], dtype=warp_type, device=device)
144
+ lock = wp.zeros(shape=(1, 1, 1), dtype=warp_type, device=device)
145
+
132
146
  @wp.func
133
- def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=3)):
147
+ def spinlock_acquire_3d(lock: wp.array3d(dtype=warp_type)):
134
148
  # Try to acquire the lock by setting it to 1 if it's 0
135
- while wp.atomic_cas(lock, 0, 0, 0, dtype(0), dtype(1)) == 1:
149
+ while wp.atomic_cas(lock, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
136
150
  pass
137
151
 
138
152
  @wp.func
139
- def spinlock_release(lock: wp.array(dtype=dtype, ndim=3)):
153
+ def spinlock_release_3d(lock: wp.array3d(dtype=warp_type)):
140
154
  # Release the lock by setting it back to 0
141
- wp.atomic_exch(lock, 0, 0, 0, dtype(0))
155
+ wp.atomic_exch(lock, 0, 0, 0, warp_type(0))
142
156
 
143
157
  @wp.func
144
- def volatile_read(ptr: wp.array(dtype=dtype), index: int):
145
- value = wp.atomic_exch(ptr, index, dtype(0))
158
+ def volatile_read_3d(ptr: wp.array(dtype=warp_type), index: int):
159
+ value = wp.atomic_exch(ptr, index, warp_type(0))
146
160
  wp.atomic_exch(ptr, index, value)
147
161
  return value
148
162
 
149
- @wp.kernel
150
- def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=3)):
163
+ def test_spinlock_counter_3d(counter: wp.array(dtype=warp_type), lock: wp.array3d(dtype=warp_type)):
151
164
  # Try to acquire the lock
152
- spinlock_acquire(lock)
165
+ spinlock_acquire_3d(lock)
153
166
 
154
167
  # Critical section - increment counter
155
168
  # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
156
169
 
157
170
  # Work around since warp arrays cannot be marked as volatile
158
- value = volatile_read(counter, 0)
159
- counter[0] = value + dtype(1)
171
+ value = volatile_read_3d(counter, 0)
172
+ counter[0] = value + warp_type(1)
160
173
 
161
174
  # Release the lock
162
- spinlock_release(lock)
175
+ spinlock_release_3d(lock)
163
176
 
164
- return test_spinlock_counter
177
+ kernel = getkernel(test_spinlock_counter_3d, suffix=dtype.__name__)
165
178
 
179
+ if register_kernels:
180
+ return
166
181
 
167
- def test_atomic_cas_3d(test, device, warp_type, numpy_type):
168
- n = 100
169
- counter = wp.array([0], dtype=warp_type, device=device)
170
- lock = wp.zeros(shape=(1, 1, 1), dtype=warp_type, device=device)
171
-
172
- test_spinlock_counter = create_spinlock_test_3d(warp_type)
173
- wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
182
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
174
183
 
175
184
  # Verify counter reached n
176
185
  counter_np = counter.numpy()
177
- expected = np.array([n], dtype=numpy_type)
186
+ expected = np.array([n], dtype=dtype)
178
187
 
179
188
  if not np.array_equal(counter_np, expected):
180
189
  print(f"Counter mismatch: expected {expected}, got {counter_np}")
@@ -218,17 +227,53 @@ def create_spinlock_test_4d(dtype):
218
227
  return test_spinlock_counter
219
228
 
220
229
 
221
- def test_atomic_cas_4d(test, device, warp_type, numpy_type):
230
+ def test_atomic_cas_4d(test, device, dtype, register_kernels=False):
231
+ warp_type = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
222
232
  n = 100
223
233
  counter = wp.array([0], dtype=warp_type, device=device)
224
234
  lock = wp.zeros(shape=(1, 1, 1, 1), dtype=warp_type, device=device)
225
235
 
226
- test_spinlock_counter = create_spinlock_test_4d(warp_type)
227
- wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
236
+ @wp.func
237
+ def spinlock_acquire_4d(lock: wp.array4d(dtype=warp_type)):
238
+ # Try to acquire the lock by setting it to 1 if it's 0
239
+ while wp.atomic_cas(lock, 0, 0, 0, 0, warp_type(0), warp_type(1)) == 1:
240
+ pass
241
+
242
+ @wp.func
243
+ def spinlock_release_4d(lock: wp.array4d(dtype=warp_type)):
244
+ # Release the lock by setting it back to 0
245
+ wp.atomic_exch(lock, 0, 0, 0, 0, warp_type(0))
246
+
247
+ @wp.func
248
+ def volatile_read_4d(ptr: wp.array(dtype=warp_type), index: int):
249
+ value = wp.atomic_exch(ptr, index, warp_type(0))
250
+ wp.atomic_exch(ptr, index, value)
251
+ return value
252
+
253
+ def test_spinlock_counter_4d(counter: wp.array(dtype=warp_type), lock: wp.array4d(dtype=warp_type)):
254
+ # Try to acquire the lock
255
+ spinlock_acquire_4d(lock)
256
+
257
+ # Critical section - increment counter
258
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
259
+
260
+ # Work around since warp arrays cannot be marked as volatile
261
+ value = volatile_read_4d(counter, 0)
262
+ counter[0] = value + warp_type(1)
263
+
264
+ # Release the lock
265
+ spinlock_release_4d(lock)
266
+
267
+ kernel = getkernel(test_spinlock_counter_4d, suffix=dtype.__name__)
268
+
269
+ if register_kernels:
270
+ return
271
+
272
+ wp.launch(kernel, dim=n, inputs=[counter, lock], device=device)
228
273
 
229
274
  # Verify counter reached n
230
275
  counter_np = counter.numpy()
231
- expected = np.array([n], dtype=numpy_type)
276
+ expected = np.array([n], dtype=dtype)
232
277
 
233
278
  if not np.array_equal(counter_np, expected):
234
279
  print(f"Counter mismatch: expected {expected}, got {counter_np}")
@@ -244,54 +289,22 @@ class TestAtomicCAS(unittest.TestCase):
244
289
 
245
290
 
246
291
  # Test all supported types
247
- test_types = [
248
- (wp.int32, np.int32),
249
- (wp.uint32, np.uint32),
250
- (wp.int64, np.int64),
251
- (wp.uint64, np.uint64),
252
- (wp.float32, np.float32),
253
- (wp.float64, np.float64),
254
- ]
255
-
256
- for warp_type, numpy_type in test_types:
257
- type_name = warp_type.__name__
258
- add_function_test(
259
- TestAtomicCAS,
260
- f"test_cas_{type_name}",
261
- test_atomic_cas,
262
- devices=devices,
263
- warp_type=warp_type,
264
- numpy_type=numpy_type,
265
- )
292
+ np_test_types = (np.int32, np.uint32, np.int64, np.uint64, np.float32, np.float64)
266
293
 
294
+ for dtype in np_test_types:
295
+ type_name = dtype.__name__
296
+ add_function_test_register_kernel(
297
+ TestAtomicCAS, f"test_cas_{type_name}", test_atomic_cas, devices=devices, dtype=dtype
298
+ )
267
299
  # Add 2D test for each type
268
- add_function_test(
269
- TestAtomicCAS,
270
- f"test_cas_2d_{type_name}",
271
- test_atomic_cas_2d,
272
- devices=devices,
273
- warp_type=warp_type,
274
- numpy_type=numpy_type,
300
+ add_function_test_register_kernel(
301
+ TestAtomicCAS, f"test_cas_2d_{type_name}", test_atomic_cas_2d, devices=devices, dtype=dtype
275
302
  )
276
-
277
- # Add 3D test for each type
278
- add_function_test(
279
- TestAtomicCAS,
280
- f"test_cas_3d_{type_name}",
281
- test_atomic_cas_3d,
282
- devices=devices,
283
- warp_type=warp_type,
284
- numpy_type=numpy_type,
303
+ add_function_test_register_kernel(
304
+ TestAtomicCAS, f"test_cas_3d_{type_name}", test_atomic_cas_3d, devices=devices, dtype=dtype
285
305
  )
286
-
287
- # Add 4D test for each type
288
- add_function_test(
289
- TestAtomicCAS,
290
- f"test_cas_4d_{type_name}",
291
- test_atomic_cas_4d,
292
- devices=devices,
293
- warp_type=warp_type,
294
- numpy_type=numpy_type,
306
+ add_function_test_register_kernel(
307
+ TestAtomicCAS, f"test_cas_4d_{type_name}", test_atomic_cas_4d, devices=devices, dtype=dtype
295
308
  )
296
309
 
297
310
  if __name__ == "__main__":
@@ -756,6 +756,7 @@ def test_multiple_return_values(test, device):
756
756
  test_multiple_return_values_quat_to_axis_angle_kernel,
757
757
  dim=1,
758
758
  inputs=(q, expected_axis, expected_angle),
759
+ device=device,
759
760
  )
760
761
 
761
762
  # fmt: off
@@ -791,9 +792,9 @@ def test_multiple_return_values(test, device):
791
792
 
792
793
  test.assertAlmostEqual(V[0][0], expected_V[0][0], places=5)
793
794
  test.assertAlmostEqual(V[0][1], expected_V[0][1], places=5)
794
- test.assertAlmostEqual(V[0][2], expected_V[0][2], places=5)
795
+ test.assertAlmostEqual(V[0][2], expected_V[0][2], places=4) # precision issue on ARM64 (GH-905)
795
796
  test.assertAlmostEqual(V[1][0], expected_V[1][0], places=5)
796
- test.assertAlmostEqual(V[1][1], expected_V[1][1], places=5)
797
+ test.assertAlmostEqual(V[1][1], expected_V[1][1], places=4) # precision issue on ARM64 (GH-905)
797
798
  test.assertAlmostEqual(V[1][2], expected_V[1][2], places=5)
798
799
  test.assertAlmostEqual(V[2][0], expected_V[2][0], places=5)
799
800
  test.assertAlmostEqual(V[2][1], expected_V[2][1], places=5)