warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+ TILE_M = wp.constant(8)
25
+ TILE_N = wp.constant(4)
26
+ TILE_K = wp.constant(8)
27
+
28
+ # num threads per-tile
29
+ TILE_DIM = 64
30
+
31
+
32
+ @wp.kernel
33
+ def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
34
+ # output tile index
35
+ i = wp.tid()
36
+
37
+ a = wp.tile_load(A[i], shape=(TILE_M, TILE_K))
38
+ b = wp.tile_load(B[i], shape=(TILE_K, TILE_N))
39
+
40
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
41
+
42
+ wp.tile_matmul(a, b, sum)
43
+
44
+ wp.tile_store(C[i], sum)
45
+
46
+
47
+ def test_tile_grouped_gemm(test, device):
48
+ batch_count = 56
49
+
50
+ M = TILE_M
51
+ N = TILE_N
52
+ K = TILE_K
53
+
54
+ rng = np.random.default_rng(42)
55
+ A = rng.random((batch_count, M, K), dtype=np.float32)
56
+ B = rng.random((batch_count, K, N), dtype=np.float32)
57
+ C = A @ B
58
+
59
+ A_wp = wp.array(A, requires_grad=True, device=device)
60
+ B_wp = wp.array(B, requires_grad=True, device=device)
61
+ C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
62
+
63
+ with wp.Tape() as tape:
64
+ wp.launch_tiled(
65
+ tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
66
+ )
67
+
68
+ # TODO: 32 mismatched elements
69
+ assert_np_equal(C_wp.numpy(), C, 1e-6)
70
+
71
+
72
+ @wp.kernel
73
+ def tile_gemm(A: wp.array2d(dtype=Any), B: wp.array2d(dtype=Any), C: wp.array2d(dtype=Any)):
74
+ # output tile index
75
+ i, j = wp.tid()
76
+
77
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=A.dtype)
78
+
79
+ M = A.shape[0]
80
+ N = B.shape[1]
81
+ K = A.shape[1]
82
+
83
+ count = int(K / TILE_K)
84
+
85
+ for k in range(0, count):
86
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
87
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
88
+
89
+ # sum += a*b
90
+ wp.tile_matmul(a, b, sum)
91
+
92
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
93
+
94
+
95
+ wp.overload(
96
+ tile_gemm, {"A": wp.array2d(dtype=wp.float16), "B": wp.array2d(dtype=wp.float16), "C": wp.array2d(dtype=wp.float16)}
97
+ )
98
+ wp.overload(
99
+ tile_gemm, {"A": wp.array2d(dtype=wp.float32), "B": wp.array2d(dtype=wp.float32), "C": wp.array2d(dtype=wp.float32)}
100
+ )
101
+ wp.overload(
102
+ tile_gemm, {"A": wp.array2d(dtype=wp.float64), "B": wp.array2d(dtype=wp.float64), "C": wp.array2d(dtype=wp.float64)}
103
+ )
104
+
105
+
106
+ def test_tile_gemm(dtype):
107
+ def test(test, device):
108
+ M = TILE_M * 7
109
+ K = TILE_K * 6
110
+ N = TILE_N * 5
111
+
112
+ rng = np.random.default_rng(42)
113
+ A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
114
+ B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
115
+ C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
116
+
117
+ A_wp = wp.array(A, requires_grad=True, device=device)
118
+ B_wp = wp.array(B, requires_grad=True, device=device)
119
+ C_wp = wp.array(C, requires_grad=True, device=device)
120
+
121
+ with wp.Tape() as tape:
122
+ wp.launch_tiled(
123
+ tile_gemm,
124
+ dim=(int(M / TILE_M), int(N / TILE_N)),
125
+ inputs=[A_wp, B_wp, C_wp],
126
+ block_dim=TILE_DIM,
127
+ device=device,
128
+ )
129
+
130
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
131
+
132
+ adj_C = np.ones_like(C)
133
+
134
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
135
+
136
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
137
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
138
+
139
+ return test
140
+
141
+
142
+ @wp.kernel
143
+ def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
144
+ x = wp.tile_load(input, shape=(TILE_M, TILE_N))
145
+ y = wp.tile_transpose(x)
146
+
147
+ z = wp.tile_zeros(dtype=float, shape=(TILE_N, TILE_N))
148
+ wp.tile_matmul(y, x, z)
149
+
150
+ wp.tile_store(output, z)
151
+
152
+
153
+ def test_tile_transpose_matmul(test, device):
154
+ rng = np.random.default_rng(42)
155
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
156
+ output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
157
+
158
+ wp.launch_tiled(
159
+ test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=TILE_DIM, device=device
160
+ )
161
+
162
+ assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
163
+
164
+
165
+ class TestTileMatmul(unittest.TestCase):
166
+ pass
167
+
168
+
169
+ devices = get_test_devices()
170
+
171
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
172
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
173
+ add_function_test(TestTileMatmul, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
174
+ add_function_test(TestTileMatmul, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
175
+ add_function_test(TestTileMatmul, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
176
+
177
+ if __name__ == "__main__":
178
+ wp.clear_kernel_cache()
179
+ unittest.main(verbosity=2, failfast=True)
@@ -73,6 +73,46 @@ def test_tile_reduce_sum(test, device):
73
73
  assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
74
74
 
75
75
 
76
+ @wp.kernel
77
+ def tile_sum_to_shared_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
78
+ i, lane = wp.tid()
79
+
80
+ a = wp.tile_load(input[i], shape=TILE_DIM)
81
+ s = wp.tile_sum(a)
82
+ v = s[0] # force shared storage for s
83
+ wp.tile_store(output, s * 0.5, offset=i)
84
+
85
+
86
+ def test_tile_sum_to_shared(test, device):
87
+ batch_count = 1
88
+
89
+ rng = np.random.default_rng(42)
90
+ input = rng.random((batch_count, TILE_DIM), dtype=np.float32)
91
+
92
+ input_wp = wp.array(input, requires_grad=True, device=device, dtype=float)
93
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device, dtype=float)
94
+
95
+ with wp.Tape() as tape:
96
+ wp.launch_tiled(
97
+ tile_sum_to_shared_kernel,
98
+ dim=[batch_count],
99
+ inputs=[input_wp, output_wp],
100
+ block_dim=TILE_DIM,
101
+ device=device,
102
+ )
103
+
104
+ sum_wp = output_wp.numpy()
105
+ for i in range(batch_count):
106
+ sum_np = np.sum(input[i], axis=0) * 0.5
107
+ assert_np_equal(sum_wp[i], sum_np, tol=0.0001)
108
+
109
+ output_wp.grad.fill_(1.0)
110
+
111
+ tape.backward()
112
+
113
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5, tol=1.0e-4)
114
+
115
+
76
116
  @wp.kernel
77
117
  def tile_min_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
78
118
  # output tile index
@@ -84,6 +124,13 @@ def tile_min_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float
84
124
  wp.tile_store(output, m, offset=i)
85
125
 
86
126
 
127
+ @wp.kernel
128
+ def tile_min_kernel_edge_case(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
129
+ t = wp.tile_load(x, shape=(3, 3))
130
+ min = wp.tile_min(t)
131
+ wp.tile_store(y, min)
132
+
133
+
87
134
  def test_tile_reduce_min(test, device):
88
135
  batch_count = 56
89
136
 
@@ -105,6 +152,62 @@ def test_tile_reduce_min(test, device):
105
152
  min_np = np.min(input[i])
106
153
  test.assertAlmostEqual(min_wp[i], min_np, places=4)
107
154
 
155
+ # test edge case: tile is multiple warps in size but at least one is empty
156
+ x = wp.array(np.array([[2.0, 2.0, 3.0], [4.0, 1.0, 6.0], [7.0, 3.0, 9.0]]), dtype=float, device=device)
157
+ y = wp.zeros(1, dtype=float, device=device)
158
+
159
+ wp.launch_tiled(tile_min_kernel_edge_case, dim=1, inputs=[x, y], block_dim=64, device=device)
160
+
161
+ assert_np_equal(y.numpy(), np.array([1.0]))
162
+
163
+
164
+ @wp.kernel
165
+ def tile_argmin_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=int)):
166
+ # output tile index
167
+ i = wp.tid()
168
+
169
+ a = wp.tile_load(input[i], shape=TILE_DIM)
170
+ m = wp.tile_argmin(a)
171
+
172
+ wp.tile_store(output, m, offset=i)
173
+
174
+
175
+ @wp.kernel
176
+ def tile_argmin_kernel_edge_case(x: wp.array2d(dtype=float), y: wp.array(dtype=int)):
177
+ t = wp.tile_load(x, shape=(3, 3))
178
+ min = wp.tile_argmin(t)
179
+ wp.tile_store(y, min)
180
+
181
+
182
+ def test_tile_reduce_argmin(test, device):
183
+ batch_count = 56
184
+
185
+ N = TILE_DIM
186
+
187
+ rng = np.random.default_rng(42)
188
+ input = rng.random((batch_count, N), dtype=np.float32)
189
+
190
+ input_wp = wp.array(input, requires_grad=True, device=device)
191
+ output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
192
+
193
+ with wp.Tape() as tape:
194
+ wp.launch_tiled(
195
+ tile_argmin_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
196
+ )
197
+
198
+ argmin_wp = output_wp.numpy()
199
+ for i in range(batch_count):
200
+ argmin_np = np.argmin(input[i])
201
+ test.assertAlmostEqual(argmin_wp[i], argmin_np, places=4)
202
+
203
+ # test edge case: tile is multiple warps in size but at least one is empty
204
+ x = wp.array(np.array([[2.0, 2.0, 3.0], [4.0, 1.0, 6.0], [7.0, 3.0, 9.0]]), dtype=float, device=device)
205
+ y = wp.zeros(1, dtype=int, device=device)
206
+
207
+ wp.launch_tiled(tile_argmin_kernel_edge_case, dim=1, inputs=[x, y], block_dim=64, device=device)
208
+
209
+ assert_np_equal(y.numpy(), np.array([4]))
210
+
108
211
 
109
212
  @wp.kernel
110
213
  def tile_max_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
@@ -139,6 +242,39 @@ def test_tile_reduce_max(test, device):
139
242
  test.assertAlmostEqual(max_wp[i], max_np, places=4)
140
243
 
141
244
 
245
+ @wp.kernel
246
+ def tile_argmax_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=int)):
247
+ # output tile index
248
+ i = wp.tid()
249
+
250
+ a = wp.tile_load(input[i], shape=TILE_DIM)
251
+ m = wp.tile_argmax(a)
252
+
253
+ wp.tile_store(output, m, offset=i)
254
+
255
+
256
+ def test_tile_reduce_argmax(test, device):
257
+ batch_count = 56
258
+
259
+ N = TILE_DIM
260
+
261
+ rng = np.random.default_rng(42)
262
+ input = rng.random((batch_count, N), dtype=np.float32)
263
+
264
+ input_wp = wp.array(input, requires_grad=True, device=device)
265
+ output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
266
+
267
+ with wp.Tape() as tape:
268
+ wp.launch_tiled(
269
+ tile_argmax_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
270
+ )
271
+
272
+ argmax_wp = output_wp.numpy()
273
+ for i in range(batch_count):
274
+ argmax_np = np.argmax(input[i])
275
+ test.assertAlmostEqual(argmax_wp[i], argmax_np, places=4)
276
+
277
+
142
278
  @wp.kernel
143
279
  def tile_reduce_custom_kernel(input: wp.array2d(dtype=float), output: wp.array(dtype=float)):
144
280
  # output tile index
@@ -176,6 +312,79 @@ def test_tile_reduce_custom(test, device):
176
312
  test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
177
313
 
178
314
 
315
+ def create_tile_scan_inclusive_kernel(tile_dim: int):
316
+ @wp.kernel(module="unique")
317
+ def tile_scan_inclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
318
+ i = wp.tid()
319
+ t = wp.tile_load(input[i], shape=tile_dim)
320
+ t = wp.tile_scan_inclusive(t)
321
+ wp.tile_store(output[i], t)
322
+
323
+ return tile_scan_inclusive_kernel
324
+
325
+
326
+ def test_tile_scan_inclusive(test, device):
327
+ batch_count = 56
328
+ N = 1234
329
+
330
+ rng = np.random.default_rng(42)
331
+ input = rng.random((batch_count, N), dtype=np.float32)
332
+
333
+ input_wp = wp.array2d(input, requires_grad=True, device=device)
334
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
335
+
336
+ with wp.Tape() as tape:
337
+ wp.launch_tiled(
338
+ create_tile_scan_inclusive_kernel(N),
339
+ dim=[batch_count],
340
+ inputs=[input_wp, output_wp],
341
+ block_dim=TILE_DIM,
342
+ device=device,
343
+ )
344
+
345
+ scan_wp = output_wp.numpy()
346
+ for i in range(batch_count):
347
+ scan_np = np.cumsum(input[i])
348
+ np.testing.assert_allclose(scan_wp[i], scan_np, rtol=1e-5, atol=1e-6)
349
+
350
+
351
+ def create_tile_scan_exclusive_kernel(tile_dim: int):
352
+ @wp.kernel(module="unique")
353
+ def tile_scan_exclusive_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
354
+ i = wp.tid()
355
+ t = wp.tile_load(input[i], shape=tile_dim)
356
+ t = wp.tile_scan_exclusive(t)
357
+ wp.tile_store(output[i], t)
358
+
359
+ return tile_scan_exclusive_kernel
360
+
361
+
362
+ def test_tile_scan_exclusive(test, device):
363
+ batch_count = 56
364
+ N = 1234
365
+
366
+ rng = np.random.default_rng(42)
367
+ input = rng.random((batch_count, N), dtype=np.float32)
368
+
369
+ input_wp = wp.array2d(input, requires_grad=True, device=device)
370
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
371
+
372
+ with wp.Tape() as tape:
373
+ wp.launch_tiled(
374
+ create_tile_scan_exclusive_kernel(N),
375
+ dim=[batch_count],
376
+ inputs=[input_wp, output_wp],
377
+ block_dim=TILE_DIM,
378
+ device=device,
379
+ )
380
+
381
+ scan_wp = output_wp.numpy()
382
+ for i in range(batch_count):
383
+ scan_np = np.zeros(N, dtype=np.float32)
384
+ scan_np[1:] = np.cumsum(input[i][:-1])
385
+ np.testing.assert_allclose(scan_wp[i], scan_np, rtol=1e-5, atol=1e-6)
386
+
387
+
179
388
  @wp.struct
180
389
  class KeyValue:
181
390
  key: wp.int32
@@ -259,7 +468,7 @@ def test_tile_reduce_grouped_sum(test, device):
259
468
 
260
469
  with wp.Tape() as tape:
261
470
  wp.launch_tiled(
262
- tile_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
471
+ tile_grouped_sum_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
263
472
  )
264
473
 
265
474
  sum_wp = output_wp.numpy()
@@ -359,17 +568,17 @@ def test_untile_vector_kernel(input: wp.array(dtype=wp.vec3), output: wp.array(d
359
568
 
360
569
 
361
570
  def test_tile_untile_vector(test, device):
362
- input = wp.full(16, wp.vec3(1.0, 2.0, 3.0), requires_grad=True, device=device)
571
+ input = wp.full(TILE_DIM, wp.vec3(1.0, 2.0, 3.0), requires_grad=True, device=device)
363
572
  output = wp.zeros_like(input, device=device)
364
573
 
365
574
  with wp.Tape() as tape:
366
- wp.launch(test_untile_vector_kernel, dim=16, inputs=[input, output], block_dim=16, device=device)
575
+ wp.launch(test_untile_vector_kernel, dim=TILE_DIM, inputs=[input, output], block_dim=TILE_DIM, device=device)
367
576
 
368
577
  output.grad = wp.ones_like(output, device=device)
369
578
  tape.backward()
370
579
 
371
580
  assert_np_equal(output.numpy(), input.numpy())
372
- assert_np_equal(input.grad.numpy(), np.ones((16, 3)))
581
+ assert_np_equal(input.grad.numpy(), np.ones((TILE_DIM, 3)))
373
582
 
374
583
 
375
584
  @wp.kernel
@@ -423,6 +632,91 @@ def test_tile_arange(test, device):
423
632
  assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
424
633
 
425
634
 
635
+ @wp.kernel(module="unique")
636
+ def tile_strided_loop_kernel(arr: wp.array(dtype=float), max_val: wp.array(dtype=float)):
637
+ tid, lane = wp.tid()
638
+
639
+ num_threads = wp.block_dim()
640
+
641
+ thread_max = wp.float32(-wp.inf)
642
+
643
+ length = arr.shape[0]
644
+ upper = ((length + num_threads - 1) // num_threads) * num_threads
645
+ for el_id in range(lane, upper, num_threads):
646
+ if el_id < length:
647
+ val = arr[el_id]
648
+ else:
649
+ val = wp.float32(-wp.inf)
650
+
651
+ t = wp.tile(val)
652
+ local_max = wp.tile_max(t)
653
+
654
+ thread_max = wp.max(thread_max, local_max[0])
655
+
656
+ if lane == 0:
657
+ max_val[0] = thread_max
658
+
659
+
660
+ def test_tile_strided_loop(test, device):
661
+ N = 5 # Length of array
662
+
663
+ rng = np.random.default_rng(42)
664
+ input = rng.random(N, dtype=np.float32)
665
+
666
+ input_wp = wp.array(input, device=device)
667
+ output_wp = wp.zeros(1, dtype=wp.float32, device=device)
668
+
669
+ wp.launch_tiled(
670
+ tile_strided_loop_kernel,
671
+ dim=[1],
672
+ inputs=[input_wp, output_wp],
673
+ device=device,
674
+ block_dim=128,
675
+ )
676
+
677
+ max_wp = output_wp.numpy()
678
+ max_np = np.max(input)
679
+ test.assertAlmostEqual(max_wp[0], max_np, places=4)
680
+
681
+
682
+ @wp.kernel
683
+ def test_tile_reduce_matrix_kernel(y: wp.array(dtype=wp.mat33)):
684
+ i = wp.tid()
685
+ I = wp.identity(3, dtype=wp.float32)
686
+ m = wp.float32(i) * I
687
+
688
+ t = wp.tile(m, preserve_type=True)
689
+ sum = wp.tile_reduce(wp.add, t)
690
+
691
+ wp.tile_atomic_add(y, sum)
692
+
693
+
694
+ def test_tile_reduce_matrix(test, device):
695
+ y = wp.zeros(shape=1, dtype=wp.mat33, device=device)
696
+
697
+ wp.launch(test_tile_reduce_matrix_kernel, dim=TILE_DIM, inputs=[], outputs=[y], block_dim=TILE_DIM, device=device)
698
+
699
+ assert_np_equal(y.numpy().squeeze(), np.eye(3, dtype=np.float32) * 2016.0)
700
+
701
+
702
+ @wp.kernel
703
+ def test_tile_reduce_vector_kernel(out: wp.array(dtype=wp.vec3)):
704
+ v = wp.vec3f(1.0)
705
+ v_tile = wp.tile(v, preserve_type=True)
706
+
707
+ sum = wp.tile_reduce(wp.add, v_tile)
708
+
709
+ wp.tile_atomic_add(out, sum)
710
+
711
+
712
+ def test_tile_reduce_vector(test, device):
713
+ out = wp.zeros(1, dtype=wp.vec3, device=device)
714
+
715
+ wp.launch(kernel=test_tile_reduce_vector_kernel, dim=8, inputs=[], outputs=[out], block_dim=TILE_DIM, device=device)
716
+
717
+ assert_np_equal(out.numpy(), np.array([[8.0, 8.0, 8.0]]))
718
+
719
+
426
720
  devices = get_test_devices()
427
721
 
428
722
 
@@ -431,16 +725,24 @@ class TestTileReduce(unittest.TestCase):
431
725
 
432
726
 
433
727
  add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum, devices=devices)
728
+ add_function_test(TestTileReduce, "test_tile_sum_to_shared", test_tile_sum_to_shared, devices=devices)
434
729
  add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
435
730
  add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
731
+ add_function_test(TestTileReduce, "test_tile_reduce_argmin", test_tile_reduce_argmin, devices=devices)
732
+ add_function_test(TestTileReduce, "test_tile_reduce_argmax", test_tile_reduce_argmax, devices=devices)
436
733
  add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
437
734
  add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
438
- add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
735
+ add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_grouped_sum, devices=devices)
439
736
  add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
440
737
  add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
441
738
  add_function_test(TestTileReduce, "test_tile_arange", test_tile_arange, devices=devices)
442
739
  add_function_test(TestTileReduce, "test_tile_untile_scalar", test_tile_untile_scalar, devices=devices)
443
740
  add_function_test(TestTileReduce, "test_tile_untile_vector", test_tile_untile_vector, devices=devices)
741
+ add_function_test(TestTileReduce, "test_tile_strided_loop", test_tile_strided_loop, devices=devices)
742
+ add_function_test(TestTileReduce, "test_tile_scan_inclusive", test_tile_scan_inclusive, devices=devices)
743
+ add_function_test(TestTileReduce, "test_tile_scan_exclusive", test_tile_scan_exclusive, devices=devices)
744
+ add_function_test(TestTileReduce, "test_tile_reduce_matrix", test_tile_reduce_matrix, devices=devices)
745
+ add_function_test(TestTileReduce, "test_tile_reduce_vector", test_tile_reduce_vector, devices=devices)
444
746
 
445
747
  if __name__ == "__main__":
446
748
  wp.clear_kernel_cache()