warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ TILE_M = wp.constant(8)
25
+ TILE_N = wp.constant(4)
26
+ TILE_K = wp.constant(8)
27
+
28
+ # num threads per-tile
29
+ TILE_DIM = 64
30
+
31
+
32
+ @wp.kernel
33
+ def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
34
+ # output tile index
35
+ i = wp.tid()
36
+
37
+ a = wp.tile_load(A[i], shape=(TILE_M, TILE_K))
38
+ b = wp.tile_load(B[i], shape=(TILE_K, TILE_N))
39
+
40
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
41
+
42
+ wp.tile_matmul(a, b, sum)
43
+
44
+ wp.tile_store(C[i], sum)
45
+
46
+
47
+ def test_tile_grouped_gemm(test, device):
48
+ batch_count = 56
49
+
50
+ M = TILE_M
51
+ N = TILE_N
52
+ K = TILE_K
53
+
54
+ rng = np.random.default_rng(42)
55
+ A = rng.random((batch_count, M, K), dtype=np.float32)
56
+ B = rng.random((batch_count, K, N), dtype=np.float32)
57
+ C = A @ B
58
+
59
+ A_wp = wp.array(A, requires_grad=True, device=device)
60
+ B_wp = wp.array(B, requires_grad=True, device=device)
61
+ C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
62
+
63
+ with wp.Tape() as tape:
64
+ wp.launch_tiled(
65
+ tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
66
+ )
67
+
68
+ # TODO: 32 mismatched elements
69
+ assert_np_equal(C_wp.numpy(), C, 1e-6)
70
+
71
+
72
+ @wp.kernel
73
+ def tile_gemm(A: wp.array2d(dtype=Any), B: wp.array2d(dtype=Any), C: wp.array2d(dtype=Any)):
74
+ # output tile index
75
+ i, j = wp.tid()
76
+
77
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=A.dtype)
78
+
79
+ M = A.shape[0]
80
+ N = B.shape[1]
81
+ K = A.shape[1]
82
+
83
+ count = int(K / TILE_K)
84
+
85
+ for k in range(0, count):
86
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
87
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
88
+
89
+ # sum += a*b
90
+ wp.tile_matmul(a, b, sum)
91
+
92
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
93
+
94
+
95
+ wp.overload(
96
+ tile_gemm, {"A": wp.array2d(dtype=wp.float16), "B": wp.array2d(dtype=wp.float16), "C": wp.array2d(dtype=wp.float16)}
97
+ )
98
+ wp.overload(
99
+ tile_gemm, {"A": wp.array2d(dtype=wp.float32), "B": wp.array2d(dtype=wp.float32), "C": wp.array2d(dtype=wp.float32)}
100
+ )
101
+ wp.overload(
102
+ tile_gemm, {"A": wp.array2d(dtype=wp.float64), "B": wp.array2d(dtype=wp.float64), "C": wp.array2d(dtype=wp.float64)}
103
+ )
104
+
105
+
106
+ def test_tile_gemm(dtype):
107
+ def test(test, device):
108
+ M = TILE_M * 7
109
+ K = TILE_K * 6
110
+ N = TILE_N * 5
111
+
112
+ rng = np.random.default_rng(42)
113
+ A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
114
+ B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
115
+ C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
116
+
117
+ A_wp = wp.array(A, requires_grad=True, device=device)
118
+ B_wp = wp.array(B, requires_grad=True, device=device)
119
+ C_wp = wp.array(C, requires_grad=True, device=device)
120
+
121
+ with wp.Tape() as tape:
122
+ wp.launch_tiled(
123
+ tile_gemm,
124
+ dim=(int(M / TILE_M), int(N / TILE_N)),
125
+ inputs=[A_wp, B_wp, C_wp],
126
+ block_dim=TILE_DIM,
127
+ device=device,
128
+ )
129
+
130
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
131
+
132
+ adj_C = np.ones_like(C)
133
+
134
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
135
+
136
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
137
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
138
+
139
+ return test
140
+
141
+
142
+ @wp.kernel
143
+ def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
144
+ x = wp.tile_load(input, shape=(TILE_M, TILE_N))
145
+ y = wp.tile_transpose(x)
146
+
147
+ z = wp.tile_zeros(dtype=float, shape=(TILE_N, TILE_N))
148
+ wp.tile_matmul(y, x, z)
149
+
150
+ wp.tile_store(output, z)
151
+
152
+
153
+ def test_tile_transpose_matmul(test, device):
154
+ rng = np.random.default_rng(42)
155
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
156
+ output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
157
+
158
+ wp.launch_tiled(
159
+ test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=TILE_DIM, device=device
160
+ )
161
+
162
+ assert_np_equal(output.numpy(), input.numpy().T @ input.numpy(), 1e-6)
163
+
164
+
165
+ class TestTileMatmul(unittest.TestCase):
166
+ pass
167
+
168
+
169
+ devices = get_test_devices()
170
+
171
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
172
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
173
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
174
+ add_function_test(TestTileMatmul, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
175
+ add_function_test(TestTileMatmul, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
176
+
177
+ if __name__ == "__main__":
178
+ wp.clear_kernel_cache()
179
+ unittest.main(verbosity=2, failfast=True)
@@ -43,7 +43,7 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
43
43
  def test_multi_layer_nn(test, device):
44
44
  import torch as tc
45
45
 
46
- if device.is_cuda and not wp.context.runtime.core.is_mathdx_enabled():
46
+ if device.is_cuda and not wp.context.runtime.core.wp_is_mathdx_enabled():
47
47
  test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
48
48
 
49
49
  NUM_FREQ = wp.constant(8)
@@ -73,6 +73,46 @@ def test_tile_reduce_sum(test, device):
73
73
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
74
74
 
75
75
 
76
+ @wp.kernel
77
+ def tile_sum_to_shared_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
78
+ i, lane = wp.tid()
79
+
80
+ a = wp.tile_load(input[i], shape=TILE_DIM)
81
+ s = wp.tile_sum(a)
82
+ v = s[0] # force shared storage for s
83
+ wp.tile_store(output, s * 0.5, offset=i)
84
+
85
+
86
+ def test_tile_sum_to_shared(test, device):
87
+ batch_count = 1
88
+
89
+ rng = np.random.default_rng(42)
90
+ input = rng.random((batch_count, TILE_DIM), dtype=np.float32)
91
+
92
+ input_wp = wp.array(input, requires_grad=True, device=device, dtype=float)
93
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device, dtype=float)
94
+
95
+ with wp.Tape() as tape:
96
+ wp.launch_tiled(
97
+ tile_sum_to_shared_kernel,
98
+ dim=[batch_count],
99
+ inputs=[input_wp, output_wp],
100
+ block_dim=TILE_DIM,
101
+ device=device,
102
+ )
103
+
104
+ sum_wp = output_wp.numpy()
105
+ for i in range(batch_count):
106
+ sum_np = np.sum(input[i], axis=0) * 0.5
107
+ assert_np_equal(sum_wp[i], sum_np, tol=0.0001)
108
+
109
+ output_wp.grad.fill_(1.0)
110
+
111
+ tape.backward()
112
+
113
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
114
+
115
+
76
116
  @wp.kernel
77
117
  def tile_min_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
78
118
  # output tile index
@@ -84,6 +124,13 @@ def tile_min_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float
84
124
  wp.tile_store(output, m, offset=i)
85
125
 
86
126
 
127
+ @wp.kernel
128
+ def tile_min_kernel_edge_case(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
129
+ t = wp.tile_load(x, shape=(3, 3))
130
+ min = wp.tile_min(t)
131
+ wp.tile_store(y, min)
132
+
133
+
87
134
  def test_tile_reduce_min(test, device):
88
135
  batch_count = 56
89
136
 
@@ -105,6 +152,14 @@ def test_tile_reduce_min(test, device):
105
152
  min_np = np.min(input[i])
106
153
  test.assertAlmostEqual(min_wp[i], min_np, places=4)
107
154
 
155
+ # test edge case: tile is multiple warps in size but at least one is empty
156
+ x = wp.array(np.array([[2.0, 2.0, 3.0], [4.0, 1.0, 6.0], [7.0, 3.0, 9.0]]), dtype=float, device=device)
157
+ y = wp.zeros(1, dtype=float, device=device)
158
+
159
+ wp.launch_tiled(tile_min_kernel_edge_case, dim=1, inputs=[x, y], block_dim=64, device=device)
160
+
161
+ assert_np_equal(y.numpy(), np.array([1.0]))
162
+
108
163
 
109
164
  @wp.kernel
110
165
  def tile_argmin_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=int)):
@@ -117,6 +172,13 @@ def tile_argmin_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=in
117
172
  wp.tile_store(output, m, offset=i)
118
173
 
119
174
 
175
+ @wp.kernel
176
+ def tile_argmin_kernel_edge_case(x: wp.array2d(dtype=float), y: wp.array(dtype=int)):
177
+ t = wp.tile_load(x, shape=(3, 3))
178
+ min = wp.tile_argmin(t)
179
+ wp.tile_store(y, min)
180
+
181
+
120
182
  def test_tile_reduce_argmin(test, device):
121
183
  batch_count = 56
122
184
 
@@ -138,6 +200,14 @@ def test_tile_reduce_argmin(test, device):
138
200
  argmin_np = np.argmin(input[i])
139
201
  test.assertAlmostEqual(argmin_wp[i], argmin_np, places=4)
140
202
 
203
+ # test edge case: tile is multiple warps in size but at least one is empty
204
+ x = wp.array(np.array([[2.0, 2.0, 3.0], [4.0, 1.0, 6.0], [7.0, 3.0, 9.0]]), dtype=float, device=device)
205
+ y = wp.zeros(1, dtype=int, device=device)
206
+
207
+ wp.launch_tiled(tile_argmin_kernel_edge_case, dim=1, inputs=[x, y], block_dim=64, device=device)
208
+
209
+ assert_np_equal(y.numpy(), np.array([4]))
210
+
141
211
 
142
212
  @wp.kernel
143
213
  def tile_max_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
@@ -243,7 +313,7 @@ def test_tile_reduce_custom(test, device):
243
313
 
244
314
 
245
315
  def create_tile_scan_inclusive_kernel(tile_dim: int):
246
- @wp.kernel
316
+ @wp.kernel(module="unique")
247
317
  def tile_scan_inclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
248
318
  i = wp.tid()
249
319
  t = wp.tile_load(input[i], shape=tile_dim)
@@ -279,7 +349,7 @@ def test_tile_scan_inclusive(test, device):
279
349
 
280
350
 
281
351
  def create_tile_scan_exclusive_kernel(tile_dim: int):
282
- @wp.kernel
352
+ @wp.kernel(module="unique")
283
353
  def tile_scan_exclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
284
354
  i = wp.tid()
285
355
  t = wp.tile_load(input[i], shape=tile_dim)
@@ -398,7 +468,7 @@ def test_tile_reduce_grouped_sum(test, device):
398
468
 
399
469
  with wp.Tape() as tape:
400
470
  wp.launch_tiled(
401
- tile_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
471
+ tile_grouped_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
402
472
  )
403
473
 
404
474
  sum_wp = output_wp.numpy()
@@ -498,17 +568,17 @@ def test_untile_vector_kernel(input: wp.array(dtype=wp.vec3), output: wp.array(d
498
568
 
499
569
 
500
570
  def test_tile_untile_vector(test, device):
501
- input = wp.full(16, wp.vec3(1.0, 2.0, 3.0), requires_grad=True, device=device)
571
+ input = wp.full(TILE_DIM, wp.vec3(1.0, 2.0, 3.0), requires_grad=True, device=device)
502
572
  output = wp.zeros_like(input, device=device)
503
573
 
504
574
  with wp.Tape() as tape:
505
- wp.launch(test_untile_vector_kernel, dim=16, inputs=[input, output], block_dim=16, device=device)
575
+ wp.launch(test_untile_vector_kernel, dim=TILE_DIM, inputs=[input, output], block_dim=TILE_DIM, device=device)
506
576
 
507
577
  output.grad = wp.ones_like(output, device=device)
508
578
  tape.backward()
509
579
 
510
580
  assert_np_equal(output.numpy(), input.numpy())
511
- assert_np_equal(input.grad.numpy(), np.ones((16, 3)))
581
+ assert_np_equal(input.grad.numpy(), np.ones((TILE_DIM, 3)))
512
582
 
513
583
 
514
584
  @wp.kernel
@@ -562,7 +632,7 @@ def test_tile_arange(test, device):
562
632
  assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
563
633
 
564
634
 
565
- @wp.kernel
635
+ @wp.kernel(module="unique")
566
636
  def tile_strided_loop_kernel(arr: wp.array(dtype=float), max_val: wp.array(dtype=float)):
567
637
  tid, lane = wp.tid()
568
638
 
@@ -618,7 +688,7 @@ def test_tile_reduce_matrix_kernel(y: wp.array(dtype=wp.mat33)):
618
688
  t = wp.tile(m, preserve_type=True)
619
689
  sum = wp.tile_reduce(wp.add, t)
620
690
 
621
- wp.tile_store(y, sum)
691
+ wp.tile_atomic_add(y, sum)
622
692
 
623
693
 
624
694
  def test_tile_reduce_matrix(test, device):
@@ -629,8 +699,25 @@ def test_tile_reduce_matrix(test, device):
629
699
  assert_np_equal(y.numpy().squeeze(), np.eye(3, dtype=np.float32) * 2016.0)
630
700
 
631
701
 
702
+ @wp.kernel
703
+ def test_tile_reduce_vector_kernel(out: wp.array(dtype=wp.vec3)):
704
+ v = wp.vec3f(1.0)
705
+ v_tile = wp.tile(v, preserve_type=True)
706
+
707
+ sum = wp.tile_reduce(wp.add, v_tile)
708
+
709
+ wp.tile_atomic_add(out, sum)
710
+
711
+
712
+ def test_tile_reduce_vector(test, device):
713
+ out = wp.zeros(1, dtype=wp.vec3, device=device)
714
+
715
+ wp.launch(kernel=test_tile_reduce_vector_kernel, dim=8, inputs=[], outputs=[out], block_dim=TILE_DIM, device=device)
716
+
717
+ assert_np_equal(out.numpy(), np.array([[8.0, 8.0, 8.0]]))
718
+
719
+
632
720
  devices = get_test_devices()
633
- cuda_devices = get_cuda_test_devices()
634
721
 
635
722
 
636
723
  class TestTileReduce(unittest.TestCase):
@@ -638,13 +725,14 @@ class TestTileReduce(unittest.TestCase):
638
725
 
639
726
 
640
727
  add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum, devices=devices)
728
+ add_function_test(TestTileReduce, "test_tile_sum_to_shared", test_tile_sum_to_shared, devices=devices)
641
729
  add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
642
730
  add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
643
731
  add_function_test(TestTileReduce, "test_tile_reduce_argmin", test_tile_reduce_argmin, devices=devices)
644
732
  add_function_test(TestTileReduce, "test_tile_reduce_argmax", test_tile_reduce_argmax, devices=devices)
645
733
  add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
646
734
  add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
647
- add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
735
+ add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_grouped_sum, devices=devices)
648
736
  add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
649
737
  add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
650
738
  add_function_test(TestTileReduce, "test_tile_arange", test_tile_arange, devices=devices)
@@ -653,7 +741,8 @@ add_function_test(TestTileReduce, "test_tile_untile_vector", test_tile_untile_ve
653
741
  add_function_test(TestTileReduce, "test_tile_strided_loop", test_tile_strided_loop, devices=devices)
654
742
  add_function_test(TestTileReduce, "test_tile_scan_inclusive", test_tile_scan_inclusive, devices=devices)
655
743
  add_function_test(TestTileReduce, "test_tile_scan_exclusive", test_tile_scan_exclusive, devices=devices)
656
- add_function_test(TestTileReduce, "test_tile_reduce_matrix", test_tile_reduce_matrix, devices=cuda_devices)
744
+ add_function_test(TestTileReduce, "test_tile_reduce_matrix", test_tile_reduce_matrix, devices=devices)
745
+ add_function_test(TestTileReduce, "test_tile_reduce_vector", test_tile_reduce_vector, devices=devices)
657
746
 
658
747
  if __name__ == "__main__":
659
748
  wp.clear_kernel_cache()
@@ -28,7 +28,7 @@ def test_tile_shared_mem_size(test, device):
28
28
 
29
29
  BLOCK_DIM = 256
30
30
 
31
- @wp.kernel
31
+ @wp.kernel(module="unique")
32
32
  def compute(out: wp.array2d(dtype=float)):
33
33
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
34
34
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -64,7 +64,7 @@ def test_tile_shared_mem_large(test, device):
64
64
  BLOCK_DIM = 256
65
65
 
66
66
  # we disable backward kernel gen since 128k is not supported on most architectures
67
- @wp.kernel(enable_backward=False)
67
+ @wp.kernel(enable_backward=False, module="unique")
68
68
  def compute(out: wp.array2d(dtype=float)):
69
69
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
70
70
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -100,7 +100,7 @@ def test_tile_shared_mem_graph(test, device):
100
100
 
101
101
  BLOCK_DIM = 256
102
102
 
103
- @wp.kernel
103
+ @wp.kernel(module="unique")
104
104
  def compute(out: wp.array2d(dtype=float)):
105
105
  a = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared")
106
106
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 2.0
@@ -110,13 +110,13 @@ def test_tile_shared_mem_graph(test, device):
110
110
 
111
111
  out = wp.empty((DIM_M, DIM_N), dtype=float, device=device)
112
112
 
113
- wp.load_module(device=device)
113
+ # preload the unique module
114
+ wp.load_module(compute.module, device=device, block_dim=BLOCK_DIM)
114
115
 
115
- wp.capture_begin(device, force_module_load=False)
116
- wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
117
- graph = wp.capture_end(device)
116
+ with wp.ScopedCapture(device, force_module_load=False) as capture:
117
+ wp.launch_tiled(compute, dim=[1], inputs=[out], block_dim=BLOCK_DIM, device=device)
118
118
 
119
- wp.capture_launch(graph)
119
+ wp.capture_launch(capture.graph)
120
120
 
121
121
  # check output
122
122
  assert_np_equal(out.numpy(), np.ones((DIM_M, DIM_N)) * 3.0)
@@ -157,7 +157,7 @@ def test_tile_shared_mem_func(test, device):
157
157
 
158
158
  return a + b
159
159
 
160
- @wp.kernel
160
+ @wp.kernel(module="unique")
161
161
  def compute(out: wp.array2d(dtype=float)):
162
162
  s = add_tile_small()
163
163
  b = add_tile_big()
@@ -197,7 +197,7 @@ def test_tile_shared_non_aligned(test, device):
197
197
  b = wp.tile_ones(shape=(DIM_M, DIM_N), dtype=float, storage="shared") * 3.0
198
198
  return a + b
199
199
 
200
- @wp.kernel
200
+ @wp.kernel(module="unique")
201
201
  def compute(out: wp.array2d(dtype=float)):
202
202
  # This test the logic in the stack allocator, which should increment and
203
203
  # decrement the stack pointer each time foo() is called
@@ -225,9 +225,9 @@ def test_tile_shared_non_aligned(test, device):
225
225
 
226
226
 
227
227
  def test_tile_shared_vec_accumulation(test, device):
228
- BLOCK_DIM = 64
228
+ BLOCK_DIM = 256
229
229
 
230
- @wp.kernel
230
+ @wp.kernel(module="unique")
231
231
  def compute(indices: wp.array(dtype=int), vecs: wp.array(dtype=wp.vec3), output: wp.array2d(dtype=float)):
232
232
  i, j = wp.tid()
233
233
 
@@ -286,9 +286,9 @@ def test_tile_shared_vec_accumulation(test, device):
286
286
 
287
287
 
288
288
  def test_tile_shared_simple_reduction_add(test, device):
289
- BLOCK_DIM = 64
289
+ BLOCK_DIM = 256
290
290
 
291
- @wp.kernel
291
+ @wp.kernel(module="unique")
292
292
  def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
293
293
  i, j = wp.tid()
294
294
 
@@ -313,9 +313,9 @@ def test_tile_shared_simple_reduction_add(test, device):
313
313
 
314
314
 
315
315
  def test_tile_shared_simple_reduction_sub(test, device):
316
- BLOCK_DIM = 64
316
+ BLOCK_DIM = 256
317
317
 
318
- @wp.kernel
318
+ @wp.kernel(module="unique")
319
319
  def compute(x: wp.array(dtype=float), y: wp.array(dtype=float)):
320
320
  i, j = wp.tid()
321
321
 
@@ -44,63 +44,67 @@ def create_sort_kernel(KEY_TYPE, MAX_SORT_LENGTH):
44
44
 
45
45
 
46
46
  def test_tile_sort(test, device):
47
- for dtype in [int, float]: # Loop over int and float keys
47
+ # Forward-declare kernels for more efficient compilation
48
+ kernels = {}
49
+ for dtype in [int, float]:
50
+ for i in range(0, 11):
51
+ length = 2**i + 1
52
+ kernels[(dtype, length)] = create_sort_kernel(dtype, length)
53
+
54
+ for (dtype, length), kernel in kernels.items():
48
55
  for j in range(5, 10):
49
56
  TILE_DIM = 2**j
50
- for i in range(0, 11): # Start from 1 to avoid zero-length cases
51
- length = 2**i + 1
52
-
53
- rng = np.random.default_rng(42) # Create a random generator instance
54
-
55
- if dtype == int:
56
- np_keys = rng.choice(1000000000, size=length, replace=False)
57
- else: # dtype == float
58
- np_keys = rng.uniform(0, 1000000000, size=length)
59
-
60
- np_values = np.arange(length)
61
-
62
- # Generate random keys and iota indexer
63
- input_keys = wp.array(np_keys, dtype=dtype, device=device)
64
- input_values = wp.array(np_values, dtype=int, device=device)
65
- output_keys = wp.zeros_like(input_keys, device=device)
66
- output_values = wp.zeros_like(input_values, device=device)
67
-
68
- # Execute sorting kernel
69
- kernel = create_sort_kernel(dtype, length)
70
- wp.launch_tiled(
71
- kernel,
72
- dim=1,
73
- inputs=[input_keys, input_values, output_keys, output_values],
74
- block_dim=TILE_DIM,
75
- device=device,
76
- )
77
- wp.synchronize()
78
-
79
- # Sort using NumPy for validation
80
- sorted_indices = np.argsort(np_keys)
81
- np_sorted_keys = np_keys[sorted_indices]
82
- np_sorted_values = np_values[sorted_indices]
83
-
84
- if dtype == int:
85
- keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
86
- else: # dtype == float
87
- keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
88
-
89
- values_match = np.array_equal(output_values.numpy(), np_sorted_values)
90
-
91
- if not keys_match or not values_match:
92
- print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
93
- print("")
94
- print(output_keys.numpy())
95
- print(np_sorted_keys)
96
- print("")
97
- print(output_values.numpy())
98
- print(np_sorted_values)
99
- print("")
100
-
101
- # Validate results
102
- assert keys_match, f"Key sorting mismatch for dtype={dtype}!"
103
- assert values_match, f"Value sorting mismatch for dtype={dtype}!"
57
+
58
+ rng = np.random.default_rng(42) # Create a random generator instance
59
+
60
+ if dtype == int:
61
+ np_keys = rng.choice(1000000000, size=length, replace=False)
62
+ else: # dtype == float
63
+ np_keys = rng.uniform(0, 1000000000, size=length).astype(dtype)
64
+
65
+ np_values = np.arange(length)
66
+
67
+ # Generate random keys and iota indexer
68
+ input_keys = wp.array(np_keys, dtype=dtype, device=device)
69
+ input_values = wp.array(np_values, dtype=int, device=device)
70
+ output_keys = wp.zeros_like(input_keys, device=device)
71
+ output_values = wp.zeros_like(input_values, device=device)
72
+
73
+ # Execute sorting kernel
74
+ wp.launch_tiled(
75
+ kernel,
76
+ dim=1,
77
+ inputs=[input_keys, input_values, output_keys, output_values],
78
+ block_dim=TILE_DIM,
79
+ device=device,
80
+ )
81
+ wp.synchronize()
82
+
83
+ # Sort using NumPy for validation
84
+ sorted_indices = np.argsort(np_keys)
85
+ np_sorted_keys = np_keys[sorted_indices]
86
+ np_sorted_values = np_values[sorted_indices]
87
+
88
+ if dtype == int:
89
+ keys_match = np.array_equal(output_keys.numpy(), np_sorted_keys)
90
+ else: # dtype == float
91
+ keys_match = np.allclose(output_keys.numpy(), np_sorted_keys, atol=1e-6) # Use tolerance for floats
92
+
93
+ values_match = np.array_equal(output_values.numpy(), np_sorted_values)
94
+
95
+ if not keys_match or not values_match:
96
+ print(f"Test failed for dtype={dtype}, TILE_DIM={TILE_DIM}, length={length}")
97
+ print("")
98
+ print(output_keys.numpy())
99
+ print(np_sorted_keys)
100
+ print("")
101
+ print(output_values.numpy())
102
+ print(np_sorted_values)
103
+ print("")
104
+
105
+ # Validate results
106
+ test.assertTrue(keys_match, f"Key sorting mismatch for dtype={dtype}!")
107
+ test.assertTrue(values_match, f"Value sorting mismatch for dtype={dtype}!")
104
108
 
105
109
 
106
110
  devices = get_test_devices()