warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp-clang.dylib +0 -0
  5. warp/bin/libwarp.dylib +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import unittest
17
+ from typing import Any
17
18
 
18
19
  import numpy as np
19
20
 
@@ -360,13 +361,208 @@ def test_tile_operators(test, device):
360
361
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.75)
361
362
 
362
363
 
364
+ @wp.kernel
365
+ def test_tile_tile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
366
+ a = x[0]
367
+ t = wp.tile(a, preserve_type=True)
368
+ wp.tile_store(y, t)
369
+
370
+
371
+ @wp.kernel
372
+ def test_tile_tile_scalar_expansion_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
373
+ a = x[0]
374
+ t = wp.tile(a)
375
+ wp.tile_store(y, t)
376
+
377
+
378
+ @wp.kernel
379
+ def test_tile_tile_vec_expansion_kernel(x: wp.array(dtype=wp.vec3), y: wp.array2d(dtype=float)):
380
+ a = x[0]
381
+ t = wp.tile(a)
382
+ wp.tile_store(y, t)
383
+
384
+
385
+ @wp.kernel
386
+ def test_tile_tile_mat_expansion_kernel(x: wp.array(dtype=wp.mat33), y: wp.array3d(dtype=float)):
387
+ a = x[0]
388
+ t = wp.tile(a)
389
+ wp.tile_store(y, t)
390
+
391
+
392
+ def test_tile_tile(test, device):
393
+ # preserve type
394
+ def test_func_preserve_type(type: Any):
395
+ x = wp.ones(1, dtype=type, requires_grad=True, device=device)
396
+ y = wp.zeros((TILE_DIM), dtype=type, requires_grad=True, device=device)
397
+
398
+ tape = wp.Tape()
399
+ with tape:
400
+ wp.launch(
401
+ test_tile_tile_preserve_type_kernel,
402
+ dim=[TILE_DIM],
403
+ inputs=[x],
404
+ outputs=[y],
405
+ block_dim=TILE_DIM,
406
+ device=device,
407
+ )
408
+
409
+ y.grad = wp.ones_like(y)
410
+
411
+ tape.backward()
412
+
413
+ assert_np_equal(y.numpy(), wp.full((TILE_DIM), type(1.0), dtype=type, device="cpu").numpy())
414
+ assert_np_equal(x.grad.numpy(), wp.full((1,), type(TILE_DIM), dtype=type, device="cpu").numpy())
415
+
416
+ test_func_preserve_type(float)
417
+ test_func_preserve_type(wp.vec3)
418
+ test_func_preserve_type(wp.quat)
419
+ test_func_preserve_type(wp.mat33)
420
+
421
+ # scalar expansion
422
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
423
+ y = wp.zeros((TILE_DIM), dtype=float, requires_grad=True, device=device)
424
+
425
+ tape = wp.Tape()
426
+ with tape:
427
+ wp.launch(
428
+ test_tile_tile_scalar_expansion_kernel,
429
+ dim=[TILE_DIM],
430
+ inputs=[x],
431
+ outputs=[y],
432
+ block_dim=TILE_DIM,
433
+ device=device,
434
+ )
435
+
436
+ y.grad = wp.ones_like(y)
437
+
438
+ tape.backward()
439
+
440
+ assert_np_equal(y.numpy(), wp.full((TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
441
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=float, device="cpu").numpy())
442
+
443
+ # vec expansion
444
+ x = wp.ones(1, dtype=wp.vec3, requires_grad=True, device=device)
445
+ y = wp.zeros((3, TILE_DIM), dtype=float, requires_grad=True, device=device)
446
+
447
+ tape = wp.Tape()
448
+ with tape:
449
+ wp.launch(
450
+ test_tile_tile_vec_expansion_kernel,
451
+ dim=[TILE_DIM],
452
+ inputs=[x],
453
+ outputs=[y],
454
+ block_dim=TILE_DIM,
455
+ device=device,
456
+ )
457
+
458
+ y.grad = wp.ones_like(y)
459
+
460
+ tape.backward()
461
+
462
+ assert_np_equal(y.numpy(), wp.full((3, TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
463
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=wp.vec3, device="cpu").numpy())
464
+
465
+ # mat expansion
466
+ x = wp.ones(1, dtype=wp.mat33, requires_grad=True, device=device)
467
+ y = wp.zeros((3, 3, TILE_DIM), dtype=float, requires_grad=True, device=device)
468
+
469
+ tape = wp.Tape()
470
+ with tape:
471
+ wp.launch(
472
+ test_tile_tile_mat_expansion_kernel,
473
+ dim=[TILE_DIM],
474
+ inputs=[x],
475
+ outputs=[y],
476
+ block_dim=TILE_DIM,
477
+ device=device,
478
+ )
479
+
480
+ y.grad = wp.ones_like(y)
481
+
482
+ tape.backward()
483
+
484
+ assert_np_equal(y.numpy(), wp.full((3, 3, TILE_DIM), 1.0, dtype=float, device="cpu").numpy())
485
+ assert_np_equal(x.grad.numpy(), wp.full((1,), wp.float32(TILE_DIM), dtype=wp.mat33, device="cpu").numpy())
486
+
487
+
488
+ @wp.kernel
489
+ def test_tile_untile_preserve_type_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
490
+ i = wp.tid()
491
+ a = x[i]
492
+ t = wp.tile(a, preserve_type=True)
493
+ b = wp.untile(t)
494
+ y[i] = b
495
+
496
+
497
+ @wp.kernel
498
+ def test_tile_untile_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
499
+ i = wp.tid()
500
+ a = x[i]
501
+ t = wp.tile(a)
502
+ b = wp.untile(t)
503
+ y[i] = b
504
+
505
+
506
+ def test_tile_untile(test, device):
507
+ def test_func_preserve_type(type: Any):
508
+ x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
509
+ y = wp.zeros_like(x)
510
+
511
+ tape = wp.Tape()
512
+ with tape:
513
+ wp.launch(
514
+ test_tile_untile_preserve_type_kernel,
515
+ dim=TILE_DIM,
516
+ inputs=[x],
517
+ outputs=[y],
518
+ block_dim=TILE_DIM,
519
+ device=device,
520
+ )
521
+
522
+ y.grad = wp.ones_like(y)
523
+
524
+ tape.backward()
525
+
526
+ assert_np_equal(y.numpy(), x.numpy())
527
+ assert_np_equal(x.grad.numpy(), y.grad.numpy())
528
+
529
+ test_func_preserve_type(float)
530
+ test_func_preserve_type(wp.vec3)
531
+ test_func_preserve_type(wp.quat)
532
+ test_func_preserve_type(wp.mat33)
533
+
534
+ def test_func(type: Any):
535
+ x = wp.ones(TILE_DIM, dtype=type, requires_grad=True, device=device)
536
+ y = wp.zeros_like(x)
537
+
538
+ tape = wp.Tape()
539
+ with tape:
540
+ wp.launch(test_tile_untile_kernel, dim=TILE_DIM, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
541
+
542
+ y.grad = wp.ones_like(y)
543
+
544
+ tape.backward()
545
+
546
+ assert_np_equal(y.numpy(), x.numpy())
547
+ assert_np_equal(x.grad.numpy(), y.grad.numpy())
548
+
549
+ test_func(float)
550
+ test_func(wp.vec3)
551
+ test_func(wp.mat33)
552
+
553
+
554
+ @wp.func
555
+ def tile_sum_func(a: wp.tile(dtype=float, shape=(TILE_M, TILE_N))):
556
+ return wp.tile_sum(a) * 0.5
557
+
558
+
363
559
  @wp.kernel
364
560
  def tile_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
365
561
  # output tile index
366
562
  i = wp.tid()
367
563
 
368
564
  a = wp.tile_load(input[i], shape=(TILE_M, TILE_N))
369
- s = wp.tile_sum(a) * 0.5
565
+ s = tile_sum_func(a)
370
566
 
371
567
  wp.tile_store(output, s, offset=i)
372
568
 
@@ -728,6 +924,116 @@ def test_tile_broadcast_grad(test, device):
728
924
  assert_np_equal(a.grad.numpy(), np.ones(5) * 5.0)
729
925
 
730
926
 
927
+ @wp.kernel
928
+ def test_tile_squeeze_kernel(x: wp.array3d(dtype=float), y: wp.array(dtype=float)):
929
+ a = wp.tile_load(x, shape=(1, TILE_M, 1), offset=(0, 0, 0))
930
+ b = wp.tile_squeeze(a, axis=(2,))
931
+ c = wp.tile_squeeze(b)
932
+
933
+ wp.tile_store(y, c, offset=(0,))
934
+
935
+
936
+ def test_tile_squeeze(test, device):
937
+ x = wp.ones((1, TILE_M, 1), dtype=float, device=device, requires_grad=True)
938
+ y = wp.zeros((TILE_M,), dtype=float, device=device, requires_grad=True)
939
+
940
+ tape = wp.Tape()
941
+ with tape:
942
+ wp.launch_tiled(test_tile_squeeze_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
943
+
944
+ y.grad = wp.ones_like(y)
945
+ tape.backward()
946
+
947
+ assert_np_equal(y.numpy(), np.ones((TILE_M,), dtype=np.float32))
948
+ assert_np_equal(x.grad.numpy(), np.ones((1, TILE_M, 1), dtype=np.float32))
949
+
950
+
951
+ @wp.kernel
952
+ def test_tile_reshape_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
953
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(0, 0))
954
+ b = wp.tile_reshape(a, shape=(wp.static(TILE_M * TILE_N), 1))
955
+ c = wp.tile_reshape(b, shape=(-1, 1))
956
+
957
+ wp.tile_store(y, c, offset=(0, 0))
958
+
959
+
960
+ def test_tile_reshape(test, device):
961
+ x = wp.ones((TILE_M, TILE_N), dtype=float, device=device, requires_grad=True)
962
+ y = wp.zeros((TILE_M * TILE_N, 1), dtype=float, device=device, requires_grad=True)
963
+
964
+ tape = wp.Tape()
965
+ with tape:
966
+ wp.launch_tiled(test_tile_reshape_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
967
+
968
+ y.grad = wp.ones_like(y)
969
+ tape.backward()
970
+
971
+ assert_np_equal(y.numpy(), np.ones((TILE_M * TILE_N, 1), dtype=np.float32))
972
+ assert_np_equal(x.grad.numpy(), np.ones((TILE_M, TILE_N), dtype=np.float32))
973
+
974
+
975
+ @wp.kernel
976
+ def test_tile_astype_kernel(x: wp.array2d(dtype=Any), y: wp.array2d(dtype=wp.float32)):
977
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
978
+ b = wp.tile_astype(a, dtype=wp.float32)
979
+ wp.tile_store(y, b)
980
+
981
+
982
+ def test_tile_astype(test, device):
983
+ x_np = np.arange(TILE_M * TILE_N, dtype=np.int32).reshape((TILE_M, TILE_N))
984
+ x = wp.array(x_np, dtype=wp.int32, device=device)
985
+ y = wp.zeros((TILE_M, TILE_N), dtype=wp.float32, device=device)
986
+
987
+ wp.launch_tiled(test_tile_astype_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
988
+
989
+ assert_np_equal(y.numpy(), x_np.astype(np.float32))
990
+
991
+ x_np = np.arange(TILE_M * TILE_N, dtype=np.float64).reshape((TILE_M, TILE_N))
992
+ x = wp.array(x_np, dtype=wp.float64, requires_grad=True, device=device)
993
+ y = wp.zeros((TILE_M, TILE_N), dtype=wp.float32, requires_grad=True, device=device)
994
+
995
+ tape = wp.Tape()
996
+ with tape:
997
+ wp.launch_tiled(test_tile_astype_kernel, dim=1, inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
998
+
999
+ y.grad = wp.ones_like(y)
1000
+
1001
+ tape.backward()
1002
+
1003
+ assert_np_equal(y.numpy(), x_np.astype(np.float32))
1004
+ assert_np_equal(x.grad.numpy(), np.ones_like(x_np))
1005
+
1006
+
1007
+ @wp.func
1008
+ def test_tile_func_return_func(tile: Any):
1009
+ return tile
1010
+
1011
+
1012
+ @wp.kernel
1013
+ def test_tile_func_return_kernel(x: wp.array2d(dtype=wp.float32), y: wp.array2d(dtype=wp.float32)):
1014
+ a = wp.tile_load(x, shape=(TILE_M, 1))
1015
+ b = wp.tile_broadcast(a, shape=(TILE_M, TILE_K))
1016
+ c = test_tile_func_return_func(b)
1017
+ wp.tile_store(y, c)
1018
+
1019
+
1020
+ def test_tile_func_return(test, device):
1021
+ x = wp.ones(shape=(TILE_M, 1), dtype=wp.float32, requires_grad=True, device=device)
1022
+ y = wp.zeros(shape=(TILE_M, TILE_K), dtype=wp.float32, requires_grad=True, device=device)
1023
+
1024
+ tape = wp.Tape()
1025
+ with tape:
1026
+ wp.launch_tiled(
1027
+ test_tile_func_return_kernel, dim=[1, 1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
1028
+ )
1029
+
1030
+ y.grad = wp.ones_like(y)
1031
+ tape.backward()
1032
+
1033
+ assert_np_equal(y.numpy(), np.ones((TILE_M, TILE_K), dtype=np.float32))
1034
+ assert_np_equal(x.grad.numpy(), np.ones((TILE_M, 1), dtype=np.float32) * TILE_K)
1035
+
1036
+
731
1037
  @wp.kernel
732
1038
  def tile_len_kernel(
733
1039
  a: wp.array(dtype=float, ndim=2),
@@ -771,6 +1077,111 @@ def test_tile_print(test, device):
771
1077
  wp.synchronize()
772
1078
 
773
1079
 
1080
+ @wp.kernel
1081
+ def test_tile_add_inplace_kernel(
1082
+ input_a: wp.array2d(dtype=float),
1083
+ input_b: wp.array2d(dtype=float),
1084
+ output_reg: wp.array2d(dtype=float),
1085
+ output_shared: wp.array2d(dtype=float),
1086
+ ):
1087
+ i, j = wp.tid()
1088
+
1089
+ a_reg = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1090
+ b_reg = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1091
+ a_shared = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1092
+ b_shared = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1093
+
1094
+ a_reg += b_reg
1095
+ a_reg += b_shared
1096
+ a_shared += b_reg
1097
+ a_shared += b_shared
1098
+
1099
+ wp.tile_store(output_reg, a_reg, offset=(i * TILE_M, j * TILE_N))
1100
+ wp.tile_store(output_shared, a_shared, offset=(i * TILE_M, j * TILE_N))
1101
+
1102
+
1103
+ @wp.kernel
1104
+ def test_tile_sub_inplace_kernel(
1105
+ input_a: wp.array2d(dtype=float),
1106
+ input_b: wp.array2d(dtype=float),
1107
+ output_reg: wp.array2d(dtype=float),
1108
+ output_shared: wp.array2d(dtype=float),
1109
+ ):
1110
+ i, j = wp.tid()
1111
+
1112
+ a_reg = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1113
+ b_reg = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1114
+ a_shared = wp.tile_load(input_a, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1115
+ b_shared = wp.tile_load(input_b, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="shared")
1116
+
1117
+ a_reg -= b_reg
1118
+ a_reg -= b_shared
1119
+ a_shared -= b_reg
1120
+ a_shared -= b_shared
1121
+
1122
+ wp.tile_store(output_reg, a_reg, offset=(i * TILE_M, j * TILE_N))
1123
+ wp.tile_store(output_shared, a_shared, offset=(i * TILE_M, j * TILE_N))
1124
+
1125
+
1126
+ def test_tile_inplace(test, device):
1127
+ M = TILE_M * 2
1128
+ N = TILE_N * 2
1129
+
1130
+ a = wp.zeros((M, N), requires_grad=True, device=device)
1131
+ b = wp.ones_like(a, requires_grad=True, device=device)
1132
+ c = wp.zeros_like(a, requires_grad=True, device=device)
1133
+ d = wp.zeros_like(a, requires_grad=True, device=device)
1134
+
1135
+ with wp.Tape() as tape:
1136
+ wp.launch_tiled(
1137
+ test_tile_add_inplace_kernel,
1138
+ dim=[int(M / TILE_M), int(N / TILE_N)],
1139
+ inputs=[a, b, c, d],
1140
+ block_dim=TILE_DIM,
1141
+ device=device,
1142
+ )
1143
+
1144
+ assert_np_equal(a.numpy(), np.zeros((M, N)))
1145
+ assert_np_equal(b.numpy(), np.ones((M, N)))
1146
+ assert_np_equal(c.numpy(), 2.0 * np.ones((M, N)))
1147
+ assert_np_equal(d.numpy(), 2.0 * np.ones((M, N)))
1148
+
1149
+ c.grad = wp.ones_like(c, device=device)
1150
+ d.grad = wp.ones_like(d, device=device)
1151
+ tape.backward()
1152
+
1153
+ assert_np_equal(a.grad.numpy(), 2.0 * np.ones((M, N)))
1154
+ assert_np_equal(b.grad.numpy(), 4.0 * np.ones((M, N)))
1155
+
1156
+ tape.zero()
1157
+
1158
+ a.zero_()
1159
+ b.fill_(1.0)
1160
+ c.zero_()
1161
+ d.zero_()
1162
+
1163
+ with wp.Tape() as tape:
1164
+ wp.launch_tiled(
1165
+ test_tile_sub_inplace_kernel,
1166
+ dim=[int(M / TILE_M), int(N / TILE_N)],
1167
+ inputs=[a, b, c, d],
1168
+ block_dim=TILE_DIM,
1169
+ device=device,
1170
+ )
1171
+
1172
+ assert_np_equal(a.numpy(), np.zeros((M, N)))
1173
+ assert_np_equal(b.numpy(), np.ones((M, N)))
1174
+ assert_np_equal(c.numpy(), -2.0 * np.ones((M, N)))
1175
+ assert_np_equal(d.numpy(), -2.0 * np.ones((M, N)))
1176
+
1177
+ c.grad = wp.ones_like(c, device=device)
1178
+ d.grad = wp.ones_like(d, device=device)
1179
+ tape.backward()
1180
+
1181
+ assert_np_equal(a.grad.numpy(), 2.0 * np.ones((M, N)))
1182
+ assert_np_equal(b.grad.numpy(), -4.0 * np.ones((M, N)))
1183
+
1184
+
774
1185
  devices = get_test_devices()
775
1186
 
776
1187
 
@@ -789,6 +1200,8 @@ add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), d
789
1200
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
790
1201
  add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
791
1202
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
1203
+ add_function_test(TestTile, "test_tile_tile", test_tile_tile, devices=get_cuda_test_devices())
1204
+ add_function_test(TestTile, "test_tile_untile", test_tile_untile, devices=devices)
792
1205
  add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
793
1206
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
794
1207
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
@@ -799,8 +1212,14 @@ add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_ad
799
1212
  add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
800
1213
  add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
801
1214
  add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
1215
+ add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1216
+ add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
802
1217
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
803
1218
  add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1219
+ add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1220
+ add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
1221
+ add_function_test(TestTile, "test_tile_func_return", test_tile_func_return, devices=devices)
1222
+
804
1223
 
805
1224
  if __name__ == "__main__":
806
1225
  wp.clear_kernel_cache()