warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h
ADDED
|
@@ -0,0 +1,1857 @@
|
|
|
1
|
+
/** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "builtin.h"
|
|
12
|
+
|
|
13
|
+
#if !defined(__CUDA_ARCH__)
|
|
14
|
+
#define WP_TILE_SHARED static
|
|
15
|
+
#define WP_TILE_SYNC void
|
|
16
|
+
#else
|
|
17
|
+
#define WP_TILE_SHARED __shared__
|
|
18
|
+
#define WP_TILE_SYNC __syncthreads
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
|
|
22
|
+
#if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
|
|
23
|
+
#define WP_PRAGMA_UNROLL _Pragma("unroll")
|
|
24
|
+
#define WP_PRAGMA_NO_UNROLL _Pragma("unroll 1")
|
|
25
|
+
#else
|
|
26
|
+
#define WP_PRAGMA_UNROLL #pragma unroll
|
|
27
|
+
#define WP_PRAGMA_NO_UNROLL #pragma unroll 1
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
#else
|
|
31
|
+
|
|
32
|
+
#define WP_PRAGMA_UNROLL
|
|
33
|
+
#define WP_PRAGMA_NO_UNROLL
|
|
34
|
+
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
#define WP_USE_ASYNC_PIPELINE 0
|
|
38
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
39
|
+
#include "cuda_pipeline_primitives.h"
|
|
40
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
41
|
+
|
|
42
|
+
#define WP_USE_REGISTER_GEMM 0
|
|
43
|
+
|
|
44
|
+
/* Tile Expressions
|
|
45
|
+
|
|
46
|
+
[ ] Tiles
|
|
47
|
+
[x] Register, Shared, Global
|
|
48
|
+
[ ] Layouts
|
|
49
|
+
[x] Simple
|
|
50
|
+
[ ] Cute
|
|
51
|
+
[x] Remove Alloc type from tile_shared_t
|
|
52
|
+
[x] wp.launch_tiled() helper
|
|
53
|
+
[ ] Creation
|
|
54
|
+
[x] zeros
|
|
55
|
+
[x] ones
|
|
56
|
+
[x] arange
|
|
57
|
+
[x] tile()
|
|
58
|
+
[x] untile()
|
|
59
|
+
[ ] fromfunction()
|
|
60
|
+
[ ] explicit storage
|
|
61
|
+
[ ] Load/Store
|
|
62
|
+
[ ] 1D load/store variants
|
|
63
|
+
[ ] max_coord option for non-aligned loads
|
|
64
|
+
[ ] Indexed load
|
|
65
|
+
[x] wp.tile_atomic_add()
|
|
66
|
+
[ ] Maps
|
|
67
|
+
[x] Support user functions
|
|
68
|
+
[x] Support built-in functions
|
|
69
|
+
[ ] Support for lambda functions
|
|
70
|
+
[ ] Infer tile_map() output from operator type (e.g.: dot for each element)
|
|
71
|
+
[ ] Reductions
|
|
72
|
+
[x] Sum
|
|
73
|
+
[x] Forward
|
|
74
|
+
[x] Reverse
|
|
75
|
+
[x] Min
|
|
76
|
+
[x] Max
|
|
77
|
+
[x] Custom
|
|
78
|
+
[x] MatMul
|
|
79
|
+
[x] Forward
|
|
80
|
+
[x] Reverse
|
|
81
|
+
[ ] Operators
|
|
82
|
+
[ ] +, -, *, /, @?
|
|
83
|
+
[ ] += for matmul, e.g.: c += a@b, or c = a@b
|
|
84
|
+
[ ] Reshape
|
|
85
|
+
[ ] Broadcasting
|
|
86
|
+
[ ] Transpose
|
|
87
|
+
[x] Shared
|
|
88
|
+
[ ] Register
|
|
89
|
+
[ ] Slice
|
|
90
|
+
[ ] Runtime
|
|
91
|
+
[x] Compile-time block dimensions
|
|
92
|
+
[x] Switch between SIMT / Tile based execution if `block_dim` not provided to wp.launch()
|
|
93
|
+
[ ] Examples
|
|
94
|
+
[ ] Point registration
|
|
95
|
+
[ ] GEMM
|
|
96
|
+
[ ] MLP
|
|
97
|
+
[ ] LayerNorm
|
|
98
|
+
[ ] SoftMax
|
|
99
|
+
[ ] GEMM
|
|
100
|
+
[ ] warp.sim (CRBA)
|
|
101
|
+
[ ] Batched MLP
|
|
102
|
+
[ ] Layer norm
|
|
103
|
+
[ ] FNO + Burgers equation
|
|
104
|
+
[ ] Stochastic financial modeling
|
|
105
|
+
[ ] Convolution: https://github.com/NVIDIA/MinkowskiEngine/blob/master/src/convolution_kernel.cu#L123
|
|
106
|
+
[ ] MeshCNN (Modulus, Oliver)
|
|
107
|
+
[ ] BioNemo (Ali)
|
|
108
|
+
[ ] Skinning (David/Or/Vismay)
|
|
109
|
+
[ ] warp.sim (VBD)
|
|
110
|
+
[ ] Error checking
|
|
111
|
+
[ ] Ensure functions passed to tile_map() are compatible with tile type
|
|
112
|
+
[ ] Ensure that args passed to tile ops are compatible
|
|
113
|
+
[ ] Ensure tile load/store operations don't go out of bounds of arrays in debug mode
|
|
114
|
+
|
|
115
|
+
*/
|
|
116
|
+
|
|
117
|
+
/*
|
|
118
|
+
Notes on shared memory synchronization
|
|
119
|
+
======================================
|
|
120
|
+
|
|
121
|
+
Currently operations that write to shared memory tiles (e.g.: tile_load())
|
|
122
|
+
must synchronize before they return through WP_TILE_SYNC(), this
|
|
123
|
+
ensures subsequent read operations from the tile do not cause a race condition.
|
|
124
|
+
|
|
125
|
+
For tile_shared_t adjoints, the gradient accumulation is done through shared
|
|
126
|
+
memory atomics, i.e.: atomic_add(), since for broadcast tiles multiple threads
|
|
127
|
+
may map to the same location. Synchronization is still required after these
|
|
128
|
+
updates, since subsequent operations e.g.: adj_tile_load() will store the
|
|
129
|
+
gradients to memory, and all updates must be visible at that point, e.g.:
|
|
130
|
+
|
|
131
|
+
a = wp.tile_load(...)
|
|
132
|
+
b = wp.tile_load(...)
|
|
133
|
+
c = wp.tile_matmul(a, b)
|
|
134
|
+
wp.tile_store(c)
|
|
135
|
+
|
|
136
|
+
// loads incoming adjoints from global -> shared
|
|
137
|
+
wp.adj_tile_store(c, adj_c)
|
|
138
|
+
// consumes adj_c, requires synchronization
|
|
139
|
+
wp.adj_tile_matmul(a, b, adj_a, adj_b, adj_c)
|
|
140
|
+
// consumes adj_b, requires synchronization
|
|
141
|
+
wp.adj_tile_load(..., adj_b)
|
|
142
|
+
// consumes adj_b, requires synchronization
|
|
143
|
+
wp.adj_tile_load(..., adj_a)
|
|
144
|
+
|
|
145
|
+
Generally synchronization to adjoint tiles will happen through the
|
|
146
|
+
tile_shared_t::add() and tile_shared_t::assign() function automatically,
|
|
147
|
+
but in some cases e.g.: tile_matmul() it is done manually.
|
|
148
|
+
|
|
149
|
+
The current synchronization strategy is conservative, and can lead to more
|
|
150
|
+
synchronization than necessary. A more sophisticated strategy would be
|
|
151
|
+
to track the 'dirty' state of shared tiles, and synchronize only when
|
|
152
|
+
necessary. In addition, custom synchronization for e.g.: tile_load()
|
|
153
|
+
operations could be added through a SyncProvider template parameter on
|
|
154
|
+
the tile_shared_t type, for example to support barrier synchronization
|
|
155
|
+
for asynchronous global to shared loads.
|
|
156
|
+
*/
|
|
157
|
+
|
|
158
|
+
namespace wp
|
|
159
|
+
{
|
|
160
|
+
|
|
161
|
+
// Primary template
|
|
162
|
+
template <typename T, typename U>
|
|
163
|
+
struct is_same {
|
|
164
|
+
static constexpr bool value = false;
|
|
165
|
+
};
|
|
166
|
+
|
|
167
|
+
// Specialization for the case when T and U are the same type
|
|
168
|
+
template <typename T>
|
|
169
|
+
struct is_same<T, T> {
|
|
170
|
+
static constexpr bool value = true;
|
|
171
|
+
};
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
template <typename Tile>
|
|
175
|
+
constexpr int tile_size(Tile& t) { return Tile::M*Tile::N; }
|
|
176
|
+
|
|
177
|
+
constexpr int tile_regcount(int m, int n) {
|
|
178
|
+
return (m*n + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
struct coord_t
|
|
182
|
+
{
|
|
183
|
+
int i;
|
|
184
|
+
int j;
|
|
185
|
+
};
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
// represents a tile stored in global memory with dynamic strides
|
|
189
|
+
// only used to represent the source for tile loads to register/shared
|
|
190
|
+
template <typename T>
|
|
191
|
+
struct tile_global_t
|
|
192
|
+
{
|
|
193
|
+
using Type = T;
|
|
194
|
+
|
|
195
|
+
array_t<T> data;
|
|
196
|
+
int x;
|
|
197
|
+
int y;
|
|
198
|
+
|
|
199
|
+
tile_global_t(array_t<T>& a, int x, int y) : data(a), x(x), y(y)
|
|
200
|
+
{
|
|
201
|
+
}
|
|
202
|
+
};
|
|
203
|
+
|
|
204
|
+
// represents a tile stored in registers across a block
|
|
205
|
+
template <typename T, int M_, int N_>
|
|
206
|
+
struct tile_register_t
|
|
207
|
+
{
|
|
208
|
+
using Type = T;
|
|
209
|
+
static constexpr int M = M_;
|
|
210
|
+
static constexpr int N = N_;
|
|
211
|
+
static constexpr int Size = M*N;
|
|
212
|
+
|
|
213
|
+
static constexpr int NumRegs = tile_regcount(M, N);
|
|
214
|
+
|
|
215
|
+
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
216
|
+
|
|
217
|
+
T data[NumRegs];
|
|
218
|
+
|
|
219
|
+
inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
|
|
220
|
+
{
|
|
221
|
+
// zero-initialize by default necessary for tile adjoints
|
|
222
|
+
// need to check if this results in worse codegen
|
|
223
|
+
// than doing adj_var = tile_zeros() explicitly
|
|
224
|
+
// in backwards pass and letting default constructor
|
|
225
|
+
// avoid initialization
|
|
226
|
+
|
|
227
|
+
for (int i=0; i < NumRegs; ++i)
|
|
228
|
+
data[i] = value;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
|
|
232
|
+
{
|
|
233
|
+
if (t.data.ndim == 1)
|
|
234
|
+
copy_from_global(t.data, t.x); // 1d load
|
|
235
|
+
else
|
|
236
|
+
copy_from_global(t.data, t.x, t.y); // 2d load
|
|
237
|
+
|
|
238
|
+
return *this;
|
|
239
|
+
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
// define the += operator which is used during backward pass codegen
|
|
243
|
+
// when returning a register tile from a user defined function
|
|
244
|
+
inline CUDA_CALLABLE auto& operator += (tile_register_t<T, M, N>& rhs)
|
|
245
|
+
{
|
|
246
|
+
this->grad_add(rhs);
|
|
247
|
+
return *this;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
inline CUDA_CALLABLE T& operator()(int index)
|
|
251
|
+
{
|
|
252
|
+
assert(index < NumRegs);
|
|
253
|
+
return data[index];
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
inline CUDA_CALLABLE const T& operator()(int index) const
|
|
257
|
+
{
|
|
258
|
+
assert(index < NumRegs);
|
|
259
|
+
return data[index];
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
// compute linear tile index from a local register index
|
|
264
|
+
inline CUDA_CALLABLE int index(int reg) const
|
|
265
|
+
{
|
|
266
|
+
return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// compute tile coordinate from linear index
|
|
270
|
+
inline CUDA_CALLABLE coord_t coord(int index) const
|
|
271
|
+
{
|
|
272
|
+
return {index/N, index%N};
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Returns the number of valid registers for this tile
|
|
276
|
+
// i.e.: how many registers map to a valid coordinate.
|
|
277
|
+
// When a tile's size is not aligned to the block dimension
|
|
278
|
+
// some of the trailing registers may lie outside the valid range
|
|
279
|
+
inline CUDA_CALLABLE int valid() const
|
|
280
|
+
{
|
|
281
|
+
return (Size - threadIdx.x)/WP_TILE_BLOCK_DIM;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
|
|
285
|
+
{
|
|
286
|
+
for (int i=0; i < NumRegs; ++i)
|
|
287
|
+
data[i] = tile.data[i];
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
inline CUDA_CALLABLE void zero()
|
|
291
|
+
{
|
|
292
|
+
for (int i=0; i < NumRegs; ++i)
|
|
293
|
+
data[i] = T(0);
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
// extract a single tile element to a native type
|
|
297
|
+
inline CUDA_CALLABLE Type extract(int i, int j)
|
|
298
|
+
{
|
|
299
|
+
// map from logical coords (i, j) -> (thread, reg)
|
|
300
|
+
const int linear = i*N + j;
|
|
301
|
+
|
|
302
|
+
const int thread = linear/NumRegs;
|
|
303
|
+
const int reg = linear%NumRegs;
|
|
304
|
+
|
|
305
|
+
WP_TILE_SHARED Type scratch;
|
|
306
|
+
|
|
307
|
+
// ensure any previously scheduled threads have finished reading from scratch
|
|
308
|
+
WP_TILE_SYNC();
|
|
309
|
+
|
|
310
|
+
if (threadIdx.x == thread)
|
|
311
|
+
{
|
|
312
|
+
scratch = data[reg];
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
// ensure extraction thread has updated smem
|
|
316
|
+
WP_TILE_SYNC();
|
|
317
|
+
|
|
318
|
+
return scratch;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
// backward version of scalar extract
|
|
323
|
+
inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
|
|
324
|
+
{
|
|
325
|
+
// map from logical coords (i, j) -> (thread, reg)
|
|
326
|
+
const int linear = i*N + j;
|
|
327
|
+
|
|
328
|
+
const int thread = linear/NumRegs;
|
|
329
|
+
const int reg = linear%NumRegs;
|
|
330
|
+
|
|
331
|
+
if (threadIdx.x == thread)
|
|
332
|
+
{
|
|
333
|
+
data[reg] += adj_ret;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
inline CUDA_CALLABLE void print() const;
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
// return the in-register version of this tile (nop)
|
|
341
|
+
inline CUDA_CALLABLE auto& copy_to_register()
|
|
342
|
+
{
|
|
343
|
+
return *this;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
inline CUDA_CALLABLE const auto& copy_to_register() const
|
|
347
|
+
{
|
|
348
|
+
return *this;
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
// in-place gradient zero
|
|
352
|
+
inline CUDA_CALLABLE void grad_zero()
|
|
353
|
+
{
|
|
354
|
+
zero();
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
// accumulate gradients onto this tile
|
|
358
|
+
inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
|
|
359
|
+
{
|
|
360
|
+
for (int i=0; i < NumRegs; ++i)
|
|
361
|
+
data[i] += tile.data[i];
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// copy shared tile to register
|
|
365
|
+
inline CUDA_CALLABLE auto& grad_to_register()
|
|
366
|
+
{
|
|
367
|
+
return *this;
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
void copy_to_global(array_t<T> dest, int x)
|
|
371
|
+
{
|
|
372
|
+
assert(dest.ndim == 1);
|
|
373
|
+
|
|
374
|
+
const int tile_i = x*N;
|
|
375
|
+
|
|
376
|
+
WP_PRAGMA_UNROLL
|
|
377
|
+
for (int i=0; i < NumRegs; ++i)
|
|
378
|
+
{
|
|
379
|
+
// handle case where tile size is not
|
|
380
|
+
// aligned to block dimensions
|
|
381
|
+
int linear = index(i);
|
|
382
|
+
if (!Aligned && linear >= Size)
|
|
383
|
+
break;
|
|
384
|
+
|
|
385
|
+
wp::index(dest, tile_i + linear) = data[i];
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
void copy_to_global(array_t<T> dest, int x, int y)
|
|
390
|
+
{
|
|
391
|
+
assert(dest.ndim == 2);
|
|
392
|
+
|
|
393
|
+
const int tile_i = x*M;
|
|
394
|
+
const int tile_j = y*N;
|
|
395
|
+
|
|
396
|
+
// wp.array() indexing generates poor code due to char* casting
|
|
397
|
+
// here we unroll some of the ops, note this assumes byte strides are
|
|
398
|
+
// aligned to the element size
|
|
399
|
+
T* ptr = &wp::index(dest, tile_i, tile_j);
|
|
400
|
+
const int stride_i = dest.strides[0]/sizeof(T);
|
|
401
|
+
const int stride_j = dest.strides[1]/sizeof(T);
|
|
402
|
+
|
|
403
|
+
WP_PRAGMA_UNROLL
|
|
404
|
+
for (int i=0; i < NumRegs; ++i)
|
|
405
|
+
{
|
|
406
|
+
// handle case where tile size is not
|
|
407
|
+
// aligned to block dimensions
|
|
408
|
+
int linear = index(i);
|
|
409
|
+
if (!Aligned && linear >= Size)
|
|
410
|
+
break;
|
|
411
|
+
|
|
412
|
+
coord_t c = coord(linear);
|
|
413
|
+
ptr[c.i*stride_i + c.j*stride_j] = data[i];
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
|
|
418
|
+
{
|
|
419
|
+
// todo: use async pipelines or TMA here
|
|
420
|
+
const int tile_i = x*N;
|
|
421
|
+
|
|
422
|
+
WP_PRAGMA_UNROLL
|
|
423
|
+
for (int i=0; i < NumRegs; ++i)
|
|
424
|
+
{
|
|
425
|
+
int linear = index(i);
|
|
426
|
+
if (!Aligned && linear >= Size)
|
|
427
|
+
break;
|
|
428
|
+
|
|
429
|
+
data[i] = wp::index(src, tile_i + linear);
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
|
|
434
|
+
{
|
|
435
|
+
// todo: use async pipelines or TMA here
|
|
436
|
+
const int tile_i = x*M;
|
|
437
|
+
const int tile_j = y*N;
|
|
438
|
+
|
|
439
|
+
// wp.array() indexing generates poor code due to char* casting
|
|
440
|
+
// here we unroll some of the ops, note this assumes array byte strides are
|
|
441
|
+
// aligned to the element size
|
|
442
|
+
const T* ptr = &wp::index(src, tile_i, tile_j);
|
|
443
|
+
|
|
444
|
+
assert(src.strides[0]%sizeof(T) == 0);
|
|
445
|
+
assert(src.strides[1]%sizeof(T) == 0);
|
|
446
|
+
|
|
447
|
+
const int stride_i = src.strides[0]/sizeof(T);
|
|
448
|
+
const int stride_j = src.strides[1]/sizeof(T);
|
|
449
|
+
|
|
450
|
+
WP_PRAGMA_UNROLL
|
|
451
|
+
for (int i=0; i < NumRegs; ++i)
|
|
452
|
+
{
|
|
453
|
+
int linear = index(i);
|
|
454
|
+
if (!Aligned && linear >= Size)
|
|
455
|
+
break;
|
|
456
|
+
|
|
457
|
+
coord_t c = coord(linear);
|
|
458
|
+
data[i] = ptr[c.i*stride_i + c.j*stride_j];
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
};
|
|
462
|
+
|
|
463
|
+
// helper to allocate a register tile like another tile
|
|
464
|
+
template<typename Tile>
|
|
465
|
+
auto tile_register_like()
|
|
466
|
+
{
|
|
467
|
+
using T = typename Tile::Type;
|
|
468
|
+
|
|
469
|
+
return tile_register_t<T, Tile::M, Tile::N>(T(0.0));
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
473
|
+
{
|
|
474
|
+
// note this much match value in Python types.py
|
|
475
|
+
const int alignment = 16;
|
|
476
|
+
|
|
477
|
+
return ((num_bytes + alignment - 1) / alignment) * alignment;
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
|
|
481
|
+
{
|
|
482
|
+
// we maintain a per-thread offset into dynamic
|
|
483
|
+
// shared memory that allows us to keep track of
|
|
484
|
+
// current use across dynamic function calls
|
|
485
|
+
__shared__ int smem_base[WP_TILE_BLOCK_DIM];
|
|
486
|
+
|
|
487
|
+
if (init)
|
|
488
|
+
{
|
|
489
|
+
smem_base[threadIdx.x] = 0;
|
|
490
|
+
return NULL;
|
|
491
|
+
}
|
|
492
|
+
else
|
|
493
|
+
{
|
|
494
|
+
const int offset = smem_base[threadIdx.x];
|
|
495
|
+
|
|
496
|
+
// one entry per-thread so no need for synchronization
|
|
497
|
+
smem_base[threadIdx.x] += tile_align(num_bytes);
|
|
498
|
+
|
|
499
|
+
extern __shared__ char dynamic_smem_base[];
|
|
500
|
+
return &(dynamic_smem_base[offset]);
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
template <typename T, int M_, int N_, int StrideM_=N_, int StrideN_=1, bool Owner_=true>
|
|
507
|
+
struct tile_shared_t
|
|
508
|
+
{
|
|
509
|
+
using Type = T;
|
|
510
|
+
static constexpr int M = M_;
|
|
511
|
+
static constexpr int N = N_;
|
|
512
|
+
static constexpr int Size = M*N;
|
|
513
|
+
|
|
514
|
+
static constexpr int StrideM = StrideM_;
|
|
515
|
+
static constexpr int StrideN = StrideN_;
|
|
516
|
+
|
|
517
|
+
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
518
|
+
static constexpr bool Unique = (StrideM >= N) && (StrideN >= 1);
|
|
519
|
+
static constexpr bool Owner = Owner_;
|
|
520
|
+
|
|
521
|
+
struct Storage
|
|
522
|
+
{
|
|
523
|
+
T* ptr;
|
|
524
|
+
|
|
525
|
+
Storage(T* p) : ptr(p) {}
|
|
526
|
+
|
|
527
|
+
inline CUDA_CALLABLE T& operator()(int i, int j)
|
|
528
|
+
{
|
|
529
|
+
assert(i < M);
|
|
530
|
+
assert(j < N);
|
|
531
|
+
|
|
532
|
+
return ptr[i*StrideM + j*StrideN];
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
inline CUDA_CALLABLE const T& operator()(int i, int j) const
|
|
536
|
+
{
|
|
537
|
+
assert(i < M);
|
|
538
|
+
assert(j < N);
|
|
539
|
+
|
|
540
|
+
return ptr[i*StrideM + j*StrideN];
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
inline CUDA_CALLABLE T& operator()(int index)
|
|
544
|
+
{
|
|
545
|
+
assert(index < M*N);
|
|
546
|
+
|
|
547
|
+
// unravel
|
|
548
|
+
int i = index/N;
|
|
549
|
+
int j = index%N;
|
|
550
|
+
|
|
551
|
+
return (*this)(i,j);
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
inline CUDA_CALLABLE const T& operator()(int index) const
|
|
555
|
+
{
|
|
556
|
+
assert(index < M*N);
|
|
557
|
+
|
|
558
|
+
// unravel
|
|
559
|
+
int i = index/N;
|
|
560
|
+
int j = index%N;
|
|
561
|
+
|
|
562
|
+
return (*this)(i,j);
|
|
563
|
+
}
|
|
564
|
+
};
|
|
565
|
+
|
|
566
|
+
Storage data;
|
|
567
|
+
Storage grad;
|
|
568
|
+
|
|
569
|
+
// default initialization (non-initialized)
|
|
570
|
+
inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL)
|
|
571
|
+
{
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
// initialize from an existing tile's memory
|
|
575
|
+
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL) : data(data), grad(grad)
|
|
576
|
+
{
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
inline CUDA_CALLABLE ~tile_shared_t()
|
|
580
|
+
{
|
|
581
|
+
if (Owner)
|
|
582
|
+
{
|
|
583
|
+
// update our per-thread shared memory allocator
|
|
584
|
+
if (data.ptr)
|
|
585
|
+
tile_alloc_shared(-M*N*int(sizeof(T)));
|
|
586
|
+
|
|
587
|
+
if (grad.ptr)
|
|
588
|
+
tile_alloc_shared(-M*N*int(sizeof(T)));
|
|
589
|
+
}
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
// assign from a register tile
|
|
593
|
+
template <typename Tile>
|
|
594
|
+
inline CUDA_CALLABLE auto& operator=(const Tile& t)
|
|
595
|
+
{
|
|
596
|
+
assign(t);
|
|
597
|
+
return *this;
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
// construct from another shared tile, this constructor
|
|
601
|
+
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
602
|
+
template <typename OtherT, int OtherM, int OtherN, int OtherStrideM, int OtherStrideN>
|
|
603
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>& rhs)
|
|
604
|
+
{
|
|
605
|
+
using OtherTile = tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>;
|
|
606
|
+
|
|
607
|
+
// check dimensions are compatible
|
|
608
|
+
static_assert(Size == OtherTile::Size);
|
|
609
|
+
|
|
610
|
+
// alias tile directly
|
|
611
|
+
data = rhs.data;
|
|
612
|
+
grad = rhs.grad;
|
|
613
|
+
|
|
614
|
+
return *this;
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// assign from a global tile (load)
|
|
618
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
|
|
619
|
+
{
|
|
620
|
+
if (t.data.ndim == 1)
|
|
621
|
+
copy_from_global(t.data, t.x); // 1d load
|
|
622
|
+
else
|
|
623
|
+
copy_from_global(t.data, t.x, t.y); // 2d load
|
|
624
|
+
|
|
625
|
+
// synchronization happens in copy functions above
|
|
626
|
+
|
|
627
|
+
return *this;
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
// assign from a constant value
|
|
631
|
+
inline CUDA_CALLABLE auto& operator=(const T& x)
|
|
632
|
+
{
|
|
633
|
+
for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
|
|
634
|
+
data(i) = x;
|
|
635
|
+
|
|
636
|
+
WP_TILE_SYNC();
|
|
637
|
+
return *this;
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
// compute tile coordinate from linear index
|
|
642
|
+
inline CUDA_CALLABLE coord_t coord(int index) const
|
|
643
|
+
{
|
|
644
|
+
return {index/N, index%N};
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
// in-place zero
|
|
648
|
+
inline CUDA_CALLABLE void zero()
|
|
649
|
+
{
|
|
650
|
+
for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
|
|
651
|
+
data(i) = T(0);
|
|
652
|
+
|
|
653
|
+
WP_TILE_SYNC();
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
// extract a single tile element to a native type
|
|
657
|
+
inline CUDA_CALLABLE Type extract(int i, int j)
|
|
658
|
+
{
|
|
659
|
+
return data(i, j);
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
// backward of scalar extraction
|
|
663
|
+
inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
|
|
664
|
+
{
|
|
665
|
+
if (threadIdx.x == 0)
|
|
666
|
+
data(i, j) += adj_ret;
|
|
667
|
+
|
|
668
|
+
WP_TILE_SYNC();
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
// copy register tile to shared
|
|
673
|
+
inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
|
|
674
|
+
{
|
|
675
|
+
WP_PRAGMA_UNROLL
|
|
676
|
+
for (int i=0; i < tile.NumRegs; ++i)
|
|
677
|
+
{
|
|
678
|
+
const int linear = tile.index(i);
|
|
679
|
+
|
|
680
|
+
// handle case where tile size is not
|
|
681
|
+
// aligned to block dimensions
|
|
682
|
+
if (!Aligned && linear >= Size)
|
|
683
|
+
break;
|
|
684
|
+
|
|
685
|
+
data(linear) = tile.data[i];
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
WP_TILE_SYNC();
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
// in-place gradient zero
|
|
692
|
+
inline CUDA_CALLABLE void grad_zero()
|
|
693
|
+
{
|
|
694
|
+
// todo: make this subtile (stride aware)
|
|
695
|
+
for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
|
|
696
|
+
grad(i) = T(0);
|
|
697
|
+
|
|
698
|
+
WP_TILE_SYNC();
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
// accumulate gradients onto this tile
|
|
703
|
+
inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
|
|
704
|
+
{
|
|
705
|
+
WP_PRAGMA_UNROLL
|
|
706
|
+
for (int i=0; i < tile.NumRegs; ++i)
|
|
707
|
+
{
|
|
708
|
+
const int linear = tile.index(i);
|
|
709
|
+
|
|
710
|
+
// handle case where tile size is not
|
|
711
|
+
// aligned to block dimensions
|
|
712
|
+
if (!Aligned && linear >= Size)
|
|
713
|
+
break;
|
|
714
|
+
|
|
715
|
+
if (Unique)
|
|
716
|
+
grad(linear) += tile.data[i];
|
|
717
|
+
else
|
|
718
|
+
// use shared memory atomics to accumulate gradients
|
|
719
|
+
// since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
|
|
720
|
+
// may map to a single location in shared memory
|
|
721
|
+
atomic_add(&grad(linear), tile.data[i]);
|
|
722
|
+
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
WP_TILE_SYNC();
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
// copy shared tile to register
|
|
729
|
+
inline CUDA_CALLABLE tile_register_t<T, M, N> grad_to_register()
|
|
730
|
+
{
|
|
731
|
+
tile_register_t<T, M, N> out;
|
|
732
|
+
|
|
733
|
+
WP_PRAGMA_UNROLL
|
|
734
|
+
for (int i=0; i < out.NumRegs; ++i)
|
|
735
|
+
{
|
|
736
|
+
const int linear = out.index(i);
|
|
737
|
+
|
|
738
|
+
// handle case where tile size is not
|
|
739
|
+
// aligned to block dimensions
|
|
740
|
+
if (!Aligned && linear >= Size)
|
|
741
|
+
break;
|
|
742
|
+
|
|
743
|
+
out(i) = grad(linear);
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
return out;
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
inline CUDA_CALLABLE void print() const
|
|
750
|
+
{
|
|
751
|
+
if (threadIdx.x == 0)
|
|
752
|
+
{
|
|
753
|
+
printf("tile(m=%d, n=%d, storage=shared) = [", M, N);
|
|
754
|
+
for (int i=0; i < M; ++i)
|
|
755
|
+
{
|
|
756
|
+
printf("%*s[", i>0, "");
|
|
757
|
+
for (int j=0; j < N; ++j)
|
|
758
|
+
{
|
|
759
|
+
printf("%g ", double(data(i, j)));
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
if (i == M-1)
|
|
763
|
+
printf("]]\n");
|
|
764
|
+
else
|
|
765
|
+
printf("]\n");
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
// copy shared tile to register
|
|
771
|
+
inline CUDA_CALLABLE tile_register_t<T, M, N> copy_to_register() const
|
|
772
|
+
{
|
|
773
|
+
tile_register_t<T, M, N> out;
|
|
774
|
+
|
|
775
|
+
WP_PRAGMA_UNROLL
|
|
776
|
+
for (int i=0; i < out.NumRegs; ++i)
|
|
777
|
+
{
|
|
778
|
+
const int linear = out.index(i);
|
|
779
|
+
|
|
780
|
+
// handle case where tile size is not
|
|
781
|
+
// aligned to block dimensions
|
|
782
|
+
if (!Aligned && linear >= Size)
|
|
783
|
+
break;
|
|
784
|
+
|
|
785
|
+
out(i) = data(linear);
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
return out;
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x) const
|
|
792
|
+
{
|
|
793
|
+
assert(dest.ndim == 1);
|
|
794
|
+
|
|
795
|
+
// todo: use TMA here
|
|
796
|
+
const int tile_i = x*N;
|
|
797
|
+
|
|
798
|
+
WP_PRAGMA_UNROLL
|
|
799
|
+
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
800
|
+
{
|
|
801
|
+
wp::index(dest, tile_i + i) = data(i);
|
|
802
|
+
}
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x, int y)
|
|
806
|
+
{
|
|
807
|
+
// todo: use TMA here
|
|
808
|
+
const int tile_i = x*M;
|
|
809
|
+
const int tile_j = y*N;
|
|
810
|
+
|
|
811
|
+
// check each row is contiguous and 128bit aligned
|
|
812
|
+
if (StrideN == 1 && dest.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
|
|
813
|
+
{
|
|
814
|
+
constexpr int num_rows = M;
|
|
815
|
+
constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
|
|
816
|
+
|
|
817
|
+
tile_shared_t<float4, num_rows, num_cols> src128((float4*)data.ptr);
|
|
818
|
+
|
|
819
|
+
// alias of shared tile with 128bit type
|
|
820
|
+
float4* ptr = (float4*)&wp::index(dest, tile_i, tile_j);
|
|
821
|
+
|
|
822
|
+
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
823
|
+
assert(((uint64_t)(ptr))%sizeof(float4) == 0);
|
|
824
|
+
|
|
825
|
+
const int stride_i = dest.strides[0]/sizeof(float4);
|
|
826
|
+
const int stride_j = 1;
|
|
827
|
+
|
|
828
|
+
WP_PRAGMA_UNROLL
|
|
829
|
+
for (int i=threadIdx.x; i < src128.Size; i += WP_TILE_BLOCK_DIM)
|
|
830
|
+
{
|
|
831
|
+
coord_t c = src128.coord(i);
|
|
832
|
+
ptr[c.i*stride_i + c.j*stride_j] = src128.data(i);
|
|
833
|
+
}
|
|
834
|
+
}
|
|
835
|
+
else
|
|
836
|
+
{
|
|
837
|
+
// wp.array() indexing generates poor code due to char* casting
|
|
838
|
+
// here we unroll some of the ops, note this assumes byte strides are
|
|
839
|
+
// aligned to the element size
|
|
840
|
+
T* ptr = &wp::index(dest, tile_i, tile_j);
|
|
841
|
+
const int stride_i = dest.strides[0]/sizeof(T);
|
|
842
|
+
const int stride_j = dest.strides[1]/sizeof(T);
|
|
843
|
+
|
|
844
|
+
WP_PRAGMA_UNROLL
|
|
845
|
+
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
846
|
+
{
|
|
847
|
+
coord_t c = coord(i);
|
|
848
|
+
ptr[c.i*stride_i + c.j*stride_j] = data(c.i, c.j);
|
|
849
|
+
}
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
|
|
854
|
+
{
|
|
855
|
+
// todo: use async pipelines or TMA here
|
|
856
|
+
const int tile_i = x*N;
|
|
857
|
+
|
|
858
|
+
WP_PRAGMA_UNROLL
|
|
859
|
+
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
860
|
+
{
|
|
861
|
+
data(i) = wp::index(src, tile_i + i);
|
|
862
|
+
}
|
|
863
|
+
|
|
864
|
+
WP_TILE_SYNC();
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
|
|
868
|
+
{
|
|
869
|
+
// todo: use async pipelines or TMA here
|
|
870
|
+
const int tile_i = x*M;
|
|
871
|
+
const int tile_j = y*N;
|
|
872
|
+
|
|
873
|
+
// check each row is contiguous and 128bit aligned
|
|
874
|
+
if (StrideN == 1 && src.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
|
|
875
|
+
{
|
|
876
|
+
constexpr int num_rows = M;
|
|
877
|
+
constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
|
|
878
|
+
|
|
879
|
+
// alias of shared tile with 128bit type
|
|
880
|
+
tile_shared_t<float4, num_rows, num_cols> dest128((float4*)data.ptr);
|
|
881
|
+
|
|
882
|
+
const float4* ptr = (const float4*)&wp::index(src, tile_i, tile_j);
|
|
883
|
+
|
|
884
|
+
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
885
|
+
assert(((uint64_t)(ptr))%sizeof(float4) == 0);
|
|
886
|
+
|
|
887
|
+
const int stride_i = src.strides[0]/sizeof(float4);
|
|
888
|
+
//const int stride_j = 1;
|
|
889
|
+
|
|
890
|
+
WP_PRAGMA_UNROLL
|
|
891
|
+
for (int i=threadIdx.x; i < dest128.Size; i += WP_TILE_BLOCK_DIM)
|
|
892
|
+
{
|
|
893
|
+
coord_t c = dest128.coord(i);
|
|
894
|
+
|
|
895
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
896
|
+
__pipeline_memcpy_async(&dest128.data(i),
|
|
897
|
+
&ptr[c.i*stride_i + c.j],
|
|
898
|
+
sizeof(float4));
|
|
899
|
+
#else
|
|
900
|
+
dest128.data(i) = ptr[c.i*stride_i + c.j];
|
|
901
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
905
|
+
__pipeline_commit();
|
|
906
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
907
|
+
|
|
908
|
+
}
|
|
909
|
+
else
|
|
910
|
+
{
|
|
911
|
+
// wp.array() indexing generates poor code due to char* casting
|
|
912
|
+
// here we unroll some of the ops, note this assumes array byte strides are
|
|
913
|
+
// aligned to the element size
|
|
914
|
+
const T* ptr = &wp::index(src, tile_i, tile_j);
|
|
915
|
+
|
|
916
|
+
assert(src.strides[0]%sizeof(T) == 0);
|
|
917
|
+
assert(src.strides[1]%sizeof(T) == 0);
|
|
918
|
+
|
|
919
|
+
const int stride_i = src.strides[0]/sizeof(T);
|
|
920
|
+
const int stride_j = src.strides[1]/sizeof(T);
|
|
921
|
+
|
|
922
|
+
WP_PRAGMA_UNROLL
|
|
923
|
+
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
924
|
+
{
|
|
925
|
+
coord_t c = coord(i);
|
|
926
|
+
data(c.i, c.j) = ptr[c.i*stride_i + c.j*stride_j];
|
|
927
|
+
}
|
|
928
|
+
}
|
|
929
|
+
|
|
930
|
+
#if !WP_USE_ASYNC_PIPELINE
|
|
931
|
+
WP_TILE_SYNC();
|
|
932
|
+
#endif
|
|
933
|
+
|
|
934
|
+
}
|
|
935
|
+
};
|
|
936
|
+
|
|
937
|
+
template <typename T, int M, int N>
|
|
938
|
+
void tile_register_t<T, M, N>::print() const
|
|
939
|
+
{
|
|
940
|
+
// create a temporary shared tile so that
|
|
941
|
+
// we can print it deterministically
|
|
942
|
+
WP_TILE_SHARED T smem[M*N];
|
|
943
|
+
|
|
944
|
+
tile_shared_t<T, M, N> scratch(smem, NULL);
|
|
945
|
+
scratch.assign(*this);
|
|
946
|
+
|
|
947
|
+
WP_TILE_SYNC();
|
|
948
|
+
|
|
949
|
+
if (threadIdx.x == 0)
|
|
950
|
+
{
|
|
951
|
+
printf("tile(m=%d, n=%d, storage=register) = [", M, N);
|
|
952
|
+
for (int i=0; i < M; ++i)
|
|
953
|
+
{
|
|
954
|
+
printf("%*s[", i>0, "");
|
|
955
|
+
for (int j=0; j < N; ++j)
|
|
956
|
+
{
|
|
957
|
+
printf("%g ", double(scratch.data(i, j)));
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
if (i == M-1)
|
|
961
|
+
printf("]]\n");
|
|
962
|
+
else
|
|
963
|
+
printf("]\n");
|
|
964
|
+
}
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
WP_TILE_SYNC();
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
template <typename T, int M, int N>
|
|
971
|
+
inline CUDA_CALLABLE void print(const tile_register_t<T, M, N>& t)
|
|
972
|
+
{
|
|
973
|
+
t.print();
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
template <typename T, int M, int N>
|
|
977
|
+
inline CUDA_CALLABLE void adj_print(const tile_register_t<T, M, N>& t, const tile_register_t<T, M, N>& a)
|
|
978
|
+
{
|
|
979
|
+
a.print();
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
|
|
983
|
+
inline CUDA_CALLABLE void print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t)
|
|
984
|
+
{
|
|
985
|
+
t.print();
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
|
|
989
|
+
inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t, const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& a)
|
|
990
|
+
{
|
|
991
|
+
a.print();
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
// helpers to allocate shared tiles
|
|
995
|
+
template <typename T, int M, int N, bool RequiresGrad>
|
|
996
|
+
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
997
|
+
|
|
998
|
+
{ constexpr int Len = M*N;
|
|
999
|
+
T* data = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1000
|
+
T* grad = NULL;
|
|
1001
|
+
|
|
1002
|
+
#if FP_CHECK
|
|
1003
|
+
|
|
1004
|
+
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1005
|
+
data[i] = T(nanf(""));
|
|
1006
|
+
|
|
1007
|
+
WP_TILE_SYNC();
|
|
1008
|
+
|
|
1009
|
+
#endif // FP_CHECK
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
if (RequiresGrad)
|
|
1013
|
+
{
|
|
1014
|
+
grad = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1015
|
+
|
|
1016
|
+
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1017
|
+
grad[i] = T(0);
|
|
1018
|
+
|
|
1019
|
+
WP_TILE_SYNC();
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
return tile_shared_t<T, M, N>(data, grad);
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
template <typename T, int M, int N, bool RequiresGrad>
|
|
1026
|
+
inline CUDA_CALLABLE auto tile_alloc_zeros()
|
|
1027
|
+
{
|
|
1028
|
+
// compute the total storage required for the tile (may be different from M*N) for broadcast tiles
|
|
1029
|
+
constexpr int Len = M*N;
|
|
1030
|
+
T* data = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1031
|
+
T* grad = NULL;
|
|
1032
|
+
|
|
1033
|
+
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1034
|
+
data[i] = T(0);
|
|
1035
|
+
|
|
1036
|
+
if (RequiresGrad)
|
|
1037
|
+
{
|
|
1038
|
+
grad = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1039
|
+
|
|
1040
|
+
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1041
|
+
grad[i] = T(0);
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
WP_TILE_SYNC();
|
|
1045
|
+
|
|
1046
|
+
return tile_shared_t<T, M, N, StrideM, StrideN>(data, grad);
|
|
1047
|
+
}
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
//-----------------------------------------------------------------------------------------------------
|
|
1051
|
+
// High level entry points for each op (correspond to one Warp builtin)
|
|
1052
|
+
|
|
1053
|
+
// construct a tile from a local SIMT value (one per-thread)
|
|
1054
|
+
template <typename T>
|
|
1055
|
+
inline CUDA_CALLABLE auto tile(const T& x)
|
|
1056
|
+
{
|
|
1057
|
+
tile_register_t<T, 1, WP_TILE_BLOCK_DIM> result;
|
|
1058
|
+
|
|
1059
|
+
static_assert(result.NumRegs == 1);
|
|
1060
|
+
|
|
1061
|
+
result.data[0] = x;
|
|
1062
|
+
return result;
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
// overload for constructing a tile from a per-thread vector
|
|
1066
|
+
template <typename T, unsigned Length>
|
|
1067
|
+
inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
1068
|
+
{
|
|
1069
|
+
tile_register_t<T, Length, WP_TILE_BLOCK_DIM> result;
|
|
1070
|
+
|
|
1071
|
+
static_assert(result.NumRegs == Length);
|
|
1072
|
+
|
|
1073
|
+
for (int i=0; i < Length; ++i)
|
|
1074
|
+
result.data[i] = x[i];
|
|
1075
|
+
|
|
1076
|
+
return result;
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
// construct a tile from a local SIMT value (one per-thread)
|
|
1080
|
+
template <typename T, typename AdjTile>
|
|
1081
|
+
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1082
|
+
{
|
|
1083
|
+
static_assert(AdjTile::M == 1);
|
|
1084
|
+
static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
|
|
1085
|
+
|
|
1086
|
+
auto adj_reg = adj_ret.copy_to_register();
|
|
1087
|
+
|
|
1088
|
+
adj_x += adj_reg.data[0];
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
template <typename T, unsigned Length, typename AdjTile>
|
|
1092
|
+
inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
|
|
1093
|
+
{
|
|
1094
|
+
static_assert(AdjTile::M == Length);
|
|
1095
|
+
static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
|
|
1096
|
+
|
|
1097
|
+
auto adj_reg = adj_ret.copy_to_register();
|
|
1098
|
+
|
|
1099
|
+
for (int i=0; i < Length; ++i)
|
|
1100
|
+
adj_x[i] += adj_reg.data[i];
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
template <typename Tile>
|
|
1104
|
+
inline CUDA_CALLABLE auto untile(Tile& tile)
|
|
1105
|
+
{
|
|
1106
|
+
// code-gen should have set the tile to
|
|
1107
|
+
// have exactly the block dimension so
|
|
1108
|
+
// there is exactly one value per-thread
|
|
1109
|
+
auto reg = tile.copy_to_register();
|
|
1110
|
+
|
|
1111
|
+
// scalar case
|
|
1112
|
+
if constexpr(Tile::M == 1)
|
|
1113
|
+
{
|
|
1114
|
+
return reg.data[0];
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
// vector case
|
|
1118
|
+
if constexpr(Tile::M > 1)
|
|
1119
|
+
{
|
|
1120
|
+
wp::vec_t<Tile::M, typename Tile::Type> v;
|
|
1121
|
+
for (int i=0; i < Tile::M; ++i)
|
|
1122
|
+
v[i] = reg.data[i];
|
|
1123
|
+
|
|
1124
|
+
return v;
|
|
1125
|
+
}
|
|
1126
|
+
}
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
template <typename Tile, typename Value>
|
|
1131
|
+
inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
|
|
1132
|
+
{
|
|
1133
|
+
auto adj = adj_tile.copy_to_register();
|
|
1134
|
+
|
|
1135
|
+
// scalar case
|
|
1136
|
+
if constexpr(Tile::M == 1)
|
|
1137
|
+
{
|
|
1138
|
+
adj.data[0] += adj_ret;
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
// vector case
|
|
1142
|
+
if constexpr(Tile::M > 1)
|
|
1143
|
+
{
|
|
1144
|
+
for (int i=0; i < Tile::M; ++i)
|
|
1145
|
+
adj.data[i] = adj_ret[i];
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
adj_tile.assign(adj);
|
|
1149
|
+
}
|
|
1150
|
+
|
|
1151
|
+
// zero initialized tile
|
|
1152
|
+
template <typename T, int M, int N>
|
|
1153
|
+
inline CUDA_CALLABLE auto tile_zeros()
|
|
1154
|
+
{
|
|
1155
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
1156
|
+
return T(0);
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
// zero initialized tile
|
|
1160
|
+
template <typename T, int M, int N>
|
|
1161
|
+
inline CUDA_CALLABLE auto tile_ones()
|
|
1162
|
+
{
|
|
1163
|
+
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
1164
|
+
return T(1);
|
|
1165
|
+
}
|
|
1166
|
+
|
|
1167
|
+
// zero initialized tile
|
|
1168
|
+
template <typename T, int M, int N>
|
|
1169
|
+
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
|
|
1170
|
+
{
|
|
1171
|
+
tile_register_t<T, M, N> out;
|
|
1172
|
+
|
|
1173
|
+
WP_PRAGMA_UNROLL
|
|
1174
|
+
for (int i=0; i < out.NumRegs; ++i)
|
|
1175
|
+
{
|
|
1176
|
+
const int linear = out.index(i);
|
|
1177
|
+
|
|
1178
|
+
// handle case where tile size is not
|
|
1179
|
+
// aligned to block dimensions
|
|
1180
|
+
if (!out.Aligned && linear >= out.Size)
|
|
1181
|
+
break;
|
|
1182
|
+
|
|
1183
|
+
out.data[i] = start + linear*step;
|
|
1184
|
+
}
|
|
1185
|
+
|
|
1186
|
+
return out;
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
template <typename T, typename AdjTile>
|
|
1190
|
+
inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
|
|
1191
|
+
T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
|
|
1192
|
+
|
|
1193
|
+
// entry point for 1d load
|
|
1194
|
+
template <typename T, int N>
|
|
1195
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x)
|
|
1196
|
+
{
|
|
1197
|
+
return tile_global_t<T>(src, x, 0);
|
|
1198
|
+
}
|
|
1199
|
+
|
|
1200
|
+
// entry point for 2d load
|
|
1201
|
+
template <typename T, int M, int N>
|
|
1202
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x, int y)
|
|
1203
|
+
{
|
|
1204
|
+
return tile_global_t<T>(src, x, y);
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
// entry point for 1d store
|
|
1208
|
+
template <typename T, typename Tile>
|
|
1209
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src)
|
|
1210
|
+
{
|
|
1211
|
+
// dispatch to tile type
|
|
1212
|
+
src.copy_to_global(dest, x);
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
// entry point for 2d store
|
|
1216
|
+
template <typename T, typename Tile>
|
|
1217
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src)
|
|
1218
|
+
{
|
|
1219
|
+
// dispatch to tile type
|
|
1220
|
+
src.copy_to_global(dest, x, y);
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
// entry point for store
|
|
1224
|
+
template <typename T, typename Tile>
|
|
1225
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src)
|
|
1226
|
+
{
|
|
1227
|
+
auto src_reg = src.copy_to_register();
|
|
1228
|
+
|
|
1229
|
+
const int tile_i = x*src_reg.M;
|
|
1230
|
+
const int tile_j = y*src_reg.N;
|
|
1231
|
+
|
|
1232
|
+
tile_register_t<T, src_reg.M, src_reg.N> previous;
|
|
1233
|
+
|
|
1234
|
+
WP_PRAGMA_UNROLL
|
|
1235
|
+
for (int i=0; i < src_reg.NumRegs; ++i)
|
|
1236
|
+
{
|
|
1237
|
+
// handle case where tile size is not
|
|
1238
|
+
// aligned to block dimensions
|
|
1239
|
+
int linear = src_reg.index(i);
|
|
1240
|
+
if (!src_reg.Aligned && linear >= src_reg.Size)
|
|
1241
|
+
break;
|
|
1242
|
+
|
|
1243
|
+
coord_t c = src_reg.coord(linear);
|
|
1244
|
+
previous.data[i] = atomic_add(dest, tile_i + c.i, tile_j + c.j, src_reg.data[i]);
|
|
1245
|
+
}
|
|
1246
|
+
|
|
1247
|
+
return previous;
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
//-------------------------------------
|
|
1253
|
+
// Adjoints
|
|
1254
|
+
|
|
1255
|
+
template <typename T, typename AdjTile>
|
|
1256
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x,
|
|
1257
|
+
array_t<T>& adj_src, int adj_x,
|
|
1258
|
+
AdjTile& adj_ret)
|
|
1259
|
+
{
|
|
1260
|
+
// early out
|
|
1261
|
+
// if (!src.grad)
|
|
1262
|
+
// return;
|
|
1263
|
+
|
|
1264
|
+
auto adj_reg = adj_ret.grad_to_register();
|
|
1265
|
+
|
|
1266
|
+
const int tile_i = x*adj_reg.N;
|
|
1267
|
+
|
|
1268
|
+
// add gradients to src array
|
|
1269
|
+
WP_PRAGMA_UNROLL
|
|
1270
|
+
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1271
|
+
{
|
|
1272
|
+
int linear = adj_reg.index(i);
|
|
1273
|
+
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1274
|
+
break;
|
|
1275
|
+
|
|
1276
|
+
auto grad = adj_reg.data[i];
|
|
1277
|
+
|
|
1278
|
+
if (adj_src.data)
|
|
1279
|
+
adj_atomic_add(&index(adj_src, tile_i + linear), grad);
|
|
1280
|
+
else if (src.grad)
|
|
1281
|
+
adj_atomic_add(&index_grad(src, tile_i + linear), grad);
|
|
1282
|
+
}
|
|
1283
|
+
}
|
|
1284
|
+
|
|
1285
|
+
template <typename T, typename AdjTile>
|
|
1286
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y,
|
|
1287
|
+
array_t<T>& adj_src, int adj_x, int adj_y,
|
|
1288
|
+
AdjTile& adj_ret)
|
|
1289
|
+
{
|
|
1290
|
+
// early out
|
|
1291
|
+
// if (!src.grad)
|
|
1292
|
+
// return;
|
|
1293
|
+
|
|
1294
|
+
auto adj_reg = adj_ret.grad_to_register();
|
|
1295
|
+
|
|
1296
|
+
const int tile_i = x*adj_reg.M;
|
|
1297
|
+
const int tile_j = y*adj_reg.N;
|
|
1298
|
+
|
|
1299
|
+
// add gradients to src array
|
|
1300
|
+
WP_PRAGMA_UNROLL
|
|
1301
|
+
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1302
|
+
{
|
|
1303
|
+
int linear = adj_reg.index(i);
|
|
1304
|
+
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1305
|
+
break;
|
|
1306
|
+
|
|
1307
|
+
coord_t coord = adj_reg.coord(linear);
|
|
1308
|
+
|
|
1309
|
+
auto grad = adj_reg.data[i];
|
|
1310
|
+
|
|
1311
|
+
if (adj_src.data)
|
|
1312
|
+
adj_atomic_add(&index(adj_src, tile_i + coord.i, tile_j + coord.j), grad);
|
|
1313
|
+
else if (src.grad)
|
|
1314
|
+
adj_atomic_add(&index_grad(src, tile_i + coord.i, tile_j + coord.j), grad);
|
|
1315
|
+
}
|
|
1316
|
+
}
|
|
1317
|
+
|
|
1318
|
+
|
|
1319
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1320
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t)
|
|
1321
|
+
{
|
|
1322
|
+
// convert to register if necessary
|
|
1323
|
+
tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
|
|
1324
|
+
|
|
1325
|
+
const int tile_i = x*adj_reg.N;
|
|
1326
|
+
|
|
1327
|
+
// load gradients from output
|
|
1328
|
+
WP_PRAGMA_UNROLL
|
|
1329
|
+
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1330
|
+
{
|
|
1331
|
+
int linear = adj_reg.index(i);
|
|
1332
|
+
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1333
|
+
break;
|
|
1334
|
+
|
|
1335
|
+
if (adj_dest.data)
|
|
1336
|
+
adj_reg.data[i] = index(adj_dest, tile_i + linear);
|
|
1337
|
+
else if (dest.grad)
|
|
1338
|
+
adj_reg.data[i] = index_grad(dest, tile_i + linear);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
// store adjoint back to tile
|
|
1342
|
+
adj_t.grad_add(adj_reg);
|
|
1343
|
+
}
|
|
1344
|
+
|
|
1345
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1346
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t)
|
|
1347
|
+
{
|
|
1348
|
+
// allocate register tile to load grads into
|
|
1349
|
+
tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
|
|
1350
|
+
|
|
1351
|
+
const int tile_i = x*adj_reg.M;
|
|
1352
|
+
const int tile_j = y*adj_reg.N;
|
|
1353
|
+
|
|
1354
|
+
// load gradients from output
|
|
1355
|
+
WP_PRAGMA_UNROLL
|
|
1356
|
+
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1357
|
+
{
|
|
1358
|
+
int linear = adj_reg.index(i);
|
|
1359
|
+
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1360
|
+
break;
|
|
1361
|
+
|
|
1362
|
+
coord_t coord = adj_reg.coord(linear);
|
|
1363
|
+
|
|
1364
|
+
if (adj_dest.data)
|
|
1365
|
+
adj_reg.data[i] = index(adj_dest, tile_i + coord.i, tile_j + coord.j);
|
|
1366
|
+
else if (dest.grad)
|
|
1367
|
+
adj_reg.data[i] = index_grad(dest, tile_i + coord.i, tile_j + coord.j);
|
|
1368
|
+
}
|
|
1369
|
+
|
|
1370
|
+
// store adjoint back to tile
|
|
1371
|
+
adj_t.grad_add(adj_reg);
|
|
1372
|
+
}
|
|
1373
|
+
|
|
1374
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1375
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret)
|
|
1376
|
+
{
|
|
1377
|
+
adj_tile_store(dest, x, y, t, adj_dest, adj_x, adj_y, adj_t);
|
|
1378
|
+
}
|
|
1379
|
+
|
|
1380
|
+
|
|
1381
|
+
// unary map
|
|
1382
|
+
template <typename Tile, typename Fwd>
|
|
1383
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1384
|
+
Tile &a)
|
|
1385
|
+
{
|
|
1386
|
+
auto out = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
|
|
1387
|
+
auto a_reg = a.copy_to_register();
|
|
1388
|
+
|
|
1389
|
+
WP_PRAGMA_UNROLL
|
|
1390
|
+
for (int i=0; i < out.NumRegs; ++i)
|
|
1391
|
+
{
|
|
1392
|
+
out.data[i] = op(a_reg.data[i]);
|
|
1393
|
+
}
|
|
1394
|
+
|
|
1395
|
+
return out;
|
|
1396
|
+
}
|
|
1397
|
+
|
|
1398
|
+
|
|
1399
|
+
template <typename Tile, typename AdjTile, typename Fwd, typename Adj>
|
|
1400
|
+
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1401
|
+
Tile& a,
|
|
1402
|
+
Adj adj_op,
|
|
1403
|
+
Tile& adj_a,
|
|
1404
|
+
AdjTile& adj_ret)
|
|
1405
|
+
{
|
|
1406
|
+
auto a_reg = a.copy_to_register();
|
|
1407
|
+
auto adj_a_reg = tile_register_like<Tile>();
|
|
1408
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1409
|
+
|
|
1410
|
+
WP_PRAGMA_UNROLL
|
|
1411
|
+
for (int i=0; i < a_reg.NumRegs; ++i)
|
|
1412
|
+
{
|
|
1413
|
+
adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
// write adjoints back
|
|
1417
|
+
adj_a.grad_add(adj_a_reg);
|
|
1418
|
+
}
|
|
1419
|
+
|
|
1420
|
+
// binary map
|
|
1421
|
+
template <typename TileA, typename TileB, typename Fwd>
|
|
1422
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1423
|
+
TileA& a,
|
|
1424
|
+
TileB& b)
|
|
1425
|
+
{
|
|
1426
|
+
auto out = tile_register_t<typename TileA::Type, TileA::M, TileA::N>();
|
|
1427
|
+
|
|
1428
|
+
auto a_reg = a.copy_to_register();
|
|
1429
|
+
auto b_reg = b.copy_to_register();
|
|
1430
|
+
|
|
1431
|
+
WP_PRAGMA_UNROLL
|
|
1432
|
+
for (int i=0; i < out.NumRegs; ++i)
|
|
1433
|
+
out.data[i] = op(a_reg.data[i], b_reg.data[i]);
|
|
1434
|
+
|
|
1435
|
+
return out;
|
|
1436
|
+
}
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
|
|
1440
|
+
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1441
|
+
TileA &a,
|
|
1442
|
+
TileB &b,
|
|
1443
|
+
Adj adj_op,
|
|
1444
|
+
TileA &adj_a,
|
|
1445
|
+
TileB &adj_b,
|
|
1446
|
+
AdjTile &adj_ret)
|
|
1447
|
+
{
|
|
1448
|
+
auto a_reg = a.copy_to_register();
|
|
1449
|
+
auto b_reg = b.copy_to_register();
|
|
1450
|
+
|
|
1451
|
+
// allocate storage for adjoints
|
|
1452
|
+
auto adj_a_reg = tile_register_like<TileA>();
|
|
1453
|
+
auto adj_b_reg = tile_register_like<TileB>();
|
|
1454
|
+
|
|
1455
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1456
|
+
|
|
1457
|
+
WP_PRAGMA_UNROLL
|
|
1458
|
+
for (int i=0; i < a_reg.NumRegs; ++i)
|
|
1459
|
+
{
|
|
1460
|
+
adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
|
|
1461
|
+
}
|
|
1462
|
+
|
|
1463
|
+
adj_a.grad_add(adj_a_reg);
|
|
1464
|
+
adj_b.grad_add(adj_b_reg);
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1467
|
+
// wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
1468
|
+
// this is important because many of the builtin operators don't follow particular conventions on references for
|
|
1469
|
+
// the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
|
|
1470
|
+
#define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
|
|
1471
|
+
#define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
|
|
1472
|
+
|
|
1473
|
+
#define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
|
|
1474
|
+
#define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
1475
|
+
|
|
1476
|
+
// -tile (unary neg)
|
|
1477
|
+
template <typename Tile>
|
|
1478
|
+
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
|
|
1479
|
+
|
|
1480
|
+
template <typename Tile, typename AdjTile>
|
|
1481
|
+
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
|
|
1482
|
+
|
|
1483
|
+
|
|
1484
|
+
// tile + tile
|
|
1485
|
+
template <typename TileA, typename TileB>
|
|
1486
|
+
inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
1487
|
+
{
|
|
1488
|
+
return tile_binary_map(add, a, b);
|
|
1489
|
+
}
|
|
1490
|
+
|
|
1491
|
+
// // tile + tile, we implement this
|
|
1492
|
+
// template <typename TileA, typename TileB>
|
|
1493
|
+
// inline CUDA_CALLABLE auto add(TileA& a, TileB& b)
|
|
1494
|
+
// {
|
|
1495
|
+
// return tile_binary_map(add, a, b);
|
|
1496
|
+
// }
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1500
|
+
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1501
|
+
{
|
|
1502
|
+
adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
// tile*scalar
|
|
1506
|
+
template <typename Tile>
|
|
1507
|
+
inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
|
|
1508
|
+
{
|
|
1509
|
+
// promote scalar to a constant tile
|
|
1510
|
+
auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
|
|
1511
|
+
|
|
1512
|
+
return tile_binary_map(mul, a, s_tile);
|
|
1513
|
+
}
|
|
1514
|
+
|
|
1515
|
+
template <typename Tile, typename AdjTile>
|
|
1516
|
+
inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
1517
|
+
Tile& adj_a, typename Tile::Type& adj_s,
|
|
1518
|
+
AdjTile& adj_c)
|
|
1519
|
+
{
|
|
1520
|
+
auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
|
|
1521
|
+
auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
|
|
1522
|
+
|
|
1523
|
+
adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
|
|
1524
|
+
|
|
1525
|
+
for (int i=0; i < adj_s_tile.NumRegs; ++i)
|
|
1526
|
+
{
|
|
1527
|
+
adj_s += adj_s_tile.data[i];
|
|
1528
|
+
}
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1531
|
+
|
|
1532
|
+
// scalar*tile
|
|
1533
|
+
template <typename Tile>
|
|
1534
|
+
inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
|
|
1535
|
+
{
|
|
1536
|
+
// promote scalar to a constant tile
|
|
1537
|
+
auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
|
|
1538
|
+
|
|
1539
|
+
return tile_binary_map(mul, s_tile, a);
|
|
1540
|
+
}
|
|
1541
|
+
|
|
1542
|
+
template <typename Tile, typename AdjTile>
|
|
1543
|
+
inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
1544
|
+
typename Tile::Type& adj_s, Tile& adj_a,
|
|
1545
|
+
AdjTile& adj_c)
|
|
1546
|
+
{
|
|
1547
|
+
auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
|
|
1548
|
+
auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
|
|
1549
|
+
|
|
1550
|
+
adj_tile_binary_map(mul, s_tile, a, adj_mul, adj_s_tile, adj_a, adj_c);
|
|
1551
|
+
|
|
1552
|
+
for (int i=0; i < adj_s_tile.NumRegs; ++i)
|
|
1553
|
+
{
|
|
1554
|
+
adj_s += adj_s_tile.data[i];
|
|
1555
|
+
}
|
|
1556
|
+
}
|
|
1557
|
+
|
|
1558
|
+
|
|
1559
|
+
|
|
1560
|
+
template<typename Tile>
|
|
1561
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j)
|
|
1562
|
+
{
|
|
1563
|
+
assert(i < Tile::M);
|
|
1564
|
+
assert(j < Tile::N);
|
|
1565
|
+
|
|
1566
|
+
return t.extract(i, j);
|
|
1567
|
+
}
|
|
1568
|
+
|
|
1569
|
+
template<typename Tile, typename AdjTile>
|
|
1570
|
+
void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret)
|
|
1571
|
+
{
|
|
1572
|
+
assert(i < Tile::M);
|
|
1573
|
+
assert(j < Tile::N);
|
|
1574
|
+
|
|
1575
|
+
adj_t.adj_extract(i, j, adj_ret);
|
|
1576
|
+
}
|
|
1577
|
+
|
|
1578
|
+
namespace partitioned_gemm
|
|
1579
|
+
{
|
|
1580
|
+
|
|
1581
|
+
template <typename T>
|
|
1582
|
+
inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
|
|
1583
|
+
{
|
|
1584
|
+
return p[i*stride + j];
|
|
1585
|
+
}
|
|
1586
|
+
|
|
1587
|
+
template <typename T>
|
|
1588
|
+
inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
|
|
1589
|
+
{
|
|
1590
|
+
return p[i*stride + j];
|
|
1591
|
+
}
|
|
1592
|
+
|
|
1593
|
+
template <int PartitionM, int PartitionN, typename Tile>
|
|
1594
|
+
struct partition_t
|
|
1595
|
+
{
|
|
1596
|
+
static constexpr int M = PartitionM;
|
|
1597
|
+
static constexpr int N = PartitionN;
|
|
1598
|
+
static constexpr int Stride = Tile::N;
|
|
1599
|
+
|
|
1600
|
+
using T = typename Tile::Type;
|
|
1601
|
+
|
|
1602
|
+
inline partition_t(Tile& A)
|
|
1603
|
+
{
|
|
1604
|
+
data = A.data.ptr;
|
|
1605
|
+
|
|
1606
|
+
// todo: do ceil div for non-multiples of M,N
|
|
1607
|
+
shape[0] = Tile::M/PartitionM;
|
|
1608
|
+
shape[1] = Tile::N/PartitionN;
|
|
1609
|
+
}
|
|
1610
|
+
|
|
1611
|
+
// underlying data
|
|
1612
|
+
T* data;
|
|
1613
|
+
|
|
1614
|
+
// partition dimensions
|
|
1615
|
+
int shape[2];
|
|
1616
|
+
};
|
|
1617
|
+
|
|
1618
|
+
template <typename Partition>
|
|
1619
|
+
inline int partition_size(const Partition& part)
|
|
1620
|
+
{
|
|
1621
|
+
return part.shape[0]*part.shape[1];
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
// returns the x, y coordinates of a tile given a linear index
|
|
1625
|
+
template <typename Partition>
|
|
1626
|
+
inline void partition_coord(const Partition& part, const int t, int& i, int& j)
|
|
1627
|
+
{
|
|
1628
|
+
i = t/part.shape[1];
|
|
1629
|
+
j = t%part.shape[1];
|
|
1630
|
+
}
|
|
1631
|
+
|
|
1632
|
+
template <typename Partition>
|
|
1633
|
+
inline auto partition_load(const Partition& tile, int i, int j)
|
|
1634
|
+
{
|
|
1635
|
+
mat_t<Partition::M, Partition::N, typename Partition::T> out;
|
|
1636
|
+
|
|
1637
|
+
const int tile_i = i*Partition::M;
|
|
1638
|
+
const int tile_j = j*Partition::N;
|
|
1639
|
+
|
|
1640
|
+
WP_PRAGMA_UNROLL
|
|
1641
|
+
for (int i=0; i < Partition::M; ++i)
|
|
1642
|
+
{
|
|
1643
|
+
WP_PRAGMA_UNROLL
|
|
1644
|
+
for (int j=0; j < Partition::N; ++j)
|
|
1645
|
+
{
|
|
1646
|
+
out.data[i][j] = index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
|
|
1647
|
+
}
|
|
1648
|
+
}
|
|
1649
|
+
|
|
1650
|
+
return out;
|
|
1651
|
+
}
|
|
1652
|
+
|
|
1653
|
+
template <typename Partition, typename Value>
|
|
1654
|
+
inline void partition_store(const Partition& tile, int i, int j, const Value& value)
|
|
1655
|
+
{
|
|
1656
|
+
const int tile_i = Partition::M*i;
|
|
1657
|
+
const int tile_j = Partition::N*j;
|
|
1658
|
+
|
|
1659
|
+
WP_PRAGMA_UNROLL
|
|
1660
|
+
for (int i=0; i < Partition::M; ++i)
|
|
1661
|
+
{
|
|
1662
|
+
WP_PRAGMA_UNROLL
|
|
1663
|
+
for (int j=0; j < Partition::N; ++j)
|
|
1664
|
+
{
|
|
1665
|
+
index(tile.data, tile_i + i, tile_j + j, Partition::Stride) = value.data[i][j];
|
|
1666
|
+
}
|
|
1667
|
+
}
|
|
1668
|
+
}
|
|
1669
|
+
|
|
1670
|
+
template <typename TileA, typename TileB, typename TileC>
|
|
1671
|
+
inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
1672
|
+
{
|
|
1673
|
+
const int TILE_M = 4;
|
|
1674
|
+
const int TILE_N = 4;
|
|
1675
|
+
const int TILE_K = 4;
|
|
1676
|
+
|
|
1677
|
+
auto A_tile = partition_t<TILE_M, TILE_K, TileA>(A);
|
|
1678
|
+
auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
|
|
1679
|
+
auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
|
|
1680
|
+
|
|
1681
|
+
const int length = partition_size(C_tile);
|
|
1682
|
+
|
|
1683
|
+
for (int t=threadIdx.x; t < length; t += blockDim.x)
|
|
1684
|
+
{
|
|
1685
|
+
int i, j;
|
|
1686
|
+
partition_coord(C_tile, t, i, j);
|
|
1687
|
+
|
|
1688
|
+
// accumulator
|
|
1689
|
+
auto sum = partition_load(C_tile, i, j);
|
|
1690
|
+
|
|
1691
|
+
WP_PRAGMA_UNROLL
|
|
1692
|
+
for (int k=0; k < A_tile.shape[1]; k++)
|
|
1693
|
+
{
|
|
1694
|
+
const auto a = partition_load(A_tile, i, k);
|
|
1695
|
+
const auto b = partition_load(B_tile, k, j);
|
|
1696
|
+
|
|
1697
|
+
sum += mul(a, b);
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
partition_store(C_tile, i, j, sum);
|
|
1701
|
+
}
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
} // namespace partition_gemm
|
|
1705
|
+
|
|
1706
|
+
template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
1707
|
+
TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
|
|
1708
|
+
{
|
|
1709
|
+
using T = typename TileA::Type;
|
|
1710
|
+
|
|
1711
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
1712
|
+
__pipeline_wait_prior(0);
|
|
1713
|
+
WP_TILE_SYNC();
|
|
1714
|
+
#endif
|
|
1715
|
+
|
|
1716
|
+
#if WP_USE_REGISTER_GEMM
|
|
1717
|
+
partitioned_gemm::matmul(A, B, C);
|
|
1718
|
+
#else
|
|
1719
|
+
fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
|
|
1720
|
+
#endif
|
|
1721
|
+
|
|
1722
|
+
WP_TILE_SYNC();
|
|
1723
|
+
|
|
1724
|
+
return C;
|
|
1725
|
+
}
|
|
1726
|
+
|
|
1727
|
+
// backward for the wp.tile_matmul(a, b, out) syntax
|
|
1728
|
+
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
1729
|
+
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
1730
|
+
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
|
|
1731
|
+
{
|
|
1732
|
+
using T = typename TileA::Type;
|
|
1733
|
+
|
|
1734
|
+
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
1735
|
+
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
1736
|
+
WP_TILE_SYNC();
|
|
1737
|
+
}
|
|
1738
|
+
|
|
1739
|
+
// backward for the out = wp.tile_matmul(a, b) syntax
|
|
1740
|
+
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
1741
|
+
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
1742
|
+
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
|
|
1743
|
+
{
|
|
1744
|
+
using T = typename TileA::Type;
|
|
1745
|
+
|
|
1746
|
+
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
1747
|
+
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
1748
|
+
WP_TILE_SYNC();
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
// TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
|
|
1752
|
+
// TODO(lcambier): use dynamic smem
|
|
1753
|
+
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
1754
|
+
do { \
|
|
1755
|
+
void function_name(dtype*, dtype*); \
|
|
1756
|
+
WP_TILE_SHARED __align__(16) char buffer[shared_memory_size]; \
|
|
1757
|
+
__align__(16) dtype data[ept]; \
|
|
1758
|
+
for(int b = 0; b < (int)batch_size; b++) { \
|
|
1759
|
+
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
1760
|
+
memcpy(data, inout, sizeof(dtype) * ept); \
|
|
1761
|
+
function_name(data, (dtype*)buffer); \
|
|
1762
|
+
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
1763
|
+
WP_TILE_SYNC(); \
|
|
1764
|
+
} \
|
|
1765
|
+
} while (0)
|
|
1766
|
+
|
|
1767
|
+
#define tile_ifft tile_fft
|
|
1768
|
+
|
|
1769
|
+
// adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept are all ignored
|
|
1770
|
+
|
|
1771
|
+
#define adj_tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
|
|
1772
|
+
adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
|
|
1773
|
+
adj_Xinout) \
|
|
1774
|
+
do { \
|
|
1775
|
+
tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
1776
|
+
} while (0)
|
|
1777
|
+
|
|
1778
|
+
#define adj_tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
|
|
1779
|
+
adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
|
|
1780
|
+
adj_Xinout) \
|
|
1781
|
+
do { \
|
|
1782
|
+
tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
1783
|
+
} while (0)
|
|
1784
|
+
|
|
1785
|
+
|
|
1786
|
+
template <typename Tile>
|
|
1787
|
+
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
1788
|
+
{
|
|
1789
|
+
// alias incoming tile
|
|
1790
|
+
return tile_shared_t<typename Tile::Type, Tile::N, Tile::M, Tile::StrideN, Tile::StrideM, false>(t.data.ptr, t.grad.ptr);
|
|
1791
|
+
}
|
|
1792
|
+
|
|
1793
|
+
template <typename Tile, typename AdjTile>
|
|
1794
|
+
inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
1795
|
+
{
|
|
1796
|
+
auto a = tile_transpose(adj_ret);
|
|
1797
|
+
auto b = adj_t;
|
|
1798
|
+
|
|
1799
|
+
adj_t.assign(tile_add(a,b));
|
|
1800
|
+
}
|
|
1801
|
+
|
|
1802
|
+
template <int M, int N, int StrideM, int StrideN, typename Tile>
|
|
1803
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
1804
|
+
{
|
|
1805
|
+
// alias incoming tile with new strides
|
|
1806
|
+
return tile_shared_t<typename Tile::Type, M, N, StrideM, StrideN, false>(t.data.ptr, t.grad.ptr);
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
template <typename Tile, typename AdjTile>
|
|
1810
|
+
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
1811
|
+
{
|
|
1812
|
+
// nop, since memory is aliased grads already accumulated
|
|
1813
|
+
|
|
1814
|
+
}
|
|
1815
|
+
|
|
1816
|
+
template <int M, int N, typename Tile>
|
|
1817
|
+
inline CUDA_CALLABLE auto tile_view(Tile& t, int i, int j)
|
|
1818
|
+
{
|
|
1819
|
+
// alias incoming tile with new strides
|
|
1820
|
+
return tile_shared_t<typename Tile::Type, M, N, Tile::StrideM, Tile::StrideN, false>(&t.data(i, j), &t.grad(i, j));
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
template <typename Tile, typename AdjTile>
|
|
1824
|
+
inline CUDA_CALLABLE void adj_tile_view(Tile& t, int i, int j, Tile& adj_t, int adj_i, int adj_j, AdjTile& adj_ret)
|
|
1825
|
+
{
|
|
1826
|
+
// nop, since memory is aliased grads already accumulated
|
|
1827
|
+
|
|
1828
|
+
}
|
|
1829
|
+
|
|
1830
|
+
template <typename TileA, typename TileB>
|
|
1831
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, int i, int j, TileB& src)
|
|
1832
|
+
{
|
|
1833
|
+
for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
|
|
1834
|
+
{
|
|
1835
|
+
coord_t c = src.coord(t);
|
|
1836
|
+
dest.data(i + c.i, j + c.j) = src.data(c.i, c.j);
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
WP_TILE_SYNC();
|
|
1840
|
+
}
|
|
1841
|
+
|
|
1842
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
1843
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, int i, int j, TileB& src,
|
|
1844
|
+
AdjTileA& adj_dest, int adj_i, int adj_j, AdjTileB& adj_src)
|
|
1845
|
+
{
|
|
1846
|
+
for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
|
|
1847
|
+
{
|
|
1848
|
+
coord_t c = src.coord(t);
|
|
1849
|
+
src.grad(c.i, c.j) += dest.grad(i + c.i, j + c.j);
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
WP_TILE_SYNC();
|
|
1853
|
+
}
|
|
1854
|
+
|
|
1855
|
+
|
|
1856
|
+
|
|
1857
|
+
} // namespace wp
|