warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,700 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
16
+
17
+ TILE_M = wp.constant(8)
18
+ TILE_N = wp.constant(4)
19
+ TILE_K = wp.constant(8)
20
+
21
+ # num threads per-tile
22
+ TILE_DIM = 64
23
+
24
+
25
+ @wp.kernel
26
+ def tile_copy_1d_kernel(A: wp.array(dtype=float), B: wp.array(dtype=float)):
27
+ # tile index
28
+ i = wp.tid()
29
+
30
+ a = wp.tile_load(A, i, n=TILE_N)
31
+ wp.tile_store(B, i, a)
32
+
33
+
34
+ def test_tile_copy_1d(test, device):
35
+ rng = np.random.default_rng(42)
36
+
37
+ N = TILE_N * 5
38
+
39
+ A = rng.random((N), dtype=np.float32)
40
+ B = rng.random((N), dtype=np.float32)
41
+
42
+ A_wp = wp.array(A, requires_grad=True, device=device)
43
+ B_wp = wp.array(B, requires_grad=True, device=device)
44
+
45
+ with wp.Tape() as tape:
46
+ wp.launch_tiled(
47
+ tile_copy_1d_kernel,
48
+ dim=[int(N / TILE_N)],
49
+ inputs=[A_wp, B_wp],
50
+ block_dim=TILE_DIM,
51
+ device=device,
52
+ )
53
+
54
+ # verify forward pass
55
+ assert_array_equal(B_wp, A_wp)
56
+
57
+ # verify backward pass
58
+ B_wp.grad = wp.ones_like(B_wp, device=device)
59
+ tape.backward()
60
+
61
+ assert_array_equal(B_wp.grad, A_wp.grad)
62
+
63
+
64
+ @wp.kernel
65
+ def tile_copy_2d_kernel(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float)):
66
+ # tile index
67
+ i, j = wp.tid()
68
+
69
+ a = wp.tile_load(A, i, j, m=TILE_M, n=TILE_N)
70
+ wp.tile_store(B, i, j, a)
71
+
72
+
73
+ def test_tile_copy_2d(test, device):
74
+ rng = np.random.default_rng(42)
75
+
76
+ M = TILE_M * 7
77
+ N = TILE_N * 5
78
+
79
+ A = rng.random((M, N), dtype=np.float32)
80
+ B = rng.random((M, N), dtype=np.float32)
81
+
82
+ A_wp = wp.array(A, requires_grad=True, device=device)
83
+ B_wp = wp.array(B, requires_grad=True, device=device)
84
+
85
+ with wp.Tape() as tape:
86
+ wp.launch_tiled(
87
+ tile_copy_2d_kernel,
88
+ dim=[int(M / TILE_M), int(N / TILE_N)],
89
+ inputs=[A_wp, B_wp],
90
+ block_dim=TILE_DIM,
91
+ device=device,
92
+ )
93
+
94
+ # verify forward pass
95
+ assert_array_equal(B_wp, A_wp)
96
+
97
+ # verify backward pass
98
+ B_wp.grad = wp.ones_like(B_wp, device=device)
99
+ tape.backward()
100
+
101
+ assert_array_equal(B_wp.grad, A_wp.grad)
102
+
103
+
104
+ @wp.func
105
+ def unary_func(x: float):
106
+ return wp.sin(x)
107
+
108
+
109
+ @wp.kernel
110
+ def tile_unary_map(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
111
+ # tile index
112
+ i, j = wp.tid()
113
+
114
+ a = wp.tile_load(input, i, j, m=TILE_M, n=TILE_N)
115
+
116
+ sa = wp.tile_map(wp.sin, a)
117
+
118
+ wp.tile_store(output, i, j, sa)
119
+
120
+
121
+ def test_tile_unary_map(test, device):
122
+ rng = np.random.default_rng(42)
123
+
124
+ M = TILE_M * 7
125
+ N = TILE_N * 5
126
+
127
+ A = rng.random((M, N), dtype=np.float32)
128
+ B = np.sin(A)
129
+
130
+ A_grad = np.cos(A)
131
+
132
+ A_wp = wp.array(A, requires_grad=True, device=device)
133
+ B_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
134
+
135
+ with wp.Tape() as tape:
136
+ wp.launch_tiled(
137
+ tile_unary_map,
138
+ dim=[int(M / TILE_M), int(N / TILE_N)],
139
+ inputs=[A_wp, B_wp],
140
+ block_dim=TILE_DIM,
141
+ device=device,
142
+ )
143
+
144
+ # verify forward pass
145
+ assert_np_equal(B_wp.numpy(), B, tol=1.0e-4)
146
+
147
+ # verify backward pass
148
+ B_wp.grad = wp.ones_like(B_wp, device=device)
149
+ tape.backward()
150
+
151
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
152
+
153
+
154
+ @wp.func
155
+ def binary_func(x: float, y: float):
156
+ return wp.sin(x) + y
157
+
158
+
159
+ @wp.kernel
160
+ def tile_binary_map(
161
+ input_a: wp.array2d(dtype=float), input_b: wp.array2d(dtype=float), output: wp.array2d(dtype=float)
162
+ ):
163
+ # tile index
164
+ i, j = wp.tid()
165
+
166
+ a = wp.tile_load(input_a, i, j, m=TILE_M, n=TILE_N)
167
+ b = wp.tile_load(input_b, i, j, m=TILE_M, n=TILE_N)
168
+
169
+ sa = wp.tile_map(binary_func, a, b)
170
+
171
+ wp.tile_store(output, i, j, sa)
172
+
173
+
174
+ def test_tile_binary_map(test, device):
175
+ rng = np.random.default_rng(42)
176
+
177
+ M = TILE_M * 7
178
+ N = TILE_N * 5
179
+
180
+ A = rng.random((M, N), dtype=np.float32)
181
+ B = rng.random((M, N), dtype=np.float32)
182
+ C = np.sin(A) + B
183
+
184
+ A_grad = np.cos(A)
185
+ B_grad = np.ones_like(B)
186
+
187
+ A_wp = wp.array(A, requires_grad=True, device=device)
188
+ B_wp = wp.array(B, requires_grad=True, device=device)
189
+ C_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
190
+
191
+ with wp.Tape() as tape:
192
+ wp.launch_tiled(
193
+ tile_binary_map,
194
+ dim=[int(M / TILE_M), int(N / TILE_N)],
195
+ inputs=[A_wp, B_wp, C_wp],
196
+ block_dim=TILE_DIM,
197
+ device=device,
198
+ )
199
+
200
+ # verify forward pass
201
+ assert_np_equal(C_wp.numpy(), C, tol=1.0e-6)
202
+
203
+ # verify backward pass
204
+ C_wp.grad = wp.ones_like(C_wp, device=device)
205
+ tape.backward()
206
+
207
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
208
+ assert_np_equal(B_wp.grad.numpy(), B_grad)
209
+
210
+
211
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
212
+ def test_tile_grouped_gemm(test, device):
213
+ @wp.kernel
214
+ def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
215
+ # output tile index
216
+ i = wp.tid()
217
+
218
+ a = wp.tile_load(A[i], 0, 0, m=TILE_M, n=TILE_K)
219
+ b = wp.tile_load(B[i], 0, 0, m=TILE_K, n=TILE_N)
220
+
221
+ sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
222
+
223
+ wp.tile_matmul(a, b, sum)
224
+
225
+ wp.tile_store(C[i], 0, 0, sum)
226
+
227
+ batch_count = 56
228
+
229
+ M = TILE_M
230
+ N = TILE_N
231
+ K = TILE_K
232
+
233
+ rng = np.random.default_rng(42)
234
+ A = rng.random((batch_count, M, K), dtype=np.float32)
235
+ B = rng.random((batch_count, K, N), dtype=np.float32)
236
+ C = A @ B
237
+
238
+ A_wp = wp.array(A, requires_grad=True, device=device)
239
+ B_wp = wp.array(B, requires_grad=True, device=device)
240
+ C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
241
+
242
+ with wp.Tape() as tape:
243
+ wp.launch_tiled(
244
+ tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
245
+ )
246
+
247
+ # TODO: 32 mismatched elements
248
+ assert_np_equal(C_wp.numpy(), C)
249
+
250
+
251
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
252
+ def test_tile_gemm(test, device):
253
+ @wp.kernel
254
+ def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
255
+ # output tile index
256
+ i, j = wp.tid()
257
+
258
+ sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
259
+
260
+ M = A.shape[0]
261
+ N = B.shape[1]
262
+ K = A.shape[1]
263
+
264
+ count = int(K / TILE_K)
265
+
266
+ for k in range(0, count):
267
+ a = wp.tile_load(A, i, k, m=TILE_M, n=TILE_K)
268
+ b = wp.tile_load(B, k, j, m=TILE_K, n=TILE_N)
269
+
270
+ # sum += a*b
271
+ wp.tile_matmul(a, b, sum)
272
+
273
+ wp.tile_store(C, i, j, sum)
274
+
275
+ M = TILE_M * 7
276
+ K = TILE_K * 6
277
+ N = TILE_N * 5
278
+
279
+ rng = np.random.default_rng(42)
280
+ A = rng.random((M, K), dtype=np.float32)
281
+ B = rng.random((K, N), dtype=np.float32)
282
+ C = np.zeros((M, N), dtype=np.float32)
283
+
284
+ A_wp = wp.array(A, requires_grad=True, device=device)
285
+ B_wp = wp.array(B, requires_grad=True, device=device)
286
+ C_wp = wp.array(C, requires_grad=True, device=device)
287
+
288
+ with wp.Tape() as tape:
289
+ wp.launch_tiled(
290
+ tile_gemm,
291
+ dim=(int(M / TILE_M), int(N / TILE_N)),
292
+ inputs=[A_wp, B_wp, C_wp],
293
+ block_dim=TILE_DIM,
294
+ device=device,
295
+ )
296
+
297
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-5)
298
+
299
+ adj_C = np.ones_like(C)
300
+
301
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
302
+
303
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-5)
304
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-5)
305
+
306
+
307
+ @wp.kernel
308
+ def tile_operators(input: wp.array3d(dtype=float), output: wp.array3d(dtype=float)):
309
+ # output tile index
310
+ i = wp.tid()
311
+
312
+ a = wp.tile_load(input[i], 0, 0, m=TILE_M, n=TILE_N)
313
+
314
+ # neg
315
+ b = -a
316
+
317
+ # right scalar multiply
318
+ c = b * 0.5
319
+
320
+ # left scalar multiply
321
+ d = 0.5 * c
322
+
323
+ # add tiles
324
+ e = a + d
325
+
326
+ wp.tile_store(output[i], 0, 0, e)
327
+
328
+
329
+ def test_tile_operators(test, device):
330
+ batch_count = 56
331
+
332
+ M = TILE_M
333
+ N = TILE_N
334
+
335
+ rng = np.random.default_rng(42)
336
+ input = rng.random((batch_count, M, N), dtype=np.float32)
337
+ output = input * 0.75
338
+
339
+ input_wp = wp.array(input, requires_grad=True, device=device)
340
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
341
+
342
+ with wp.Tape() as tape:
343
+ wp.launch_tiled(
344
+ tile_operators, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
345
+ )
346
+
347
+ assert_np_equal(output_wp.numpy(), output)
348
+
349
+ output_wp.grad.fill_(1.0)
350
+
351
+ tape.backward()
352
+
353
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.75)
354
+
355
+
356
+ @wp.kernel
357
+ def tile_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
358
+ # output tile index
359
+ i = wp.tid()
360
+
361
+ a = wp.tile_load(input[i], 0, 0, m=TILE_M, n=TILE_N)
362
+ s = wp.tile_sum(a) * 0.5
363
+
364
+ wp.tile_store(output, i, s)
365
+
366
+
367
+ def test_tile_sum(test, device):
368
+ batch_count = 56
369
+
370
+ M = TILE_M
371
+ N = TILE_N
372
+
373
+ rng = np.random.default_rng(42)
374
+ input = rng.random((batch_count, M, N), dtype=np.float32)
375
+
376
+ input_wp = wp.array(input, requires_grad=True, device=device)
377
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
378
+
379
+ with wp.Tape() as tape:
380
+ wp.launch_tiled(
381
+ tile_sum_kernel,
382
+ dim=[batch_count],
383
+ inputs=[input_wp, output_wp],
384
+ block_dim=TILE_DIM,
385
+ device=device,
386
+ )
387
+
388
+ sum_wp = output_wp.numpy()
389
+
390
+ for i in range(batch_count):
391
+ sum_np = np.sum(input[i]) * 0.5
392
+ test.assertAlmostEqual(sum_wp[i], sum_np, places=5)
393
+
394
+ output_wp.grad.fill_(1.0)
395
+
396
+ tape.backward()
397
+
398
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
399
+
400
+
401
+ @wp.kernel
402
+ def tile_extract_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
403
+ # output tile index
404
+ i = wp.tid()
405
+
406
+ t = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
407
+
408
+ # perform a scalar copy, extracting each
409
+ # tile element individually
410
+ for i in range(TILE_M):
411
+ for j in range(TILE_N):
412
+ output[i, j] = t[i, j]
413
+
414
+
415
+ def test_tile_extract(test, device):
416
+ M = TILE_M
417
+ N = TILE_N
418
+
419
+ rng = np.random.default_rng(42)
420
+ input = rng.random((M, N), dtype=np.float32)
421
+
422
+ input_wp = wp.array(input, requires_grad=True, device=device)
423
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
424
+
425
+ with wp.Tape() as tape:
426
+ wp.launch_tiled(tile_extract_kernel, dim=[1], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device)
427
+
428
+ assert_array_equal(output_wp, input_wp)
429
+
430
+ output_wp.grad.fill_(1.0)
431
+
432
+ tape.backward()
433
+
434
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input))
435
+
436
+
437
+ @wp.kernel
438
+ def test_tile_transpose_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
439
+ x = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
440
+ y = wp.tile_transpose(x)
441
+
442
+ wp.tile_store(output, 0, 0, y)
443
+
444
+
445
+ def test_tile_transpose(test, device):
446
+ rng = np.random.default_rng(42)
447
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
448
+ output = wp.zeros_like(input.transpose(), device=device)
449
+
450
+ wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
451
+
452
+ assert_np_equal(output.numpy(), input.numpy().T)
453
+
454
+
455
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
456
+ def test_tile_transpose_matmul(test, device):
457
+ @wp.kernel
458
+ def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
459
+ x = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
460
+ y = wp.tile_transpose(x)
461
+
462
+ z = wp.tile_zeros(dtype=float, m=TILE_N, n=TILE_N)
463
+ wp.tile_matmul(y, x, z)
464
+
465
+ wp.tile_store(output, 0, 0, z)
466
+
467
+ rng = np.random.default_rng(42)
468
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
469
+ output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
470
+
471
+ wp.launch_tiled(test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
472
+
473
+ assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
474
+
475
+
476
+ @wp.kernel
477
+ def test_tile_broadcast_add_kernel(
478
+ input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
479
+ ):
480
+ a = wp.tile_load(input_a, 0, 0, m=10, n=10)
481
+ b = wp.tile_load(input_b, 0, n=10)
482
+
483
+ c = wp.tile_broadcast(b, 10, 10)
484
+ d = a + c
485
+
486
+ wp.tile_store(output, 0, 0, d)
487
+
488
+
489
+ def test_tile_broadcast_add(test, device):
490
+ M = 10
491
+ N = 10
492
+
493
+ a = wp.array(np.ones((M, N), dtype=np.float32), device=device)
494
+ b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
495
+ out = wp.zeros((M, N), dtype=float, device=device)
496
+
497
+ wp.launch_tiled(test_tile_broadcast_add_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
498
+
499
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
500
+
501
+
502
+ @wp.kernel
503
+ def test_tile_broadcast_grad_kernel(a: wp.array(dtype=float), b: wp.array2d(dtype=float)):
504
+ x = wp.tile_load(a, i=0, n=5)
505
+ y = wp.tile_broadcast(x, m=5, n=5)
506
+
507
+ w = wp.tile_ones(dtype=float, m=5, n=5)
508
+ z = w + y
509
+
510
+ wp.tile_store(b, 0, 0, z)
511
+
512
+
513
+ def test_tile_broadcast_grad(test, device):
514
+ a = wp.array(np.arange(0, 5, dtype=np.float32), requires_grad=True, device=device)
515
+ b = wp.array(np.ones((5, 5), dtype=np.float32), requires_grad=True, device=device)
516
+
517
+ with wp.Tape() as tape:
518
+ wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
519
+
520
+ b.grad = wp.ones_like(b, device=device)
521
+ tape.backward()
522
+
523
+ assert_np_equal(b.numpy(), a.numpy() + np.ones((5, 5)))
524
+ assert_np_equal(a.grad.numpy(), np.ones(5) * 5.0)
525
+
526
+
527
+ TILE_VIEW_M = 16
528
+ TILE_VIEW_N = 128
529
+
530
+
531
+ @wp.kernel
532
+ def test_tile_view_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
533
+ # load whole source into local memory
534
+ a = wp.tile_load(src, 0, 0, TILE_VIEW_M, TILE_VIEW_N)
535
+
536
+ # copy the source array row by row
537
+ for i in range(TILE_VIEW_M):
538
+ # create a view on original array and store
539
+ row = a[i]
540
+ wp.tile_store(dst, i, 0, row)
541
+
542
+
543
+ def test_tile_view(test, device):
544
+ rng = np.random.default_rng(42)
545
+
546
+ a = wp.array(rng.random((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
547
+ b = wp.array(np.zeros((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
548
+
549
+ with wp.Tape() as tape:
550
+ wp.launch_tiled(test_tile_view_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
551
+
552
+ assert_np_equal(b.numpy(), a.numpy())
553
+
554
+ b.grad = wp.ones_like(b, device=device)
555
+ tape.backward()
556
+
557
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
558
+
559
+
560
+ @wp.kernel
561
+ def test_tile_assign_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
562
+ # load whole source into local memory
563
+ a = wp.tile_load(src, 0, 0, m=TILE_VIEW_M, n=TILE_VIEW_N)
564
+ b = wp.tile_zeros(dtype=float, m=TILE_VIEW_M, n=TILE_VIEW_N)
565
+
566
+ # copy the source array row by row
567
+ for i in range(TILE_VIEW_M):
568
+ # create views onto source and dest rows
569
+ row_src = a[i]
570
+ row_dst = b[i]
571
+
572
+ # copy onto dest row
573
+ wp.tile_assign(row_dst, 0, 0, row_src)
574
+
575
+ wp.tile_store(dst, 0, 0, b)
576
+
577
+
578
+ def test_tile_assign(test, device):
579
+ rng = np.random.default_rng(42)
580
+
581
+ a = wp.array(rng.random((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
582
+ b = wp.array(np.zeros((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
583
+
584
+ with wp.Tape() as tape:
585
+ wp.launch_tiled(test_tile_assign_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
586
+
587
+ assert_np_equal(b.numpy(), a.numpy())
588
+
589
+ b.grad = wp.ones_like(b, device=device)
590
+ tape.backward()
591
+
592
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
593
+
594
+
595
+ # #-----------------------------------------
596
+ # # center of mass computation
597
+
598
+ # start = offset[i]
599
+ # end = offset[i+1]
600
+
601
+ # com = wp.tile_zeros(dtype=wp.vec3, M=1)
602
+
603
+ # # load chunks of indices
604
+ # for i in range(start, end, N):
605
+
606
+ # count = wp.min(N, end-i)
607
+
608
+ # idx = wp.tile_load(indices, i, N, max_col=count)
609
+ # p = wp.tile_load(points, idx, max_col=count)
610
+
611
+ # com += wp.tile_sum(p)
612
+
613
+
614
+ # wp.tile_store(out[i], com)
615
+
616
+
617
+ # #-------------------------------------------
618
+ # # compute deformation gradient
619
+
620
+ # i =
621
+ # j =
622
+ # k =
623
+ # l =
624
+
625
+ # f = wp.tile(F) # generate a block size tile of feature vectors
626
+
627
+ # # layer 1
628
+ # w1 = wp.tile_load(weights)
629
+ # b1 = wp.tile_load(bias)
630
+
631
+ # z = wp.tile_matmul(w1, f) + b1
632
+ # z = wp.tile_map(relu, z)
633
+
634
+ # # layer 2
635
+ # w2 = wp.tile_load(weights)
636
+ # b2 = wp.tile_load(bias)
637
+
638
+ # z = wp.tile_matmul(w2, z) + b2
639
+ # z = wp.tile_map(relu, z)
640
+
641
+ # o = wp.untile(f)
642
+
643
+
644
+ # #----------------------------------
645
+ # # MLP with helper function for linear layers
646
+ # # where shape is only partially known
647
+ # # at compile time, and the other dims
648
+ # # are inferred from the input vector
649
+
650
+ # f = wp.tile(F)
651
+
652
+ # z = wp.tile_linear(weights1, bias1, f, hidden=16)
653
+ # z = wp.tile_map(relu, z)
654
+
655
+ # z = wp.tile_linear(weights2, bias2, f, hidden=8)
656
+ # z = wp.tile_map(relu, z)
657
+
658
+ # z = wp.tile_linear(weights3, bias3, f, hidden=4)
659
+ # z = wp.tile_map(relu, z)
660
+
661
+ # o = wp.untile(z)
662
+
663
+
664
+ # #----------------------------------
665
+ # # softmax
666
+
667
+ # def softmax(z: Any):
668
+
669
+ # e = wp.tile_map(wp.exp, z)
670
+ # s = wp.tile_sum(e, dim=0)
671
+
672
+ # return z/s[0]
673
+
674
+ devices = get_cuda_test_devices()
675
+
676
+
677
+ class TestTile(unittest.TestCase):
678
+ pass
679
+
680
+
681
+ add_function_test(TestTile, "test_tile_copy_1d", test_tile_copy_1d, devices=devices)
682
+ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devices)
683
+ add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
684
+ add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
685
+ add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
686
+ add_function_test(TestTile, "test_tile_gemm", test_tile_gemm, devices=devices)
687
+ add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
688
+ add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
689
+ add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
690
+ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
691
+ add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
692
+ add_function_test(TestTile, "test_tile_broadcast_add", test_tile_broadcast_add, devices=devices)
693
+ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
694
+ add_function_test(TestTile, "test_tile_view", test_tile_view, devices=devices)
695
+ add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices)
696
+
697
+
698
+ if __name__ == "__main__":
699
+ wp.clear_kernel_cache()
700
+ unittest.main(verbosity=2, failfast=True)