warp-lang 1.5.0__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1124 -497
- warp/codegen.py +261 -136
- warp/config.py +1 -1
- warp/context.py +357 -119
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth.py +3 -1
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/fem/geometry/geometry.py +0 -2
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/coloring.cpp +5 -1
- warp/native/cuda_util.cpp +91 -53
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1187 -669
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +130 -64
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +134 -72
- warp/sparse.py +1 -1
- warp/stubs.py +265 -132
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_coloring.py +12 -2
- warp/tests/test_examples.py +12 -1
- warp/tests/test_func.py +21 -4
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +17 -16
- warp/tests/test_matmul_lite.py +10 -15
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +47 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_static.py +19 -3
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +178 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -13
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +411 -101
- warp/utils.py +10 -7
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -1,179 +0,0 @@
|
|
|
1
|
-
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
|
|
11
|
-
import warp as wp
|
|
12
|
-
|
|
13
|
-
wp.init()
|
|
14
|
-
wp.set_module_options({"enable_backward": False, "fast_math": True})
|
|
15
|
-
wp.set_device("cuda:0")
|
|
16
|
-
|
|
17
|
-
wp.build.clear_kernel_cache()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@wp.kernel
|
|
21
|
-
def gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
|
|
22
|
-
# output index
|
|
23
|
-
i, j = wp.tid()
|
|
24
|
-
|
|
25
|
-
sum = float(0.0)
|
|
26
|
-
|
|
27
|
-
for k in range(0, A.shape[1]):
|
|
28
|
-
sum += A[i, k] * B[k, j]
|
|
29
|
-
|
|
30
|
-
C[i, j] = sum
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
TILE_M = wp.constant(64)
|
|
34
|
-
TILE_N = wp.constant(64)
|
|
35
|
-
TILE_K = wp.constant(8)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
@wp.kernel
|
|
39
|
-
def gemm_tiled(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
|
|
40
|
-
# output tile index
|
|
41
|
-
i, j = wp.tid()
|
|
42
|
-
|
|
43
|
-
sum = wp.tile_zeros(m=TILE_M, n=TILE_N, dtype=wp.float32)
|
|
44
|
-
|
|
45
|
-
_M = A.shape[0]
|
|
46
|
-
_N = B.shape[1]
|
|
47
|
-
K = A.shape[1]
|
|
48
|
-
|
|
49
|
-
count = int(K / 8) # TODO: code-gen bug if you use a constant before passing it to a kwd arg (in this case TILE_K)
|
|
50
|
-
|
|
51
|
-
for k in range(count):
|
|
52
|
-
a = wp.tile_load(A, i, k, m=TILE_M, n=TILE_K)
|
|
53
|
-
b = wp.tile_load(B, k, j, m=TILE_K, n=TILE_N)
|
|
54
|
-
|
|
55
|
-
# sum += a*b
|
|
56
|
-
wp.tile_matmul(a, b, sum)
|
|
57
|
-
|
|
58
|
-
wp.tile_store(C, i, j, sum)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def benchmark_numpy(A, B, C):
|
|
62
|
-
timers = {}
|
|
63
|
-
iters = 10
|
|
64
|
-
|
|
65
|
-
# warm up
|
|
66
|
-
for _i in range(10):
|
|
67
|
-
_C = A @ B
|
|
68
|
-
|
|
69
|
-
with wp.ScopedTimer("NumPy", dict=timers):
|
|
70
|
-
for _i in range(iters):
|
|
71
|
-
_C = A @ B
|
|
72
|
-
|
|
73
|
-
return min(timers["NumPy"])
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def benchmark_warp_simt(A, B, C):
|
|
77
|
-
timers = {}
|
|
78
|
-
iters = 10
|
|
79
|
-
|
|
80
|
-
A_wp = wp.array(A)
|
|
81
|
-
B_wp = wp.array(B)
|
|
82
|
-
C_wp = wp.array(C)
|
|
83
|
-
|
|
84
|
-
# warm up
|
|
85
|
-
for _i in range(10):
|
|
86
|
-
wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp])
|
|
87
|
-
|
|
88
|
-
with wp.ScopedTimer("Warp (SIMT)", dict=timers, print=False, synchronize=True):
|
|
89
|
-
for _i in range(iters):
|
|
90
|
-
wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp])
|
|
91
|
-
|
|
92
|
-
return min(timers["Warp (SIMT)"])
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def benchmark_warp_tiled(A, B, C):
|
|
96
|
-
timers = {}
|
|
97
|
-
iters = 10
|
|
98
|
-
|
|
99
|
-
# must match with the tile_matmul() partition size
|
|
100
|
-
SUB_TILE_M = 4
|
|
101
|
-
SUB_TILE_N = 4
|
|
102
|
-
|
|
103
|
-
num_threads = int(TILE_M / SUB_TILE_M) * int(TILE_N / SUB_TILE_N)
|
|
104
|
-
A_wp = wp.array(A)
|
|
105
|
-
B_wp = wp.array(B)
|
|
106
|
-
C_wp = wp.array(C)
|
|
107
|
-
|
|
108
|
-
# warm up
|
|
109
|
-
wp.capture_begin()
|
|
110
|
-
|
|
111
|
-
for _i in range(iters):
|
|
112
|
-
wp.launch(gemm_tiled, dim=(int(M / TILE_M), int(N / TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads)
|
|
113
|
-
|
|
114
|
-
graph = wp.capture_end()
|
|
115
|
-
|
|
116
|
-
with wp.ScopedTimer("Warp (Tiled)", dict=timers, print=False, synchronize=True):
|
|
117
|
-
# for i in range(iters):
|
|
118
|
-
# wp.launch(gemm_tiled, dim=(int(M/TILE_M), int(N/TILE_N)), inputs=[A_wp, B_wp, C_wp], tile_size=num_threads)
|
|
119
|
-
wp.capture_launch(graph)
|
|
120
|
-
|
|
121
|
-
return min(timers["Warp (Tiled)"])
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
def benchmark_torch(A, B, C):
|
|
125
|
-
A_tc = torch.from_numpy(A).to("cuda:0")
|
|
126
|
-
B_tc = torch.from_numpy(B).to("cuda:0")
|
|
127
|
-
C_tc = torch.from_numpy(C).to("cuda:0")
|
|
128
|
-
|
|
129
|
-
# warm-up
|
|
130
|
-
for _i in range(10):
|
|
131
|
-
torch.matmul(A_tc, B_tc, out=C_tc)
|
|
132
|
-
|
|
133
|
-
timers = {}
|
|
134
|
-
iters = 10
|
|
135
|
-
|
|
136
|
-
torch.cuda.synchronize()
|
|
137
|
-
|
|
138
|
-
with wp.ScopedTimer("Torch", dict=timers, print=False):
|
|
139
|
-
for _i in range(iters):
|
|
140
|
-
torch.matmul(A_tc, B_tc) # , out=C_tc)
|
|
141
|
-
|
|
142
|
-
torch.cuda.synchronize()
|
|
143
|
-
|
|
144
|
-
return min(timers["Torch"])
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
results_torch = []
|
|
148
|
-
results_warp_simt = []
|
|
149
|
-
results_warp_tiled = []
|
|
150
|
-
|
|
151
|
-
print("{:>8s} {:>8s} {:>8s} {:>8s} {:>8s} {:>8s}".format("M", "N", "K", "Torch", "Warp (SIMT)", "Warp (Tiled)"))
|
|
152
|
-
print("--------------------------------------------------------")
|
|
153
|
-
|
|
154
|
-
for i in range(2, 33):
|
|
155
|
-
# for i in range(8,9):
|
|
156
|
-
|
|
157
|
-
M = i * 128
|
|
158
|
-
N = M
|
|
159
|
-
K = N
|
|
160
|
-
|
|
161
|
-
# M = TILE_M*21
|
|
162
|
-
# K = TILE_K*7
|
|
163
|
-
# N = TILE_M*12
|
|
164
|
-
|
|
165
|
-
rng = np.random.default_rng(42)
|
|
166
|
-
|
|
167
|
-
A = rng.random((M, K), dtype=np.float32)
|
|
168
|
-
B = rng.random((K, N), dtype=np.float32)
|
|
169
|
-
C = np.zeros((M, N), dtype=np.float32)
|
|
170
|
-
|
|
171
|
-
results_torch.append(benchmark_torch(A, B, C))
|
|
172
|
-
results_warp_simt.append(0.0) # benchmark_warp_simt(A, B, C))
|
|
173
|
-
results_warp_tiled.append(benchmark_warp_tiled(A, B, C))
|
|
174
|
-
|
|
175
|
-
print(
|
|
176
|
-
"{:>8d} {:>8d} {:>8d} {:>8f} {:>8f} {:>8f}".format(
|
|
177
|
-
M, N, K, results_torch[-1], results_warp_simt[-1], results_warp_tiled[-1]
|
|
178
|
-
)
|
|
179
|
-
)
|
warp/native/tile_gemm.h
DELETED
|
@@ -1,341 +0,0 @@
|
|
|
1
|
-
/** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
* and proprietary rights in and to this software, related documentation
|
|
4
|
-
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
* distribution of this software and related documentation without an express
|
|
6
|
-
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
*/
|
|
8
|
-
|
|
9
|
-
#pragma once
|
|
10
|
-
|
|
11
|
-
#include "builtin.h"
|
|
12
|
-
|
|
13
|
-
#define USE_CUTE 0
|
|
14
|
-
|
|
15
|
-
#if USE_CUTE
|
|
16
|
-
#include "cutlass/include/cute/tensor.hpp"
|
|
17
|
-
#include "cutlass/include/cute/algorithm/cooperative_gemm.hpp"
|
|
18
|
-
#endif // USE_CUTE
|
|
19
|
-
|
|
20
|
-
namespace wp
|
|
21
|
-
{
|
|
22
|
-
|
|
23
|
-
/*
|
|
24
|
-
// 2D tile zero
|
|
25
|
-
template <typename T, int M, int N, int Index>
|
|
26
|
-
inline CUDA_CALLABLE array_t<T> tile_zeros()
|
|
27
|
-
{
|
|
28
|
-
const int length = M*N;
|
|
29
|
-
|
|
30
|
-
WP_TILE_SHARED __align__(16) T data[length];
|
|
31
|
-
|
|
32
|
-
WP_PRAGMA_UNROLL
|
|
33
|
-
for (int t=threadIdx.x; t < length; t += blockDim.x)
|
|
34
|
-
{
|
|
35
|
-
data[t] = T(0.0);
|
|
36
|
-
}
|
|
37
|
-
|
|
38
|
-
return array_t<T>(data, M, N, nullptr);
|
|
39
|
-
}
|
|
40
|
-
|
|
41
|
-
// 2D tile load
|
|
42
|
-
template <typename T, int M, int N, int Index>
|
|
43
|
-
inline CUDA_CALLABLE array_t<T> tile_load(const array_t<T>& src, int i, int j)
|
|
44
|
-
{
|
|
45
|
-
const int length = M*N;
|
|
46
|
-
|
|
47
|
-
WP_TILE_SHARED __align__(16) T data[length];
|
|
48
|
-
|
|
49
|
-
//---------------
|
|
50
|
-
// naive-synchronous load
|
|
51
|
-
//
|
|
52
|
-
// WP_PRAGMA_UNROLL
|
|
53
|
-
// for (int t=threadIdx.x; t < length; t += blockDim.x)
|
|
54
|
-
// {
|
|
55
|
-
// data[t] = index(src, i*M + t/N, j*N + t%N);
|
|
56
|
-
// }
|
|
57
|
-
|
|
58
|
-
//---------------
|
|
59
|
-
// async 128 bit loads (assumes row-major i.e.: stride 1 on y axis and 4-element alignment on dimension)
|
|
60
|
-
const int s = 4;
|
|
61
|
-
|
|
62
|
-
WP_PRAGMA_UNROLL
|
|
63
|
-
for (int t=threadIdx.x*s; t < length; t += blockDim.x*s)
|
|
64
|
-
{
|
|
65
|
-
__pipeline_memcpy_async(&data[t],
|
|
66
|
-
&index(src, i*M + t/N, j*N + t%N),
|
|
67
|
-
sizeof(T)*s);
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
__pipeline_commit();
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
return array_t<T>(data, M, N, nullptr);
|
|
74
|
-
}
|
|
75
|
-
|
|
76
|
-
// 2D tile store
|
|
77
|
-
template <typename T>
|
|
78
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int i, int j, const array_t<T>& src)
|
|
79
|
-
{
|
|
80
|
-
const int M = src.shape[0];
|
|
81
|
-
const int N = src.shape[1];
|
|
82
|
-
|
|
83
|
-
const int length = M*N;
|
|
84
|
-
|
|
85
|
-
// cooperatively store the tile, using a block-stride iterator
|
|
86
|
-
WP_PRAGMA_UNROLL
|
|
87
|
-
for (int t=threadIdx.x; t < length; t += blockDim.x)
|
|
88
|
-
{
|
|
89
|
-
index(dest, i*M + t/N, j*N + t%N) = src.data[t];
|
|
90
|
-
}
|
|
91
|
-
}
|
|
92
|
-
*/
|
|
93
|
-
|
|
94
|
-
template <typename T>
|
|
95
|
-
inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
|
|
96
|
-
{
|
|
97
|
-
return p[i*stride + j];
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
template <typename T>
|
|
101
|
-
inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
|
|
102
|
-
{
|
|
103
|
-
return p[i*stride + j];
|
|
104
|
-
}
|
|
105
|
-
|
|
106
|
-
template <unsigned M, unsigned N, typename T>
|
|
107
|
-
struct partition_t
|
|
108
|
-
{
|
|
109
|
-
inline partition_t(array_t<T> A)
|
|
110
|
-
{
|
|
111
|
-
data = A;
|
|
112
|
-
|
|
113
|
-
// todo: do ceil div for non-multiples of M,N
|
|
114
|
-
shape[0] = A.shape[0]/M;
|
|
115
|
-
shape[1] = A.shape[1]/N;
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
// underlying data
|
|
119
|
-
array_t<T> data;
|
|
120
|
-
|
|
121
|
-
// partition dimensions
|
|
122
|
-
int shape[2];
|
|
123
|
-
};
|
|
124
|
-
|
|
125
|
-
template <unsigned M, unsigned N, typename T>
|
|
126
|
-
inline int partition_size(const partition_t<M, N, T>& tile)
|
|
127
|
-
{
|
|
128
|
-
return tile.shape[0]*tile.shape[1];
|
|
129
|
-
}
|
|
130
|
-
|
|
131
|
-
// returns the x, y coordinates of a tile given a linear index
|
|
132
|
-
template <unsigned M, unsigned N, typename T>
|
|
133
|
-
inline void partition_coord(const partition_t<M, N, T>& tile, const int t, int& i, int& j)
|
|
134
|
-
{
|
|
135
|
-
i = t/tile.shape[1];
|
|
136
|
-
j = t%tile.shape[1];
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
template <unsigned M, unsigned N, typename T>
|
|
140
|
-
inline mat_t<M, N, T> partition_load(const partition_t<M, N, T>& tile, int i, int j)
|
|
141
|
-
{
|
|
142
|
-
mat_t<M, N, T> out;
|
|
143
|
-
|
|
144
|
-
const int tile_i = i*M;
|
|
145
|
-
const int tile_j = j*N;
|
|
146
|
-
|
|
147
|
-
WP_PRAGMA_UNROLL
|
|
148
|
-
for (int i=0; i < M; ++i)
|
|
149
|
-
{
|
|
150
|
-
WP_PRAGMA_UNROLL
|
|
151
|
-
for (int j=0; j < N; ++j)
|
|
152
|
-
{
|
|
153
|
-
out.data[i][j] = index(tile.data, tile_i + i, tile_j + j);
|
|
154
|
-
}
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
return out;
|
|
158
|
-
}
|
|
159
|
-
|
|
160
|
-
template <unsigned M, unsigned N, typename T>
|
|
161
|
-
inline void partition_store(const partition_t<M, N, T>& tile, int i, int j, const mat_t<M, N, T>& value)
|
|
162
|
-
{
|
|
163
|
-
mat_t<M, N, T> out;
|
|
164
|
-
|
|
165
|
-
const int tile_i = M*i;
|
|
166
|
-
const int tile_j = N*j;
|
|
167
|
-
|
|
168
|
-
WP_PRAGMA_UNROLL
|
|
169
|
-
for (int i=0; i < M; ++i)
|
|
170
|
-
{
|
|
171
|
-
WP_PRAGMA_UNROLL
|
|
172
|
-
for (int j=0; j < N; ++j)
|
|
173
|
-
{
|
|
174
|
-
index(tile.data, tile_i + i, tile_j + j) = value.data[i][j];
|
|
175
|
-
}
|
|
176
|
-
}
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
#if !USE_CUTE
|
|
181
|
-
|
|
182
|
-
template <typename T>
|
|
183
|
-
inline CUDA_CALLABLE void gemm(const array_t<T>& A, const array_t<T>& B, const array_t<T>& out)
|
|
184
|
-
{
|
|
185
|
-
const int TILE_M = 4;
|
|
186
|
-
const int TILE_N = 4;
|
|
187
|
-
const int TILE_K = 4;
|
|
188
|
-
|
|
189
|
-
partition_t A_tile = partition_t<TILE_M, TILE_K, T>(A);
|
|
190
|
-
partition_t B_tile = partition_t<TILE_K, TILE_N, T>(B);
|
|
191
|
-
partition_t C_tile = partition_t<TILE_M, TILE_N, T>(out);
|
|
192
|
-
|
|
193
|
-
const int length = partition_size(C_tile);
|
|
194
|
-
|
|
195
|
-
__pipeline_wait_prior(0);
|
|
196
|
-
|
|
197
|
-
WP_TILE_SYNC();
|
|
198
|
-
|
|
199
|
-
for (int t=threadIdx.x; t < length; t += blockDim.x)
|
|
200
|
-
{
|
|
201
|
-
int i, j;
|
|
202
|
-
partition_coord(C_tile, t, i, j);
|
|
203
|
-
|
|
204
|
-
// accumulator
|
|
205
|
-
mat_t<TILE_M, TILE_N, T> sum = partition_load(C_tile, i, j);
|
|
206
|
-
|
|
207
|
-
WP_PRAGMA_UNROLL
|
|
208
|
-
for (int k=0; k < A_tile.shape[1]; k++)
|
|
209
|
-
{
|
|
210
|
-
const mat_t<TILE_M, TILE_K, T> a = partition_load(A_tile, i, k);
|
|
211
|
-
const mat_t<TILE_K, TILE_N, T> b = partition_load(B_tile, k, j);
|
|
212
|
-
|
|
213
|
-
sum += mul(a, b);
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
partition_store(C_tile, i, j, sum);
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
WP_TILE_SYNC();
|
|
220
|
-
}
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
// 2D gemm accumulate out += A*B
|
|
224
|
-
template <typename TileA, typename TileB, typename TileC>
|
|
225
|
-
inline CUDA_CALLABLE void tile_matmul_scalar(const TileA& A,
|
|
226
|
-
const TileB& B,
|
|
227
|
-
TileC& out)
|
|
228
|
-
{
|
|
229
|
-
const int length = tile_size(out);
|
|
230
|
-
|
|
231
|
-
WP_TILE_SYNC();
|
|
232
|
-
|
|
233
|
-
using T = typename TileA::Type;
|
|
234
|
-
|
|
235
|
-
WP_PRAGMA_UNROLL
|
|
236
|
-
for (int t=threadIdx.x; t < length; t += WP_TILE_BLOCK_DIM)
|
|
237
|
-
{
|
|
238
|
-
// compute output index
|
|
239
|
-
const int i = t/out.N;
|
|
240
|
-
const int j = t%out.N;
|
|
241
|
-
|
|
242
|
-
T sum(0.0);
|
|
243
|
-
|
|
244
|
-
WP_PRAGMA_UNROLL
|
|
245
|
-
for (int k=0; k < A.N; ++k)
|
|
246
|
-
{
|
|
247
|
-
T a = A(i,k);
|
|
248
|
-
T b = B(k,j);
|
|
249
|
-
|
|
250
|
-
sum += a*b; // todo: use fmaf()
|
|
251
|
-
}
|
|
252
|
-
|
|
253
|
-
out(i,j) += sum;
|
|
254
|
-
}
|
|
255
|
-
|
|
256
|
-
WP_TILE_SYNC();
|
|
257
|
-
}
|
|
258
|
-
|
|
259
|
-
#else
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
template <typename T>
|
|
263
|
-
inline CUDA_CALLABLE void tile_matmul(const array_t<T>& A, const array_t<T>& B, const array_t<T>& out)
|
|
264
|
-
{
|
|
265
|
-
using namespace cute;
|
|
266
|
-
|
|
267
|
-
__pipeline_wait_prior(0);
|
|
268
|
-
|
|
269
|
-
// ensure smem tile is ready
|
|
270
|
-
WP_TILE_SYNC();
|
|
271
|
-
|
|
272
|
-
// Define CTA matrix size (static)
|
|
273
|
-
auto bM = Int<64>{};
|
|
274
|
-
auto bN = Int<64>{};
|
|
275
|
-
auto bK = Int<8>{};
|
|
276
|
-
|
|
277
|
-
// Define the smem layouts (static)
|
|
278
|
-
auto sA = make_layout(make_shape(bM, bK), LayoutRight{});
|
|
279
|
-
auto sB = make_layout(make_shape(bN, bK));
|
|
280
|
-
auto sC = make_layout(make_shape(bM, bN), LayoutRight{});
|
|
281
|
-
|
|
282
|
-
Tensor s_a_tensor = make_tensor(make_smem_ptr<float>(A.data), sA);
|
|
283
|
-
Tensor s_b_tensor = make_tensor(make_smem_ptr<float>(B.data), sB);
|
|
284
|
-
Tensor s_c_tensor = make_tensor(make_smem_ptr<float>(out.data), sC);
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
// TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
|
|
288
|
-
// Layout<Shape<_16,_8,_1>>{}); // 16x8x1 UniversalFMA, assumes blockDim=128
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
// TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
|
|
292
|
-
// Layout<Shape<_8,_16>,Stride<_16,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
TiledMMA tiled_mma = make_tiled_mma(UniversalFMA<float,float,float>{},
|
|
297
|
-
Layout<Shape<_2,_64>,Stride<_64,_1>>{}); // 8x16x1 UniversalFMA, assumes blockDim=128
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
cooperative_gemm< AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>,
|
|
301
|
-
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>,
|
|
302
|
-
AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<float>>
|
|
303
|
-
>(
|
|
304
|
-
threadIdx.x, tiled_mma,
|
|
305
|
-
1.0f, s_a_tensor, s_b_tensor, 1.0f, s_c_tensor,
|
|
306
|
-
cute::identity(), cute::identity(), cute::identity(), cute::identity()
|
|
307
|
-
);
|
|
308
|
-
|
|
309
|
-
WP_TILE_SYNC();
|
|
310
|
-
|
|
311
|
-
}
|
|
312
|
-
|
|
313
|
-
#endif // USE_CUTE
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
#if 0
|
|
317
|
-
|
|
318
|
-
template <typename TileA, typename TileB, typename TileC>
|
|
319
|
-
void tile_matmul(TileA& a, TileB& b, TileC& c)
|
|
320
|
-
{
|
|
321
|
-
static_assert(wp::is_same<typename TileA::Type, typename TileB::Type>::value, "Error, tile datatypes must match");
|
|
322
|
-
static_assert(TileA::N == TileB::M, "Error, inner dimensions must match");
|
|
323
|
-
static_assert(TileC::M == TileA::M, "Error, first output dimension must match");
|
|
324
|
-
static_assert(TileC::N == TileB::N, "Error, second output dimension must match");
|
|
325
|
-
|
|
326
|
-
tile_matmul_scalar(a, b, c);
|
|
327
|
-
}
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
template <typename TileA, typename TileB, typename TileC,
|
|
331
|
-
typename AdjTileA, typename AdjTileB, typename AdjTileC>
|
|
332
|
-
void adj_tile_matmul(TileA& a, TileB& b, TileC& c,
|
|
333
|
-
AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c)
|
|
334
|
-
{
|
|
335
|
-
tile_matmul_scalar(adj_c, wp::tile_transpose(b), adj_a);
|
|
336
|
-
tile_matmul_scalar(wp::tile_transpose(a), adj_c, adj_b);
|
|
337
|
-
}
|
|
338
|
-
|
|
339
|
-
#endif // 0
|
|
340
|
-
|
|
341
|
-
} // namespace wp
|
|
File without changes
|
|
File without changes
|