warp-lang 1.4.1__py3-none-manylinux2014_x86_64.whl → 1.5.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import functools
9
+ import unittest
10
+
11
+ import numpy as np
12
+
13
+ import warp as wp
14
+ from warp.tests.unittest_utils import *
15
+
16
+ wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
17
+
18
+ TILE_M = wp.constant(8)
19
+ TILE_N = wp.constant(4)
20
+ TILE_K = wp.constant(8)
21
+
22
+ # num threads per-tile
23
+ TILE_DIM = 32
24
+ FFT_SIZE_FP32 = 64
25
+ FFT_SIZE_FP64 = 64
26
+
27
+
28
+ @wp.kernel()
29
+ def tile_math_matmul_kernel(
30
+ ga: wp.array2d(dtype=wp.float16), gb: wp.array2d(dtype=wp.float32), gc: wp.array2d(dtype=wp.float64)
31
+ ):
32
+ i, j = wp.tid()
33
+ a = wp.tile_load(ga, i, j, m=TILE_M, n=TILE_K)
34
+ b = wp.tile_load(gb, i, j, m=TILE_K, n=TILE_N)
35
+ c = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float64)
36
+ wp.tile_matmul(a, b, c)
37
+ wp.tile_store(gc, i, j, c)
38
+
39
+
40
+ def test_tile_math_matmul(test, device):
41
+ rng = np.random.default_rng(42)
42
+
43
+ A = rng.random((TILE_M, TILE_K), dtype=np.float64).astype(np.float16)
44
+ B = rng.random((TILE_K, TILE_N), dtype=np.float32)
45
+ C = np.zeros((TILE_M, TILE_N), dtype=np.float64)
46
+
47
+ A_wp = wp.array(A, requires_grad=True, device=device)
48
+ B_wp = wp.array(B, requires_grad=True, device=device)
49
+ C_wp = wp.array(C, requires_grad=True, device=device)
50
+
51
+ with wp.Tape() as tape:
52
+ wp.launch_tiled(
53
+ tile_math_matmul_kernel,
54
+ dim=[1, 1],
55
+ inputs=[A_wp, B_wp, C_wp],
56
+ block_dim=TILE_DIM,
57
+ device=device,
58
+ )
59
+
60
+ # verify forward pass
61
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1e-2)
62
+
63
+ adj_C = np.ones_like(C)
64
+
65
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
66
+
67
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1e-2)
68
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, tol=1e-2)
69
+
70
+
71
+ @wp.kernel()
72
+ def tile_math_fft_kernel_vec2f(gx: wp.array2d(dtype=wp.vec2f), gy: wp.array2d(dtype=wp.vec2f)):
73
+ i, j = wp.tid()
74
+ xy = wp.tile_load(gx, i, j, m=FFT_SIZE_FP32, n=FFT_SIZE_FP32)
75
+ wp.tile_fft(xy)
76
+ wp.tile_store(gy, i, j, xy)
77
+
78
+
79
+ @wp.kernel()
80
+ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dtype=wp.vec2d)):
81
+ i, j = wp.tid()
82
+ xy = wp.tile_load(gx, i, j, m=FFT_SIZE_FP64, n=FFT_SIZE_FP64)
83
+ wp.tile_fft(xy)
84
+ wp.tile_store(gy, i, j, xy)
85
+
86
+
87
+ def test_tile_math_fft(test, device, wp_dtype):
88
+ np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
89
+ np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
90
+ kernel = {wp.vec2d: tile_math_fft_kernel_vec2d, wp.vec2f: tile_math_fft_kernel_vec2f}[wp_dtype]
91
+ fft_size = {wp.vec2d: FFT_SIZE_FP64, wp.vec2f: FFT_SIZE_FP32}[wp_dtype]
92
+
93
+ rng = np.random.default_rng(42)
94
+
95
+ # Warp doesn't really have a complex64 type,
96
+ # so we use 2 float32 to represent a single complex64 number and then convert it to vec2f
97
+
98
+ X = rng.random((fft_size, 2 * fft_size), dtype=np_real_dtype)
99
+ Y = np.zeros_like(X)
100
+
101
+ X_wp = wp.array2d(X, requires_grad=True, dtype=wp_dtype, device=device)
102
+ Y_wp = wp.array2d(Y, requires_grad=True, dtype=wp_dtype, device=device)
103
+
104
+ X_c64 = X.view(np_cplx_dtype).reshape(fft_size, fft_size)
105
+ Y_c64 = np.fft.fft(X_c64, axis=-1)
106
+
107
+ with wp.Tape() as tape:
108
+ wp.launch_tiled(kernel, dim=[1, 1], inputs=[X_wp, Y_wp], block_dim=TILE_DIM, device=device)
109
+
110
+ Y_wp_c64 = Y_wp.numpy().view(np_cplx_dtype).reshape(fft_size, fft_size)
111
+
112
+ assert_np_equal(Y_wp_c64, Y_c64, tol=1.0e-4)
113
+
114
+ # TODO: implement and test backward pass
115
+
116
+
117
+ devices = get_cuda_test_devices()
118
+
119
+
120
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
121
+ class TestTileMathDx(unittest.TestCase):
122
+ pass
123
+
124
+
125
+ # check_output=False so we can enable libmathdx's logging without failing the tests
126
+ add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
127
+ add_function_test(
128
+ TestTileMathDx,
129
+ "test_tile_math_fft_vec2f",
130
+ functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
131
+ devices=devices,
132
+ check_output=False,
133
+ )
134
+ add_function_test(
135
+ TestTileMathDx,
136
+ "test_tile_math_fft_vec2d",
137
+ functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
138
+ devices=devices,
139
+ check_output=False,
140
+ )
141
+
142
+ if __name__ == "__main__":
143
+ wp.clear_kernel_cache()
144
+ unittest.main(verbosity=2, failfast=True)
@@ -0,0 +1,383 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import os
9
+
10
+ import numpy as np
11
+
12
+ import warp as wp
13
+ import warp.examples
14
+ import warp.optim
15
+ from warp.tests.unittest_utils import *
16
+
17
+ wp.init()
18
+
19
+ # needs to be constant for the whole module
20
+ NUM_THREADS = 32
21
+
22
+
23
+ def create_layer(rng, dim_in, dim_hid, dtype=float):
24
+ w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
25
+ b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
26
+
27
+ weights = wp.array(w, dtype=dtype, requires_grad=True)
28
+ bias = wp.array(b, dtype=dtype, requires_grad=True)
29
+
30
+ return (weights, bias)
31
+
32
+
33
+ def create_array(rng, dim_in, dim_hid, dtype=float):
34
+ s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
35
+ a = wp.array(s, dtype=dtype, requires_grad=True)
36
+
37
+ return a
38
+
39
+
40
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
41
+ def test_multi_layer_nn(test, device):
42
+ import torch as tc
43
+
44
+ NUM_FREQ = wp.constant(8)
45
+
46
+ DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
47
+ DIM_HID = 32
48
+ DIM_OUT = 3
49
+
50
+ IMG_WIDTH = 256
51
+ IMG_HEIGHT = 256
52
+
53
+ BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
54
+
55
+ dtype = wp.float16
56
+
57
+ @wp.func
58
+ def relu(x: dtype):
59
+ return wp.max(x, dtype(0.0))
60
+
61
+ @wp.func
62
+ def sigmoid(x: dtype):
63
+ return dtype(1.0 / (1.0 + wp.exp(-float(x))))
64
+
65
+ @wp.kernel
66
+ def zero(loss: wp.array(dtype=float)):
67
+ loss[0] = 0.0
68
+
69
+ @wp.kernel
70
+ def compute(
71
+ batches: wp.array(dtype=int),
72
+ input: wp.array2d(dtype=dtype),
73
+ weights_0: wp.array2d(dtype=dtype),
74
+ bias_0: wp.array2d(dtype=dtype),
75
+ weights_1: wp.array2d(dtype=dtype),
76
+ bias_1: wp.array2d(dtype=dtype),
77
+ weights_2: wp.array2d(dtype=dtype),
78
+ bias_2: wp.array2d(dtype=dtype),
79
+ weights_3: wp.array2d(dtype=dtype),
80
+ bias_3: wp.array2d(dtype=dtype),
81
+ reference: wp.array2d(dtype=float),
82
+ loss: wp.array1d(dtype=float),
83
+ out: wp.array2d(dtype=float),
84
+ ):
85
+ linear = batches[wp.tid()]
86
+ row = linear / IMG_WIDTH
87
+ col = linear % IMG_WIDTH
88
+
89
+ # normalize input coordinates to [-1, 1]
90
+ x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
91
+ y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
92
+
93
+ local = wp.vector(dtype=dtype, length=DIM_IN)
94
+
95
+ # construct positional encoding
96
+ for s in range(NUM_FREQ):
97
+ scale = wp.pow(2.0, float(s)) * wp.pi
98
+
99
+ # x-coord
100
+ local[s * 4 + 0] = dtype(wp.sin(x * scale))
101
+ local[s * 4 + 1] = dtype(wp.cos(x * scale))
102
+
103
+ # y-coord
104
+ local[s * 4 + 2] = dtype(wp.sin(y * scale))
105
+ local[s * 4 + 3] = dtype(wp.cos(y * scale))
106
+
107
+ # write input back to array so that torch can use it
108
+ input[s * 4 + 0, linear] = local[s * 4 + 0]
109
+ input[s * 4 + 1, linear] = local[s * 4 + 1]
110
+ input[s * 4 + 2, linear] = local[s * 4 + 2]
111
+ input[s * 4 + 3, linear] = local[s * 4 + 3]
112
+
113
+ # tile feature vectors across the block, returns [dim(f), NUM_THREADS]
114
+ f = wp.tile(local)
115
+
116
+ # input layer
117
+ w0 = wp.tile_load(weights_0, 0, 0, m=DIM_HID, n=DIM_IN)
118
+ b0 = wp.tile_load(bias_0, 0, 0, m=DIM_HID, n=1)
119
+ z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, m=DIM_HID, n=NUM_THREADS))
120
+
121
+ # hidden layer
122
+ w1 = wp.tile_load(weights_1, 0, 0, m=DIM_HID, n=DIM_HID)
123
+ b1 = wp.tile_load(bias_1, 0, 0, m=DIM_HID, n=1)
124
+ z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, m=DIM_HID, n=NUM_THREADS))
125
+
126
+ w2 = wp.tile_load(weights_2, 0, 0, m=DIM_HID, n=DIM_HID)
127
+ b2 = wp.tile_load(bias_2, 0, 0, m=DIM_HID, n=1)
128
+ z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, m=DIM_HID, n=NUM_THREADS))
129
+
130
+ # output layer
131
+ w3 = wp.tile_load(weights_3, 0, 0, m=DIM_OUT, n=DIM_HID)
132
+ b3 = wp.tile_load(bias_3, 0, 0, m=DIM_OUT, n=1)
133
+ o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, m=DIM_OUT, n=NUM_THREADS))
134
+
135
+ # untile back to SIMT
136
+ output = wp.untile(o)
137
+
138
+ # compute error
139
+ error = wp.vec3(
140
+ float(output[0]) - reference[0, linear],
141
+ float(output[1]) - reference[1, linear],
142
+ float(output[2]) - reference[2, linear],
143
+ )
144
+
145
+ # write MSE loss
146
+ wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
147
+
148
+ # image output
149
+ for i in range(DIM_OUT):
150
+ out[i, linear] = float(output[i])
151
+
152
+ with wp.ScopedDevice(device):
153
+ torch_device = wp.device_to_torch(device)
154
+
155
+ rng = np.random.default_rng(45)
156
+
157
+ weights_0, bias_0 = create_layer(rng, DIM_IN, DIM_HID, dtype=dtype)
158
+ weights_1, bias_1 = create_layer(rng, DIM_HID, DIM_HID, dtype=dtype)
159
+ weights_2, bias_2 = create_layer(rng, DIM_HID, DIM_HID, dtype=dtype)
160
+ weights_3, bias_3 = create_layer(rng, DIM_HID, DIM_OUT, dtype=dtype)
161
+
162
+ input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
163
+ output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
164
+
165
+ reference_np = np.load(os.path.join(os.path.dirname(__file__), "assets/pixel.npy"), allow_pickle=True) / 255.0
166
+ reference = wp.array(reference_np, dtype=float)
167
+
168
+ assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
169
+
170
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
171
+
172
+ params = [weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3]
173
+
174
+ optimizer_grads = [p.grad.flatten() for p in params]
175
+ optimizer_inputs = [p.flatten() for p in params]
176
+ optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
177
+
178
+ num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
179
+ max_epochs = 30
180
+
181
+ # create randomized batch indices
182
+ batches = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
183
+ rng.shuffle(batches)
184
+ batches = wp.array(batches)
185
+
186
+ with wp.ScopedTimer("Training", active=False):
187
+ for epoch in range(max_epochs):
188
+ for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
189
+ loss.zero_()
190
+
191
+ with wp.Tape() as tape:
192
+ wp.launch(
193
+ compute,
194
+ dim=[BATCH_SIZE],
195
+ inputs=[
196
+ batches[b : b + BATCH_SIZE],
197
+ input,
198
+ weights_0,
199
+ bias_0,
200
+ weights_1,
201
+ bias_1,
202
+ weights_2,
203
+ bias_2,
204
+ weights_3,
205
+ bias_3,
206
+ reference,
207
+ loss,
208
+ output,
209
+ ],
210
+ block_dim=NUM_THREADS,
211
+ )
212
+
213
+ tape.backward(loss)
214
+
215
+ # check outputs + grads on the first few epoch only
216
+ # since this is a relatively slow operation
217
+ verify = True
218
+ if verify and epoch < 3:
219
+ indices = batches[b : b + BATCH_SIZE].numpy()
220
+
221
+ z_np = np.maximum(weights_0.numpy() @ input.numpy()[:, indices] + bias_0.numpy(), 0.0)
222
+ z_np = np.maximum(weights_1.numpy() @ z_np + bias_1.numpy(), 0.0)
223
+ z_np = np.maximum(weights_2.numpy() @ z_np + bias_2.numpy(), 0.0)
224
+ z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
225
+
226
+ # test numpy forward
227
+ assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
228
+
229
+ # torch
230
+ input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
231
+
232
+ weights_0_tc = tc.tensor(weights_0.numpy(), requires_grad=True, device=torch_device)
233
+ bias_0_tc = tc.tensor(bias_0.numpy(), requires_grad=True, device=torch_device)
234
+
235
+ weights_1_tc = tc.tensor(weights_1.numpy(), requires_grad=True, device=torch_device)
236
+ bias_1_tc = tc.tensor(bias_1.numpy(), requires_grad=True, device=torch_device)
237
+
238
+ weights_2_tc = tc.tensor(weights_2.numpy(), requires_grad=True, device=torch_device)
239
+ bias_2_tc = tc.tensor(bias_2.numpy(), requires_grad=True, device=torch_device)
240
+
241
+ weights_3_tc = tc.tensor(weights_3.numpy(), requires_grad=True, device=torch_device)
242
+ bias_3_tc = tc.tensor(bias_3.numpy(), requires_grad=True, device=torch_device)
243
+
244
+ z_tc = tc.clamp(weights_0_tc @ input_tc + bias_0_tc, min=0.0)
245
+ z_tc = tc.clamp(weights_1_tc @ z_tc + bias_1_tc, min=0.0)
246
+ z_tc = tc.clamp(weights_2_tc @ z_tc + bias_2_tc, min=0.0)
247
+ z_tc = tc.clamp(weights_3_tc @ z_tc + bias_3_tc, min=0.0)
248
+
249
+ ref_tc = tc.tensor(reference.numpy()[:, indices], requires_grad=True, device=torch_device)
250
+
251
+ l_tc = tc.mean((z_tc - ref_tc) ** 2)
252
+ l_tc.backward()
253
+
254
+ # test torch
255
+ assert_np_equal(z_tc.cpu().detach().numpy(), output.numpy()[:, indices], tol=1.0e-2)
256
+ assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
257
+ assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
258
+ assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
259
+ assert_np_equal(bias_1.grad.numpy(), bias_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
260
+ assert_np_equal(weights_2.grad.numpy(), weights_2_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
261
+ assert_np_equal(bias_2.grad.numpy(), bias_2_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
262
+ assert_np_equal(weights_3.grad.numpy(), weights_3_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
263
+ assert_np_equal(bias_3.grad.numpy(), bias_3_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
264
+
265
+ optimizer.step(optimizer_grads)
266
+ tape.zero()
267
+
268
+ # initial loss is ~0.061
269
+ test.assertLess(loss.numpy()[0], 0.002)
270
+
271
+
272
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
273
+ def test_single_layer_nn(test, device):
274
+ import torch as tc
275
+
276
+ DIM_IN = 8
277
+ DIM_HID = 32
278
+ DIM_OUT = 16
279
+
280
+ NUM_BLOCKS = 56
281
+
282
+ @wp.func
283
+ def relu(x: float):
284
+ return wp.max(x, 0.0)
285
+
286
+ @wp.kernel
287
+ def compute(
288
+ input: wp.array2d(dtype=float),
289
+ weights: wp.array2d(dtype=float),
290
+ bias: wp.array2d(dtype=float),
291
+ out: wp.array2d(dtype=float),
292
+ ):
293
+ i = wp.tid()
294
+
295
+ f = wp.tile_load(input, 0, i, m=DIM_IN, n=NUM_THREADS)
296
+
297
+ w = wp.tile_load(weights, 0, 0, DIM_OUT, DIM_IN)
298
+ b = wp.tile_load(bias, 0, 0, m=DIM_OUT, n=1)
299
+
300
+ o = wp.tile_map(relu, wp.tile_matmul(w, f) + wp.tile_broadcast(b, m=DIM_OUT, n=NUM_THREADS))
301
+
302
+ wp.tile_store(out, 0, i, o)
303
+
304
+ with wp.ScopedDevice(device):
305
+ rng = np.random.default_rng(45)
306
+
307
+ # single layer weights, bias
308
+ weights, bias = create_layer(rng, DIM_IN, DIM_OUT, dtype=float)
309
+
310
+ input = create_array(rng, NUM_THREADS * NUM_BLOCKS, DIM_IN)
311
+ output = create_array(rng, NUM_THREADS * NUM_BLOCKS, DIM_OUT)
312
+
313
+ with wp.Tape() as tape:
314
+ wp.launch_tiled(compute, dim=[NUM_BLOCKS], inputs=[input, weights, bias, output], block_dim=NUM_THREADS)
315
+
316
+ output.grad = wp.ones_like(output)
317
+ tape.backward()
318
+
319
+ # numpy
320
+ output_np = np.maximum(weights.numpy() @ input.numpy() + bias.numpy(), 0.0)
321
+
322
+ # test numpy forward
323
+ assert_np_equal(output.numpy(), output_np, tol=1.0e-2)
324
+
325
+ # torch
326
+ weights_tc = tc.from_numpy(weights.numpy()).requires_grad_(True) # use .numpy() to avoid any memory aliasing
327
+ input_tc = tc.from_numpy(input.numpy()).requires_grad_(True)
328
+ bias_tc = tc.from_numpy(bias.numpy()).requires_grad_(True)
329
+
330
+ output_tc = tc.clamp(weights_tc @ input_tc + bias_tc, min=0.0)
331
+ output_tc.backward(tc.ones_like(output_tc))
332
+
333
+ # test torch
334
+ assert_np_equal(output_tc.detach().numpy(), output.numpy(), tol=1.0e-2)
335
+ assert_np_equal(input.grad.numpy(), input_tc.grad.detach().numpy(), tol=1.0e-2)
336
+
337
+
338
+ class TestTileMLP(unittest.TestCase):
339
+ pass
340
+
341
+
342
+ test_devices = get_test_devices()
343
+
344
+ try:
345
+ import torch
346
+
347
+ # check which Warp devices work with Torch
348
+ # CUDA devices may fail if Torch was not compiled with CUDA support
349
+ torch_compatible_devices = []
350
+ torch_compatible_cuda_devices = []
351
+
352
+ for d in test_devices:
353
+ try:
354
+ t = torch.arange(10, device=wp.device_to_torch(d))
355
+ t += 1
356
+ torch_compatible_devices.append(d)
357
+ if d.is_cuda:
358
+ torch_compatible_cuda_devices.append(d)
359
+ except Exception as e:
360
+ print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
361
+
362
+ add_function_test(
363
+ TestTileMLP,
364
+ "test_single_layer_nn",
365
+ test_single_layer_nn,
366
+ check_output=False,
367
+ devices=torch_compatible_cuda_devices,
368
+ )
369
+ add_function_test(
370
+ TestTileMLP,
371
+ "test_multi_layer_nn",
372
+ test_multi_layer_nn,
373
+ check_output=False,
374
+ devices=torch_compatible_cuda_devices,
375
+ )
376
+
377
+ except Exception as e:
378
+ print(f"Skipping Torch tests due to exception: {e}")
379
+
380
+
381
+ if __name__ == "__main__":
382
+ wp.clear_kernel_cache()
383
+ unittest.main(verbosity=2, failfast=True)