warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -109,12 +109,29 @@ def test_tile_copy_2d(test, device):
109
109
 
110
110
 
111
111
  @wp.func
112
- def unary_func(x: float):
112
+ def unary_func(x: wp.float32):
113
113
  return wp.sin(x)
114
114
 
115
115
 
116
+ @wp.func
117
+ def unary_func(x: wp.float64):
118
+ return wp.sin(x)
119
+
120
+
121
+ @wp.kernel
122
+ def tile_unary_map_user_func(input: wp.array2d(dtype=Any), output: wp.array2d(dtype=Any)):
123
+ # tile index
124
+ i, j = wp.tid()
125
+
126
+ a = wp.tile_load(input, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
127
+
128
+ sa = wp.tile_map(unary_func, a)
129
+
130
+ wp.tile_store(output, sa, offset=(i * TILE_M, j * TILE_N))
131
+
132
+
116
133
  @wp.kernel
117
- def tile_unary_map(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
134
+ def tile_unary_map_builtin_func(input: wp.array2d(dtype=Any), output: wp.array2d(dtype=Any)):
118
135
  # tile index
119
136
  i, j = wp.tid()
120
137
 
@@ -131,185 +148,235 @@ def test_tile_unary_map(test, device):
131
148
  M = TILE_M * 7
132
149
  N = TILE_N * 5
133
150
 
134
- A = rng.random((M, N), dtype=np.float32)
135
- B = np.sin(A)
151
+ def run(kernel, dtype):
152
+ A = rng.random((M, N), dtype=dtype)
153
+ B = np.sin(A)
136
154
 
137
- A_grad = np.cos(A)
155
+ A_grad = np.cos(A)
138
156
 
139
- A_wp = wp.array(A, requires_grad=True, device=device)
140
- B_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
157
+ A_wp = wp.array(A, requires_grad=True, device=device)
158
+ B_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
141
159
 
142
- with wp.Tape() as tape:
143
- wp.launch_tiled(
144
- tile_unary_map,
145
- dim=[int(M / TILE_M), int(N / TILE_N)],
146
- inputs=[A_wp, B_wp],
147
- block_dim=TILE_DIM,
148
- device=device,
149
- )
160
+ with wp.Tape() as tape:
161
+ wp.launch_tiled(
162
+ kernel,
163
+ dim=[int(M / TILE_M), int(N / TILE_N)],
164
+ inputs=[A_wp, B_wp],
165
+ block_dim=TILE_DIM,
166
+ device=device,
167
+ )
150
168
 
151
- # verify forward pass
152
- assert_np_equal(B_wp.numpy(), B, tol=1.0e-4)
169
+ tol = 1.0e-6 if dtype == np.float64 else 1.0e-4
153
170
 
154
- # verify backward pass
155
- B_wp.grad = wp.ones_like(B_wp, device=device)
156
- tape.backward()
171
+ # verify forward pass
172
+ assert_np_equal(B_wp.numpy(), B, tol=tol)
157
173
 
158
- assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
174
+ # verify backward pass
175
+ B_wp.grad = wp.ones_like(B_wp, device=device)
176
+ tape.backward()
177
+
178
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=tol)
179
+
180
+ dtypes = [np.float32, np.float64]
181
+
182
+ for dtype in dtypes:
183
+ run(tile_unary_map_user_func, dtype)
184
+ run(tile_unary_map_builtin_func, dtype)
159
185
 
160
186
 
161
187
  @wp.func
162
- def binary_func(x: float, y: float):
163
- return wp.sin(x) + y
188
+ def unary_func_mixed_types(x: int) -> float:
189
+ return wp.sin(float(x))
164
190
 
165
191
 
166
192
  @wp.kernel
167
- def tile_binary_map(
168
- input_a: wp.array2d(dtype=float), input_b: wp.array2d(dtype=float), output: wp.array2d(dtype=float)
169
- ):
193
+ def tile_unary_map_mixed_types(input: wp.array2d(dtype=int), output: wp.array2d(dtype=float)):
170
194
  # tile index
171
195
  i, j = wp.tid()
172
196
 
173
- a = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
174
- b = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
197
+ a = wp.tile_load(input, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
175
198
 
176
- sa = wp.tile_map(binary_func, a, b)
199
+ sa = wp.tile_map(unary_func_mixed_types, a)
177
200
 
178
201
  wp.tile_store(output, sa, offset=(i * TILE_M, j * TILE_N))
179
202
 
180
203
 
181
- def test_tile_binary_map(test, device):
204
+ def test_tile_unary_map_mixed_types(test, device):
182
205
  rng = np.random.default_rng(42)
183
206
 
184
207
  M = TILE_M * 7
185
208
  N = TILE_N * 5
186
209
 
187
- A = rng.random((M, N), dtype=np.float32)
188
- B = rng.random((M, N), dtype=np.float32)
189
- C = np.sin(A) + B
210
+ A = rng.integers(0, 100, size=(M, N), dtype=np.int32)
211
+ B = np.sin(A.astype(np.float32))
190
212
 
191
- A_grad = np.cos(A)
192
- B_grad = np.ones_like(B)
213
+ A_grad = np.cos(A.astype(np.float32))
193
214
 
194
215
  A_wp = wp.array(A, requires_grad=True, device=device)
195
- B_wp = wp.array(B, requires_grad=True, device=device)
196
- C_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
216
+ B_wp = wp.zeros((M, N), dtype=float, requires_grad=True, device=device)
197
217
 
198
218
  with wp.Tape() as tape:
199
219
  wp.launch_tiled(
200
- tile_binary_map,
220
+ tile_unary_map_mixed_types,
201
221
  dim=[int(M / TILE_M), int(N / TILE_N)],
202
- inputs=[A_wp, B_wp, C_wp],
222
+ inputs=[A_wp, B_wp],
203
223
  block_dim=TILE_DIM,
204
224
  device=device,
205
225
  )
206
226
 
207
227
  # verify forward pass
208
- assert_np_equal(C_wp.numpy(), C, tol=1.0e-6)
228
+ assert_np_equal(B_wp.numpy(), B, tol=1.0e-4)
209
229
 
210
230
  # verify backward pass
211
- C_wp.grad = wp.ones_like(C_wp, device=device)
231
+ B_wp.grad = wp.ones_like(B_wp, device=device)
212
232
  tape.backward()
213
233
 
214
- assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
215
- assert_np_equal(B_wp.grad.numpy(), B_grad)
216
-
217
-
218
- def test_tile_grouped_gemm(test, device):
219
- @wp.kernel
220
- def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
221
- # output tile index
222
- i = wp.tid()
234
+ # The a gradients are now stored as ints and can't capture the correct values
235
+ # assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
223
236
 
224
- a = wp.tile_load(A[i], shape=(TILE_M, TILE_K))
225
- b = wp.tile_load(B[i], shape=(TILE_K, TILE_N))
226
237
 
227
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
228
-
229
- wp.tile_matmul(a, b, sum)
238
+ @wp.func
239
+ def binary_func(x: wp.float32, y: wp.float32):
240
+ return x + y
230
241
 
231
- wp.tile_store(C[i], sum)
232
242
 
233
- batch_count = 56
243
+ @wp.func
244
+ def binary_func(x: wp.float64, y: wp.float64):
245
+ return x + y
234
246
 
235
- M = TILE_M
236
- N = TILE_N
237
- K = TILE_K
238
247
 
239
- rng = np.random.default_rng(42)
240
- A = rng.random((batch_count, M, K), dtype=np.float32)
241
- B = rng.random((batch_count, K, N), dtype=np.float32)
242
- C = A @ B
248
+ @wp.kernel
249
+ def tile_binary_map_user_func(
250
+ input_a: wp.array2d(dtype=Any), input_b: wp.array2d(dtype=Any), output: wp.array2d(dtype=Any)
251
+ ):
252
+ # tile index
253
+ i, j = wp.tid()
243
254
 
244
- A_wp = wp.array(A, requires_grad=True, device=device)
245
- B_wp = wp.array(B, requires_grad=True, device=device)
246
- C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
255
+ a = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
256
+ b = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
247
257
 
248
- with wp.Tape() as tape:
249
- wp.launch_tiled(
250
- tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
251
- )
258
+ sa = wp.tile_map(binary_func, a, b)
252
259
 
253
- # TODO: 32 mismatched elements
254
- assert_np_equal(C_wp.numpy(), C, 1e-6)
260
+ wp.tile_store(output, sa, offset=(i * TILE_M, j * TILE_N))
255
261
 
256
262
 
257
- def test_tile_gemm(dtype):
258
- def test(test, device):
259
- @wp.kernel
260
- def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
261
- # output tile index
262
- i, j = wp.tid()
263
+ @wp.kernel
264
+ def tile_binary_map_builtin_func(
265
+ input_a: wp.array2d(dtype=Any), input_b: wp.array2d(dtype=Any), output: wp.array2d(dtype=Any)
266
+ ):
267
+ # tile index
268
+ i, j = wp.tid()
263
269
 
264
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
270
+ a = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
271
+ b = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
265
272
 
266
- M = A.shape[0]
267
- N = B.shape[1]
268
- K = A.shape[1]
273
+ sa = wp.tile_map(wp.add, a, b)
269
274
 
270
- count = int(K / TILE_K)
275
+ wp.tile_store(output, sa, offset=(i * TILE_M, j * TILE_N))
271
276
 
272
- for k in range(0, count):
273
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
274
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
275
277
 
276
- # sum += a*b
277
- wp.tile_matmul(a, b, sum)
278
+ def test_tile_binary_map(test, device):
279
+ rng = np.random.default_rng(42)
278
280
 
279
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
281
+ M = TILE_M * 7
282
+ N = TILE_N * 5
280
283
 
281
- M = TILE_M * 7
282
- K = TILE_K * 6
283
- N = TILE_N * 5
284
+ def run(kernel, dtype):
285
+ A = rng.random((M, N), dtype=dtype)
286
+ B = rng.random((M, N), dtype=dtype)
287
+ C = A + B
284
288
 
285
- rng = np.random.default_rng(42)
286
- A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
- B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
288
- C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
289
+ A_grad = np.ones_like(A)
290
+ B_grad = np.ones_like(B)
289
291
 
290
292
  A_wp = wp.array(A, requires_grad=True, device=device)
291
293
  B_wp = wp.array(B, requires_grad=True, device=device)
292
- C_wp = wp.array(C, requires_grad=True, device=device)
294
+ C_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
293
295
 
294
296
  with wp.Tape() as tape:
295
297
  wp.launch_tiled(
296
- tile_gemm,
297
- dim=(int(M / TILE_M), int(N / TILE_N)),
298
+ kernel,
299
+ dim=[int(M / TILE_M), int(N / TILE_N)],
298
300
  inputs=[A_wp, B_wp, C_wp],
299
301
  block_dim=TILE_DIM,
300
302
  device=device,
301
303
  )
302
304
 
303
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
305
+ tol = 1.0e-6 if dtype == np.float64 else 1.0e-4
304
306
 
305
- adj_C = np.ones_like(C)
307
+ # verify forward pass
308
+ assert_np_equal(C_wp.numpy(), C, tol=tol)
306
309
 
307
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
310
+ # verify backward pass
311
+ C_wp.grad = wp.ones_like(C_wp, device=device)
312
+ tape.backward()
313
+
314
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=tol)
315
+ assert_np_equal(B_wp.grad.numpy(), B_grad, tol=tol)
316
+
317
+ dtypes = [np.float32, np.float64]
308
318
 
309
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
310
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
319
+ for dtype in dtypes:
320
+ run(tile_binary_map_builtin_func, dtype)
321
+ run(tile_binary_map_user_func, dtype)
311
322
 
312
- return test
323
+
324
+ @wp.func
325
+ def binary_func_mixed_types(x: int, y: float) -> float:
326
+ return wp.sin(float(x)) + y
327
+
328
+
329
+ @wp.kernel
330
+ def tile_binary_map_mixed_types(
331
+ input_a: wp.array2d(dtype=int), input_b: wp.array2d(dtype=float), output: wp.array2d(dtype=float)
332
+ ):
333
+ # tile index
334
+ i, j = wp.tid()
335
+
336
+ a = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
337
+ b = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
338
+
339
+ sa = wp.tile_map(binary_func_mixed_types, a, b)
340
+
341
+ wp.tile_store(output, sa, offset=(i * TILE_M, j * TILE_N))
342
+
343
+
344
+ def test_tile_binary_map_mixed_types(test, device):
345
+ rng = np.random.default_rng(42)
346
+
347
+ M = TILE_M * 7
348
+ N = TILE_N * 5
349
+
350
+ A = rng.integers(0, 100, size=(M, N), dtype=np.int32)
351
+ B = rng.random((M, N), dtype=np.float32)
352
+ C = np.sin(A.astype(np.float32)) + B
353
+
354
+ A_grad = np.cos(A.astype(np.float32))
355
+ B_grad = np.ones_like(B)
356
+
357
+ A_wp = wp.array(A, requires_grad=True, device=device)
358
+ B_wp = wp.array(B, requires_grad=True, device=device)
359
+ C_wp = wp.zeros_like(B_wp, requires_grad=True, device=device)
360
+
361
+ with wp.Tape() as tape:
362
+ wp.launch_tiled(
363
+ tile_binary_map_mixed_types,
364
+ dim=[int(M / TILE_M), int(N / TILE_N)],
365
+ inputs=[A_wp, B_wp, C_wp],
366
+ block_dim=TILE_DIM,
367
+ device=device,
368
+ )
369
+
370
+ # verify forward pass
371
+ assert_np_equal(C_wp.numpy(), C, tol=1.0e-6)
372
+
373
+ # verify backward pass
374
+ C_wp.grad = wp.ones_like(C_wp, device=device)
375
+ tape.backward()
376
+
377
+ # The a gradiens are now stored as ints and can't capture the correct values
378
+ # assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
379
+ assert_np_equal(B_wp.grad.numpy(), B_grad)
313
380
 
314
381
 
315
382
  @wp.kernel
@@ -368,6 +435,12 @@ def test_tile_tile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtyp
368
435
  wp.tile_store(y, t)
369
436
 
370
437
 
438
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
439
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
440
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
441
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
442
+
443
+
371
444
  @wp.kernel
372
445
  def test_tile_tile_scalar_expansion_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
373
446
  a = x[0]
@@ -494,6 +567,12 @@ def test_tile_untile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dt
494
567
  y[i] = b
495
568
 
496
569
 
570
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
571
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
572
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
573
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
574
+
575
+
497
576
  @wp.kernel
498
577
  def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
499
578
  i = wp.tid()
@@ -503,6 +582,11 @@ def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
503
582
  y[i] = b
504
583
 
505
584
 
585
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
586
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
587
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
588
+
589
+
506
590
  def test_tile_untile(test, device):
507
591
  def test_func_preserve_type(type: Any):
508
592
  x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
@@ -644,7 +728,7 @@ def test_tile_sum_launch(test, device):
644
728
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
645
729
 
646
730
 
647
- @wp.kernel
731
+ @wp.kernel(module="unique")
648
732
  def test_tile_extract_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
649
733
  i, j, x, y = wp.tid()
650
734
 
@@ -680,7 +764,7 @@ def test_tile_extract(test, device):
680
764
  assert_np_equal(a.grad.numpy(), expected_grad)
681
765
 
682
766
 
683
- @wp.kernel
767
+ @wp.kernel(module="unique")
684
768
  def test_tile_extract_repeated_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
685
769
  i, j, x, y = wp.tid()
686
770
 
@@ -744,7 +828,7 @@ def test_tile_assign(test, device):
744
828
 
745
829
  tape = wp.Tape()
746
830
  with tape:
747
- wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=64, device=device)
831
+ wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
748
832
 
749
833
  y.grad = wp.ones_like(y)
750
834
  tape.backward()
@@ -766,31 +850,11 @@ def test_tile_transpose(test, device):
766
850
  input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
767
851
  output = wp.zeros_like(input.transpose(), device=device)
768
852
 
769
- wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
853
+ wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=TILE_DIM, device=device)
770
854
 
771
855
  assert_np_equal(output.numpy(), input.numpy().T)
772
856
 
773
857
 
774
- def test_tile_transpose_matmul(test, device):
775
- @wp.kernel
776
- def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
777
- x = wp.tile_load(input, shape=(TILE_M, TILE_N))
778
- y = wp.tile_transpose(x)
779
-
780
- z = wp.tile_zeros(dtype=float, shape=(TILE_N, TILE_N))
781
- wp.tile_matmul(y, x, z)
782
-
783
- wp.tile_store(output, z)
784
-
785
- rng = np.random.default_rng(42)
786
- input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
787
- output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
788
-
789
- wp.launch_tiled(test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
790
-
791
- assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
792
-
793
-
794
858
  @wp.kernel
795
859
  def test_tile_broadcast_add_1d_kernel(
796
860
  input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
@@ -812,7 +876,7 @@ def test_tile_broadcast_add_1d(test, device):
812
876
  b = wp.array(np.ones(1, dtype=np.float32), device=device)
813
877
  out = wp.zeros((N,), dtype=float, device=device)
814
878
 
815
- wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
879
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
816
880
 
817
881
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
818
882
 
@@ -839,7 +903,7 @@ def test_tile_broadcast_add_2d(test, device):
839
903
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
840
904
  out = wp.zeros((M, N), dtype=float, device=device)
841
905
 
842
- wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
906
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
843
907
 
844
908
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
845
909
 
@@ -867,7 +931,7 @@ def test_tile_broadcast_add_3d(test, device):
867
931
  b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
868
932
  out = wp.zeros((M, N, O), dtype=float, device=device)
869
933
 
870
- wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
934
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
871
935
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
872
936
 
873
937
 
@@ -894,7 +958,7 @@ def test_tile_broadcast_add_4d(test, device):
894
958
  b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
895
959
  out = wp.zeros((M, N, O, P), dtype=float, device=device)
896
960
 
897
- wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
961
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
898
962
 
899
963
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
900
964
 
@@ -915,7 +979,7 @@ def test_tile_broadcast_grad(test, device):
915
979
  b = wp.array(np.ones((5, 5), dtype=np.float32), requires_grad=True, device=device)
916
980
 
917
981
  with wp.Tape() as tape:
918
- wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
982
+ wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=TILE_DIM, device=device)
919
983
 
920
984
  b.grad = wp.ones_like(b, device=device)
921
985
  tape.backward()
@@ -1049,14 +1113,7 @@ def tile_len_kernel(
1049
1113
  def test_tile_len(test, device):
1050
1114
  a = wp.zeros((TILE_M, TILE_N), dtype=float, device=device)
1051
1115
  out = wp.empty(1, dtype=int, device=device)
1052
- wp.launch_tiled(
1053
- tile_len_kernel,
1054
- dim=(1,),
1055
- inputs=(a,),
1056
- outputs=(out,),
1057
- block_dim=32,
1058
- device=device,
1059
- )
1116
+ wp.launch_tiled(tile_len_kernel, dim=(1,), inputs=(a,), outputs=(out,), block_dim=TILE_DIM, device=device)
1060
1117
 
1061
1118
  test.assertEqual(out.numpy()[0], TILE_M)
1062
1119
 
@@ -1192,13 +1249,10 @@ class TestTile(unittest.TestCase):
1192
1249
  add_function_test(TestTile, "test_tile_copy_1d", test_tile_copy_1d, devices=devices)
1193
1250
  add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devices)
1194
1251
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
1252
+ add_function_test(TestTile, "test_tile_unary_map_mixed_types", test_tile_unary_map_mixed_types, devices=devices)
1195
1253
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
1196
- add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
1197
- add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
1198
- add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
1199
- add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
1254
+ add_function_test(TestTile, "test_tile_binary_map_mixed_types", test_tile_binary_map_mixed_types, devices=devices)
1200
1255
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
1201
- add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
1202
1256
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
1203
1257
  add_function_test(TestTile, "test_tile_tile", test_tile_tile, devices=get_cuda_test_devices())
1204
1258
  add_function_test(TestTile, "test_tile_untile", test_tile_untile, devices=devices)
@@ -1215,10 +1269,10 @@ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad
1215
1269
  add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1216
1270
  add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
1217
1271
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
1218
- add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1219
- add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1220
- add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1221
- add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1272
+ # add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1273
+ # add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1274
+ # add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1275
+ # add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1222
1276
 
1223
1277
 
1224
1278
  if __name__ == "__main__":