warp-lang 1.8.0__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (59) hide show
  1. warp/bin/warp-clang.so +0 -0
  2. warp/bin/warp.so +0 -0
  3. warp/build_dll.py +5 -0
  4. warp/codegen.py +15 -3
  5. warp/config.py +1 -1
  6. warp/context.py +122 -24
  7. warp/examples/interop/example_jax_callable.py +34 -4
  8. warp/examples/interop/example_jax_kernel.py +27 -1
  9. warp/fem/field/virtual.py +2 -0
  10. warp/fem/integrate.py +78 -47
  11. warp/jax_experimental/ffi.py +201 -53
  12. warp/native/array.h +4 -4
  13. warp/native/builtin.h +8 -4
  14. warp/native/coloring.cpp +5 -1
  15. warp/native/cuda_util.cpp +1 -1
  16. warp/native/intersect.h +2 -2
  17. warp/native/mat.h +3 -3
  18. warp/native/mesh.h +1 -1
  19. warp/native/quat.h +6 -2
  20. warp/native/rand.h +7 -7
  21. warp/native/sparse.cu +1 -1
  22. warp/native/svd.h +23 -8
  23. warp/native/tile.h +20 -1
  24. warp/native/tile_radix_sort.h +5 -1
  25. warp/native/tile_reduce.h +16 -25
  26. warp/native/tuple.h +2 -2
  27. warp/native/vec.h +4 -4
  28. warp/native/warp.cpp +1 -1
  29. warp/native/warp.cu +15 -2
  30. warp/native/warp.h +1 -1
  31. warp/render/render_opengl.py +52 -51
  32. warp/render/render_usd.py +0 -1
  33. warp/sim/collide.py +1 -2
  34. warp/sim/integrator_vbd.py +10 -2
  35. warp/sparse.py +1 -1
  36. warp/tape.py +2 -0
  37. warp/tests/sim/test_cloth.py +89 -6
  38. warp/tests/sim/test_coloring.py +76 -1
  39. warp/tests/test_assert.py +53 -0
  40. warp/tests/test_atomic_cas.py +127 -114
  41. warp/tests/test_mat.py +22 -0
  42. warp/tests/test_quat.py +22 -0
  43. warp/tests/test_sparse.py +32 -0
  44. warp/tests/test_static.py +48 -0
  45. warp/tests/test_tape.py +38 -0
  46. warp/tests/test_vec.py +38 -408
  47. warp/tests/test_vec_constructors.py +325 -0
  48. warp/tests/tile/test_tile.py +31 -143
  49. warp/tests/tile/test_tile_mathdx.py +2 -2
  50. warp/tests/tile/test_tile_matmul.py +179 -0
  51. warp/tests/tile/test_tile_reduce.py +100 -11
  52. warp/tests/tile/test_tile_shared_memory.py +12 -12
  53. warp/tests/tile/test_tile_sort.py +59 -55
  54. warp/tests/unittest_suites.py +10 -0
  55. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/METADATA +4 -4
  56. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/RECORD +59 -57
  57. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  58. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  59. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,325 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_float_types = [np.float16, np.float32, np.float64]
24
+
25
+ kernel_cache = {}
26
+
27
+
28
+ def getkernel(func, suffix=""):
29
+ key = func.__name__ + "_" + suffix
30
+ if key not in kernel_cache:
31
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
32
+ return kernel_cache[key]
33
+
34
+
35
+ def test_anon_constructor_error_length_mismatch(test, device):
36
+ @wp.kernel
37
+ def kernel():
38
+ wp.vector(wp.vector(length=2, dtype=float), length=3, dtype=float)
39
+
40
+ with test.assertRaisesRegex(
41
+ RuntimeError,
42
+ r"incompatible vector of length 3 given when copy constructing a vector of length 2$",
43
+ ):
44
+ wp.launch(kernel, dim=1, inputs=[], device=device)
45
+
46
+
47
+ def test_anon_constructor_error_numeric_arg_missing(test, device):
48
+ @wp.kernel
49
+ def kernel():
50
+ wp.vector(1.0, 2.0, length=12345)
51
+
52
+ with test.assertRaisesRegex(
53
+ RuntimeError,
54
+ r"incompatible number of values given \(2\) when constructing a vector of length 12345$",
55
+ ):
56
+ wp.launch(kernel, dim=1, inputs=[], device=device)
57
+
58
+
59
+ def test_anon_constructor_error_length_arg_missing(test, device):
60
+ @wp.kernel
61
+ def kernel():
62
+ wp.vector()
63
+
64
+ with test.assertRaisesRegex(
65
+ RuntimeError,
66
+ r"the `length` argument must be specified when zero-initializing a vector$",
67
+ ):
68
+ wp.launch(kernel, dim=1, inputs=[], device=device)
69
+
70
+
71
+ def test_anon_constructor_error_numeric_args_mismatch(test, device):
72
+ @wp.kernel
73
+ def kernel():
74
+ wp.vector(1.0, 2)
75
+
76
+ with test.assertRaisesRegex(
77
+ RuntimeError,
78
+ r"all values given when constructing a vector must have the same type$",
79
+ ):
80
+ wp.launch(kernel, dim=1, inputs=[], device=device)
81
+
82
+
83
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
84
+ @wp.kernel
85
+ def kernel():
86
+ wp.vec3(wp.vec2(1.0, 2.0))
87
+
88
+ with test.assertRaisesRegex(
89
+ RuntimeError, "incompatible vector of length 3 given when copy constructing a vector of length 2"
90
+ ):
91
+ wp.launch(kernel, dim=1, inputs=[], device=device)
92
+
93
+
94
+ def test_tpl_constructor_error_numeric_args_mismatch(test, device):
95
+ @wp.kernel
96
+ def kernel():
97
+ wp.vec2(1.0, 2)
98
+
99
+ with test.assertRaisesRegex(
100
+ RuntimeError,
101
+ r"all values given when constructing a vector must have the same type$",
102
+ ):
103
+ wp.launch(kernel, dim=1, inputs=[], device=device)
104
+
105
+
106
+ def test_casting_constructors(test, device, dtype, register_kernels=False):
107
+ np_type = np.dtype(dtype)
108
+ wp_type = wp.types.np_dtype_to_warp_type[np_type]
109
+ vec3 = wp.types.vector(length=3, dtype=wp_type)
110
+
111
+ np16 = np.dtype(np.float16)
112
+ wp16 = wp.types.np_dtype_to_warp_type[np16]
113
+
114
+ np32 = np.dtype(np.float32)
115
+ wp32 = wp.types.np_dtype_to_warp_type[np32]
116
+
117
+ np64 = np.dtype(np.float64)
118
+ wp64 = wp.types.np_dtype_to_warp_type[np64]
119
+
120
+ def cast_float16(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp16, ndim=2)):
121
+ tid = wp.tid()
122
+
123
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
124
+ v2 = wp.vector(v1, dtype=wp16)
125
+
126
+ b[tid, 0] = v2[0]
127
+ b[tid, 1] = v2[1]
128
+ b[tid, 2] = v2[2]
129
+
130
+ def cast_float32(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp32, ndim=2)):
131
+ tid = wp.tid()
132
+
133
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
134
+ v2 = wp.vector(v1, dtype=wp32)
135
+
136
+ b[tid, 0] = v2[0]
137
+ b[tid, 1] = v2[1]
138
+ b[tid, 2] = v2[2]
139
+
140
+ def cast_float64(a: wp.array(dtype=wp_type, ndim=2), b: wp.array(dtype=wp64, ndim=2)):
141
+ tid = wp.tid()
142
+
143
+ v1 = vec3(a[tid, 0], a[tid, 1], a[tid, 2])
144
+ v2 = wp.vector(v1, dtype=wp64)
145
+
146
+ b[tid, 0] = v2[0]
147
+ b[tid, 1] = v2[1]
148
+ b[tid, 2] = v2[2]
149
+
150
+ kernel_16 = getkernel(cast_float16, suffix=dtype.__name__)
151
+ kernel_32 = getkernel(cast_float32, suffix=dtype.__name__)
152
+ kernel_64 = getkernel(cast_float64, suffix=dtype.__name__)
153
+
154
+ if register_kernels:
155
+ return
156
+
157
+ # check casting to float 16
158
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
159
+ b = wp.array(np.zeros((1, 3), dtype=np16), dtype=wp16, requires_grad=True, device=device)
160
+ b_result = np.ones((1, 3), dtype=np16)
161
+ b_grad = wp.array(np.ones((1, 3), dtype=np16), dtype=wp16, device=device)
162
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
163
+
164
+ tape = wp.Tape()
165
+ with tape:
166
+ wp.launch(kernel=kernel_16, dim=1, inputs=[a, b], device=device)
167
+
168
+ tape.backward(grads={b: b_grad})
169
+ out = tape.gradients[a].numpy()
170
+
171
+ assert_np_equal(b.numpy(), b_result)
172
+ assert_np_equal(out, a_grad.numpy())
173
+
174
+ # check casting to float 32
175
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
176
+ b = wp.array(np.zeros((1, 3), dtype=np32), dtype=wp32, requires_grad=True, device=device)
177
+ b_result = np.ones((1, 3), dtype=np32)
178
+ b_grad = wp.array(np.ones((1, 3), dtype=np32), dtype=wp32, device=device)
179
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
180
+
181
+ tape = wp.Tape()
182
+ with tape:
183
+ wp.launch(kernel=kernel_32, dim=1, inputs=[a, b], device=device)
184
+
185
+ tape.backward(grads={b: b_grad})
186
+ out = tape.gradients[a].numpy()
187
+
188
+ assert_np_equal(b.numpy(), b_result)
189
+ assert_np_equal(out, a_grad.numpy())
190
+
191
+ # check casting to float 64
192
+ a = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, requires_grad=True, device=device)
193
+ b = wp.array(np.zeros((1, 3), dtype=np64), dtype=wp64, requires_grad=True, device=device)
194
+ b_result = np.ones((1, 3), dtype=np64)
195
+ b_grad = wp.array(np.ones((1, 3), dtype=np64), dtype=wp64, device=device)
196
+ a_grad = wp.array(np.ones((1, 3), dtype=np_type), dtype=wp_type, device=device)
197
+
198
+ tape = wp.Tape()
199
+ with tape:
200
+ wp.launch(kernel=kernel_64, dim=1, inputs=[a, b], device=device)
201
+
202
+ tape.backward(grads={b: b_grad})
203
+ out = tape.gradients[a].numpy()
204
+
205
+ assert_np_equal(b.numpy(), b_result)
206
+ assert_np_equal(out, a_grad.numpy())
207
+
208
+
209
+ @wp.kernel
210
+ def test_vector_constructors_value_func():
211
+ a = wp.vec2()
212
+ b = wp.vector(a, dtype=wp.float16)
213
+ c = wp.vector(a)
214
+ d = wp.vector(a, length=2)
215
+ e = wp.vector(1.0, 2.0, 3.0, dtype=float)
216
+
217
+
218
+ # Test matrix constructors using explicit type (float16)
219
+ # note that these tests are specifically not using generics / closure
220
+ # args to create kernels dynamically (like the rest of this file)
221
+ # as those use different code paths to resolve arg types which
222
+ # has lead to regressions.
223
+ @wp.kernel
224
+ def test_vector_constructors_explicit_precision():
225
+ # construction for custom matrix types
226
+ ones = wp.vector(wp.float16(1.0), length=2)
227
+ zeros = wp.vector(length=2, dtype=wp.float16)
228
+ custom = wp.vector(wp.float16(0.0), wp.float16(1.0))
229
+
230
+ for i in range(2):
231
+ wp.expect_eq(ones[i], wp.float16(1.0))
232
+ wp.expect_eq(zeros[i], wp.float16(0.0))
233
+ wp.expect_eq(custom[i], wp.float16(i))
234
+
235
+
236
+ # Same as above but with a default (float/int) type
237
+ # which tests some different code paths that
238
+ # need to ensure types are correctly canonicalized
239
+ # during codegen
240
+ @wp.kernel
241
+ def test_vector_constructors_default_precision():
242
+ # construction for custom matrix types
243
+ ones = wp.vector(1.0, length=2)
244
+ zeros = wp.vector(length=2, dtype=float)
245
+ custom = wp.vector(0.0, 1.0)
246
+
247
+ for i in range(2):
248
+ wp.expect_eq(ones[i], 1.0)
249
+ wp.expect_eq(zeros[i], 0.0)
250
+ wp.expect_eq(custom[i], float(i))
251
+
252
+
253
+ CONSTANT_LENGTH = wp.constant(10)
254
+
255
+
256
+ # tests that we can use global constants in length keyword argument
257
+ # for vector constructor
258
+ @wp.kernel
259
+ def test_vector_constructors_constant_length():
260
+ v = wp.vector(length=(CONSTANT_LENGTH), dtype=float)
261
+
262
+ for i in range(CONSTANT_LENGTH):
263
+ v[i] = float(i)
264
+
265
+
266
+ devices = get_test_devices()
267
+
268
+
269
+ class TestVecConstructors(unittest.TestCase):
270
+ pass
271
+
272
+
273
+ add_function_test(
274
+ TestVecConstructors,
275
+ "test_anon_constructor_error_length_mismatch",
276
+ test_anon_constructor_error_length_mismatch,
277
+ devices=devices,
278
+ )
279
+ add_function_test(
280
+ TestVecConstructors,
281
+ "test_anon_constructor_error_numeric_arg_missing",
282
+ test_anon_constructor_error_numeric_arg_missing,
283
+ devices=devices,
284
+ )
285
+ add_function_test(
286
+ TestVecConstructors,
287
+ "test_anon_constructor_error_length_arg_missing",
288
+ test_anon_constructor_error_length_arg_missing,
289
+ devices=devices,
290
+ )
291
+ add_function_test(
292
+ TestVecConstructors,
293
+ "test_anon_constructor_error_numeric_args_mismatch",
294
+ test_anon_constructor_error_numeric_args_mismatch,
295
+ devices=devices,
296
+ )
297
+ add_function_test(
298
+ TestVecConstructors,
299
+ "test_tpl_constructor_error_incompatible_sizes",
300
+ test_tpl_constructor_error_incompatible_sizes,
301
+ devices=devices,
302
+ )
303
+ add_function_test(
304
+ TestVecConstructors,
305
+ "test_tpl_constructor_error_numeric_args_mismatch",
306
+ test_tpl_constructor_error_numeric_args_mismatch,
307
+ devices=devices,
308
+ )
309
+ add_kernel_test(TestVecConstructors, test_vector_constructors_value_func, dim=1, devices=devices)
310
+ add_kernel_test(TestVecConstructors, test_vector_constructors_explicit_precision, dim=1, devices=devices)
311
+ add_kernel_test(TestVecConstructors, test_vector_constructors_default_precision, dim=1, devices=devices)
312
+ add_kernel_test(TestVecConstructors, test_vector_constructors_constant_length, dim=1, devices=devices)
313
+
314
+ for dtype in np_float_types:
315
+ add_function_test_register_kernel(
316
+ TestVecConstructors,
317
+ f"test_casting_constructors_{dtype.__name__}",
318
+ test_casting_constructors,
319
+ devices=devices,
320
+ dtype=dtype,
321
+ )
322
+
323
+ if __name__ == "__main__":
324
+ wp.clear_kernel_cache()
325
+ unittest.main(verbosity=2, failfast=True)
@@ -215,103 +215,6 @@ def test_tile_binary_map(test, device):
215
215
  assert_np_equal(B_wp.grad.numpy(), B_grad)
216
216
 
217
217
 
218
- def test_tile_grouped_gemm(test, device):
219
- @wp.kernel
220
- def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
221
- # output tile index
222
- i = wp.tid()
223
-
224
- a = wp.tile_load(A[i], shape=(TILE_M, TILE_K))
225
- b = wp.tile_load(B[i], shape=(TILE_K, TILE_N))
226
-
227
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
228
-
229
- wp.tile_matmul(a, b, sum)
230
-
231
- wp.tile_store(C[i], sum)
232
-
233
- batch_count = 56
234
-
235
- M = TILE_M
236
- N = TILE_N
237
- K = TILE_K
238
-
239
- rng = np.random.default_rng(42)
240
- A = rng.random((batch_count, M, K), dtype=np.float32)
241
- B = rng.random((batch_count, K, N), dtype=np.float32)
242
- C = A @ B
243
-
244
- A_wp = wp.array(A, requires_grad=True, device=device)
245
- B_wp = wp.array(B, requires_grad=True, device=device)
246
- C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
247
-
248
- with wp.Tape() as tape:
249
- wp.launch_tiled(
250
- tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
251
- )
252
-
253
- # TODO: 32 mismatched elements
254
- assert_np_equal(C_wp.numpy(), C, 1e-6)
255
-
256
-
257
- def test_tile_gemm(dtype):
258
- def test(test, device):
259
- @wp.kernel
260
- def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
261
- # output tile index
262
- i, j = wp.tid()
263
-
264
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
265
-
266
- M = A.shape[0]
267
- N = B.shape[1]
268
- K = A.shape[1]
269
-
270
- count = int(K / TILE_K)
271
-
272
- for k in range(0, count):
273
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
274
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
275
-
276
- # sum += a*b
277
- wp.tile_matmul(a, b, sum)
278
-
279
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
280
-
281
- M = TILE_M * 7
282
- K = TILE_K * 6
283
- N = TILE_N * 5
284
-
285
- rng = np.random.default_rng(42)
286
- A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
- B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
288
- C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
289
-
290
- A_wp = wp.array(A, requires_grad=True, device=device)
291
- B_wp = wp.array(B, requires_grad=True, device=device)
292
- C_wp = wp.array(C, requires_grad=True, device=device)
293
-
294
- with wp.Tape() as tape:
295
- wp.launch_tiled(
296
- tile_gemm,
297
- dim=(int(M / TILE_M), int(N / TILE_N)),
298
- inputs=[A_wp, B_wp, C_wp],
299
- block_dim=TILE_DIM,
300
- device=device,
301
- )
302
-
303
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
304
-
305
- adj_C = np.ones_like(C)
306
-
307
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
308
-
309
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
310
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
311
-
312
- return test
313
-
314
-
315
218
  @wp.kernel
316
219
  def tile_operators(input: wp.array3d(dtype=float), output: wp.array3d(dtype=float)):
317
220
  # output tile index
@@ -368,6 +271,12 @@ def test_tile_tile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtyp
368
271
  wp.tile_store(y, t)
369
272
 
370
273
 
274
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
275
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
276
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
277
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
278
+
279
+
371
280
  @wp.kernel
372
281
  def test_tile_tile_scalar_expansion_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
373
282
  a = x[0]
@@ -494,6 +403,12 @@ def test_tile_untile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dt
494
403
  y[i] = b
495
404
 
496
405
 
406
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
407
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
408
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
409
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
410
+
411
+
497
412
  @wp.kernel
498
413
  def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
499
414
  i = wp.tid()
@@ -503,6 +418,11 @@ def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
503
418
  y[i] = b
504
419
 
505
420
 
421
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
422
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
423
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
424
+
425
+
506
426
  def test_tile_untile(test, device):
507
427
  def test_func_preserve_type(type: Any):
508
428
  x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
@@ -644,7 +564,7 @@ def test_tile_sum_launch(test, device):
644
564
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
645
565
 
646
566
 
647
- @wp.kernel
567
+ @wp.kernel(module="unique")
648
568
  def test_tile_extract_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
649
569
  i, j, x, y = wp.tid()
650
570
 
@@ -680,7 +600,7 @@ def test_tile_extract(test, device):
680
600
  assert_np_equal(a.grad.numpy(), expected_grad)
681
601
 
682
602
 
683
- @wp.kernel
603
+ @wp.kernel(module="unique")
684
604
  def test_tile_extract_repeated_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
685
605
  i, j, x, y = wp.tid()
686
606
 
@@ -744,7 +664,7 @@ def test_tile_assign(test, device):
744
664
 
745
665
  tape = wp.Tape()
746
666
  with tape:
747
- wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=64, device=device)
667
+ wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
748
668
 
749
669
  y.grad = wp.ones_like(y)
750
670
  tape.backward()
@@ -766,31 +686,11 @@ def test_tile_transpose(test, device):
766
686
  input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
767
687
  output = wp.zeros_like(input.transpose(), device=device)
768
688
 
769
- wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
689
+ wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=TILE_DIM, device=device)
770
690
 
771
691
  assert_np_equal(output.numpy(), input.numpy().T)
772
692
 
773
693
 
774
- def test_tile_transpose_matmul(test, device):
775
- @wp.kernel
776
- def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
777
- x = wp.tile_load(input, shape=(TILE_M, TILE_N))
778
- y = wp.tile_transpose(x)
779
-
780
- z = wp.tile_zeros(dtype=float, shape=(TILE_N, TILE_N))
781
- wp.tile_matmul(y, x, z)
782
-
783
- wp.tile_store(output, z)
784
-
785
- rng = np.random.default_rng(42)
786
- input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
787
- output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
788
-
789
- wp.launch_tiled(test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
790
-
791
- assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
792
-
793
-
794
694
  @wp.kernel
795
695
  def test_tile_broadcast_add_1d_kernel(
796
696
  input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
@@ -812,7 +712,7 @@ def test_tile_broadcast_add_1d(test, device):
812
712
  b = wp.array(np.ones(1, dtype=np.float32), device=device)
813
713
  out = wp.zeros((N,), dtype=float, device=device)
814
714
 
815
- wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
715
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
816
716
 
817
717
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
818
718
 
@@ -839,7 +739,7 @@ def test_tile_broadcast_add_2d(test, device):
839
739
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
840
740
  out = wp.zeros((M, N), dtype=float, device=device)
841
741
 
842
- wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
742
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
843
743
 
844
744
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
845
745
 
@@ -867,7 +767,7 @@ def test_tile_broadcast_add_3d(test, device):
867
767
  b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
868
768
  out = wp.zeros((M, N, O), dtype=float, device=device)
869
769
 
870
- wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
770
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
871
771
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
872
772
 
873
773
 
@@ -894,7 +794,7 @@ def test_tile_broadcast_add_4d(test, device):
894
794
  b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
895
795
  out = wp.zeros((M, N, O, P), dtype=float, device=device)
896
796
 
897
- wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
797
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
898
798
 
899
799
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
900
800
 
@@ -915,7 +815,7 @@ def test_tile_broadcast_grad(test, device):
915
815
  b = wp.array(np.ones((5, 5), dtype=np.float32), requires_grad=True, device=device)
916
816
 
917
817
  with wp.Tape() as tape:
918
- wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
818
+ wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=TILE_DIM, device=device)
919
819
 
920
820
  b.grad = wp.ones_like(b, device=device)
921
821
  tape.backward()
@@ -1049,14 +949,7 @@ def tile_len_kernel(
1049
949
  def test_tile_len(test, device):
1050
950
  a = wp.zeros((TILE_M, TILE_N), dtype=float, device=device)
1051
951
  out = wp.empty(1, dtype=int, device=device)
1052
- wp.launch_tiled(
1053
- tile_len_kernel,
1054
- dim=(1,),
1055
- inputs=(a,),
1056
- outputs=(out,),
1057
- block_dim=32,
1058
- device=device,
1059
- )
952
+ wp.launch_tiled(tile_len_kernel, dim=(1,), inputs=(a,), outputs=(out,), block_dim=TILE_DIM, device=device)
1060
953
 
1061
954
  test.assertEqual(out.numpy()[0], TILE_M)
1062
955
 
@@ -1193,12 +1086,7 @@ add_function_test(TestTile, "test_tile_copy_1d", test_tile_copy_1d, devices=devi
1193
1086
  add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devices)
1194
1087
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
1195
1088
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
1196
- add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
1197
- add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
1198
- add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
1199
- add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
1200
1089
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
1201
- add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
1202
1090
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
1203
1091
  add_function_test(TestTile, "test_tile_tile", test_tile_tile, devices=get_cuda_test_devices())
1204
1092
  add_function_test(TestTile, "test_tile_untile", test_tile_untile, devices=devices)
@@ -1215,10 +1103,10 @@ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad
1215
1103
  add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1216
1104
  add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
1217
1105
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
1218
- add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1219
- add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1220
- add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1221
- add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1106
+ # add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1107
+ # add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1108
+ # add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1109
+ # add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1222
1110
 
1223
1111
 
1224
1112
  if __name__ == "__main__":
@@ -450,7 +450,7 @@ def test_tile_math_back_substitution_multiple_rhs(test, device):
450
450
  def test_tile_math_block_cholesky(test, device):
451
451
  BLOCK_SIZE = wp.constant(TILE_M // 2)
452
452
 
453
- @wp.kernel
453
+ @wp.kernel(module="unique")
454
454
  def block_cholesky_kernel(
455
455
  A: wp.array2d(dtype=float),
456
456
  L: wp.array2d(dtype=float),
@@ -496,7 +496,7 @@ def test_tile_math_block_cholesky(test, device):
496
496
 
497
497
  wp.tile_store(L, sol_tile, offset=(i, k))
498
498
 
499
- @wp.kernel
499
+ @wp.kernel(module="unique")
500
500
  def block_cholesky_solve_kernel(
501
501
  L: wp.array2d(dtype=float),
502
502
  b: wp.array2d(dtype=float),