warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,744 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ from warp.tests.unittest_utils import *
14
+
15
+ wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
16
+
17
+ TILE_M = wp.constant(8)
18
+ TILE_N = wp.constant(4)
19
+ TILE_K = wp.constant(8)
20
+
21
+ # num threads per-tile
22
+ TILE_DIM = 64
23
+
24
+
25
+ @wp.kernel
26
+ def tile_copy_1d_kernel(A: wp.array(dtype=float), B: wp.array(dtype=float)):
27
+ # tile index
28
+ i = wp.tid()
29
+
30
+ a = wp.tile_load(A, i, n=TILE_N)
31
+ wp.tile_store(B, i, a)
32
+
33
+
34
+ def test_tile_copy_1d(test, device):
35
+ rng = np.random.default_rng(42)
36
+
37
+ N = TILE_N * 5
38
+
39
+ A = rng.random((N), dtype=np.float32)
40
+ B = rng.random((N), dtype=np.float32)
41
+
42
+ A_wp = wp.array(A, requires_grad=True, device=device)
43
+ B_wp = wp.array(B, requires_grad=True, device=device)
44
+
45
+ with wp.Tape() as tape:
46
+ wp.launch_tiled(
47
+ tile_copy_1d_kernel,
48
+ dim=[int(N / TILE_N)],
49
+ inputs=[A_wp, B_wp],
50
+ block_dim=TILE_DIM,
51
+ device=device,
52
+ )
53
+
54
+ # verify forward pass
55
+ assert_array_equal(B_wp, A_wp)
56
+
57
+ # verify backward pass
58
+ B_wp.grad = wp.ones_like(B_wp, device=device)
59
+ tape.backward()
60
+
61
+ assert_array_equal(B_wp.grad, A_wp.grad)
62
+
63
+
64
+ @wp.kernel
65
+ def tile_copy_2d_kernel(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float)):
66
+ # tile index
67
+ i, j = wp.tid()
68
+
69
+ a = wp.tile_load(A, i, j, m=TILE_M, n=TILE_N)
70
+ wp.tile_store(B, i, j, a)
71
+
72
+
73
+ def test_tile_copy_2d(test, device):
74
+ rng = np.random.default_rng(42)
75
+
76
+ M = TILE_M * 7
77
+ N = TILE_N * 5
78
+
79
+ A = rng.random((M, N), dtype=np.float32)
80
+ B = rng.random((M, N), dtype=np.float32)
81
+
82
+ A_wp = wp.array(A, requires_grad=True, device=device)
83
+ B_wp = wp.array(B, requires_grad=True, device=device)
84
+
85
+ with wp.Tape() as tape:
86
+ wp.launch_tiled(
87
+ tile_copy_2d_kernel,
88
+ dim=[int(M / TILE_M), int(N / TILE_N)],
89
+ inputs=[A_wp, B_wp],
90
+ block_dim=TILE_DIM,
91
+ device=device,
92
+ )
93
+
94
+ # verify forward pass
95
+ assert_array_equal(B_wp, A_wp)
96
+
97
+ # verify backward pass
98
+ B_wp.grad = wp.ones_like(B_wp, device=device)
99
+ tape.backward()
100
+
101
+ assert_array_equal(B_wp.grad, A_wp.grad)
102
+
103
+
104
+ @wp.func
105
+ def unary_func(x: float):
106
+ return wp.sin(x)
107
+
108
+
109
+ @wp.kernel
110
+ def tile_unary_map(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
111
+ # tile index
112
+ i, j = wp.tid()
113
+
114
+ a = wp.tile_load(input, i, j, m=TILE_M, n=TILE_N)
115
+
116
+ sa = wp.tile_map(wp.sin, a)
117
+
118
+ wp.tile_store(output, i, j, sa)
119
+
120
+
121
+ def test_tile_unary_map(test, device):
122
+ rng = np.random.default_rng(42)
123
+
124
+ M = TILE_M * 7
125
+ N = TILE_N * 5
126
+
127
+ A = rng.random((M, N), dtype=np.float32)
128
+ B = np.sin(A)
129
+
130
+ A_grad = np.cos(A)
131
+
132
+ A_wp = wp.array(A, requires_grad=True, device=device)
133
+ B_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
134
+
135
+ with wp.Tape() as tape:
136
+ wp.launch_tiled(
137
+ tile_unary_map,
138
+ dim=[int(M / TILE_M), int(N / TILE_N)],
139
+ inputs=[A_wp, B_wp],
140
+ block_dim=TILE_DIM,
141
+ device=device,
142
+ )
143
+
144
+ # verify forward pass
145
+ assert_np_equal(B_wp.numpy(), B, tol=1.0e-4)
146
+
147
+ # verify backward pass
148
+ B_wp.grad = wp.ones_like(B_wp, device=device)
149
+ tape.backward()
150
+
151
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
152
+
153
+
154
+ @wp.func
155
+ def binary_func(x: float, y: float):
156
+ return wp.sin(x) + y
157
+
158
+
159
+ @wp.kernel
160
+ def tile_binary_map(
161
+ input_a: wp.array2d(dtype=float), input_b: wp.array2d(dtype=float), output: wp.array2d(dtype=float)
162
+ ):
163
+ # tile index
164
+ i, j = wp.tid()
165
+
166
+ a = wp.tile_load(input_a, i, j, m=TILE_M, n=TILE_N)
167
+ b = wp.tile_load(input_b, i, j, m=TILE_M, n=TILE_N)
168
+
169
+ sa = wp.tile_map(binary_func, a, b)
170
+
171
+ wp.tile_store(output, i, j, sa)
172
+
173
+
174
+ def test_tile_binary_map(test, device):
175
+ rng = np.random.default_rng(42)
176
+
177
+ M = TILE_M * 7
178
+ N = TILE_N * 5
179
+
180
+ A = rng.random((M, N), dtype=np.float32)
181
+ B = rng.random((M, N), dtype=np.float32)
182
+ C = np.sin(A) + B
183
+
184
+ A_grad = np.cos(A)
185
+ B_grad = np.ones_like(B)
186
+
187
+ A_wp = wp.array(A, requires_grad=True, device=device)
188
+ B_wp = wp.array(B, requires_grad=True, device=device)
189
+ C_wp = wp.zeros_like(A_wp, requires_grad=True, device=device)
190
+
191
+ with wp.Tape() as tape:
192
+ wp.launch_tiled(
193
+ tile_binary_map,
194
+ dim=[int(M / TILE_M), int(N / TILE_N)],
195
+ inputs=[A_wp, B_wp, C_wp],
196
+ block_dim=TILE_DIM,
197
+ device=device,
198
+ )
199
+
200
+ # verify forward pass
201
+ assert_np_equal(C_wp.numpy(), C, tol=1.0e-6)
202
+
203
+ # verify backward pass
204
+ C_wp.grad = wp.ones_like(C_wp, device=device)
205
+ tape.backward()
206
+
207
+ assert_np_equal(A_wp.grad.numpy(), A_grad, tol=1.0e-6)
208
+ assert_np_equal(B_wp.grad.numpy(), B_grad)
209
+
210
+
211
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
212
+ def test_tile_grouped_gemm(test, device):
213
+ @wp.kernel
214
+ def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
215
+ # output tile index
216
+ i = wp.tid()
217
+
218
+ a = wp.tile_load(A[i], 0, 0, m=TILE_M, n=TILE_K)
219
+ b = wp.tile_load(B[i], 0, 0, m=TILE_K, n=TILE_N)
220
+
221
+ sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
222
+
223
+ wp.tile_matmul(a, b, sum)
224
+
225
+ wp.tile_store(C[i], 0, 0, sum)
226
+
227
+ batch_count = 56
228
+
229
+ M = TILE_M
230
+ N = TILE_N
231
+ K = TILE_K
232
+
233
+ rng = np.random.default_rng(42)
234
+ A = rng.random((batch_count, M, K), dtype=np.float32)
235
+ B = rng.random((batch_count, K, N), dtype=np.float32)
236
+ C = A @ B
237
+
238
+ A_wp = wp.array(A, requires_grad=True, device=device)
239
+ B_wp = wp.array(B, requires_grad=True, device=device)
240
+ C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
241
+
242
+ with wp.Tape() as tape:
243
+ wp.launch_tiled(
244
+ tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
245
+ )
246
+
247
+ # TODO: 32 mismatched elements
248
+ assert_np_equal(C_wp.numpy(), C)
249
+
250
+
251
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
252
+ def test_tile_gemm(test, device):
253
+ @wp.kernel
254
+ def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
255
+ # output tile index
256
+ i, j = wp.tid()
257
+
258
+ sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
259
+
260
+ M = A.shape[0]
261
+ N = B.shape[1]
262
+ K = A.shape[1]
263
+
264
+ count = int(K / TILE_K)
265
+
266
+ for k in range(0, count):
267
+ a = wp.tile_load(A, i, k, m=TILE_M, n=TILE_K)
268
+ b = wp.tile_load(B, k, j, m=TILE_K, n=TILE_N)
269
+
270
+ # sum += a*b
271
+ wp.tile_matmul(a, b, sum)
272
+
273
+ wp.tile_store(C, i, j, sum)
274
+
275
+ M = TILE_M * 7
276
+ K = TILE_K * 6
277
+ N = TILE_N * 5
278
+
279
+ rng = np.random.default_rng(42)
280
+ A = rng.random((M, K), dtype=np.float32)
281
+ B = rng.random((K, N), dtype=np.float32)
282
+ C = np.zeros((M, N), dtype=np.float32)
283
+
284
+ A_wp = wp.array(A, requires_grad=True, device=device)
285
+ B_wp = wp.array(B, requires_grad=True, device=device)
286
+ C_wp = wp.array(C, requires_grad=True, device=device)
287
+
288
+ with wp.Tape() as tape:
289
+ wp.launch_tiled(
290
+ tile_gemm,
291
+ dim=(int(M / TILE_M), int(N / TILE_N)),
292
+ inputs=[A_wp, B_wp, C_wp],
293
+ block_dim=TILE_DIM,
294
+ device=device,
295
+ )
296
+
297
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-5)
298
+
299
+ adj_C = np.ones_like(C)
300
+
301
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
302
+
303
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-5)
304
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-5)
305
+
306
+
307
+ @wp.kernel
308
+ def tile_operators(input: wp.array3d(dtype=float), output: wp.array3d(dtype=float)):
309
+ # output tile index
310
+ i = wp.tid()
311
+
312
+ a = wp.tile_load(input[i], 0, 0, m=TILE_M, n=TILE_N)
313
+
314
+ # neg
315
+ b = -a
316
+
317
+ # right scalar multiply
318
+ c = b * 0.5
319
+
320
+ # left scalar multiply
321
+ d = 0.5 * c
322
+
323
+ # add tiles
324
+ e = a + d
325
+
326
+ wp.tile_store(output[i], 0, 0, e)
327
+
328
+
329
+ def test_tile_operators(test, device):
330
+ batch_count = 56
331
+
332
+ M = TILE_M
333
+ N = TILE_N
334
+
335
+ rng = np.random.default_rng(42)
336
+ input = rng.random((batch_count, M, N), dtype=np.float32)
337
+ output = input * 0.75
338
+
339
+ input_wp = wp.array(input, requires_grad=True, device=device)
340
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
341
+
342
+ with wp.Tape() as tape:
343
+ wp.launch_tiled(
344
+ tile_operators, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
345
+ )
346
+
347
+ assert_np_equal(output_wp.numpy(), output)
348
+
349
+ output_wp.grad.fill_(1.0)
350
+
351
+ tape.backward()
352
+
353
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.75)
354
+
355
+
356
+ @wp.kernel
357
+ def tile_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
358
+ # output tile index
359
+ i = wp.tid()
360
+
361
+ a = wp.tile_load(input[i], 0, 0, m=TILE_M, n=TILE_N)
362
+ s = wp.tile_sum(a) * 0.5
363
+
364
+ wp.tile_store(output, i, s)
365
+
366
+
367
+ def test_tile_sum(test, device):
368
+ batch_count = 56
369
+
370
+ M = TILE_M
371
+ N = TILE_N
372
+
373
+ rng = np.random.default_rng(42)
374
+ input = rng.random((batch_count, M, N), dtype=np.float32)
375
+
376
+ input_wp = wp.array(input, requires_grad=True, device=device)
377
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
378
+
379
+ with wp.Tape() as tape:
380
+ wp.launch_tiled(
381
+ tile_sum_kernel,
382
+ dim=[batch_count],
383
+ inputs=[input_wp, output_wp],
384
+ block_dim=TILE_DIM,
385
+ device=device,
386
+ )
387
+
388
+ sum_wp = output_wp.numpy()
389
+
390
+ for i in range(batch_count):
391
+ sum_np = np.sum(input[i]) * 0.5
392
+ test.assertAlmostEqual(sum_wp[i], sum_np, places=5)
393
+
394
+ output_wp.grad.fill_(1.0)
395
+
396
+ tape.backward()
397
+
398
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
399
+
400
+
401
+ def test_tile_sum_launch(test, device):
402
+ batch_count = 56
403
+
404
+ M = TILE_M
405
+ N = TILE_N
406
+
407
+ rng = np.random.default_rng(42)
408
+ input = rng.random((batch_count, M, N), dtype=np.float32)
409
+
410
+ input_wp = wp.array(input, requires_grad=True, device=device)
411
+ output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
412
+
413
+ cmd = wp.launch_tiled(
414
+ tile_sum_kernel,
415
+ dim=[batch_count],
416
+ inputs=[input_wp, output_wp],
417
+ block_dim=TILE_DIM,
418
+ device=device,
419
+ record_cmd=True,
420
+ )
421
+ cmd.launch()
422
+
423
+ sum_wp = output_wp.numpy()
424
+
425
+ for i in range(batch_count):
426
+ sum_np = np.sum(input[i]) * 0.5
427
+ test.assertAlmostEqual(sum_wp[i], sum_np, places=5)
428
+
429
+ output_wp.grad.fill_(1.0)
430
+
431
+ wp.launch_tiled(
432
+ tile_sum_kernel,
433
+ dim=[batch_count],
434
+ inputs=[input_wp, output_wp],
435
+ adj_inputs=[input_wp.grad, output_wp.grad],
436
+ block_dim=TILE_DIM,
437
+ device=device,
438
+ adjoint=True,
439
+ )
440
+
441
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input) * 0.5)
442
+
443
+
444
+ @wp.kernel
445
+ def tile_extract_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
446
+ # output tile index
447
+ i = wp.tid()
448
+
449
+ t = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
450
+
451
+ # perform a scalar copy, extracting each
452
+ # tile element individually
453
+ for i in range(TILE_M):
454
+ for j in range(TILE_N):
455
+ output[i, j] = t[i, j]
456
+
457
+
458
+ def test_tile_extract(test, device):
459
+ M = TILE_M
460
+ N = TILE_N
461
+
462
+ rng = np.random.default_rng(42)
463
+ input = rng.random((M, N), dtype=np.float32)
464
+
465
+ input_wp = wp.array(input, requires_grad=True, device=device)
466
+ output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
467
+
468
+ with wp.Tape() as tape:
469
+ wp.launch_tiled(tile_extract_kernel, dim=[1], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device)
470
+
471
+ assert_array_equal(output_wp, input_wp)
472
+
473
+ output_wp.grad.fill_(1.0)
474
+
475
+ tape.backward()
476
+
477
+ assert_np_equal(input_wp.grad.numpy(), np.ones_like(input))
478
+
479
+
480
+ @wp.kernel
481
+ def test_tile_transpose_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
482
+ x = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
483
+ y = wp.tile_transpose(x)
484
+
485
+ wp.tile_store(output, 0, 0, y)
486
+
487
+
488
+ def test_tile_transpose(test, device):
489
+ rng = np.random.default_rng(42)
490
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
491
+ output = wp.zeros_like(input.transpose(), device=device)
492
+
493
+ wp.launch_tiled(test_tile_transpose_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
494
+
495
+ assert_np_equal(output.numpy(), input.numpy().T)
496
+
497
+
498
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
499
+ def test_tile_transpose_matmul(test, device):
500
+ @wp.kernel
501
+ def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
502
+ x = wp.tile_load(input, 0, 0, m=TILE_M, n=TILE_N)
503
+ y = wp.tile_transpose(x)
504
+
505
+ z = wp.tile_zeros(dtype=float, m=TILE_N, n=TILE_N)
506
+ wp.tile_matmul(y, x, z)
507
+
508
+ wp.tile_store(output, 0, 0, z)
509
+
510
+ rng = np.random.default_rng(42)
511
+ input = wp.array(rng.random((TILE_M, TILE_N), dtype=np.float32), device=device)
512
+ output = wp.zeros((TILE_N, TILE_N), dtype=float, device=device)
513
+
514
+ wp.launch_tiled(test_tile_transpose_matmul_kernel, dim=[1], inputs=[input, output], block_dim=32, device=device)
515
+
516
+ assert_np_equal(output.numpy(), input.numpy().T @ input.numpy())
517
+
518
+
519
+ @wp.kernel
520
+ def test_tile_broadcast_add_kernel(
521
+ input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
522
+ ):
523
+ a = wp.tile_load(input_a, 0, 0, m=10, n=10)
524
+ b = wp.tile_load(input_b, 0, n=10)
525
+
526
+ c = wp.tile_broadcast(b, 10, 10)
527
+ d = a + c
528
+
529
+ wp.tile_store(output, 0, 0, d)
530
+
531
+
532
+ def test_tile_broadcast_add(test, device):
533
+ M = 10
534
+ N = 10
535
+
536
+ a = wp.array(np.ones((M, N), dtype=np.float32), device=device)
537
+ b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
538
+ out = wp.zeros((M, N), dtype=float, device=device)
539
+
540
+ wp.launch_tiled(test_tile_broadcast_add_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
541
+
542
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
543
+
544
+
545
+ @wp.kernel
546
+ def test_tile_broadcast_grad_kernel(a: wp.array(dtype=float), b: wp.array2d(dtype=float)):
547
+ x = wp.tile_load(a, i=0, n=5)
548
+ y = wp.tile_broadcast(x, m=5, n=5)
549
+
550
+ w = wp.tile_ones(dtype=float, m=5, n=5)
551
+ z = w + y
552
+
553
+ wp.tile_store(b, 0, 0, z)
554
+
555
+
556
+ def test_tile_broadcast_grad(test, device):
557
+ a = wp.array(np.arange(0, 5, dtype=np.float32), requires_grad=True, device=device)
558
+ b = wp.array(np.ones((5, 5), dtype=np.float32), requires_grad=True, device=device)
559
+
560
+ with wp.Tape() as tape:
561
+ wp.launch_tiled(test_tile_broadcast_grad_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
562
+
563
+ b.grad = wp.ones_like(b, device=device)
564
+ tape.backward()
565
+
566
+ assert_np_equal(b.numpy(), a.numpy() + np.ones((5, 5)))
567
+ assert_np_equal(a.grad.numpy(), np.ones(5) * 5.0)
568
+
569
+
570
+ TILE_VIEW_M = 16
571
+ TILE_VIEW_N = 128
572
+
573
+
574
+ @wp.kernel
575
+ def test_tile_view_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
576
+ # load whole source into local memory
577
+ a = wp.tile_load(src, 0, 0, TILE_VIEW_M, TILE_VIEW_N)
578
+
579
+ # copy the source array row by row
580
+ for i in range(TILE_VIEW_M):
581
+ # create a view on original array and store
582
+ row = a[i]
583
+ wp.tile_store(dst, i, 0, row)
584
+
585
+
586
+ def test_tile_view(test, device):
587
+ rng = np.random.default_rng(42)
588
+
589
+ a = wp.array(rng.random((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
590
+ b = wp.array(np.zeros((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
591
+
592
+ with wp.Tape() as tape:
593
+ wp.launch_tiled(test_tile_view_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
594
+
595
+ assert_np_equal(b.numpy(), a.numpy())
596
+
597
+ b.grad = wp.ones_like(b, device=device)
598
+ tape.backward()
599
+
600
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
601
+
602
+
603
+ @wp.kernel
604
+ def test_tile_assign_kernel(src: wp.array2d(dtype=float), dst: wp.array2d(dtype=float)):
605
+ # load whole source into local memory
606
+ a = wp.tile_load(src, 0, 0, m=TILE_VIEW_M, n=TILE_VIEW_N)
607
+ b = wp.tile_zeros(dtype=float, m=TILE_VIEW_M, n=TILE_VIEW_N)
608
+
609
+ # copy the source array row by row
610
+ for i in range(TILE_VIEW_M):
611
+ # create views onto source and dest rows
612
+ row_src = a[i]
613
+ row_dst = b[i]
614
+
615
+ # copy onto dest row
616
+ wp.tile_assign(row_dst, 0, 0, row_src)
617
+
618
+ wp.tile_store(dst, 0, 0, b)
619
+
620
+
621
+ def test_tile_assign(test, device):
622
+ rng = np.random.default_rng(42)
623
+
624
+ a = wp.array(rng.random((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
625
+ b = wp.array(np.zeros((TILE_VIEW_M, TILE_VIEW_N), dtype=np.float32), requires_grad=True, device=device)
626
+
627
+ with wp.Tape() as tape:
628
+ wp.launch_tiled(test_tile_assign_kernel, dim=[1], inputs=[a, b], block_dim=32, device=device)
629
+
630
+ assert_np_equal(b.numpy(), a.numpy())
631
+
632
+ b.grad = wp.ones_like(b, device=device)
633
+ tape.backward()
634
+
635
+ assert_np_equal(a.grad.numpy(), np.ones_like(a.numpy()))
636
+
637
+
638
+ # #-----------------------------------------
639
+ # # center of mass computation
640
+
641
+ # start = offset[i]
642
+ # end = offset[i+1]
643
+
644
+ # com = wp.tile_zeros(dtype=wp.vec3, M=1)
645
+
646
+ # # load chunks of indices
647
+ # for i in range(start, end, N):
648
+
649
+ # count = wp.min(N, end-i)
650
+
651
+ # idx = wp.tile_load(indices, i, N, max_col=count)
652
+ # p = wp.tile_load(points, idx, max_col=count)
653
+
654
+ # com += wp.tile_sum(p)
655
+
656
+
657
+ # wp.tile_store(out[i], com)
658
+
659
+
660
+ # #-------------------------------------------
661
+ # # compute deformation gradient
662
+
663
+ # i =
664
+ # j =
665
+ # k =
666
+ # l =
667
+
668
+ # f = wp.tile(F) # generate a block size tile of feature vectors
669
+
670
+ # # layer 1
671
+ # w1 = wp.tile_load(weights)
672
+ # b1 = wp.tile_load(bias)
673
+
674
+ # z = wp.tile_matmul(w1, f) + b1
675
+ # z = wp.tile_map(relu, z)
676
+
677
+ # # layer 2
678
+ # w2 = wp.tile_load(weights)
679
+ # b2 = wp.tile_load(bias)
680
+
681
+ # z = wp.tile_matmul(w2, z) + b2
682
+ # z = wp.tile_map(relu, z)
683
+
684
+ # o = wp.untile(f)
685
+
686
+
687
+ # #----------------------------------
688
+ # # MLP with helper function for linear layers
689
+ # # where shape is only partially known
690
+ # # at compile time, and the other dims
691
+ # # are inferred from the input vector
692
+
693
+ # f = wp.tile(F)
694
+
695
+ # z = wp.tile_linear(weights1, bias1, f, hidden=16)
696
+ # z = wp.tile_map(relu, z)
697
+
698
+ # z = wp.tile_linear(weights2, bias2, f, hidden=8)
699
+ # z = wp.tile_map(relu, z)
700
+
701
+ # z = wp.tile_linear(weights3, bias3, f, hidden=4)
702
+ # z = wp.tile_map(relu, z)
703
+
704
+ # o = wp.untile(z)
705
+
706
+
707
+ # #----------------------------------
708
+ # # softmax
709
+
710
+ # def softmax(z: Any):
711
+
712
+ # e = wp.tile_map(wp.exp, z)
713
+ # s = wp.tile_sum(e, dim=0)
714
+
715
+ # return z/s[0]
716
+
717
+ devices = get_cuda_test_devices()
718
+
719
+
720
+ class TestTile(unittest.TestCase):
721
+ pass
722
+
723
+
724
+ add_function_test(TestTile, "test_tile_copy_1d", test_tile_copy_1d, devices=devices)
725
+ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devices)
726
+ add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
727
+ add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
728
+ add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
729
+ add_function_test(TestTile, "test_tile_gemm", test_tile_gemm, devices=devices)
730
+ add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
731
+ add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
732
+ add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
733
+ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
734
+ add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
735
+ add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
736
+ add_function_test(TestTile, "test_tile_broadcast_add", test_tile_broadcast_add, devices=devices)
737
+ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
738
+ add_function_test(TestTile, "test_tile_view", test_tile_view, devices=devices)
739
+ add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices)
740
+
741
+
742
+ if __name__ == "__main__":
743
+ wp.clear_kernel_cache()
744
+ unittest.main(verbosity=2, failfast=True)