warp-lang 1.4.1__py3-none-manylinux2014_x86_64.whl → 1.5.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
import unittest
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import warp as wp
|
|
14
|
+
from warp.tests.unittest_utils import *
|
|
15
|
+
|
|
16
|
+
wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
|
|
17
|
+
|
|
18
|
+
TILE_M = wp.constant(8)
|
|
19
|
+
TILE_N = wp.constant(4)
|
|
20
|
+
TILE_K = wp.constant(8)
|
|
21
|
+
|
|
22
|
+
# num threads per-tile
|
|
23
|
+
TILE_DIM = 32
|
|
24
|
+
FFT_SIZE_FP32 = 64
|
|
25
|
+
FFT_SIZE_FP64 = 64
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@wp.kernel()
|
|
29
|
+
def tile_math_matmul_kernel(
|
|
30
|
+
ga: wp.array2d(dtype=wp.float16), gb: wp.array2d(dtype=wp.float32), gc: wp.array2d(dtype=wp.float64)
|
|
31
|
+
):
|
|
32
|
+
i, j = wp.tid()
|
|
33
|
+
a = wp.tile_load(ga, i, j, m=TILE_M, n=TILE_K)
|
|
34
|
+
b = wp.tile_load(gb, i, j, m=TILE_K, n=TILE_N)
|
|
35
|
+
c = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float64)
|
|
36
|
+
wp.tile_matmul(a, b, c)
|
|
37
|
+
wp.tile_store(gc, i, j, c)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_tile_math_matmul(test, device):
|
|
41
|
+
rng = np.random.default_rng(42)
|
|
42
|
+
|
|
43
|
+
A = rng.random((TILE_M, TILE_K), dtype=np.float64).astype(np.float16)
|
|
44
|
+
B = rng.random((TILE_K, TILE_N), dtype=np.float32)
|
|
45
|
+
C = np.zeros((TILE_M, TILE_N), dtype=np.float64)
|
|
46
|
+
|
|
47
|
+
A_wp = wp.array(A, requires_grad=True, device=device)
|
|
48
|
+
B_wp = wp.array(B, requires_grad=True, device=device)
|
|
49
|
+
C_wp = wp.array(C, requires_grad=True, device=device)
|
|
50
|
+
|
|
51
|
+
with wp.Tape() as tape:
|
|
52
|
+
wp.launch_tiled(
|
|
53
|
+
tile_math_matmul_kernel,
|
|
54
|
+
dim=[1, 1],
|
|
55
|
+
inputs=[A_wp, B_wp, C_wp],
|
|
56
|
+
block_dim=TILE_DIM,
|
|
57
|
+
device=device,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# verify forward pass
|
|
61
|
+
assert_np_equal(C_wp.numpy(), A @ B, tol=1e-2)
|
|
62
|
+
|
|
63
|
+
adj_C = np.ones_like(C)
|
|
64
|
+
|
|
65
|
+
tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
|
|
66
|
+
|
|
67
|
+
assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1e-2)
|
|
68
|
+
assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, tol=1e-2)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@wp.kernel()
|
|
72
|
+
def tile_math_fft_kernel_vec2f(gx: wp.array2d(dtype=wp.vec2f), gy: wp.array2d(dtype=wp.vec2f)):
|
|
73
|
+
i, j = wp.tid()
|
|
74
|
+
xy = wp.tile_load(gx, i, j, m=FFT_SIZE_FP32, n=FFT_SIZE_FP32)
|
|
75
|
+
wp.tile_fft(xy)
|
|
76
|
+
wp.tile_store(gy, i, j, xy)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@wp.kernel()
|
|
80
|
+
def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dtype=wp.vec2d)):
|
|
81
|
+
i, j = wp.tid()
|
|
82
|
+
xy = wp.tile_load(gx, i, j, m=FFT_SIZE_FP64, n=FFT_SIZE_FP64)
|
|
83
|
+
wp.tile_fft(xy)
|
|
84
|
+
wp.tile_store(gy, i, j, xy)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def test_tile_math_fft(test, device, wp_dtype):
|
|
88
|
+
np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
|
|
89
|
+
np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
|
|
90
|
+
kernel = {wp.vec2d: tile_math_fft_kernel_vec2d, wp.vec2f: tile_math_fft_kernel_vec2f}[wp_dtype]
|
|
91
|
+
fft_size = {wp.vec2d: FFT_SIZE_FP64, wp.vec2f: FFT_SIZE_FP32}[wp_dtype]
|
|
92
|
+
|
|
93
|
+
rng = np.random.default_rng(42)
|
|
94
|
+
|
|
95
|
+
# Warp doesn't really have a complex64 type,
|
|
96
|
+
# so we use 2 float32 to represent a single complex64 number and then convert it to vec2f
|
|
97
|
+
|
|
98
|
+
X = rng.random((fft_size, 2 * fft_size), dtype=np_real_dtype)
|
|
99
|
+
Y = np.zeros_like(X)
|
|
100
|
+
|
|
101
|
+
X_wp = wp.array2d(X, requires_grad=True, dtype=wp_dtype, device=device)
|
|
102
|
+
Y_wp = wp.array2d(Y, requires_grad=True, dtype=wp_dtype, device=device)
|
|
103
|
+
|
|
104
|
+
X_c64 = X.view(np_cplx_dtype).reshape(fft_size, fft_size)
|
|
105
|
+
Y_c64 = np.fft.fft(X_c64, axis=-1)
|
|
106
|
+
|
|
107
|
+
with wp.Tape() as tape:
|
|
108
|
+
wp.launch_tiled(kernel, dim=[1, 1], inputs=[X_wp, Y_wp], block_dim=TILE_DIM, device=device)
|
|
109
|
+
|
|
110
|
+
Y_wp_c64 = Y_wp.numpy().view(np_cplx_dtype).reshape(fft_size, fft_size)
|
|
111
|
+
|
|
112
|
+
assert_np_equal(Y_wp_c64, Y_c64, tol=1.0e-4)
|
|
113
|
+
|
|
114
|
+
# TODO: implement and test backward pass
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
devices = get_cuda_test_devices()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
121
|
+
class TestTileMathDx(unittest.TestCase):
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# check_output=False so we can enable libmathdx's logging without failing the tests
|
|
126
|
+
add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
|
|
127
|
+
add_function_test(
|
|
128
|
+
TestTileMathDx,
|
|
129
|
+
"test_tile_math_fft_vec2f",
|
|
130
|
+
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
|
|
131
|
+
devices=devices,
|
|
132
|
+
check_output=False,
|
|
133
|
+
)
|
|
134
|
+
add_function_test(
|
|
135
|
+
TestTileMathDx,
|
|
136
|
+
"test_tile_math_fft_vec2d",
|
|
137
|
+
functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
|
|
138
|
+
devices=devices,
|
|
139
|
+
check_output=False,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if __name__ == "__main__":
|
|
143
|
+
wp.clear_kernel_cache()
|
|
144
|
+
unittest.main(verbosity=2, failfast=True)
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import warp as wp
|
|
13
|
+
import warp.examples
|
|
14
|
+
import warp.optim
|
|
15
|
+
from warp.tests.unittest_utils import *
|
|
16
|
+
|
|
17
|
+
wp.init()
|
|
18
|
+
|
|
19
|
+
# needs to be constant for the whole module
|
|
20
|
+
NUM_THREADS = 32
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def create_layer(rng, dim_in, dim_hid, dtype=float):
|
|
24
|
+
w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
25
|
+
b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
|
|
26
|
+
|
|
27
|
+
weights = wp.array(w, dtype=dtype, requires_grad=True)
|
|
28
|
+
bias = wp.array(b, dtype=dtype, requires_grad=True)
|
|
29
|
+
|
|
30
|
+
return (weights, bias)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def create_array(rng, dim_in, dim_hid, dtype=float):
|
|
34
|
+
s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
35
|
+
a = wp.array(s, dtype=dtype, requires_grad=True)
|
|
36
|
+
|
|
37
|
+
return a
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
41
|
+
def test_multi_layer_nn(test, device):
|
|
42
|
+
import torch as tc
|
|
43
|
+
|
|
44
|
+
NUM_FREQ = wp.constant(8)
|
|
45
|
+
|
|
46
|
+
DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
|
|
47
|
+
DIM_HID = 32
|
|
48
|
+
DIM_OUT = 3
|
|
49
|
+
|
|
50
|
+
IMG_WIDTH = 256
|
|
51
|
+
IMG_HEIGHT = 256
|
|
52
|
+
|
|
53
|
+
BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
|
|
54
|
+
|
|
55
|
+
dtype = wp.float16
|
|
56
|
+
|
|
57
|
+
@wp.func
|
|
58
|
+
def relu(x: dtype):
|
|
59
|
+
return wp.max(x, dtype(0.0))
|
|
60
|
+
|
|
61
|
+
@wp.func
|
|
62
|
+
def sigmoid(x: dtype):
|
|
63
|
+
return dtype(1.0 / (1.0 + wp.exp(-float(x))))
|
|
64
|
+
|
|
65
|
+
@wp.kernel
|
|
66
|
+
def zero(loss: wp.array(dtype=float)):
|
|
67
|
+
loss[0] = 0.0
|
|
68
|
+
|
|
69
|
+
@wp.kernel
|
|
70
|
+
def compute(
|
|
71
|
+
batches: wp.array(dtype=int),
|
|
72
|
+
input: wp.array2d(dtype=dtype),
|
|
73
|
+
weights_0: wp.array2d(dtype=dtype),
|
|
74
|
+
bias_0: wp.array2d(dtype=dtype),
|
|
75
|
+
weights_1: wp.array2d(dtype=dtype),
|
|
76
|
+
bias_1: wp.array2d(dtype=dtype),
|
|
77
|
+
weights_2: wp.array2d(dtype=dtype),
|
|
78
|
+
bias_2: wp.array2d(dtype=dtype),
|
|
79
|
+
weights_3: wp.array2d(dtype=dtype),
|
|
80
|
+
bias_3: wp.array2d(dtype=dtype),
|
|
81
|
+
reference: wp.array2d(dtype=float),
|
|
82
|
+
loss: wp.array1d(dtype=float),
|
|
83
|
+
out: wp.array2d(dtype=float),
|
|
84
|
+
):
|
|
85
|
+
linear = batches[wp.tid()]
|
|
86
|
+
row = linear / IMG_WIDTH
|
|
87
|
+
col = linear % IMG_WIDTH
|
|
88
|
+
|
|
89
|
+
# normalize input coordinates to [-1, 1]
|
|
90
|
+
x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
|
|
91
|
+
y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
|
|
92
|
+
|
|
93
|
+
local = wp.vector(dtype=dtype, length=DIM_IN)
|
|
94
|
+
|
|
95
|
+
# construct positional encoding
|
|
96
|
+
for s in range(NUM_FREQ):
|
|
97
|
+
scale = wp.pow(2.0, float(s)) * wp.pi
|
|
98
|
+
|
|
99
|
+
# x-coord
|
|
100
|
+
local[s * 4 + 0] = dtype(wp.sin(x * scale))
|
|
101
|
+
local[s * 4 + 1] = dtype(wp.cos(x * scale))
|
|
102
|
+
|
|
103
|
+
# y-coord
|
|
104
|
+
local[s * 4 + 2] = dtype(wp.sin(y * scale))
|
|
105
|
+
local[s * 4 + 3] = dtype(wp.cos(y * scale))
|
|
106
|
+
|
|
107
|
+
# write input back to array so that torch can use it
|
|
108
|
+
input[s * 4 + 0, linear] = local[s * 4 + 0]
|
|
109
|
+
input[s * 4 + 1, linear] = local[s * 4 + 1]
|
|
110
|
+
input[s * 4 + 2, linear] = local[s * 4 + 2]
|
|
111
|
+
input[s * 4 + 3, linear] = local[s * 4 + 3]
|
|
112
|
+
|
|
113
|
+
# tile feature vectors across the block, returns [dim(f), NUM_THREADS]
|
|
114
|
+
f = wp.tile(local)
|
|
115
|
+
|
|
116
|
+
# input layer
|
|
117
|
+
w0 = wp.tile_load(weights_0, 0, 0, m=DIM_HID, n=DIM_IN)
|
|
118
|
+
b0 = wp.tile_load(bias_0, 0, 0, m=DIM_HID, n=1)
|
|
119
|
+
z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, m=DIM_HID, n=NUM_THREADS))
|
|
120
|
+
|
|
121
|
+
# hidden layer
|
|
122
|
+
w1 = wp.tile_load(weights_1, 0, 0, m=DIM_HID, n=DIM_HID)
|
|
123
|
+
b1 = wp.tile_load(bias_1, 0, 0, m=DIM_HID, n=1)
|
|
124
|
+
z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, m=DIM_HID, n=NUM_THREADS))
|
|
125
|
+
|
|
126
|
+
w2 = wp.tile_load(weights_2, 0, 0, m=DIM_HID, n=DIM_HID)
|
|
127
|
+
b2 = wp.tile_load(bias_2, 0, 0, m=DIM_HID, n=1)
|
|
128
|
+
z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, m=DIM_HID, n=NUM_THREADS))
|
|
129
|
+
|
|
130
|
+
# output layer
|
|
131
|
+
w3 = wp.tile_load(weights_3, 0, 0, m=DIM_OUT, n=DIM_HID)
|
|
132
|
+
b3 = wp.tile_load(bias_3, 0, 0, m=DIM_OUT, n=1)
|
|
133
|
+
o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, m=DIM_OUT, n=NUM_THREADS))
|
|
134
|
+
|
|
135
|
+
# untile back to SIMT
|
|
136
|
+
output = wp.untile(o)
|
|
137
|
+
|
|
138
|
+
# compute error
|
|
139
|
+
error = wp.vec3(
|
|
140
|
+
float(output[0]) - reference[0, linear],
|
|
141
|
+
float(output[1]) - reference[1, linear],
|
|
142
|
+
float(output[2]) - reference[2, linear],
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# write MSE loss
|
|
146
|
+
wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
|
|
147
|
+
|
|
148
|
+
# image output
|
|
149
|
+
for i in range(DIM_OUT):
|
|
150
|
+
out[i, linear] = float(output[i])
|
|
151
|
+
|
|
152
|
+
with wp.ScopedDevice(device):
|
|
153
|
+
torch_device = wp.device_to_torch(device)
|
|
154
|
+
|
|
155
|
+
rng = np.random.default_rng(45)
|
|
156
|
+
|
|
157
|
+
weights_0, bias_0 = create_layer(rng, DIM_IN, DIM_HID, dtype=dtype)
|
|
158
|
+
weights_1, bias_1 = create_layer(rng, DIM_HID, DIM_HID, dtype=dtype)
|
|
159
|
+
weights_2, bias_2 = create_layer(rng, DIM_HID, DIM_HID, dtype=dtype)
|
|
160
|
+
weights_3, bias_3 = create_layer(rng, DIM_HID, DIM_OUT, dtype=dtype)
|
|
161
|
+
|
|
162
|
+
input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
|
|
163
|
+
output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
|
|
164
|
+
|
|
165
|
+
reference_np = np.load(os.path.join(os.path.dirname(__file__), "assets/pixel.npy"), allow_pickle=True) / 255.0
|
|
166
|
+
reference = wp.array(reference_np, dtype=float)
|
|
167
|
+
|
|
168
|
+
assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
|
|
169
|
+
|
|
170
|
+
loss = wp.zeros(1, dtype=float, requires_grad=True)
|
|
171
|
+
|
|
172
|
+
params = [weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3]
|
|
173
|
+
|
|
174
|
+
optimizer_grads = [p.grad.flatten() for p in params]
|
|
175
|
+
optimizer_inputs = [p.flatten() for p in params]
|
|
176
|
+
optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
|
|
177
|
+
|
|
178
|
+
num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
|
|
179
|
+
max_epochs = 30
|
|
180
|
+
|
|
181
|
+
# create randomized batch indices
|
|
182
|
+
batches = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
|
|
183
|
+
rng.shuffle(batches)
|
|
184
|
+
batches = wp.array(batches)
|
|
185
|
+
|
|
186
|
+
with wp.ScopedTimer("Training", active=False):
|
|
187
|
+
for epoch in range(max_epochs):
|
|
188
|
+
for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
|
|
189
|
+
loss.zero_()
|
|
190
|
+
|
|
191
|
+
with wp.Tape() as tape:
|
|
192
|
+
wp.launch(
|
|
193
|
+
compute,
|
|
194
|
+
dim=[BATCH_SIZE],
|
|
195
|
+
inputs=[
|
|
196
|
+
batches[b : b + BATCH_SIZE],
|
|
197
|
+
input,
|
|
198
|
+
weights_0,
|
|
199
|
+
bias_0,
|
|
200
|
+
weights_1,
|
|
201
|
+
bias_1,
|
|
202
|
+
weights_2,
|
|
203
|
+
bias_2,
|
|
204
|
+
weights_3,
|
|
205
|
+
bias_3,
|
|
206
|
+
reference,
|
|
207
|
+
loss,
|
|
208
|
+
output,
|
|
209
|
+
],
|
|
210
|
+
block_dim=NUM_THREADS,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
tape.backward(loss)
|
|
214
|
+
|
|
215
|
+
# check outputs + grads on the first few epoch only
|
|
216
|
+
# since this is a relatively slow operation
|
|
217
|
+
verify = True
|
|
218
|
+
if verify and epoch < 3:
|
|
219
|
+
indices = batches[b : b + BATCH_SIZE].numpy()
|
|
220
|
+
|
|
221
|
+
z_np = np.maximum(weights_0.numpy() @ input.numpy()[:, indices] + bias_0.numpy(), 0.0)
|
|
222
|
+
z_np = np.maximum(weights_1.numpy() @ z_np + bias_1.numpy(), 0.0)
|
|
223
|
+
z_np = np.maximum(weights_2.numpy() @ z_np + bias_2.numpy(), 0.0)
|
|
224
|
+
z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
|
|
225
|
+
|
|
226
|
+
# test numpy forward
|
|
227
|
+
assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
|
|
228
|
+
|
|
229
|
+
# torch
|
|
230
|
+
input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
|
|
231
|
+
|
|
232
|
+
weights_0_tc = tc.tensor(weights_0.numpy(), requires_grad=True, device=torch_device)
|
|
233
|
+
bias_0_tc = tc.tensor(bias_0.numpy(), requires_grad=True, device=torch_device)
|
|
234
|
+
|
|
235
|
+
weights_1_tc = tc.tensor(weights_1.numpy(), requires_grad=True, device=torch_device)
|
|
236
|
+
bias_1_tc = tc.tensor(bias_1.numpy(), requires_grad=True, device=torch_device)
|
|
237
|
+
|
|
238
|
+
weights_2_tc = tc.tensor(weights_2.numpy(), requires_grad=True, device=torch_device)
|
|
239
|
+
bias_2_tc = tc.tensor(bias_2.numpy(), requires_grad=True, device=torch_device)
|
|
240
|
+
|
|
241
|
+
weights_3_tc = tc.tensor(weights_3.numpy(), requires_grad=True, device=torch_device)
|
|
242
|
+
bias_3_tc = tc.tensor(bias_3.numpy(), requires_grad=True, device=torch_device)
|
|
243
|
+
|
|
244
|
+
z_tc = tc.clamp(weights_0_tc @ input_tc + bias_0_tc, min=0.0)
|
|
245
|
+
z_tc = tc.clamp(weights_1_tc @ z_tc + bias_1_tc, min=0.0)
|
|
246
|
+
z_tc = tc.clamp(weights_2_tc @ z_tc + bias_2_tc, min=0.0)
|
|
247
|
+
z_tc = tc.clamp(weights_3_tc @ z_tc + bias_3_tc, min=0.0)
|
|
248
|
+
|
|
249
|
+
ref_tc = tc.tensor(reference.numpy()[:, indices], requires_grad=True, device=torch_device)
|
|
250
|
+
|
|
251
|
+
l_tc = tc.mean((z_tc - ref_tc) ** 2)
|
|
252
|
+
l_tc.backward()
|
|
253
|
+
|
|
254
|
+
# test torch
|
|
255
|
+
assert_np_equal(z_tc.cpu().detach().numpy(), output.numpy()[:, indices], tol=1.0e-2)
|
|
256
|
+
assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
257
|
+
assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
258
|
+
assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
259
|
+
assert_np_equal(bias_1.grad.numpy(), bias_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
260
|
+
assert_np_equal(weights_2.grad.numpy(), weights_2_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
261
|
+
assert_np_equal(bias_2.grad.numpy(), bias_2_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
262
|
+
assert_np_equal(weights_3.grad.numpy(), weights_3_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
263
|
+
assert_np_equal(bias_3.grad.numpy(), bias_3_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
|
|
264
|
+
|
|
265
|
+
optimizer.step(optimizer_grads)
|
|
266
|
+
tape.zero()
|
|
267
|
+
|
|
268
|
+
# initial loss is ~0.061
|
|
269
|
+
test.assertLess(loss.numpy()[0], 0.002)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
@unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
|
|
273
|
+
def test_single_layer_nn(test, device):
|
|
274
|
+
import torch as tc
|
|
275
|
+
|
|
276
|
+
DIM_IN = 8
|
|
277
|
+
DIM_HID = 32
|
|
278
|
+
DIM_OUT = 16
|
|
279
|
+
|
|
280
|
+
NUM_BLOCKS = 56
|
|
281
|
+
|
|
282
|
+
@wp.func
|
|
283
|
+
def relu(x: float):
|
|
284
|
+
return wp.max(x, 0.0)
|
|
285
|
+
|
|
286
|
+
@wp.kernel
|
|
287
|
+
def compute(
|
|
288
|
+
input: wp.array2d(dtype=float),
|
|
289
|
+
weights: wp.array2d(dtype=float),
|
|
290
|
+
bias: wp.array2d(dtype=float),
|
|
291
|
+
out: wp.array2d(dtype=float),
|
|
292
|
+
):
|
|
293
|
+
i = wp.tid()
|
|
294
|
+
|
|
295
|
+
f = wp.tile_load(input, 0, i, m=DIM_IN, n=NUM_THREADS)
|
|
296
|
+
|
|
297
|
+
w = wp.tile_load(weights, 0, 0, DIM_OUT, DIM_IN)
|
|
298
|
+
b = wp.tile_load(bias, 0, 0, m=DIM_OUT, n=1)
|
|
299
|
+
|
|
300
|
+
o = wp.tile_map(relu, wp.tile_matmul(w, f) + wp.tile_broadcast(b, m=DIM_OUT, n=NUM_THREADS))
|
|
301
|
+
|
|
302
|
+
wp.tile_store(out, 0, i, o)
|
|
303
|
+
|
|
304
|
+
with wp.ScopedDevice(device):
|
|
305
|
+
rng = np.random.default_rng(45)
|
|
306
|
+
|
|
307
|
+
# single layer weights, bias
|
|
308
|
+
weights, bias = create_layer(rng, DIM_IN, DIM_OUT, dtype=float)
|
|
309
|
+
|
|
310
|
+
input = create_array(rng, NUM_THREADS * NUM_BLOCKS, DIM_IN)
|
|
311
|
+
output = create_array(rng, NUM_THREADS * NUM_BLOCKS, DIM_OUT)
|
|
312
|
+
|
|
313
|
+
with wp.Tape() as tape:
|
|
314
|
+
wp.launch_tiled(compute, dim=[NUM_BLOCKS], inputs=[input, weights, bias, output], block_dim=NUM_THREADS)
|
|
315
|
+
|
|
316
|
+
output.grad = wp.ones_like(output)
|
|
317
|
+
tape.backward()
|
|
318
|
+
|
|
319
|
+
# numpy
|
|
320
|
+
output_np = np.maximum(weights.numpy() @ input.numpy() + bias.numpy(), 0.0)
|
|
321
|
+
|
|
322
|
+
# test numpy forward
|
|
323
|
+
assert_np_equal(output.numpy(), output_np, tol=1.0e-2)
|
|
324
|
+
|
|
325
|
+
# torch
|
|
326
|
+
weights_tc = tc.from_numpy(weights.numpy()).requires_grad_(True) # use .numpy() to avoid any memory aliasing
|
|
327
|
+
input_tc = tc.from_numpy(input.numpy()).requires_grad_(True)
|
|
328
|
+
bias_tc = tc.from_numpy(bias.numpy()).requires_grad_(True)
|
|
329
|
+
|
|
330
|
+
output_tc = tc.clamp(weights_tc @ input_tc + bias_tc, min=0.0)
|
|
331
|
+
output_tc.backward(tc.ones_like(output_tc))
|
|
332
|
+
|
|
333
|
+
# test torch
|
|
334
|
+
assert_np_equal(output_tc.detach().numpy(), output.numpy(), tol=1.0e-2)
|
|
335
|
+
assert_np_equal(input.grad.numpy(), input_tc.grad.detach().numpy(), tol=1.0e-2)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class TestTileMLP(unittest.TestCase):
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
test_devices = get_test_devices()
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
import torch
|
|
346
|
+
|
|
347
|
+
# check which Warp devices work with Torch
|
|
348
|
+
# CUDA devices may fail if Torch was not compiled with CUDA support
|
|
349
|
+
torch_compatible_devices = []
|
|
350
|
+
torch_compatible_cuda_devices = []
|
|
351
|
+
|
|
352
|
+
for d in test_devices:
|
|
353
|
+
try:
|
|
354
|
+
t = torch.arange(10, device=wp.device_to_torch(d))
|
|
355
|
+
t += 1
|
|
356
|
+
torch_compatible_devices.append(d)
|
|
357
|
+
if d.is_cuda:
|
|
358
|
+
torch_compatible_cuda_devices.append(d)
|
|
359
|
+
except Exception as e:
|
|
360
|
+
print(f"Skipping Torch tests on device '{d}' due to exception: {e}")
|
|
361
|
+
|
|
362
|
+
add_function_test(
|
|
363
|
+
TestTileMLP,
|
|
364
|
+
"test_single_layer_nn",
|
|
365
|
+
test_single_layer_nn,
|
|
366
|
+
check_output=False,
|
|
367
|
+
devices=torch_compatible_cuda_devices,
|
|
368
|
+
)
|
|
369
|
+
add_function_test(
|
|
370
|
+
TestTileMLP,
|
|
371
|
+
"test_multi_layer_nn",
|
|
372
|
+
test_multi_layer_nn,
|
|
373
|
+
check_output=False,
|
|
374
|
+
devices=torch_compatible_cuda_devices,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
except Exception as e:
|
|
378
|
+
print(f"Skipping Torch tests due to exception: {e}")
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
if __name__ == "__main__":
|
|
382
|
+
wp.clear_kernel_cache()
|
|
383
|
+
unittest.main(verbosity=2, failfast=True)
|