warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import unittest
17
+ from typing import Any
17
18
 
18
19
  import numpy as np
19
20
 
@@ -214,150 +215,265 @@ def test_tile_binary_map(test, device):
214
215
  assert_np_equal(B_wp.grad.numpy(), B_grad)
215
216
 
216
217
 
217
- def test_tile_grouped_gemm(test, device):
218
- @wp.kernel
219
- def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
220
- # output tile index
221
- i = wp.tid()
218
+ @wp.kernel
219
+ def tile_operators(input: wp.array3d(dtype=float), output: wp.array3d(dtype=float)):
220
+ # output tile index
221
+ i = wp.tid()
222
222
 
223
- a = wp.tile_load(A[i], shape=(TILE_M, TILE_K))
224
- b = wp.tile_load(B[i], shape=(TILE_K, TILE_N))
223
+ a = wp.tile_load(input[i], shape=(TILE_M, TILE_N))
225
224
 
226
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
225
+ # neg
226
+ b = -a
227
227
 
228
- wp.tile_matmul(a, b, sum)
228
+ # right scalar multiply
229
+ c = b * 0.5
229
230
 
230
- wp.tile_store(C[i], sum)
231
+ # left scalar multiply
232
+ d = 0.5 * c
231
233
 
234
+ # add tiles
235
+ e = a + d
236
+
237
+ wp.tile_store(output[i], e)
238
+
239
+
240
+ def test_tile_operators(test, device):
232
241
  batch_count = 56
233
242
 
234
243
  M = TILE_M
235
244
  N = TILE_N
236
- K = TILE_K
237
245
 
238
246
  rng = np.random.default_rng(42)
239
- A = rng.random((batch_count, M, K), dtype=np.float32)
240
- B = rng.random((batch_count, K, N), dtype=np.float32)
241
- C = A @ B
247
+ input = rng.random((batch_count, M, N), dtype=np.float32)
248
+ output = input * 0.75
242
249
 
243
- A_wp = wp.array(A, requires_grad=True, device=device)
244
- B_wp = wp.array(B, requires_grad=True, device=device)
245
- C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
250
+ input_wp = wp.array(input, requires_grad=True, device=device)
251
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
246
252
 
247
253
  with wp.Tape() as tape:
248
254
  wp.launch_tiled(
249
- tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
255
+ tile_operators, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
250
256
  )
251
257
 
252
- # TODO: 32 mismatched elements
253
- assert_np_equal(C_wp.numpy(), C, 1e-6)
258
+ assert_np_equal(output_wp.numpy(), output)
254
259
 
260
+ output_wp.grad.fill_(1.0)
261
+
262
+ tape.backward()
255
263
 
256
- def test_tile_gemm(dtype):
257
- def test(test, device):
258
- @wp.kernel
259
- def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
260
- # output tile index
261
- i, j = wp.tid()
264
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.75)
262
265
 
263
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
264
266
 
265
- M = A.shape[0]
266
- N = B.shape[1]
267
- K = A.shape[1]
267
+ @wp.kernel
268
+ def test_tile_tile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
269
+ a = x[0]
270
+ t = wp.tile(a, preserve_type=True)
271
+ wp.tile_store(y, t)
268
272
 
269
- count = int(K / TILE_K)
270
273
 
271
- for k in range(0, count):
272
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
273
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
274
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
275
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
276
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
277
+ wp.overload(test_tile_tile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
274
278
 
275
- # sum += a*b
276
- wp.tile_matmul(a, b, sum)
277
279
 
278
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
280
+ @wp.kernel
281
+ def test_tile_tile_scalar_expansion_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
282
+ a = x[0]
283
+ t = wp.tile(a)
284
+ wp.tile_store(y, t)
279
285
 
280
- M = TILE_M * 7
281
- K = TILE_K * 6
282
- N = TILE_N * 5
283
286
 
284
- rng = np.random.default_rng(42)
285
- A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
286
- B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
- C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
+ @wp.kernel
288
+ def test_tile_tile_vec_expansion_kernel(x: wp.array(dtype=wp.vec3), y: wp.array2d(dtype=float)):
289
+ a = x[0]
290
+ t = wp.tile(a)
291
+ wp.tile_store(y, t)
288
292
 
289
- A_wp = wp.array(A, requires_grad=True, device=device)
290
- B_wp = wp.array(B, requires_grad=True, device=device)
291
- C_wp = wp.array(C, requires_grad=True, device=device)
292
293
 
293
- with wp.Tape() as tape:
294
- wp.launch_tiled(
295
- tile_gemm,
296
- dim=(int(M / TILE_M), int(N / TILE_N)),
297
- inputs=[A_wp, B_wp, C_wp],
294
+ @wp.kernel
295
+ def test_tile_tile_mat_expansion_kernel(x: wp.array(dtype=wp.mat33), y: wp.array3d(dtype=float)):
296
+ a = x[0]
297
+ t = wp.tile(a)
298
+ wp.tile_store(y, t)
299
+
300
+
301
+ def test_tile_tile(test, device):
302
+ # preserve type
303
+ def test_func_preserve_type(type: Any):
304
+ x = wp.ones(1, dtype=type, requires_grad=True, device=device)
305
+ y = wp.zeros((TILE_DIM), dtype=type, requires_grad=True, device=device)
306
+
307
+ tape = wp.Tape()
308
+ with tape:
309
+ wp.launch(
310
+ test_tile_tile_preserve_type_kernel,
311
+ dim=[TILE_DIM],
312
+ inputs=[x],
313
+ outputs=[y],
298
314
  block_dim=TILE_DIM,
299
315
  device=device,
300
316
  )
301
317
 
302
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
318
+ y.grad = wp.ones_like(y)
319
+
320
+ tape.backward()
303
321
 
304
- adj_C = np.ones_like(C)
322
+ assert_np_equal(y.numpy(), wp.full((TILE_DIM), type(1.0), dtype=type, device="cpu").numpy())
323
+ assert_np_equal(x.grad.numpy(), wp.full((1,), type(TILE_DIM), dtype=type, device="cpu").numpy())
305
324
 
306
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
325
+ test_func_preserve_type(float)
326
+ test_func_preserve_type(wp.vec3)
327
+ test_func_preserve_type(wp.quat)
328
+ test_func_preserve_type(wp.mat33)
307
329
 
308
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
309
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
330
+ # scalar expansion
331
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
332
+ y = wp.zeros((TILE_DIM), dtype=float, requires_grad=True, device=device)
333
+
334
+ tape = wp.Tape()
335
+ with tape:
336
+ wp.launch(
337
+ test_tile_tile_scalar_expansion_kernel,
338
+ dim=[TILE_DIM],
339
+ inputs=[x],
340
+ outputs=[y],
341
+ block_dim=TILE_DIM,
342
+ device=device,
343
+ )
344
+
345
+ y.grad = wp.ones_like(y)
310
346
 
311
- return test
347
+ tape.backward()
348
+
349
+ assert_np_equal(y.numpy(), wp.full((TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
350
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=float, device="cpu").numpy())
351
+
352
+ # vec expansion
353
+ x = wp.ones(1, dtype=wp.vec3, requires_grad=True, device=device)
354
+ y = wp.zeros((3, TILE_DIM), dtype=float, requires_grad=True, device=device)
355
+
356
+ tape = wp.Tape()
357
+ with tape:
358
+ wp.launch(
359
+ test_tile_tile_vec_expansion_kernel,
360
+ dim=[TILE_DIM],
361
+ inputs=[x],
362
+ outputs=[y],
363
+ block_dim=TILE_DIM,
364
+ device=device,
365
+ )
366
+
367
+ y.grad = wp.ones_like(y)
368
+
369
+ tape.backward()
370
+
371
+ assert_np_equal(y.numpy(), wp.full((3, TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
372
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=wp.vec3, device="cpu").numpy())
373
+
374
+ # mat expansion
375
+ x = wp.ones(1, dtype=wp.mat33, requires_grad=True, device=device)
376
+ y = wp.zeros((3, 3, TILE_DIM), dtype=float, requires_grad=True, device=device)
377
+
378
+ tape = wp.Tape()
379
+ with tape:
380
+ wp.launch(
381
+ test_tile_tile_mat_expansion_kernel,
382
+ dim=[TILE_DIM],
383
+ inputs=[x],
384
+ outputs=[y],
385
+ block_dim=TILE_DIM,
386
+ device=device,
387
+ )
388
+
389
+ y.grad = wp.ones_like(y)
390
+
391
+ tape.backward()
392
+
393
+ assert_np_equal(y.numpy(), wp.full((3, 3, TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
394
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=wp.mat33, device="cpu").numpy())
312
395
 
313
396
 
314
397
  @wp.kernel
315
- def tile_operators(input: wp.array3d(dtype=float), output: wp.array3d(dtype=float)):
316
- # output tile index
398
+ def test_tile_untile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
317
399
  i = wp.tid()
400
+ a = x[i]
401
+ t = wp.tile(a, preserve_type=True)
402
+ b = wp.untile(t)
403
+ y[i] = b
318
404
 
319
- a = wp.tile_load(input[i], shape=(TILE_M, TILE_N))
320
405
 
321
- # neg
322
- b = -a
406
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
407
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
408
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.quat), "y": wp.array(dtype=wp.quat)})
409
+ wp.overload(test_tile_untile_preserve_type_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
323
410
 
324
- # right scalar multiply
325
- c = b * 0.5
326
411
 
327
- # left scalar multiply
328
- d = 0.5 * c
412
+ @wp.kernel
413
+ def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
414
+ i = wp.tid()
415
+ a = x[i]
416
+ t = wp.tile(a)
417
+ b = wp.untile(t)
418
+ y[i] = b
329
419
 
330
- # add tiles
331
- e = a + d
332
420
 
333
- wp.tile_store(output[i], e)
421
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=float), "y": wp.array(dtype=float)})
422
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.vec3), "y": wp.array(dtype=wp.vec3)})
423
+ wp.overload(test_tile_untile_kernel, {"x": wp.array(dtype=wp.mat33), "y": wp.array(dtype=wp.mat33)})
334
424
 
335
425
 
336
- def test_tile_operators(test, device):
337
- batch_count = 56
426
+ def test_tile_untile(test, device):
427
+ def test_func_preserve_type(type: Any):
428
+ x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
429
+ y = wp.zeros_like(x)
338
430
 
339
- M = TILE_M
340
- N = TILE_N
431
+ tape = wp.Tape()
432
+ with tape:
433
+ wp.launch(
434
+ test_tile_untile_preserve_type_kernel,
435
+ dim=TILE_DIM,
436
+ inputs=[x],
437
+ outputs=[y],
438
+ block_dim=TILE_DIM,
439
+ device=device,
440
+ )
341
441
 
342
- rng = np.random.default_rng(42)
343
- input = rng.random((batch_count, M, N), dtype=np.float32)
344
- output = input * 0.75
442
+ y.grad = wp.ones_like(y)
345
443
 
346
- input_wp = wp.array(input, requires_grad=True, device=device)
347
- output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
444
+ tape.backward()
348
445
 
349
- with wp.Tape() as tape:
350
- wp.launch_tiled(
351
- tile_operators, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
352
- )
446
+ assert_np_equal(y.numpy(), x.numpy())
447
+ assert_np_equal(x.grad.numpy(), y.grad.numpy())
353
448
 
354
- assert_np_equal(output_wp.numpy(), output)
449
+ test_func_preserve_type(float)
450
+ test_func_preserve_type(wp.vec3)
451
+ test_func_preserve_type(wp.quat)
452
+ test_func_preserve_type(wp.mat33)
355
453
 
356
- output_wp.grad.fill_(1.0)
454
+ def test_func(type: Any):
455
+ x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
456
+ y = wp.zeros_like(x)
357
457
 
358
- tape.backward()
458
+ tape = wp.Tape()
459
+ with tape:
460
+ wp.launch(test_tile_untile_kernel, dim=TILE_DIM, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
359
461
 
360
- assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.75)
462
+ y.grad = wp.ones_like(y)
463
+
464
+ tape.backward()
465
+
466
+ assert_np_equal(y.numpy(), x.numpy())
467
+ assert_np_equal(x.grad.numpy(), y.grad.numpy())
468
+
469
+ test_func(float)
470
+ test_func(wp.vec3)
471
+ test_func(wp.mat33)
472
+
473
+
474
+ @wp.func
475
+ def tile_sum_func(a: wp.tile(dtype=float, shape=(TILE_M, TILE_N))):
476
+ return wp.tile_sum(a) * 0.5
361
477
 
362
478
 
363
479
  @wp.kernel
@@ -366,7 +482,7 @@ def tile_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float
366
482
  i = wp.tid()
367
483
 
368
484
  a = wp.tile_load(input[i], shape=(TILE_M, TILE_N))
369
- s = wp.tile_sum(a) * 0.5
485
+ s = tile_sum_func(a)
370
486
 
371
487
  wp.tile_store(output, s, offset=i)
372
488
 
@@ -448,7 +564,7 @@ def test_tile_sum_launch(test, device):
448
564
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
449
565
 
450
566
 
451
- @wp.kernel
567
+ @wp.kernel(module="unique")
452
568
  def test_tile_extract_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
453
569
  i, j, x, y = wp.tid()
454
570
 
@@ -484,7 +600,7 @@ def test_tile_extract(test, device):
484
600
  assert_np_equal(a.grad.numpy(), expected_grad)
485
601
 
486
602
 
487
- @wp.kernel
603
+ @wp.kernel(module="unique")
488
604
  def test_tile_extract_repeated_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
489
605
  i, j, x, y = wp.tid()
490
606
 
@@ -548,7 +664,7 @@ def test_tile_assign(test, device):
548
664
 
549
665
  tape = wp.Tape()
550
666
  with tape:
551
- wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=64, device=device)
667
+ wp.launch(test_tile_assign_kernel, dim=[1, TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
552
668
 
553
669
  y.grad = wp.ones_like(y)
554
670
  tape.backward()
@@ -570,31 +686,11 @@ def test_tile_transpose(test, device):
570
686
  input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
571
687
  output = wp.zeros_like(input.transpose(), device=device)
572
688
 
573
- wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
689
+ wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=TILE_DIM, device=device)
574
690
 
575
691
  assert_np_equal(output.numpy(), input.numpy().T)
576
692
 
577
693
 
578
- def test_tile_transpose_matmul(test, device):
579
- @wp.kernel
580
- def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
581
- x = wp.tile_load(input, shape=(TILE_M, TILE_N))
582
- y = wp.tile_transpose(x)
583
-
584
- z = wp.tile_zeros(dtype=float, shape=(TILE_N, TILE_N))
585
- wp.tile_matmul(y, x, z)
586
-
587
- wp.tile_store(output, z)
588
-
589
- rng = np.random.default_rng(42)
590
- input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
591
- output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
592
-
593
- wp.launch_tiled(test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
594
-
595
- assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
596
-
597
-
598
694
  @wp.kernel
599
695
  def test_tile_broadcast_add_1d_kernel(
600
696
  input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
@@ -616,7 +712,7 @@ def test_tile_broadcast_add_1d(test, device):
616
712
  b = wp.array(np.ones(1, dtype=np.float32), device=device)
617
713
  out = wp.zeros((N,), dtype=float, device=device)
618
714
 
619
- wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
715
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
620
716
 
621
717
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
622
718
 
@@ -643,7 +739,7 @@ def test_tile_broadcast_add_2d(test, device):
643
739
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
644
740
  out = wp.zeros((M, N), dtype=float, device=device)
645
741
 
646
- wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
742
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
647
743
 
648
744
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
649
745
 
@@ -671,7 +767,7 @@ def test_tile_broadcast_add_3d(test, device):
671
767
  b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
672
768
  out = wp.zeros((M, N, O), dtype=float, device=device)
673
769
 
674
- wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
770
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
675
771
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
676
772
 
677
773
 
@@ -698,7 +794,7 @@ def test_tile_broadcast_add_4d(test, device):
698
794
  b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
699
795
  out = wp.zeros((M, N, O, P), dtype=float, device=device)
700
796
 
701
- wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
797
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=TILE_DIM, device=device)
702
798
 
703
799
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
704
800
 
@@ -719,7 +815,7 @@ def test_tile_broadcast_grad(test, device):
719
815
  b = wp.array(np.ones((5, 5), dtype=np.float32), requires_grad=True, device=device)
720
816
 
721
817
  with wp.Tape() as tape:
722
- wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
818
+ wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=TILE_DIM, device=device)
723
819
 
724
820
  b.grad = wp.ones_like(b, device=device)
725
821
  tape.backward()
@@ -728,6 +824,116 @@ def test_tile_broadcast_grad(test, device):
728
824
  assert_np_equal(a.grad.numpy(), np.ones(5) * 5.0)
729
825
 
730
826
 
827
+ @wp.kernel
828
+ def test_tile_squeeze_kernel(x: wp.array3d(dtype=float), y: wp.array(dtype=float)):
829
+ a = wp.tile_load(x, shape=(1, TILE_M, 1), offset=(0, 0, 0))
830
+ b = wp.tile_squeeze(a, axis=(2,))
831
+ c = wp.tile_squeeze(b)
832
+
833
+ wp.tile_store(y, c, offset=(0,))
834
+
835
+
836
+ def test_tile_squeeze(test, device):
837
+ x = wp.ones((1, TILE_M, 1), dtype=float, device=device, requires_grad=True)
838
+ y = wp.zeros((TILE_M,), dtype=float, device=device, requires_grad=True)
839
+
840
+ tape = wp.Tape()
841
+ with tape:
842
+ wp.launch_tiled(test_tile_squeeze_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
843
+
844
+ y.grad = wp.ones_like(y)
845
+ tape.backward()
846
+
847
+ assert_np_equal(y.numpy(), np.ones((TILE_M,), dtype=np.float32))
848
+ assert_np_equal(x.grad.numpy(), np.ones((1, TILE_M, 1), dtype=np.float32))
849
+
850
+
851
+ @wp.kernel
852
+ def test_tile_reshape_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
853
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(0, 0))
854
+ b = wp.tile_reshape(a, shape=(wp.static(TILE_M * TILE_N), 1))
855
+ c = wp.tile_reshape(b, shape=(-1, 1))
856
+
857
+ wp.tile_store(y, c, offset=(0, 0))
858
+
859
+
860
+ def test_tile_reshape(test, device):
861
+ x = wp.ones((TILE_M, TILE_N), dtype=float, device=device, requires_grad=True)
862
+ y = wp.zeros((TILE_M * TILE_N, 1), dtype=float, device=device, requires_grad=True)
863
+
864
+ tape = wp.Tape()
865
+ with tape:
866
+ wp.launch_tiled(test_tile_reshape_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
867
+
868
+ y.grad = wp.ones_like(y)
869
+ tape.backward()
870
+
871
+ assert_np_equal(y.numpy(), np.ones((TILE_M * TILE_N, 1), dtype=np.float32))
872
+ assert_np_equal(x.grad.numpy(), np.ones((TILE_M, TILE_N), dtype=np.float32))
873
+
874
+
875
+ @wp.kernel
876
+ def test_tile_astype_kernel(x: wp.array2d(dtype=Any), y: wp.array2d(dtype=wp.float32)):
877
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
878
+ b = wp.tile_astype(a, dtype=wp.float32)
879
+ wp.tile_store(y, b)
880
+
881
+
882
+ def test_tile_astype(test, device):
883
+ x_np = np.arange(TILE_M * TILE_N, dtype=np.int32).reshape((TILE_M, TILE_N))
884
+ x = wp.array(x_np, dtype=wp.int32, device=device)
885
+ y = wp.zeros((TILE_M, TILE_N), dtype=wp.float32, device=device)
886
+
887
+ wp.launch_tiled(test_tile_astype_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
888
+
889
+ assert_np_equal(y.numpy(), x_np.astype(np.float32))
890
+
891
+ x_np = np.arange(TILE_M * TILE_N, dtype=np.float64).reshape((TILE_M, TILE_N))
892
+ x = wp.array(x_np, dtype=wp.float64, requires_grad=True, device=device)
893
+ y = wp.zeros((TILE_M, TILE_N), dtype=wp.float32, requires_grad=True, device=device)
894
+
895
+ tape = wp.Tape()
896
+ with tape:
897
+ wp.launch_tiled(test_tile_astype_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
898
+
899
+ y.grad = wp.ones_like(y)
900
+
901
+ tape.backward()
902
+
903
+ assert_np_equal(y.numpy(), x_np.astype(np.float32))
904
+ assert_np_equal(x.grad.numpy(), np.ones_like(x_np))
905
+
906
+
907
+ @wp.func
908
+ def test_tile_func_return_func(tile: Any):
909
+ return tile
910
+
911
+
912
+ @wp.kernel
913
+ def test_tile_func_return_kernel(x: wp.array2d(dtype=wp.float32), y: wp.array2d(dtype=wp.float32)):
914
+ a = wp.tile_load(x, shape=(TILE_M, 1))
915
+ b = wp.tile_broadcast(a, shape=(TILE_M, TILE_K))
916
+ c = test_tile_func_return_func(b)
917
+ wp.tile_store(y, c)
918
+
919
+
920
+ def test_tile_func_return(test, device):
921
+ x = wp.ones(shape=(TILE_M, 1), dtype=wp.float32, requires_grad=True, device=device)
922
+ y = wp.zeros(shape=(TILE_M, TILE_K), dtype=wp.float32, requires_grad=True, device=device)
923
+
924
+ tape = wp.Tape()
925
+ with tape:
926
+ wp.launch_tiled(
927
+ test_tile_func_return_kernel, dim=[1, 1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
928
+ )
929
+
930
+ y.grad = wp.ones_like(y)
931
+ tape.backward()
932
+
933
+ assert_np_equal(y.numpy(), np.ones((TILE_M, TILE_K), dtype=np.float32))
934
+ assert_np_equal(x.grad.numpy(), np.ones((TILE_M, 1), dtype=np.float32) * TILE_K)
935
+
936
+
731
937
  @wp.kernel
732
938
  def tile_len_kernel(
733
939
  a: wp.array(dtype=float, ndim=2),
@@ -743,14 +949,7 @@ def tile_len_kernel(
743
949
  def test_tile_len(test, device):
744
950
  a = wp.zeros((TILE_M, TILE_N), dtype=float, device=device)
745
951
  out = wp.empty(1, dtype=int, device=device)
746
- wp.launch_tiled(
747
- tile_len_kernel,
748
- dim=(1,),
749
- inputs=(a,),
750
- outputs=(out,),
751
- block_dim=32,
752
- device=device,
753
- )
952
+ wp.launch_tiled(tile_len_kernel, dim=(1,), inputs=(a,), outputs=(out,), block_dim=TILE_DIM, device=device)
754
953
 
755
954
  test.assertEqual(out.numpy()[0], TILE_M)
756
955
 
@@ -771,6 +970,111 @@ def test_tile_print(test, device):
771
970
  wp.synchronize()
772
971
 
773
972
 
973
+ @wp.kernel
974
+ def test_tile_add_inplace_kernel(
975
+ input_a: wp.array2d(dtype=float),
976
+ input_b: wp.array2d(dtype=float),
977
+ output_reg: wp.array2d(dtype=float),
978
+ output_shared: wp.array2d(dtype=float),
979
+ ):
980
+ i, j = wp.tid()
981
+
982
+ a_reg = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
983
+ b_reg = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
984
+ a_shared = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
985
+ b_shared = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
986
+
987
+ a_reg += b_reg
988
+ a_reg += b_shared
989
+ a_shared += b_reg
990
+ a_shared += b_shared
991
+
992
+ wp.tile_store(output_reg, a_reg, offset=(i * TILE_M, j * TILE_N))
993
+ wp.tile_store(output_shared, a_shared, offset=(i * TILE_M, j * TILE_N))
994
+
995
+
996
+ @wp.kernel
997
+ def test_tile_sub_inplace_kernel(
998
+ input_a: wp.array2d(dtype=float),
999
+ input_b: wp.array2d(dtype=float),
1000
+ output_reg: wp.array2d(dtype=float),
1001
+ output_shared: wp.array2d(dtype=float),
1002
+ ):
1003
+ i, j = wp.tid()
1004
+
1005
+ a_reg = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1006
+ b_reg = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1007
+ a_shared = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1008
+ b_shared = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1009
+
1010
+ a_reg -= b_reg
1011
+ a_reg -= b_shared
1012
+ a_shared -= b_reg
1013
+ a_shared -= b_shared
1014
+
1015
+ wp.tile_store(output_reg, a_reg, offset=(i * TILE_M, j * TILE_N))
1016
+ wp.tile_store(output_shared, a_shared, offset=(i * TILE_M, j * TILE_N))
1017
+
1018
+
1019
+ def test_tile_inplace(test, device):
1020
+ M = TILE_M * 2
1021
+ N = TILE_N * 2
1022
+
1023
+ a = wp.zeros((M, N), requires_grad=True, device=device)
1024
+ b = wp.ones_like(a, requires_grad=True, device=device)
1025
+ c = wp.zeros_like(a, requires_grad=True, device=device)
1026
+ d = wp.zeros_like(a, requires_grad=True, device=device)
1027
+
1028
+ with wp.Tape() as tape:
1029
+ wp.launch_tiled(
1030
+ test_tile_add_inplace_kernel,
1031
+ dim=[int(M / TILE_M), int(N / TILE_N)],
1032
+ inputs=[a, b, c, d],
1033
+ block_dim=TILE_DIM,
1034
+ device=device,
1035
+ )
1036
+
1037
+ assert_np_equal(a.numpy(), np.zeros((M, N)))
1038
+ assert_np_equal(b.numpy(), np.ones((M, N)))
1039
+ assert_np_equal(c.numpy(), 2.0 * np.ones((M, N)))
1040
+ assert_np_equal(d.numpy(), 2.0 * np.ones((M, N)))
1041
+
1042
+ c.grad = wp.ones_like(c, device=device)
1043
+ d.grad = wp.ones_like(d, device=device)
1044
+ tape.backward()
1045
+
1046
+ assert_np_equal(a.grad.numpy(), 2.0 * np.ones((M, N)))
1047
+ assert_np_equal(b.grad.numpy(), 4.0 * np.ones((M, N)))
1048
+
1049
+ tape.zero()
1050
+
1051
+ a.zero_()
1052
+ b.fill_(1.0)
1053
+ c.zero_()
1054
+ d.zero_()
1055
+
1056
+ with wp.Tape() as tape:
1057
+ wp.launch_tiled(
1058
+ test_tile_sub_inplace_kernel,
1059
+ dim=[int(M / TILE_M), int(N / TILE_N)],
1060
+ inputs=[a, b, c, d],
1061
+ block_dim=TILE_DIM,
1062
+ device=device,
1063
+ )
1064
+
1065
+ assert_np_equal(a.numpy(), np.zeros((M, N)))
1066
+ assert_np_equal(b.numpy(), np.ones((M, N)))
1067
+ assert_np_equal(c.numpy(), -2.0 * np.ones((M, N)))
1068
+ assert_np_equal(d.numpy(), -2.0 * np.ones((M, N)))
1069
+
1070
+ c.grad = wp.ones_like(c, device=device)
1071
+ d.grad = wp.ones_like(d, device=device)
1072
+ tape.backward()
1073
+
1074
+ assert_np_equal(a.grad.numpy(), 2.0 * np.ones((M, N)))
1075
+ assert_np_equal(b.grad.numpy(), -4.0 * np.ones((M, N)))
1076
+
1077
+
774
1078
  devices = get_test_devices()
775
1079
 
776
1080
 
@@ -782,13 +1086,10 @@ add_function_test(TestTile, "test_tile_copy_1d", test_tile_copy_1d, devices=devi
782
1086
  add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devices)
783
1087
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
784
1088
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
785
- add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
786
- add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
787
- add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
788
- add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
789
1089
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
790
- add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
791
1090
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
1091
+ add_function_test(TestTile, "test_tile_tile", test_tile_tile, devices=get_cuda_test_devices())
1092
+ add_function_test(TestTile, "test_tile_untile", test_tile_untile, devices=devices)
792
1093
  add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
793
1094
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
794
1095
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
@@ -799,8 +1100,14 @@ add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_ad
799
1100
  add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
800
1101
  add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
801
1102
  add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
1103
+ add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1104
+ add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
802
1105
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
803
- add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1106
+ # add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1107
+ # add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1108
+ # add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1109
+ # add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1110
+
804
1111
 
805
1112
  if __name__ == "__main__":
806
1113
  wp.clear_kernel_cache()