warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1077 -481
- warp/codegen.py +250 -122
- warp/config.py +65 -21
- warp/context.py +500 -149
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_marching_cubes.py +1 -1
- warp/examples/core/example_mesh.py +1 -1
- warp/examples/core/example_torch.py +18 -34
- warp/examples/core/example_wave.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +314 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +191 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +6 -2
- warp/native/crt.h +1 -0
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +57 -3
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1189 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +132 -59
- warp/render/render_usd.py +10 -2
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +289 -32
- warp/sim/import_urdf.py +20 -5
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +147 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +173 -112
- warp/sim/render.py +2 -2
- warp/stubs.py +249 -116
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +100 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +8 -8
- warp/tests/test_examples.py +16 -1
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_launch.py +77 -26
- warp/tests/test_mat.py +213 -168
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +6 -5
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +399 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +5 -2
- warp/types.py +419 -111
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -35,10 +35,6 @@
|
|
|
35
35
|
#endif
|
|
36
36
|
|
|
37
37
|
#define WP_USE_ASYNC_PIPELINE 0
|
|
38
|
-
#if WP_USE_ASYNC_PIPELINE
|
|
39
|
-
#include "cuda_pipeline_primitives.h"
|
|
40
|
-
#endif // WP_USE_ASYNC_PIPELINE
|
|
41
|
-
|
|
42
38
|
#define WP_USE_REGISTER_GEMM 0
|
|
43
39
|
|
|
44
40
|
/* Tile Expressions
|
|
@@ -171,50 +167,300 @@ struct is_same<T, T> {
|
|
|
171
167
|
};
|
|
172
168
|
|
|
173
169
|
|
|
174
|
-
template <
|
|
175
|
-
|
|
170
|
+
template <int N>
|
|
171
|
+
struct tile_coord_t
|
|
172
|
+
{
|
|
173
|
+
int indices[N];
|
|
174
|
+
|
|
175
|
+
CUDA_CALLABLE inline int operator[](int i) const { assert(0 <= 1 && i < N); return indices[i]; }
|
|
176
|
+
CUDA_CALLABLE inline int& operator[](int i) { assert(0 <= 1 && i < N); return indices[i]; }
|
|
177
|
+
|
|
178
|
+
CUDA_CALLABLE inline tile_coord_t<N> operator + (const tile_coord_t<N>& c) const
|
|
179
|
+
{
|
|
180
|
+
tile_coord_t<N> out;
|
|
181
|
+
for (int i=0; i < N; ++i)
|
|
182
|
+
{
|
|
183
|
+
out.indices[i] = indices[i] + c.indices[i];
|
|
184
|
+
}
|
|
185
|
+
return out;
|
|
186
|
+
}
|
|
187
|
+
};
|
|
188
|
+
|
|
189
|
+
// This function deduces N = sizeof...(Ints)
|
|
190
|
+
template <typename... Ints>
|
|
191
|
+
constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
|
|
192
|
+
{
|
|
193
|
+
constexpr int N = sizeof...(Ints);
|
|
194
|
+
|
|
195
|
+
// Create the result
|
|
196
|
+
tile_coord_t<N> result{};
|
|
197
|
+
|
|
198
|
+
// Capture all arguments in a local array
|
|
199
|
+
int arr[] = { static_cast<int>(idxs)... };
|
|
200
|
+
|
|
201
|
+
// C++14 or later: 'for' is allowed in a constexpr context
|
|
202
|
+
for (int i = 0; i < N; ++i)
|
|
203
|
+
{
|
|
204
|
+
result.indices[i] = arr[i];
|
|
205
|
+
}
|
|
176
206
|
|
|
177
|
-
|
|
178
|
-
|
|
207
|
+
return result;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
// helpers to construct a coord from a set of indices
|
|
211
|
+
auto tile_coord(int i)
|
|
212
|
+
{
|
|
213
|
+
auto c = tile_coord_t<1>();
|
|
214
|
+
c.indices[0] = i;
|
|
215
|
+
return c;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
auto tile_coord(int i, int j)
|
|
219
|
+
{
|
|
220
|
+
auto c = tile_coord_t<2>();
|
|
221
|
+
c.indices[0] = i;
|
|
222
|
+
c.indices[1] = j;
|
|
223
|
+
return c;
|
|
179
224
|
}
|
|
180
225
|
|
|
181
|
-
|
|
226
|
+
auto tile_coord(int i, int j, int k)
|
|
227
|
+
{
|
|
228
|
+
auto c = tile_coord_t<3>();
|
|
229
|
+
c.indices[0] = i;
|
|
230
|
+
c.indices[1] = j;
|
|
231
|
+
c.indices[2] = k;
|
|
232
|
+
return c;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
auto tile_coord(int i, int j, int k, int l)
|
|
182
236
|
{
|
|
183
|
-
|
|
184
|
-
|
|
237
|
+
auto c = tile_coord_t<4>();
|
|
238
|
+
c.indices[0] = i;
|
|
239
|
+
c.indices[1] = j;
|
|
240
|
+
c.indices[2] = k;
|
|
241
|
+
c.indices[3] = l;
|
|
242
|
+
return c;
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// represents a compile time int tuple for strides/shapes/coords
|
|
246
|
+
template <int... V>
|
|
247
|
+
struct tile_tuple_t
|
|
248
|
+
{
|
|
249
|
+
static constexpr int N = sizeof...(V);
|
|
250
|
+
static_assert(N > 0);
|
|
251
|
+
|
|
252
|
+
static constexpr int data[N] = { V... };
|
|
253
|
+
|
|
254
|
+
static constexpr int dim(int i) { assert(i < N); return data[i]; }
|
|
255
|
+
static constexpr int size()
|
|
256
|
+
{
|
|
257
|
+
int res = data[0];
|
|
258
|
+
for (int i=1; i < N; ++i)
|
|
259
|
+
res *= data[i];
|
|
260
|
+
|
|
261
|
+
return res;
|
|
262
|
+
}
|
|
185
263
|
};
|
|
186
264
|
|
|
265
|
+
// simple helper to compute strides from a shape up to 4d
|
|
266
|
+
template <typename Shape>
|
|
267
|
+
struct compute_strides;
|
|
268
|
+
|
|
269
|
+
// 1D
|
|
270
|
+
template <int D0>
|
|
271
|
+
struct compute_strides< tile_tuple_t<D0> > { using Stride = tile_tuple_t<1>; };
|
|
272
|
+
// 2D
|
|
273
|
+
template <int D0, int D1>
|
|
274
|
+
struct compute_strides< tile_tuple_t<D0, D1> > { using Stride = tile_tuple_t<D1, 1>; };
|
|
275
|
+
// 3D
|
|
276
|
+
template <int D0, int D1, int D2>
|
|
277
|
+
struct compute_strides< tile_tuple_t<D0, D1, D2> > { using Stride = tile_tuple_t<(D1 * D2), D2, 1>; };
|
|
278
|
+
// 4D
|
|
279
|
+
template <int D0, int D1, int D2, int D3>
|
|
280
|
+
struct compute_strides< tile_tuple_t<D0, D1, D2, D3> > { using Stride = tile_tuple_t<(D1 * D2 * D3), (D2 * D3), D3, 1>; };
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
// alias of tuple to represent shapes
|
|
284
|
+
template <int... V>
|
|
285
|
+
using tile_shape_t = tile_tuple_t<V...>;
|
|
286
|
+
|
|
287
|
+
// alias of tuple to represent stride
|
|
288
|
+
template <int... V>
|
|
289
|
+
using tile_stride_t = tile_tuple_t<V...>;
|
|
290
|
+
|
|
187
291
|
|
|
188
292
|
// represents a tile stored in global memory with dynamic strides
|
|
189
|
-
//
|
|
190
|
-
template <typename T>
|
|
191
|
-
struct tile_global_t
|
|
293
|
+
// used to represent the source and offset for tile loads to register/shared
|
|
294
|
+
template <typename T, typename Shape_>
|
|
295
|
+
struct tile_global_t
|
|
192
296
|
{
|
|
193
297
|
using Type = T;
|
|
298
|
+
using Shape = Shape_;
|
|
299
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
194
300
|
|
|
195
301
|
array_t<T> data;
|
|
196
|
-
|
|
197
|
-
|
|
302
|
+
Coord offset;
|
|
303
|
+
|
|
304
|
+
tile_global_t(array_t<T>& a, const Coord& c) : data(a), offset(c)
|
|
305
|
+
{
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const
|
|
309
|
+
{
|
|
310
|
+
// element index
|
|
311
|
+
int index = 0;
|
|
312
|
+
|
|
313
|
+
WP_PRAGMA_UNROLL
|
|
314
|
+
for (int i=0; i < Shape::N; ++i)
|
|
315
|
+
{
|
|
316
|
+
// global = offset + coord
|
|
317
|
+
int c = offset[i] + coord[i];
|
|
318
|
+
index += data.strides[i]*c;
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
return index/sizeof(T);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
|
|
325
|
+
{
|
|
326
|
+
// element index
|
|
327
|
+
int index = 0;
|
|
328
|
+
|
|
329
|
+
WP_PRAGMA_UNROLL
|
|
330
|
+
for (int i=0; i < Shape::N; ++i)
|
|
331
|
+
{
|
|
332
|
+
// global = offset + coord
|
|
333
|
+
int c = offset[i] + coord[i];
|
|
334
|
+
|
|
335
|
+
// handle out of bounds case
|
|
336
|
+
if (c >= data.shape[i])
|
|
337
|
+
return false;
|
|
338
|
+
else
|
|
339
|
+
index += data.strides[i]*c;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// array strides are in bytes so we convert to elements
|
|
343
|
+
out = index / sizeof(T);
|
|
344
|
+
return true;
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
inline CUDA_CALLABLE T load(const Coord& coord) const
|
|
348
|
+
{
|
|
349
|
+
int i;
|
|
350
|
+
if (index(coord, i))
|
|
351
|
+
return data.data[i];
|
|
352
|
+
else
|
|
353
|
+
return T(0);
|
|
354
|
+
}
|
|
198
355
|
|
|
199
|
-
|
|
356
|
+
inline CUDA_CALLABLE T load_grad(const Coord& coord) const
|
|
200
357
|
{
|
|
358
|
+
int i;
|
|
359
|
+
if (index(coord, i))
|
|
360
|
+
return data.grad[i];
|
|
361
|
+
else
|
|
362
|
+
return T(0);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
inline CUDA_CALLABLE void store(const Coord& coord, const T& x) const
|
|
366
|
+
{
|
|
367
|
+
int i;
|
|
368
|
+
if (index(coord, i))
|
|
369
|
+
data.data[i] = x;
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
inline CUDA_CALLABLE T atomic_add(const Coord& coord, const T& value) const
|
|
373
|
+
{
|
|
374
|
+
int i;
|
|
375
|
+
if (index(coord, i))
|
|
376
|
+
return wp::atomic_add(&data.data[i], value);
|
|
377
|
+
else
|
|
378
|
+
return T(0);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
inline CUDA_CALLABLE T atomic_add_grad(const Coord& coord, const T& grad) const
|
|
382
|
+
{
|
|
383
|
+
int i;
|
|
384
|
+
if (index(coord, i))
|
|
385
|
+
return wp::atomic_add(&data.grad[i], grad);
|
|
386
|
+
else
|
|
387
|
+
return T(0);
|
|
201
388
|
}
|
|
202
389
|
};
|
|
203
390
|
|
|
391
|
+
template <typename Shape_>
|
|
392
|
+
struct tile_layout_register_t
|
|
393
|
+
{
|
|
394
|
+
using Shape = Shape_;
|
|
395
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
396
|
+
|
|
397
|
+
static constexpr int Size = Shape::size();
|
|
398
|
+
static constexpr int NumRegs = (Size + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
|
|
399
|
+
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
400
|
+
|
|
401
|
+
static inline CUDA_CALLABLE int linear_from_register(int reg)
|
|
402
|
+
{
|
|
403
|
+
return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
static inline CUDA_CALLABLE int linear_from_coord(Coord c)
|
|
407
|
+
{
|
|
408
|
+
int linear = 0;
|
|
409
|
+
int stride = 1;
|
|
410
|
+
|
|
411
|
+
WP_PRAGMA_UNROLL
|
|
412
|
+
for (int i=Shape::N-1; i >= 0; --i)
|
|
413
|
+
{
|
|
414
|
+
linear += c[i] * stride;
|
|
415
|
+
stride *= Shape::dim(i);
|
|
416
|
+
}
|
|
417
|
+
return linear;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
static inline CUDA_CALLABLE auto coord_from_linear(int linear)
|
|
421
|
+
{
|
|
422
|
+
Coord c;
|
|
423
|
+
|
|
424
|
+
WP_PRAGMA_UNROLL
|
|
425
|
+
for (int i=Shape::N-1; i >= 0; --i)
|
|
426
|
+
{
|
|
427
|
+
c[i] = linear%Shape::dim(i);
|
|
428
|
+
linear /= Shape::dim(i);
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
return c;
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
static inline CUDA_CALLABLE int thread_from_linear(int linear)
|
|
435
|
+
{
|
|
436
|
+
const int thread = linear%WP_TILE_BLOCK_DIM;
|
|
437
|
+
return thread;
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
static inline CUDA_CALLABLE int register_from_linear(int linear)
|
|
441
|
+
{
|
|
442
|
+
const int reg = linear/WP_TILE_BLOCK_DIM;
|
|
443
|
+
return reg;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static inline CUDA_CALLABLE bool valid(int linear)
|
|
447
|
+
{
|
|
448
|
+
if (Aligned || linear < Size)
|
|
449
|
+
return true;
|
|
450
|
+
else
|
|
451
|
+
return false;
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
};
|
|
455
|
+
|
|
204
456
|
// represents a tile stored in registers across a block
|
|
205
|
-
template <typename T,
|
|
457
|
+
template <typename T, typename L>
|
|
206
458
|
struct tile_register_t
|
|
207
459
|
{
|
|
208
460
|
using Type = T;
|
|
209
|
-
|
|
210
|
-
static constexpr int N = N_;
|
|
211
|
-
static constexpr int Size = M*N;
|
|
461
|
+
using Layout = L;
|
|
212
462
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
216
|
-
|
|
217
|
-
T data[NumRegs];
|
|
463
|
+
T data[Layout::NumRegs];
|
|
218
464
|
|
|
219
465
|
inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
|
|
220
466
|
{
|
|
@@ -224,52 +470,34 @@ struct tile_register_t
|
|
|
224
470
|
// in backwards pass and letting default constructor
|
|
225
471
|
// avoid initialization
|
|
226
472
|
|
|
227
|
-
for (int i=0; i < NumRegs; ++i)
|
|
473
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
228
474
|
data[i] = value;
|
|
229
475
|
}
|
|
230
476
|
|
|
231
|
-
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
|
|
477
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
232
478
|
{
|
|
233
|
-
|
|
234
|
-
copy_from_global(t.data, t.x); // 1d load
|
|
235
|
-
else
|
|
236
|
-
copy_from_global(t.data, t.x, t.y); // 2d load
|
|
237
|
-
|
|
479
|
+
copy_from_global(t);
|
|
238
480
|
return *this;
|
|
239
|
-
|
|
240
481
|
}
|
|
241
482
|
|
|
242
483
|
// define the += operator which is used during backward pass codegen
|
|
243
484
|
// when returning a register tile from a user defined function
|
|
244
|
-
inline CUDA_CALLABLE auto& operator += (tile_register_t<T,
|
|
485
|
+
inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
|
|
245
486
|
{
|
|
246
|
-
|
|
487
|
+
grad_add(rhs);
|
|
247
488
|
return *this;
|
|
248
489
|
}
|
|
249
490
|
|
|
250
|
-
inline CUDA_CALLABLE T& operator()(int
|
|
491
|
+
inline CUDA_CALLABLE T& operator()(int reg)
|
|
251
492
|
{
|
|
252
|
-
assert(
|
|
253
|
-
return data[
|
|
493
|
+
assert(reg < Layout::NumRegs);
|
|
494
|
+
return data[reg];
|
|
254
495
|
}
|
|
255
496
|
|
|
256
|
-
inline CUDA_CALLABLE const T& operator()(int
|
|
497
|
+
inline CUDA_CALLABLE const T& operator()(int reg) const
|
|
257
498
|
{
|
|
258
|
-
assert(
|
|
259
|
-
return data[
|
|
260
|
-
}
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
// compute linear tile index from a local register index
|
|
264
|
-
inline CUDA_CALLABLE int index(int reg) const
|
|
265
|
-
{
|
|
266
|
-
return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
|
|
267
|
-
}
|
|
268
|
-
|
|
269
|
-
// compute tile coordinate from linear index
|
|
270
|
-
inline CUDA_CALLABLE coord_t coord(int index) const
|
|
271
|
-
{
|
|
272
|
-
return {index/N, index%N};
|
|
499
|
+
assert(reg < Layout::NumRegs);
|
|
500
|
+
return data[reg];
|
|
273
501
|
}
|
|
274
502
|
|
|
275
503
|
// Returns the number of valid registers for this tile
|
|
@@ -278,29 +506,29 @@ struct tile_register_t
|
|
|
278
506
|
// some of the trailing registers may lie outside the valid range
|
|
279
507
|
inline CUDA_CALLABLE int valid() const
|
|
280
508
|
{
|
|
281
|
-
return (Size - threadIdx.x)/WP_TILE_BLOCK_DIM;
|
|
509
|
+
return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
|
|
282
510
|
}
|
|
283
511
|
|
|
284
|
-
inline CUDA_CALLABLE void assign(const tile_register_t<T,
|
|
512
|
+
inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
|
|
285
513
|
{
|
|
286
|
-
for (int i=0; i < NumRegs; ++i)
|
|
514
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
287
515
|
data[i] = tile.data[i];
|
|
288
516
|
}
|
|
289
517
|
|
|
290
518
|
inline CUDA_CALLABLE void zero()
|
|
291
519
|
{
|
|
292
|
-
for (int i=0; i < NumRegs; ++i)
|
|
293
|
-
data[i] = T(0);
|
|
520
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
521
|
+
data[i] = T(0);
|
|
294
522
|
}
|
|
295
523
|
|
|
296
524
|
// extract a single tile element to a native type
|
|
297
|
-
|
|
525
|
+
template <typename Coord>
|
|
526
|
+
inline CUDA_CALLABLE Type extract(const Coord& c)
|
|
298
527
|
{
|
|
299
528
|
// map from logical coords (i, j) -> (thread, reg)
|
|
300
|
-
const int linear =
|
|
301
|
-
|
|
302
|
-
const int
|
|
303
|
-
const int reg = linear%NumRegs;
|
|
529
|
+
const int linear = Layout::linear_from_coord(c);
|
|
530
|
+
const int thread = Layout::thread_from_linear(linear);
|
|
531
|
+
const int reg = Layout::register_from_linear(linear);
|
|
304
532
|
|
|
305
533
|
WP_TILE_SHARED Type scratch;
|
|
306
534
|
|
|
@@ -320,13 +548,13 @@ struct tile_register_t
|
|
|
320
548
|
|
|
321
549
|
|
|
322
550
|
// backward version of scalar extract
|
|
323
|
-
|
|
551
|
+
template <typename Coord>
|
|
552
|
+
inline CUDA_CALLABLE void adj_extract(const Coord& c, Type adj_ret)
|
|
324
553
|
{
|
|
325
554
|
// map from logical coords (i, j) -> (thread, reg)
|
|
326
|
-
const int linear =
|
|
327
|
-
|
|
328
|
-
const int
|
|
329
|
-
const int reg = linear%NumRegs;
|
|
555
|
+
const int linear = Layout::linear_from_coord(c);
|
|
556
|
+
const int thread = Layout::thread_from_linear(linear);
|
|
557
|
+
const int reg = Layout::register_from_linear(linear);
|
|
330
558
|
|
|
331
559
|
if (threadIdx.x == thread)
|
|
332
560
|
{
|
|
@@ -348,6 +576,24 @@ struct tile_register_t
|
|
|
348
576
|
return *this;
|
|
349
577
|
}
|
|
350
578
|
|
|
579
|
+
// apply a lambda to all valid entries in the tile
|
|
580
|
+
// Op should be a functor that takes a register index and tile_coord_t as input
|
|
581
|
+
template <typename Op>
|
|
582
|
+
void apply(Op op)
|
|
583
|
+
{
|
|
584
|
+
WP_PRAGMA_UNROLL
|
|
585
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
586
|
+
{
|
|
587
|
+
int linear = Layout::linear_from_register(i);
|
|
588
|
+
if (!Layout::valid(linear))
|
|
589
|
+
break;
|
|
590
|
+
|
|
591
|
+
auto c = Layout::coord_from_linear(linear);
|
|
592
|
+
op(i, c);
|
|
593
|
+
}
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
|
|
351
597
|
// in-place gradient zero
|
|
352
598
|
inline CUDA_CALLABLE void grad_zero()
|
|
353
599
|
{
|
|
@@ -355,118 +601,77 @@ struct tile_register_t
|
|
|
355
601
|
}
|
|
356
602
|
|
|
357
603
|
// accumulate gradients onto this tile
|
|
358
|
-
inline CUDA_CALLABLE void grad_add(const tile_register_t<T,
|
|
604
|
+
inline CUDA_CALLABLE void grad_add(const tile_register_t<T, Layout>& tile)
|
|
359
605
|
{
|
|
360
|
-
for (int i=0; i < NumRegs; ++i)
|
|
606
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
361
607
|
data[i] += tile.data[i];
|
|
362
608
|
}
|
|
363
609
|
|
|
364
|
-
|
|
610
|
+
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
611
|
+
{
|
|
612
|
+
apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
|
|
613
|
+
|
|
614
|
+
}
|
|
615
|
+
|
|
365
616
|
inline CUDA_CALLABLE auto& grad_to_register()
|
|
366
617
|
{
|
|
618
|
+
// nop for register tiles
|
|
367
619
|
return *this;
|
|
368
620
|
}
|
|
369
621
|
|
|
370
|
-
|
|
622
|
+
template <typename Global>
|
|
623
|
+
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
371
624
|
{
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
const int tile_i = x*N;
|
|
375
|
-
|
|
376
|
-
WP_PRAGMA_UNROLL
|
|
377
|
-
for (int i=0; i < NumRegs; ++i)
|
|
378
|
-
{
|
|
379
|
-
// handle case where tile size is not
|
|
380
|
-
// aligned to block dimensions
|
|
381
|
-
int linear = index(i);
|
|
382
|
-
if (!Aligned && linear >= Size)
|
|
383
|
-
break;
|
|
384
|
-
|
|
385
|
-
wp::index(dest, tile_i + linear) = data[i];
|
|
386
|
-
}
|
|
625
|
+
apply([&](int reg, auto c) { dest.store(c, data[reg]); });
|
|
387
626
|
}
|
|
388
627
|
|
|
389
|
-
|
|
628
|
+
template <typename Global>
|
|
629
|
+
inline CUDA_CALLABLE void copy_from_global(const Global& src)
|
|
390
630
|
{
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
const int tile_i = x*M;
|
|
394
|
-
const int tile_j = y*N;
|
|
395
|
-
|
|
396
|
-
// wp.array() indexing generates poor code due to char* casting
|
|
397
|
-
// here we unroll some of the ops, note this assumes byte strides are
|
|
398
|
-
// aligned to the element size
|
|
399
|
-
T* ptr = &wp::index(dest, tile_i, tile_j);
|
|
400
|
-
const int stride_i = dest.strides[0]/sizeof(T);
|
|
401
|
-
const int stride_j = dest.strides[1]/sizeof(T);
|
|
402
|
-
|
|
403
|
-
WP_PRAGMA_UNROLL
|
|
404
|
-
for (int i=0; i < NumRegs; ++i)
|
|
405
|
-
{
|
|
406
|
-
// handle case where tile size is not
|
|
407
|
-
// aligned to block dimensions
|
|
408
|
-
int linear = index(i);
|
|
409
|
-
if (!Aligned && linear >= Size)
|
|
410
|
-
break;
|
|
411
|
-
|
|
412
|
-
coord_t c = coord(linear);
|
|
413
|
-
ptr[c.i*stride_i + c.j*stride_j] = data[i];
|
|
414
|
-
}
|
|
631
|
+
apply([&](int reg, auto c) { data[reg] = src.load(c); });
|
|
415
632
|
}
|
|
416
633
|
|
|
417
|
-
|
|
634
|
+
// add a register tile to a global array
|
|
635
|
+
template <typename Global>
|
|
636
|
+
inline CUDA_CALLABLE auto atomic_add(const Global& dest)
|
|
418
637
|
{
|
|
419
|
-
//
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
WP_PRAGMA_UNROLL
|
|
423
|
-
for (int i=0; i < NumRegs; ++i)
|
|
424
|
-
{
|
|
425
|
-
int linear = index(i);
|
|
426
|
-
if (!Aligned && linear >= Size)
|
|
427
|
-
break;
|
|
638
|
+
// allocate a tile to hold previous dest value
|
|
639
|
+
auto previous = *this;
|
|
428
640
|
|
|
429
|
-
|
|
430
|
-
|
|
641
|
+
apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add(c, data[reg]); });
|
|
642
|
+
return previous;
|
|
431
643
|
}
|
|
432
644
|
|
|
433
|
-
|
|
645
|
+
// add a register tile to the gradient of a global array
|
|
646
|
+
template <typename Global>
|
|
647
|
+
inline CUDA_CALLABLE auto atomic_add_grad(const Global& dest)
|
|
434
648
|
{
|
|
435
|
-
//
|
|
436
|
-
|
|
437
|
-
const int tile_j = y*N;
|
|
438
|
-
|
|
439
|
-
// wp.array() indexing generates poor code due to char* casting
|
|
440
|
-
// here we unroll some of the ops, note this assumes array byte strides are
|
|
441
|
-
// aligned to the element size
|
|
442
|
-
const T* ptr = &wp::index(src, tile_i, tile_j);
|
|
443
|
-
|
|
444
|
-
assert(src.strides[0]%sizeof(T) == 0);
|
|
445
|
-
assert(src.strides[1]%sizeof(T) == 0);
|
|
446
|
-
|
|
447
|
-
const int stride_i = src.strides[0]/sizeof(T);
|
|
448
|
-
const int stride_j = src.strides[1]/sizeof(T);
|
|
649
|
+
// allocate a tile to hold previous dest value
|
|
650
|
+
auto previous = *this;
|
|
449
651
|
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
int linear = index(i);
|
|
454
|
-
if (!Aligned && linear >= Size)
|
|
455
|
-
break;
|
|
456
|
-
|
|
457
|
-
coord_t c = coord(linear);
|
|
458
|
-
data[i] = ptr[c.i*stride_i + c.j*stride_j];
|
|
459
|
-
}
|
|
460
|
-
}
|
|
652
|
+
apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add_grad(c, data[reg]); });
|
|
653
|
+
return previous;
|
|
654
|
+
}
|
|
461
655
|
};
|
|
462
656
|
|
|
657
|
+
|
|
463
658
|
// helper to allocate a register tile like another tile
|
|
659
|
+
// users can either specify a template explicitly or
|
|
660
|
+
// pass in another concrete instance
|
|
464
661
|
template<typename Tile>
|
|
465
|
-
auto tile_register_like()
|
|
662
|
+
auto tile_register_like(Tile* t=NULL)
|
|
466
663
|
{
|
|
467
664
|
using T = typename Tile::Type;
|
|
665
|
+
using L = typename Tile::Layout;
|
|
666
|
+
|
|
667
|
+
return tile_register_t<T, tile_layout_register_t<typename L::Shape>>(T(0.0));
|
|
668
|
+
}
|
|
468
669
|
|
|
469
|
-
|
|
670
|
+
// helper to construct a register tile from a type and a list of dims
|
|
671
|
+
template <typename T, int... Dims>
|
|
672
|
+
auto tile_register()
|
|
673
|
+
{
|
|
674
|
+
return tile_register_t<T, tile_layout_register_t<tile_shape_t<Dims...>>>();
|
|
470
675
|
}
|
|
471
676
|
|
|
472
677
|
inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
@@ -474,7 +679,10 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
|
474
679
|
// note this much match value in Python types.py
|
|
475
680
|
const int alignment = 16;
|
|
476
681
|
|
|
477
|
-
|
|
682
|
+
const int num_bytes_abs = num_bytes < 0 ? - num_bytes : num_bytes;
|
|
683
|
+
const int sign = num_bytes < 0 ? - 1 : 1;
|
|
684
|
+
|
|
685
|
+
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
478
686
|
}
|
|
479
687
|
|
|
480
688
|
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
|
|
@@ -502,20 +710,78 @@ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
|
|
|
502
710
|
}
|
|
503
711
|
|
|
504
712
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
struct tile_shared_t
|
|
713
|
+
template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
|
|
714
|
+
struct tile_layout_strided_t
|
|
508
715
|
{
|
|
509
|
-
using
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
static constexpr int Size = M*N;
|
|
716
|
+
using Shape = Shape_;
|
|
717
|
+
using Stride = Stride_;
|
|
718
|
+
using Coord = tile_coord_t<Shape::N>;
|
|
513
719
|
|
|
514
|
-
static constexpr int
|
|
515
|
-
static constexpr int StrideN = StrideN_;
|
|
516
|
-
|
|
720
|
+
static constexpr int Size = Shape::size();
|
|
517
721
|
static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
|
|
518
|
-
|
|
722
|
+
|
|
723
|
+
static inline CUDA_CALLABLE auto coord_from_linear(int linear)
|
|
724
|
+
{
|
|
725
|
+
assert(linear < Size);
|
|
726
|
+
|
|
727
|
+
Coord c;
|
|
728
|
+
|
|
729
|
+
WP_PRAGMA_UNROLL
|
|
730
|
+
for (int d=Shape::N-1; d >= 0; --d)
|
|
731
|
+
{
|
|
732
|
+
c[d] = linear%Shape::dim(d);
|
|
733
|
+
linear /= Shape::dim(d);
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
return c;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
static inline CUDA_CALLABLE int index_from_coord(Coord c)
|
|
740
|
+
{
|
|
741
|
+
int index = 0;
|
|
742
|
+
|
|
743
|
+
WP_PRAGMA_UNROLL
|
|
744
|
+
for (int d=0; d < Shape::N; ++d)
|
|
745
|
+
{
|
|
746
|
+
assert(c[d] < Shape::dim(d));
|
|
747
|
+
|
|
748
|
+
index += c[d]*Stride::dim(d);
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
return index;
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
// checks whether a strided layout is unique, i.e.: if memory locations are only
|
|
755
|
+
// every referred to by one element in the tile, this is a basic test that only
|
|
756
|
+
// checks for broadcast dimensions, it would be possible to do the full check
|
|
757
|
+
// using sorted shape/strides in Python and add it as a template parameter to the type
|
|
758
|
+
static constexpr bool is_unique()
|
|
759
|
+
{
|
|
760
|
+
constexpr int N = Shape::N;
|
|
761
|
+
|
|
762
|
+
// check for any broadcast dimensions
|
|
763
|
+
for (int i=0; i < N; ++i)
|
|
764
|
+
if (Stride::dim(i) == 0)
|
|
765
|
+
return false;
|
|
766
|
+
|
|
767
|
+
return true;
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
static constexpr bool Unique = is_unique();
|
|
771
|
+
|
|
772
|
+
static inline CUDA_CALLABLE bool valid(int linear)
|
|
773
|
+
{
|
|
774
|
+
return linear < Size;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
};
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
template <typename T, typename L, bool Owner_=true>
|
|
781
|
+
struct tile_shared_t
|
|
782
|
+
{
|
|
783
|
+
using Type = T;
|
|
784
|
+
using Layout = L;
|
|
519
785
|
static constexpr bool Owner = Owner_;
|
|
520
786
|
|
|
521
787
|
struct Storage
|
|
@@ -524,55 +790,60 @@ struct tile_shared_t
|
|
|
524
790
|
|
|
525
791
|
Storage(T* p) : ptr(p) {}
|
|
526
792
|
|
|
527
|
-
inline CUDA_CALLABLE T& operator()(
|
|
793
|
+
inline CUDA_CALLABLE T& operator()(typename Layout::Coord c)
|
|
528
794
|
{
|
|
529
|
-
assert(
|
|
530
|
-
assert(j < N);
|
|
795
|
+
assert(ptr);
|
|
531
796
|
|
|
532
|
-
|
|
797
|
+
int index = Layout::index_from_coord(c);
|
|
798
|
+
return ptr[index];
|
|
533
799
|
}
|
|
534
800
|
|
|
535
|
-
inline CUDA_CALLABLE const T& operator()(
|
|
536
|
-
{
|
|
537
|
-
assert(
|
|
538
|
-
assert(j < N);
|
|
801
|
+
inline CUDA_CALLABLE const T& operator()(typename Layout::Coord c) const
|
|
802
|
+
{
|
|
803
|
+
assert(ptr);
|
|
539
804
|
|
|
540
|
-
|
|
805
|
+
int index = Layout::index_from_coord(c);
|
|
806
|
+
return ptr[index];
|
|
541
807
|
}
|
|
542
808
|
|
|
543
|
-
inline CUDA_CALLABLE T& operator()(int
|
|
809
|
+
inline CUDA_CALLABLE T& operator()(int linear)
|
|
544
810
|
{
|
|
545
|
-
assert(
|
|
546
|
-
|
|
547
|
-
// unravel
|
|
548
|
-
int i = index/N;
|
|
549
|
-
int j = index%N;
|
|
811
|
+
assert(ptr);
|
|
812
|
+
assert(Layout::valid(linear));
|
|
550
813
|
|
|
551
|
-
|
|
814
|
+
auto c = Layout::coord_from_linear(linear);
|
|
815
|
+
return (*this)(c);
|
|
552
816
|
}
|
|
553
817
|
|
|
554
|
-
inline CUDA_CALLABLE const T& operator()(int
|
|
818
|
+
inline CUDA_CALLABLE const T& operator()(int linear) const
|
|
555
819
|
{
|
|
556
|
-
assert(
|
|
557
|
-
|
|
558
|
-
// unravel
|
|
559
|
-
int i = index/N;
|
|
560
|
-
int j = index%N;
|
|
820
|
+
assert(ptr);
|
|
821
|
+
assert(Layout::valid(linear));
|
|
561
822
|
|
|
562
|
-
|
|
823
|
+
auto c = Layout::coord_from_linear(linear);
|
|
824
|
+
return (*this)(c);
|
|
563
825
|
}
|
|
564
826
|
};
|
|
565
827
|
|
|
566
828
|
Storage data;
|
|
567
829
|
Storage grad;
|
|
568
830
|
|
|
831
|
+
// we need to track whether or not this tile's data has been initialized.
|
|
832
|
+
// once true, any re-initialization of data that follows needs a WP_TILE_SYNC()
|
|
833
|
+
// call to precede it, to allow threads that are still reading from this tile
|
|
834
|
+
// to complete their work. e.g, in a dynamic loop:
|
|
835
|
+
// for i in range(x):
|
|
836
|
+
// tile = wp.tile_load(arr, i, TILE_SIZE, storage="shared")
|
|
837
|
+
// # read from tile...
|
|
838
|
+
bool initialized;
|
|
839
|
+
|
|
569
840
|
// default initialization (non-initialized)
|
|
570
|
-
inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL)
|
|
841
|
+
inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL), initialized(false)
|
|
571
842
|
{
|
|
572
843
|
}
|
|
573
844
|
|
|
574
845
|
// initialize from an existing tile's memory
|
|
575
|
-
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL) : data(data), grad(grad)
|
|
846
|
+
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
576
847
|
{
|
|
577
848
|
}
|
|
578
849
|
|
|
@@ -582,10 +853,10 @@ struct tile_shared_t
|
|
|
582
853
|
{
|
|
583
854
|
// update our per-thread shared memory allocator
|
|
584
855
|
if (data.ptr)
|
|
585
|
-
tile_alloc_shared(-
|
|
856
|
+
tile_alloc_shared(-Layout::Size*int(sizeof(T)));
|
|
586
857
|
|
|
587
858
|
if (grad.ptr)
|
|
588
|
-
tile_alloc_shared(-
|
|
859
|
+
tile_alloc_shared(-Layout::Size*int(sizeof(T)));
|
|
589
860
|
}
|
|
590
861
|
}
|
|
591
862
|
|
|
@@ -597,12 +868,13 @@ struct tile_shared_t
|
|
|
597
868
|
return *this;
|
|
598
869
|
}
|
|
599
870
|
|
|
871
|
+
|
|
600
872
|
// construct from another shared tile, this constructor
|
|
601
873
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
602
|
-
template <typename OtherT,
|
|
603
|
-
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT,
|
|
874
|
+
template <typename OtherT, typename OtherLayout>
|
|
875
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
|
|
604
876
|
{
|
|
605
|
-
using OtherTile = tile_shared_t<OtherT,
|
|
877
|
+
using OtherTile = tile_shared_t<OtherT, OtherLayout>;
|
|
606
878
|
|
|
607
879
|
// check dimensions are compatible
|
|
608
880
|
static_assert(Size == OtherTile::Size);
|
|
@@ -610,89 +882,89 @@ struct tile_shared_t
|
|
|
610
882
|
// alias tile directly
|
|
611
883
|
data = rhs.data;
|
|
612
884
|
grad = rhs.grad;
|
|
885
|
+
initialized = rhs.initialized;
|
|
613
886
|
|
|
614
887
|
return *this;
|
|
615
888
|
}
|
|
616
889
|
|
|
617
890
|
// assign from a global tile (load)
|
|
618
|
-
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
|
|
891
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
619
892
|
{
|
|
620
|
-
|
|
621
|
-
copy_from_global(t.data, t.x); // 1d load
|
|
622
|
-
else
|
|
623
|
-
copy_from_global(t.data, t.x, t.y); // 2d load
|
|
624
|
-
|
|
625
|
-
// synchronization happens in copy functions above
|
|
626
|
-
|
|
893
|
+
copy_from_global(t);
|
|
627
894
|
return *this;
|
|
628
895
|
}
|
|
629
896
|
|
|
630
897
|
// assign from a constant value
|
|
631
898
|
inline CUDA_CALLABLE auto& operator=(const T& x)
|
|
632
899
|
{
|
|
633
|
-
|
|
900
|
+
// sync if we are re-initializing data so that any threads that are still
|
|
901
|
+
// reading from this tile can complete their work, e.g.: if re-assigning
|
|
902
|
+
// to a tile during a dynamic loop
|
|
903
|
+
if (initialized)
|
|
904
|
+
WP_TILE_SYNC();
|
|
905
|
+
|
|
906
|
+
for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
634
907
|
data(i) = x;
|
|
635
908
|
|
|
909
|
+
initialized = true;
|
|
636
910
|
WP_TILE_SYNC();
|
|
637
911
|
return *this;
|
|
638
912
|
}
|
|
639
913
|
|
|
640
|
-
|
|
641
|
-
// compute tile coordinate from linear index
|
|
642
|
-
inline CUDA_CALLABLE coord_t coord(int index) const
|
|
643
|
-
{
|
|
644
|
-
return {index/N, index%N};
|
|
645
|
-
}
|
|
646
|
-
|
|
647
914
|
// in-place zero
|
|
648
915
|
inline CUDA_CALLABLE void zero()
|
|
649
916
|
{
|
|
650
|
-
for (int i=threadIdx.x; i <
|
|
917
|
+
for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
651
918
|
data(i) = T(0);
|
|
652
919
|
|
|
653
920
|
WP_TILE_SYNC();
|
|
654
921
|
}
|
|
655
922
|
|
|
656
923
|
// extract a single tile element to a native type
|
|
657
|
-
inline CUDA_CALLABLE Type extract(
|
|
924
|
+
inline CUDA_CALLABLE Type extract(const typename Layout::Coord& c)
|
|
658
925
|
{
|
|
659
|
-
return data(
|
|
926
|
+
return data(c);
|
|
660
927
|
}
|
|
661
928
|
|
|
662
929
|
// backward of scalar extraction
|
|
663
|
-
inline CUDA_CALLABLE void adj_extract(
|
|
930
|
+
inline CUDA_CALLABLE void adj_extract(const typename Layout::Coord& c, Type adj_ret)
|
|
664
931
|
{
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
932
|
+
// since multiple threads may extract the same element
|
|
933
|
+
// we need to accumulate using atomic operations
|
|
934
|
+
wp::atomic_add(&grad(c), adj_ret);
|
|
935
|
+
|
|
936
|
+
WP_TILE_SYNC();
|
|
669
937
|
}
|
|
670
938
|
|
|
671
939
|
|
|
672
940
|
// copy register tile to shared
|
|
673
|
-
|
|
941
|
+
template <typename Tile>
|
|
942
|
+
inline CUDA_CALLABLE void assign(const Tile& tile)
|
|
674
943
|
{
|
|
944
|
+
if (initialized)
|
|
945
|
+
WP_TILE_SYNC();
|
|
946
|
+
|
|
675
947
|
WP_PRAGMA_UNROLL
|
|
676
|
-
for (int i=0; i <
|
|
948
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
677
949
|
{
|
|
678
|
-
const int linear =
|
|
950
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
679
951
|
|
|
680
952
|
// handle case where tile size is not
|
|
681
953
|
// aligned to block dimensions
|
|
682
|
-
if (!
|
|
683
|
-
break;
|
|
954
|
+
if (!Tile::Layout::valid(linear))
|
|
955
|
+
break;
|
|
684
956
|
|
|
685
957
|
data(linear) = tile.data[i];
|
|
686
958
|
}
|
|
687
959
|
|
|
960
|
+
initialized = true;
|
|
688
961
|
WP_TILE_SYNC();
|
|
689
962
|
}
|
|
690
963
|
|
|
691
964
|
// in-place gradient zero
|
|
692
965
|
inline CUDA_CALLABLE void grad_zero()
|
|
693
966
|
{
|
|
694
|
-
|
|
695
|
-
for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
|
|
967
|
+
for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
696
968
|
grad(i) = T(0);
|
|
697
969
|
|
|
698
970
|
WP_TILE_SYNC();
|
|
@@ -700,44 +972,73 @@ struct tile_shared_t
|
|
|
700
972
|
|
|
701
973
|
|
|
702
974
|
// accumulate gradients onto this tile
|
|
703
|
-
|
|
975
|
+
template <typename Tile>
|
|
976
|
+
inline CUDA_CALLABLE void grad_add(const Tile& tile)
|
|
704
977
|
{
|
|
705
978
|
WP_PRAGMA_UNROLL
|
|
706
|
-
for (int i=0; i <
|
|
979
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
707
980
|
{
|
|
708
|
-
const int linear =
|
|
981
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
709
982
|
|
|
710
983
|
// handle case where tile size is not
|
|
711
984
|
// aligned to block dimensions
|
|
712
|
-
if (!
|
|
985
|
+
if (!Tile::Layout::valid(linear))
|
|
713
986
|
break;
|
|
714
987
|
|
|
715
|
-
if (
|
|
988
|
+
// if the destination layout is unique (no broadcast dimensions)
|
|
989
|
+
// then we can use regular non-atomic accmulation
|
|
990
|
+
if (Layout::Unique)
|
|
716
991
|
grad(linear) += tile.data[i];
|
|
717
992
|
else
|
|
718
993
|
// use shared memory atomics to accumulate gradients
|
|
719
994
|
// since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
|
|
720
995
|
// may map to a single location in shared memory
|
|
721
|
-
atomic_add(&grad(linear), tile.data[i]);
|
|
996
|
+
wp::atomic_add(&grad(linear), tile.data[i]);
|
|
722
997
|
|
|
723
998
|
}
|
|
724
999
|
|
|
725
1000
|
WP_TILE_SYNC();
|
|
726
1001
|
}
|
|
727
1002
|
|
|
1003
|
+
// accumulate gradient onto this tile from a global array
|
|
1004
|
+
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1005
|
+
{
|
|
1006
|
+
WP_PRAGMA_UNROLL
|
|
1007
|
+
for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1008
|
+
{
|
|
1009
|
+
auto c = Layout::coord_from_linear(i);
|
|
1010
|
+
T g = global.load_grad(c);
|
|
1011
|
+
|
|
1012
|
+
if (Layout::Unique)
|
|
1013
|
+
{
|
|
1014
|
+
// if the destination layout is unique (no broadcast dimensions)
|
|
1015
|
+
// then we can use regular non-atomic accumulation
|
|
1016
|
+
grad(c) += g;
|
|
1017
|
+
}
|
|
1018
|
+
else
|
|
1019
|
+
{
|
|
1020
|
+
// use shared memory atomics to accumulate gradients
|
|
1021
|
+
// since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
|
|
1022
|
+
// may map to a single location in shared memory
|
|
1023
|
+
wp::atomic_add(&grad(c), g);
|
|
1024
|
+
}
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
WP_TILE_SYNC();
|
|
1028
|
+
}
|
|
1029
|
+
|
|
728
1030
|
// copy shared tile to register
|
|
729
|
-
inline CUDA_CALLABLE
|
|
1031
|
+
inline CUDA_CALLABLE auto grad_to_register()
|
|
730
1032
|
{
|
|
731
|
-
tile_register_t<T,
|
|
1033
|
+
using Tile = tile_register_t<T, tile_layout_register_t<typename Layout::Shape>>;
|
|
1034
|
+
Tile out;
|
|
732
1035
|
|
|
733
1036
|
WP_PRAGMA_UNROLL
|
|
734
|
-
for (int i=0; i <
|
|
1037
|
+
for (int i=0; i < Tile::Layout::NumRegs; ++i)
|
|
735
1038
|
{
|
|
736
|
-
const int linear =
|
|
1039
|
+
const int linear = Tile::Layout::linear_from_register(i);
|
|
737
1040
|
|
|
738
|
-
|
|
739
|
-
// aligned to block dimensions
|
|
740
|
-
if (!Aligned && linear >= Size)
|
|
1041
|
+
if (!Tile::Layout::valid(linear))
|
|
741
1042
|
break;
|
|
742
1043
|
|
|
743
1044
|
out(i) = grad(linear);
|
|
@@ -746,40 +1047,20 @@ struct tile_shared_t
|
|
|
746
1047
|
return out;
|
|
747
1048
|
}
|
|
748
1049
|
|
|
749
|
-
inline CUDA_CALLABLE void print() const
|
|
750
|
-
{
|
|
751
|
-
if (threadIdx.x == 0)
|
|
752
|
-
{
|
|
753
|
-
printf("tile(m=%d, n=%d, storage=shared) = [", M, N);
|
|
754
|
-
for (int i=0; i < M; ++i)
|
|
755
|
-
{
|
|
756
|
-
printf("%*s[", i>0, "");
|
|
757
|
-
for (int j=0; j < N; ++j)
|
|
758
|
-
{
|
|
759
|
-
printf("%g ", double(data(i, j)));
|
|
760
|
-
}
|
|
761
|
-
|
|
762
|
-
if (i == M-1)
|
|
763
|
-
printf("]]\n");
|
|
764
|
-
else
|
|
765
|
-
printf("]\n");
|
|
766
|
-
}
|
|
767
|
-
}
|
|
768
|
-
}
|
|
769
|
-
|
|
770
1050
|
// copy shared tile to register
|
|
771
|
-
inline CUDA_CALLABLE
|
|
1051
|
+
inline CUDA_CALLABLE auto copy_to_register() const
|
|
772
1052
|
{
|
|
773
|
-
|
|
1053
|
+
|
|
1054
|
+
auto out = tile_register_like(this);
|
|
1055
|
+
|
|
1056
|
+
using Layout = typename decltype(out)::Layout;
|
|
774
1057
|
|
|
775
1058
|
WP_PRAGMA_UNROLL
|
|
776
|
-
for (int i=0; i <
|
|
1059
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
777
1060
|
{
|
|
778
|
-
const int linear =
|
|
1061
|
+
const int linear = Layout::linear_from_register(i);
|
|
779
1062
|
|
|
780
|
-
|
|
781
|
-
// aligned to block dimensions
|
|
782
|
-
if (!Aligned && linear >= Size)
|
|
1063
|
+
if (!Layout::valid(linear))
|
|
783
1064
|
break;
|
|
784
1065
|
|
|
785
1066
|
out(i) = data(linear);
|
|
@@ -788,220 +1069,358 @@ struct tile_shared_t
|
|
|
788
1069
|
return out;
|
|
789
1070
|
}
|
|
790
1071
|
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
1072
|
+
template <typename Global>
|
|
1073
|
+
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
1074
|
+
{
|
|
1075
|
+
// vectorized loads for specific input/output shapes
|
|
1076
|
+
if constexpr (Layout::Shape::N == 2)
|
|
1077
|
+
{
|
|
1078
|
+
constexpr int lastdim = Layout::Shape::N-1;
|
|
1079
|
+
constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
|
|
1080
|
+
const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
|
|
1081
|
+
const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
|
|
1082
|
+
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1083
|
+
|
|
1084
|
+
float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
|
|
1085
|
+
const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
|
|
1086
|
+
|
|
1087
|
+
if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
|
|
1088
|
+
{
|
|
1089
|
+
constexpr int M = Layout::Shape::dim(0);
|
|
1090
|
+
constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
|
|
1091
|
+
|
|
1092
|
+
// alias of shared tile with 128bit type
|
|
1093
|
+
using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1094
|
+
tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
|
|
1095
|
+
|
|
1096
|
+
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
1097
|
+
assert(((uint64_t)(dest128))%sizeof(float4) == 0);
|
|
1098
|
+
|
|
1099
|
+
const int stride_i = dest.data.strides[0]/sizeof(float4);
|
|
1100
|
+
const int stride_j = 1;
|
|
1101
|
+
|
|
1102
|
+
WP_PRAGMA_UNROLL
|
|
1103
|
+
for (int i=threadIdx.x; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1104
|
+
{
|
|
1105
|
+
auto c = SrcLayout::coord_from_linear(i);
|
|
1106
|
+
|
|
1107
|
+
dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i);
|
|
1108
|
+
}
|
|
794
1109
|
|
|
795
|
-
|
|
796
|
-
|
|
1110
|
+
return;
|
|
1111
|
+
}
|
|
1112
|
+
}
|
|
797
1113
|
|
|
1114
|
+
// scalar bounds checked path
|
|
798
1115
|
WP_PRAGMA_UNROLL
|
|
799
|
-
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
1116
|
+
for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
800
1117
|
{
|
|
801
|
-
|
|
1118
|
+
auto c = Layout::coord_from_linear(i);
|
|
1119
|
+
dest.store(c, data(i));
|
|
802
1120
|
}
|
|
803
1121
|
}
|
|
804
1122
|
|
|
805
|
-
|
|
1123
|
+
__device__ __forceinline__
|
|
1124
|
+
void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
806
1125
|
{
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
1126
|
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1127
|
+
|
|
1128
|
+
unsigned long long saddr = 0ULL;
|
|
1129
|
+
unsigned long long gaddr = 0ULL;
|
|
1130
|
+
|
|
1131
|
+
asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest));
|
|
1132
|
+
asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src));
|
|
1133
|
+
|
|
1134
|
+
// Use cp.async on newer architectures
|
|
1135
|
+
asm volatile(
|
|
1136
|
+
"cp.async.ca.shared.global [%0], [%1], 16;\n"
|
|
1137
|
+
:
|
|
1138
|
+
: "l"(saddr), "l"(gaddr)
|
|
1139
|
+
);
|
|
1140
|
+
#else
|
|
1141
|
+
// use regular load/store through register on older arches
|
|
1142
|
+
*shared_dest = *global_src;
|
|
1143
|
+
#endif
|
|
1144
|
+
}
|
|
821
1145
|
|
|
822
|
-
|
|
823
|
-
|
|
1146
|
+
__device__ __forceinline__
|
|
1147
|
+
void cp_async_commit_and_wait_all_128()
|
|
1148
|
+
{
|
|
1149
|
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1150
|
+
asm volatile(
|
|
1151
|
+
"cp.async.commit_group;\n"
|
|
1152
|
+
"cp.async.wait_group 0;\n" ::);
|
|
1153
|
+
#endif
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
template <typename Global>
|
|
1157
|
+
inline CUDA_CALLABLE void copy_from_global(const Global& src)
|
|
1158
|
+
{
|
|
1159
|
+
if (initialized)
|
|
1160
|
+
WP_TILE_SYNC();
|
|
1161
|
+
|
|
1162
|
+
// vectorized loads for specific input/output shapes
|
|
1163
|
+
if constexpr (Layout::Shape::N == 2)
|
|
1164
|
+
{
|
|
1165
|
+
constexpr int lastdim = Layout::Shape::N-1;
|
|
1166
|
+
constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
|
|
1167
|
+
const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
|
|
1168
|
+
const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
|
|
1169
|
+
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1170
|
+
|
|
1171
|
+
float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
|
|
1172
|
+
const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
|
|
1173
|
+
|
|
1174
|
+
if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
|
|
1175
|
+
{
|
|
1176
|
+
constexpr int M = Layout::Shape::dim(0);
|
|
1177
|
+
constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
|
|
1178
|
+
|
|
1179
|
+
// alias of shared tile with 128bit type
|
|
1180
|
+
using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1181
|
+
tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
|
|
1182
|
+
|
|
1183
|
+
assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
|
|
1184
|
+
assert(((uint64_t)(src128))%sizeof(float4) == 0);
|
|
1185
|
+
|
|
1186
|
+
const int stride_i = src.data.strides[0]/sizeof(float4);
|
|
1187
|
+
const int stride_j = 1;
|
|
1188
|
+
|
|
1189
|
+
WP_PRAGMA_UNROLL
|
|
1190
|
+
for (int i=threadIdx.x; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1191
|
+
{
|
|
1192
|
+
auto c = DestLayout::coord_from_linear(i);
|
|
1193
|
+
|
|
1194
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
1195
|
+
cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]);
|
|
1196
|
+
#else
|
|
1197
|
+
dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]];
|
|
1198
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
1199
|
+
}
|
|
824
1200
|
|
|
825
|
-
|
|
826
|
-
|
|
1201
|
+
#if WP_USE_ASYNC_PIPELINE
|
|
1202
|
+
cp_async_commit_and_wait_all_128();
|
|
1203
|
+
#endif // WP_USE_ASYNC_PIPELINE
|
|
827
1204
|
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
coord_t c = src128.coord(i);
|
|
832
|
-
ptr[c.i*stride_i + c.j*stride_j] = src128.data(i);
|
|
1205
|
+
initialized = true;
|
|
1206
|
+
WP_TILE_SYNC();
|
|
1207
|
+
return;
|
|
833
1208
|
}
|
|
834
1209
|
}
|
|
835
|
-
else
|
|
836
|
-
{
|
|
837
|
-
// wp.array() indexing generates poor code due to char* casting
|
|
838
|
-
// here we unroll some of the ops, note this assumes byte strides are
|
|
839
|
-
// aligned to the element size
|
|
840
|
-
T* ptr = &wp::index(dest, tile_i, tile_j);
|
|
841
|
-
const int stride_i = dest.strides[0]/sizeof(T);
|
|
842
|
-
const int stride_j = dest.strides[1]/sizeof(T);
|
|
843
|
-
|
|
844
|
-
WP_PRAGMA_UNROLL
|
|
845
|
-
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
846
|
-
{
|
|
847
|
-
coord_t c = coord(i);
|
|
848
|
-
ptr[c.i*stride_i + c.j*stride_j] = data(c.i, c.j);
|
|
849
|
-
}
|
|
850
|
-
}
|
|
851
|
-
}
|
|
852
|
-
|
|
853
|
-
inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
|
|
854
|
-
{
|
|
855
|
-
// todo: use async pipelines or TMA here
|
|
856
|
-
const int tile_i = x*N;
|
|
857
1210
|
|
|
1211
|
+
// scalar bounds checked path
|
|
858
1212
|
WP_PRAGMA_UNROLL
|
|
859
|
-
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
1213
|
+
for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
860
1214
|
{
|
|
861
|
-
|
|
1215
|
+
auto c = Layout::coord_from_linear(i);
|
|
1216
|
+
data(i) = src.load(c);
|
|
862
1217
|
}
|
|
863
1218
|
|
|
1219
|
+
initialized = true;
|
|
864
1220
|
WP_TILE_SYNC();
|
|
865
1221
|
}
|
|
866
1222
|
|
|
867
|
-
|
|
1223
|
+
template <typename Global>
|
|
1224
|
+
inline CUDA_CALLABLE auto atomic_add(Global& dest)
|
|
868
1225
|
{
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
const int tile_j = y*N;
|
|
1226
|
+
copy_to_register().atomic_add(dest);
|
|
1227
|
+
}
|
|
872
1228
|
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1229
|
+
template <typename Global>
|
|
1230
|
+
inline CUDA_CALLABLE auto atomic_add_grad(Global& dest)
|
|
1231
|
+
{
|
|
1232
|
+
grad_to_register().atomic_add_grad(dest);
|
|
1233
|
+
}
|
|
878
1234
|
|
|
879
|
-
|
|
880
|
-
|
|
1235
|
+
// overload for integral types
|
|
1236
|
+
inline CUDA_CALLABLE void print_value(int x) const
|
|
1237
|
+
{
|
|
1238
|
+
printf("%d", x);
|
|
1239
|
+
}
|
|
881
1240
|
|
|
882
|
-
|
|
1241
|
+
// overload for floating point types
|
|
1242
|
+
template <typename ValueType>
|
|
1243
|
+
inline CUDA_CALLABLE void print_value(ValueType x) const
|
|
1244
|
+
{
|
|
1245
|
+
printf("%g", x);
|
|
1246
|
+
}
|
|
883
1247
|
|
|
884
|
-
|
|
885
|
-
|
|
1248
|
+
template <int Level = 0>
|
|
1249
|
+
inline CUDA_CALLABLE void print_values(const Storage& storage, int index=0) const
|
|
1250
|
+
{
|
|
1251
|
+
using Shape = typename Layout::Shape;
|
|
886
1252
|
|
|
887
|
-
|
|
888
|
-
|
|
1253
|
+
if constexpr (Level < Shape::N)
|
|
1254
|
+
{
|
|
1255
|
+
if constexpr (Level == Shape::N - 1)
|
|
1256
|
+
{
|
|
1257
|
+
// Special handling for 1D case
|
|
1258
|
+
printf("[");
|
|
1259
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1260
|
+
{
|
|
1261
|
+
print_value(storage(index + i));
|
|
889
1262
|
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
1263
|
+
if (i < Shape::dim(Level) - 1)
|
|
1264
|
+
{
|
|
1265
|
+
printf(" ");
|
|
1266
|
+
}
|
|
1267
|
+
}
|
|
1268
|
+
printf("]");
|
|
1269
|
+
}
|
|
1270
|
+
else if constexpr (Level == Shape::N - 2)
|
|
1271
|
+
{
|
|
1272
|
+
// Special handling for 2D case
|
|
1273
|
+
printf("[");
|
|
1274
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1275
|
+
{
|
|
1276
|
+
printf("[");
|
|
1277
|
+
for (int j=0; j < Shape::dim(Level+1); ++j)
|
|
1278
|
+
{
|
|
1279
|
+
print_value(storage(index));
|
|
1280
|
+
|
|
1281
|
+
if (j < Shape::dim(Level+1) - 1)
|
|
1282
|
+
{
|
|
1283
|
+
printf(" ");
|
|
1284
|
+
}
|
|
1285
|
+
|
|
1286
|
+
++index;
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
printf("]");
|
|
1290
|
+
|
|
1291
|
+
// next row
|
|
1292
|
+
if (i < Shape::dim(Level)-1)
|
|
1293
|
+
{
|
|
1294
|
+
printf("\n");
|
|
1295
|
+
|
|
1296
|
+
// indent next row
|
|
1297
|
+
for (int i=0; i <= Shape::N-2; ++i)
|
|
1298
|
+
printf(" ");
|
|
1299
|
+
|
|
1300
|
+
}
|
|
1301
|
+
}
|
|
1302
|
+
printf("]");
|
|
1303
|
+
}
|
|
1304
|
+
else
|
|
1305
|
+
{
|
|
1306
|
+
printf("[");
|
|
1307
|
+
for (int i = 0; i < Shape::dim(Level); ++i)
|
|
1308
|
+
{
|
|
1309
|
+
print_values<Level + 1>(storage, index + i * Shape::dim(Level));
|
|
1310
|
+
if (i < Shape::dim(Level) - 1)
|
|
1311
|
+
{
|
|
1312
|
+
printf("\n\n");
|
|
1313
|
+
|
|
1314
|
+
// indent next row
|
|
1315
|
+
for (int i=0; i <= Level; ++i)
|
|
1316
|
+
printf(" ");
|
|
1317
|
+
}
|
|
1318
|
+
}
|
|
1319
|
+
printf("]");
|
|
902
1320
|
}
|
|
1321
|
+
}
|
|
1322
|
+
}
|
|
903
1323
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
1324
|
+
inline CUDA_CALLABLE void print(bool reverse=false) const
|
|
1325
|
+
{
|
|
1326
|
+
if (threadIdx.x != 0)
|
|
1327
|
+
return;
|
|
907
1328
|
|
|
908
|
-
|
|
1329
|
+
if (reverse)
|
|
1330
|
+
print_values(grad);
|
|
909
1331
|
else
|
|
1332
|
+
print_values(data);
|
|
1333
|
+
|
|
1334
|
+
printf(" = tile(shape=(");
|
|
1335
|
+
for (int i=0; i < Layout::Shape::N; ++i)
|
|
910
1336
|
{
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
const T* ptr = &wp::index(src, tile_i, tile_j);
|
|
915
|
-
|
|
916
|
-
assert(src.strides[0]%sizeof(T) == 0);
|
|
917
|
-
assert(src.strides[1]%sizeof(T) == 0);
|
|
918
|
-
|
|
919
|
-
const int stride_i = src.strides[0]/sizeof(T);
|
|
920
|
-
const int stride_j = src.strides[1]/sizeof(T);
|
|
921
|
-
|
|
922
|
-
WP_PRAGMA_UNROLL
|
|
923
|
-
for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
|
|
924
|
-
{
|
|
925
|
-
coord_t c = coord(i);
|
|
926
|
-
data(c.i, c.j) = ptr[c.i*stride_i + c.j*stride_j];
|
|
927
|
-
}
|
|
1337
|
+
printf("%d", Layout::Shape::dim(i));
|
|
1338
|
+
if (i != Layout::Shape::N-1)
|
|
1339
|
+
printf(",");
|
|
928
1340
|
}
|
|
929
1341
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
#endif
|
|
933
|
-
|
|
934
|
-
}
|
|
1342
|
+
printf("), storage=shared)\n");
|
|
1343
|
+
}
|
|
935
1344
|
};
|
|
936
1345
|
|
|
937
|
-
|
|
938
|
-
|
|
1346
|
+
|
|
1347
|
+
template <typename T, typename L>
|
|
1348
|
+
void tile_register_t<T, L>::print() const
|
|
939
1349
|
{
|
|
940
1350
|
// create a temporary shared tile so that
|
|
941
1351
|
// we can print it deterministically
|
|
942
|
-
WP_TILE_SHARED T smem[
|
|
943
|
-
|
|
944
|
-
|
|
1352
|
+
WP_TILE_SHARED T smem[L::Size];
|
|
1353
|
+
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>> scratch(smem, NULL);
|
|
1354
|
+
|
|
945
1355
|
scratch.assign(*this);
|
|
946
1356
|
|
|
947
1357
|
WP_TILE_SYNC();
|
|
948
1358
|
|
|
949
1359
|
if (threadIdx.x == 0)
|
|
950
1360
|
{
|
|
951
|
-
|
|
952
|
-
for (int i=0; i < M; ++i)
|
|
953
|
-
{
|
|
954
|
-
printf("%*s[", i>0, "");
|
|
955
|
-
for (int j=0; j < N; ++j)
|
|
956
|
-
{
|
|
957
|
-
printf("%g ", double(scratch.data(i, j)));
|
|
958
|
-
}
|
|
1361
|
+
scratch.print_values(scratch.data, 0);
|
|
959
1362
|
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
1363
|
+
printf(" = tile(shape=(");
|
|
1364
|
+
for (int i=0; i < L::Shape::N; ++i)
|
|
1365
|
+
{
|
|
1366
|
+
printf("%d", L::Shape::dim(i));
|
|
1367
|
+
if (i != L::Shape::N-1)
|
|
1368
|
+
printf(",");
|
|
964
1369
|
}
|
|
1370
|
+
|
|
1371
|
+
printf("), storage=register)\n");
|
|
965
1372
|
}
|
|
966
1373
|
|
|
967
1374
|
WP_TILE_SYNC();
|
|
968
1375
|
}
|
|
969
1376
|
|
|
970
|
-
|
|
971
|
-
|
|
1377
|
+
// print entry points
|
|
1378
|
+
template <typename T, typename L>
|
|
1379
|
+
inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
|
|
1380
|
+
template <typename T, typename L, bool Owner>
|
|
1381
|
+
inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
|
|
1382
|
+
|
|
1383
|
+
template <typename T, typename L, bool O>
|
|
1384
|
+
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
972
1385
|
{
|
|
973
|
-
|
|
1386
|
+
return Tile::Layout::Shape::dim(0);
|
|
974
1387
|
}
|
|
975
1388
|
|
|
976
|
-
template <typename T,
|
|
977
|
-
inline CUDA_CALLABLE void
|
|
1389
|
+
template <typename T, typename L, bool O, typename AdjTile>
|
|
1390
|
+
inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile& a, int& adj_ret)
|
|
978
1391
|
{
|
|
979
|
-
a.print();
|
|
980
1392
|
}
|
|
981
1393
|
|
|
982
|
-
template <typename T,
|
|
983
|
-
inline CUDA_CALLABLE
|
|
1394
|
+
template <typename T, typename L>
|
|
1395
|
+
inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
|
|
984
1396
|
{
|
|
985
|
-
|
|
1397
|
+
return Tile::Layout::Shape::dim(0);
|
|
986
1398
|
}
|
|
987
1399
|
|
|
988
|
-
template <typename T,
|
|
989
|
-
inline CUDA_CALLABLE void
|
|
1400
|
+
template <typename T, typename L, typename AdjTile>
|
|
1401
|
+
inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile& a, int& adj_ret)
|
|
990
1402
|
{
|
|
991
|
-
a.print();
|
|
992
1403
|
}
|
|
993
1404
|
|
|
1405
|
+
|
|
1406
|
+
template <typename T, typename L>
|
|
1407
|
+
inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
|
|
1408
|
+
template <typename T, typename L, bool Owner>
|
|
1409
|
+
inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
|
|
1410
|
+
|
|
1411
|
+
|
|
1412
|
+
|
|
994
1413
|
// helpers to allocate shared tiles
|
|
995
|
-
template <typename T,
|
|
1414
|
+
template <typename T, typename Shape, bool RequiresGrad>
|
|
996
1415
|
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
997
1416
|
|
|
998
|
-
{ constexpr int
|
|
999
|
-
T* data = (T*)tile_alloc_shared(
|
|
1417
|
+
{ constexpr int size = Shape::size();
|
|
1418
|
+
T* data = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1000
1419
|
T* grad = NULL;
|
|
1001
1420
|
|
|
1002
1421
|
#if FP_CHECK
|
|
1003
1422
|
|
|
1004
|
-
for (int i=threadIdx.x; i <
|
|
1423
|
+
for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1005
1424
|
data[i] = T(nanf(""));
|
|
1006
1425
|
|
|
1007
1426
|
WP_TILE_SYNC();
|
|
@@ -1011,15 +1430,15 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1011
1430
|
|
|
1012
1431
|
if (RequiresGrad)
|
|
1013
1432
|
{
|
|
1014
|
-
grad = (T*)tile_alloc_shared(
|
|
1433
|
+
grad = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1015
1434
|
|
|
1016
|
-
for (int i=threadIdx.x; i <
|
|
1435
|
+
for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1017
1436
|
grad[i] = T(0);
|
|
1018
1437
|
|
|
1019
1438
|
WP_TILE_SYNC();
|
|
1020
1439
|
}
|
|
1021
1440
|
|
|
1022
|
-
return tile_shared_t<T,
|
|
1441
|
+
return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
|
|
1023
1442
|
}
|
|
1024
1443
|
|
|
1025
1444
|
template <typename T, int M, int N, bool RequiresGrad>
|
|
@@ -1043,7 +1462,7 @@ inline CUDA_CALLABLE auto tile_alloc_zeros()
|
|
|
1043
1462
|
|
|
1044
1463
|
WP_TILE_SYNC();
|
|
1045
1464
|
|
|
1046
|
-
return tile_shared_t<T, M, N
|
|
1465
|
+
return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
|
|
1047
1466
|
}
|
|
1048
1467
|
|
|
1049
1468
|
|
|
@@ -1054,9 +1473,10 @@ inline CUDA_CALLABLE auto tile_alloc_zeros()
|
|
|
1054
1473
|
template <typename T>
|
|
1055
1474
|
inline CUDA_CALLABLE auto tile(const T& x)
|
|
1056
1475
|
{
|
|
1057
|
-
tile_register_t<T,
|
|
1476
|
+
tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
|
|
1058
1477
|
|
|
1059
|
-
|
|
1478
|
+
using Layout = typename decltype(result)::Layout;
|
|
1479
|
+
static_assert(Layout::NumRegs == 1);
|
|
1060
1480
|
|
|
1061
1481
|
result.data[0] = x;
|
|
1062
1482
|
return result;
|
|
@@ -1066,9 +1486,10 @@ inline CUDA_CALLABLE auto tile(const T& x)
|
|
|
1066
1486
|
template <typename T, unsigned Length>
|
|
1067
1487
|
inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
1068
1488
|
{
|
|
1069
|
-
tile_register_t<T, Length, WP_TILE_BLOCK_DIM
|
|
1489
|
+
tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
|
|
1070
1490
|
|
|
1071
|
-
|
|
1491
|
+
using Layout = typename decltype(result)::Layout;
|
|
1492
|
+
static_assert(Layout::NumRegs == Length);
|
|
1072
1493
|
|
|
1073
1494
|
for (int i=0; i < Length; ++i)
|
|
1074
1495
|
result.data[i] = x[i];
|
|
@@ -1080,8 +1501,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1080
1501
|
template <typename T, typename AdjTile>
|
|
1081
1502
|
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1082
1503
|
{
|
|
1083
|
-
static_assert(AdjTile::
|
|
1084
|
-
static_assert(AdjTile::
|
|
1504
|
+
static_assert(AdjTile::Layout::Shape::N == 1);
|
|
1505
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
|
|
1085
1506
|
|
|
1086
1507
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1087
1508
|
|
|
@@ -1091,8 +1512,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
|
1091
1512
|
template <typename T, unsigned Length, typename AdjTile>
|
|
1092
1513
|
inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
|
|
1093
1514
|
{
|
|
1094
|
-
static_assert(AdjTile::
|
|
1095
|
-
static_assert(AdjTile::
|
|
1515
|
+
static_assert(AdjTile::Layout::Shape::N == 2);
|
|
1516
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == Length);
|
|
1517
|
+
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
|
|
1096
1518
|
|
|
1097
1519
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1098
1520
|
|
|
@@ -1108,17 +1530,20 @@ inline CUDA_CALLABLE auto untile(Tile& tile)
|
|
|
1108
1530
|
// there is exactly one value per-thread
|
|
1109
1531
|
auto reg = tile.copy_to_register();
|
|
1110
1532
|
|
|
1533
|
+
constexpr int N = Tile::Layout::Shape::N;
|
|
1534
|
+
|
|
1111
1535
|
// scalar case
|
|
1112
|
-
if constexpr(
|
|
1536
|
+
if constexpr(N == 1)
|
|
1113
1537
|
{
|
|
1114
1538
|
return reg.data[0];
|
|
1115
1539
|
}
|
|
1116
1540
|
|
|
1117
1541
|
// vector case
|
|
1118
|
-
if constexpr(
|
|
1542
|
+
if constexpr(N == 2)
|
|
1119
1543
|
{
|
|
1120
|
-
|
|
1121
|
-
|
|
1544
|
+
constexpr int Length = Tile::Layout::Shape::dim(0);
|
|
1545
|
+
wp::vec_t<Length, typename Tile::Type> v;
|
|
1546
|
+
for (int i=0; i < Length; ++i)
|
|
1122
1547
|
v[i] = reg.data[i];
|
|
1123
1548
|
|
|
1124
1549
|
return v;
|
|
@@ -1130,24 +1555,27 @@ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
|
|
|
1130
1555
|
{
|
|
1131
1556
|
auto adj = adj_tile.copy_to_register();
|
|
1132
1557
|
|
|
1558
|
+
constexpr int N = Tile::Layout::Shape::N;
|
|
1559
|
+
|
|
1133
1560
|
// scalar case
|
|
1134
|
-
if constexpr(
|
|
1561
|
+
if constexpr(N == 1)
|
|
1135
1562
|
{
|
|
1136
1563
|
adj.data[0] += adj_ret;
|
|
1137
1564
|
}
|
|
1138
1565
|
|
|
1139
1566
|
// vector case
|
|
1140
|
-
if constexpr(
|
|
1567
|
+
if constexpr(N == 2)
|
|
1141
1568
|
{
|
|
1142
|
-
|
|
1143
|
-
|
|
1569
|
+
constexpr int Length = Tile::Layout::Shape::dim(0);
|
|
1570
|
+
for (int i=0; i < Length; ++i)
|
|
1571
|
+
adj.data[i] += adj_ret[i];
|
|
1144
1572
|
}
|
|
1145
1573
|
|
|
1146
1574
|
adj_tile.assign(adj);
|
|
1147
1575
|
}
|
|
1148
1576
|
|
|
1149
1577
|
// zero initialized tile
|
|
1150
|
-
template <typename T,
|
|
1578
|
+
template <typename T, unsigned... Shape>
|
|
1151
1579
|
inline CUDA_CALLABLE auto tile_zeros()
|
|
1152
1580
|
{
|
|
1153
1581
|
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
@@ -1155,7 +1583,7 @@ inline CUDA_CALLABLE auto tile_zeros()
|
|
|
1155
1583
|
}
|
|
1156
1584
|
|
|
1157
1585
|
// one-initialized tile
|
|
1158
|
-
template <typename T,
|
|
1586
|
+
template <typename T, unsigned... Shape>
|
|
1159
1587
|
inline CUDA_CALLABLE auto tile_ones()
|
|
1160
1588
|
{
|
|
1161
1589
|
// tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
|
|
@@ -1163,19 +1591,21 @@ inline CUDA_CALLABLE auto tile_ones()
|
|
|
1163
1591
|
}
|
|
1164
1592
|
|
|
1165
1593
|
// tile with evenly spaced values
|
|
1166
|
-
template <typename T, int
|
|
1594
|
+
template <typename T, int Len>
|
|
1167
1595
|
inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
|
|
1168
1596
|
{
|
|
1169
|
-
|
|
1597
|
+
auto out = tile_register<T, Len>();
|
|
1598
|
+
|
|
1599
|
+
using Layout = typename decltype(out)::Layout;
|
|
1170
1600
|
|
|
1171
1601
|
WP_PRAGMA_UNROLL
|
|
1172
|
-
for (int i=0; i <
|
|
1602
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1173
1603
|
{
|
|
1174
|
-
const int linear =
|
|
1604
|
+
const int linear = Layout::linear_from_register(i);
|
|
1175
1605
|
|
|
1176
1606
|
// handle case where tile size is not
|
|
1177
1607
|
// aligned to block dimensions
|
|
1178
|
-
if (!
|
|
1608
|
+
if (!Layout::valid(linear))
|
|
1179
1609
|
break;
|
|
1180
1610
|
|
|
1181
1611
|
out.data[i] = start + linear*step;
|
|
@@ -1188,191 +1618,106 @@ template <typename T, typename AdjTile>
|
|
|
1188
1618
|
inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
|
|
1189
1619
|
T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
|
|
1190
1620
|
|
|
1191
|
-
// entry point for
|
|
1192
|
-
template <typename
|
|
1193
|
-
inline CUDA_CALLABLE auto tile_load(array_t<T>& src,
|
|
1621
|
+
// entry point for load operations, these just return a reference to a global memory array + coordinate
|
|
1622
|
+
template <unsigned... Shape, typename... Indices, typename T>
|
|
1623
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
|
|
1194
1624
|
{
|
|
1195
|
-
return tile_global_t<T
|
|
1625
|
+
return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
|
|
1196
1626
|
}
|
|
1197
1627
|
|
|
1198
|
-
// entry point for
|
|
1199
|
-
template <typename
|
|
1200
|
-
inline CUDA_CALLABLE
|
|
1201
|
-
{
|
|
1202
|
-
|
|
1203
|
-
}
|
|
1628
|
+
// // entry point for tile store operations
|
|
1629
|
+
// template <typename... Indices, typename T, typename Tile>
|
|
1630
|
+
// inline CUDA_CALLABLE void tile_store(array_t<T>& dest, Tile& src, Indices... x)
|
|
1631
|
+
// {
|
|
1632
|
+
// src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x)));
|
|
1633
|
+
// }
|
|
1204
1634
|
|
|
1205
|
-
// entry point for
|
|
1635
|
+
// entry point for tile store operations
|
|
1206
1636
|
template <typename T, typename Tile>
|
|
1207
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src)
|
|
1208
|
-
{
|
|
1209
|
-
// dispatch to tile type
|
|
1210
|
-
src.copy_to_global(dest, x);
|
|
1211
|
-
}
|
|
1212
|
-
|
|
1213
|
-
// entry point for 2d store
|
|
1637
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1214
1638
|
template <typename T, typename Tile>
|
|
1215
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src)
|
|
1216
|
-
{
|
|
1217
|
-
// dispatch to tile type
|
|
1218
|
-
src.copy_to_global(dest, x, y);
|
|
1219
|
-
}
|
|
1220
|
-
|
|
1639
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
|
|
1221
1640
|
template <typename T, typename Tile>
|
|
1222
|
-
inline CUDA_CALLABLE
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
const int tile_i = x*src_reg.M;
|
|
1227
|
-
const int tile_j = y*src_reg.N;
|
|
1228
|
-
|
|
1229
|
-
tile_register_t<T, src_reg.M, src_reg.N> previous;
|
|
1230
|
-
|
|
1231
|
-
WP_PRAGMA_UNROLL
|
|
1232
|
-
for (int i=0; i < src_reg.NumRegs; ++i)
|
|
1233
|
-
{
|
|
1234
|
-
// handle case where tile size is not
|
|
1235
|
-
// aligned to block dimensions
|
|
1236
|
-
int linear = src_reg.index(i);
|
|
1237
|
-
if (!src_reg.Aligned && linear >= src_reg.Size)
|
|
1238
|
-
break;
|
|
1641
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
|
|
1642
|
+
template <typename T, typename Tile>
|
|
1643
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
|
|
1239
1644
|
|
|
1240
|
-
coord_t c = src_reg.coord(linear);
|
|
1241
|
-
previous.data[i] = atomic_add(dest, tile_i + c.i, tile_j + c.j, src_reg.data[i]);
|
|
1242
|
-
}
|
|
1243
1645
|
|
|
1244
|
-
return previous;
|
|
1245
|
-
}
|
|
1246
1646
|
|
|
1647
|
+
template <typename T, typename Tile>
|
|
1648
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1649
|
+
template <typename T, typename Tile>
|
|
1650
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y)));}
|
|
1651
|
+
template <typename T, typename Tile>
|
|
1652
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z)));}
|
|
1653
|
+
template <typename T, typename Tile>
|
|
1654
|
+
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w)));}
|
|
1247
1655
|
|
|
1248
1656
|
|
|
1249
1657
|
//-------------------------------------
|
|
1250
1658
|
// Adjoints
|
|
1251
1659
|
|
|
1252
|
-
template <typename T, typename AdjTile>
|
|
1253
|
-
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src,
|
|
1254
|
-
array_t<T>& adj_src,
|
|
1660
|
+
template <typename T, typename AdjTile, typename Coord>
|
|
1661
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
|
|
1662
|
+
array_t<T>& adj_src, Coord adj_c,
|
|
1255
1663
|
AdjTile& adj_ret)
|
|
1256
1664
|
{
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
//
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
const int tile_i = x*adj_reg.N;
|
|
1264
|
-
|
|
1265
|
-
// add gradients to src array
|
|
1266
|
-
WP_PRAGMA_UNROLL
|
|
1267
|
-
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1268
|
-
{
|
|
1269
|
-
int linear = adj_reg.index(i);
|
|
1270
|
-
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1271
|
-
break;
|
|
1272
|
-
|
|
1273
|
-
auto grad = adj_reg.data[i];
|
|
1665
|
+
tile_global_t<T, typename AdjTile::Layout::Shape> dest(src, c);
|
|
1666
|
+
|
|
1667
|
+
// we allow users to override grad of src
|
|
1668
|
+
if (adj_src.data)
|
|
1669
|
+
dest.data.grad = adj_src.data;
|
|
1274
1670
|
|
|
1275
|
-
|
|
1276
|
-
adj_atomic_add(&index(adj_src, tile_i + linear), grad);
|
|
1277
|
-
else if (src.grad)
|
|
1278
|
-
adj_atomic_add(&index_grad(src, tile_i + linear), grad);
|
|
1279
|
-
}
|
|
1671
|
+
adj_ret.atomic_add_grad(dest);
|
|
1280
1672
|
}
|
|
1281
1673
|
|
|
1282
|
-
template <typename T, typename AdjTile>
|
|
1283
|
-
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y,
|
|
1284
|
-
array_t<T>& adj_src, int adj_x, int adj_y,
|
|
1285
|
-
AdjTile& adj_ret)
|
|
1286
|
-
{
|
|
1287
|
-
// early out
|
|
1288
|
-
// if (!src.grad)
|
|
1289
|
-
// return;
|
|
1290
|
-
|
|
1291
|
-
auto adj_reg = adj_ret.grad_to_register();
|
|
1292
|
-
|
|
1293
|
-
const int tile_i = x*adj_reg.M;
|
|
1294
|
-
const int tile_j = y*adj_reg.N;
|
|
1295
|
-
|
|
1296
|
-
// add gradients to src array
|
|
1297
|
-
WP_PRAGMA_UNROLL
|
|
1298
|
-
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1299
|
-
{
|
|
1300
|
-
int linear = adj_reg.index(i);
|
|
1301
|
-
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1302
|
-
break;
|
|
1303
|
-
|
|
1304
|
-
coord_t coord = adj_reg.coord(linear);
|
|
1305
1674
|
|
|
1306
|
-
|
|
1675
|
+
template <typename T, typename AdjTile>
|
|
1676
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
|
|
1677
|
+
template <typename T, typename AdjTile>
|
|
1678
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, array_t<T>& adj_src, int adj_x, int adj_y, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y), adj_src, tile_coord(0,0), adj_ret); }
|
|
1679
|
+
template <typename T, typename AdjTile>
|
|
1680
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z), adj_src, tile_coord(0,0,0), adj_ret); }
|
|
1681
|
+
template <typename T, typename AdjTile>
|
|
1682
|
+
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
|
|
1307
1683
|
|
|
1308
|
-
if (adj_src.data)
|
|
1309
|
-
adj_atomic_add(&index(adj_src, tile_i + coord.i, tile_j + coord.j), grad);
|
|
1310
|
-
else if (src.grad)
|
|
1311
|
-
adj_atomic_add(&index_grad(src, tile_i + coord.i, tile_j + coord.j), grad);
|
|
1312
|
-
}
|
|
1313
|
-
}
|
|
1314
1684
|
|
|
1315
1685
|
|
|
1316
|
-
template <typename T, typename Tile, typename AdjTile>
|
|
1317
|
-
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest,
|
|
1686
|
+
template <typename T, typename Tile, typename AdjTile, typename Coord>
|
|
1687
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
|
|
1318
1688
|
{
|
|
1319
|
-
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
// load gradients from output
|
|
1325
|
-
WP_PRAGMA_UNROLL
|
|
1326
|
-
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1327
|
-
{
|
|
1328
|
-
int linear = adj_reg.index(i);
|
|
1329
|
-
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1330
|
-
break;
|
|
1689
|
+
tile_global_t<T, typename AdjTile::Layout::Shape> src(dest, c);
|
|
1690
|
+
|
|
1691
|
+
// we allow users to override grad of src
|
|
1692
|
+
if (adj_dest.data)
|
|
1693
|
+
src.data.grad = adj_dest.data;
|
|
1331
1694
|
|
|
1332
|
-
|
|
1333
|
-
|
|
1334
|
-
else if (dest.grad)
|
|
1335
|
-
adj_reg.data[i] = index_grad(dest, tile_i + linear);
|
|
1336
|
-
}
|
|
1695
|
+
if (src.data.grad == NULL)
|
|
1696
|
+
return;
|
|
1337
1697
|
|
|
1338
|
-
|
|
1339
|
-
adj_t.grad_add(adj_reg);
|
|
1698
|
+
adj_t.grad_add(src);
|
|
1340
1699
|
}
|
|
1341
1700
|
|
|
1342
1701
|
template <typename T, typename Tile, typename AdjTile>
|
|
1343
|
-
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x,
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
// load gradients from output
|
|
1352
|
-
WP_PRAGMA_UNROLL
|
|
1353
|
-
for (int i=0; i < adj_reg.NumRegs; ++i)
|
|
1354
|
-
{
|
|
1355
|
-
int linear = adj_reg.index(i);
|
|
1356
|
-
if (!adj_reg.Aligned && linear >= adj_reg.Size)
|
|
1357
|
-
break;
|
|
1358
|
-
|
|
1359
|
-
coord_t coord = adj_reg.coord(linear);
|
|
1702
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(0), adj_t); }
|
|
1703
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1704
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(0,0), adj_t); }
|
|
1705
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1706
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(0,0,0), adj_t); }
|
|
1707
|
+
template <typename T, typename Tile, typename AdjTile>
|
|
1708
|
+
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
|
|
1360
1709
|
|
|
1361
|
-
if (adj_dest.data)
|
|
1362
|
-
adj_reg.data[i] = index(adj_dest, tile_i + coord.i, tile_j + coord.j);
|
|
1363
|
-
else if (dest.grad)
|
|
1364
|
-
adj_reg.data[i] = index_grad(dest, tile_i + coord.i, tile_j + coord.j);
|
|
1365
|
-
}
|
|
1366
1710
|
|
|
1367
|
-
// store adjoint back to tile
|
|
1368
|
-
adj_t.grad_add(adj_reg);
|
|
1369
|
-
}
|
|
1370
1711
|
|
|
1712
|
+
// adj_tile_atomic_add is an alias for adj_tile_store
|
|
1371
1713
|
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1372
|
-
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x,
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1714
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(adj_x), adj_t); }
|
|
1715
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1716
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(adj_x, adj_y), adj_t); }
|
|
1717
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1718
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(adj_x, adj_y, adj_z), adj_t); }
|
|
1719
|
+
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1720
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
|
|
1376
1721
|
|
|
1377
1722
|
|
|
1378
1723
|
// unary map
|
|
@@ -1380,11 +1725,13 @@ template <typename Tile, typename Fwd>
|
|
|
1380
1725
|
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1381
1726
|
Tile &a)
|
|
1382
1727
|
{
|
|
1383
|
-
auto out =
|
|
1728
|
+
auto out = tile_register_like<Tile>();
|
|
1384
1729
|
auto a_reg = a.copy_to_register();
|
|
1730
|
+
|
|
1731
|
+
using Layout = typename decltype(out)::Layout;
|
|
1385
1732
|
|
|
1386
1733
|
WP_PRAGMA_UNROLL
|
|
1387
|
-
for (int i=0; i <
|
|
1734
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1388
1735
|
{
|
|
1389
1736
|
out.data[i] = op(a_reg.data[i]);
|
|
1390
1737
|
}
|
|
@@ -1404,8 +1751,10 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1404
1751
|
auto adj_a_reg = tile_register_like<Tile>();
|
|
1405
1752
|
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1406
1753
|
|
|
1754
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
1755
|
+
|
|
1407
1756
|
WP_PRAGMA_UNROLL
|
|
1408
|
-
for (int i=0; i <
|
|
1757
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1409
1758
|
{
|
|
1410
1759
|
adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
|
|
1411
1760
|
}
|
|
@@ -1420,14 +1769,18 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
|
1420
1769
|
TileA& a,
|
|
1421
1770
|
TileB& b)
|
|
1422
1771
|
{
|
|
1423
|
-
auto out =
|
|
1772
|
+
auto out = tile_register_like<TileA>();
|
|
1424
1773
|
|
|
1425
1774
|
auto a_reg = a.copy_to_register();
|
|
1426
1775
|
auto b_reg = b.copy_to_register();
|
|
1427
1776
|
|
|
1777
|
+
using Layout = typename decltype(out)::Layout;
|
|
1778
|
+
|
|
1428
1779
|
WP_PRAGMA_UNROLL
|
|
1429
|
-
for (int i=0; i <
|
|
1780
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1781
|
+
{
|
|
1430
1782
|
out.data[i] = op(a_reg.data[i], b_reg.data[i]);
|
|
1783
|
+
}
|
|
1431
1784
|
|
|
1432
1785
|
return out;
|
|
1433
1786
|
}
|
|
@@ -1451,8 +1804,10 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1451
1804
|
|
|
1452
1805
|
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1453
1806
|
|
|
1807
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
1808
|
+
|
|
1454
1809
|
WP_PRAGMA_UNROLL
|
|
1455
|
-
for (int i=0; i <
|
|
1810
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1456
1811
|
{
|
|
1457
1812
|
adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
|
|
1458
1813
|
}
|
|
@@ -1485,26 +1840,32 @@ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
|
1485
1840
|
return tile_binary_map(add, a, b);
|
|
1486
1841
|
}
|
|
1487
1842
|
|
|
1488
|
-
// // tile + tile, we implement this
|
|
1489
|
-
// template <typename TileA, typename TileB>
|
|
1490
|
-
// inline CUDA_CALLABLE auto add(TileA& a, TileB& b)
|
|
1491
|
-
// {
|
|
1492
|
-
// return tile_binary_map(add, a, b);
|
|
1493
|
-
// }
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
1843
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1497
1844
|
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1498
1845
|
{
|
|
1499
1846
|
adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
|
|
1500
1847
|
}
|
|
1501
1848
|
|
|
1849
|
+
// tile - tile
|
|
1850
|
+
template <typename TileA, typename TileB>
|
|
1851
|
+
inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
|
|
1852
|
+
{
|
|
1853
|
+
return tile_binary_map(sub, a, b);
|
|
1854
|
+
}
|
|
1855
|
+
|
|
1856
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1857
|
+
inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1858
|
+
{
|
|
1859
|
+
adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
|
|
1860
|
+
}
|
|
1861
|
+
|
|
1862
|
+
|
|
1502
1863
|
// tile*scalar
|
|
1503
1864
|
template <typename Tile>
|
|
1504
1865
|
inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
|
|
1505
1866
|
{
|
|
1506
1867
|
// promote scalar to a constant tile
|
|
1507
|
-
auto s_tile = tile_register_t<typename Tile::Type,
|
|
1868
|
+
auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
|
|
1508
1869
|
|
|
1509
1870
|
return tile_binary_map(mul, a, s_tile);
|
|
1510
1871
|
}
|
|
@@ -1514,12 +1875,17 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
|
1514
1875
|
Tile& adj_a, typename Tile::Type& adj_s,
|
|
1515
1876
|
AdjTile& adj_c)
|
|
1516
1877
|
{
|
|
1517
|
-
auto s_tile =
|
|
1518
|
-
auto adj_s_tile =
|
|
1878
|
+
auto s_tile = tile_register_like<Tile>();
|
|
1879
|
+
auto adj_s_tile = tile_register_like<Tile>();
|
|
1880
|
+
|
|
1881
|
+
using Layout = typename decltype(adj_s_tile)::Layout;
|
|
1882
|
+
|
|
1883
|
+
// initialize to constant
|
|
1884
|
+
s_tile = s;
|
|
1519
1885
|
|
|
1520
1886
|
adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
|
|
1521
1887
|
|
|
1522
|
-
for (int i=0; i <
|
|
1888
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
1523
1889
|
{
|
|
1524
1890
|
adj_s += adj_s_tile.data[i];
|
|
1525
1891
|
}
|
|
@@ -1530,10 +1896,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
|
1530
1896
|
template <typename Tile>
|
|
1531
1897
|
inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
|
|
1532
1898
|
{
|
|
1533
|
-
|
|
1534
|
-
auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
|
|
1535
|
-
|
|
1536
|
-
return tile_binary_map(mul, s_tile, a);
|
|
1899
|
+
return tile_mul(a, s);
|
|
1537
1900
|
}
|
|
1538
1901
|
|
|
1539
1902
|
template <typename Tile, typename AdjTile>
|
|
@@ -1541,36 +1904,30 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
|
1541
1904
|
typename Tile::Type& adj_s, Tile& adj_a,
|
|
1542
1905
|
AdjTile& adj_c)
|
|
1543
1906
|
{
|
|
1544
|
-
|
|
1545
|
-
auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
|
|
1546
|
-
|
|
1547
|
-
adj_tile_binary_map(mul, s_tile, a, adj_mul, adj_s_tile, adj_a, adj_c);
|
|
1548
|
-
|
|
1549
|
-
for (int i=0; i < adj_s_tile.NumRegs; ++i)
|
|
1550
|
-
{
|
|
1551
|
-
adj_s += adj_s_tile.data[i];
|
|
1552
|
-
}
|
|
1907
|
+
adj_tile_mul(a, s, adj_a, adj_s, adj_c);
|
|
1553
1908
|
}
|
|
1554
1909
|
|
|
1555
1910
|
|
|
1556
|
-
|
|
1557
1911
|
template<typename Tile>
|
|
1558
|
-
typename Tile::Type tile_extract(Tile& t, int i
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1912
|
+
typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
|
|
1913
|
+
template<typename Tile>
|
|
1914
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
|
|
1915
|
+
template<typename Tile>
|
|
1916
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
|
|
1917
|
+
template<typename Tile>
|
|
1918
|
+
typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
|
|
1562
1919
|
|
|
1563
|
-
return t.extract(i, j);
|
|
1564
|
-
}
|
|
1565
1920
|
|
|
1566
1921
|
template<typename Tile, typename AdjTile>
|
|
1567
|
-
void adj_tile_extract(Tile& t, int i,
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1922
|
+
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
|
|
1923
|
+
template<typename Tile, typename AdjTile>
|
|
1924
|
+
void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
|
|
1925
|
+
template<typename Tile, typename AdjTile>
|
|
1926
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
|
|
1927
|
+
template<typename Tile, typename AdjTile>
|
|
1928
|
+
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
|
|
1571
1929
|
|
|
1572
|
-
|
|
1573
|
-
}
|
|
1930
|
+
#if WP_USE_REGISTER_GEMM
|
|
1574
1931
|
|
|
1575
1932
|
namespace partitioned_gemm
|
|
1576
1933
|
{
|
|
@@ -1592,7 +1949,7 @@ struct partition_t
|
|
|
1592
1949
|
{
|
|
1593
1950
|
static constexpr int M = PartitionM;
|
|
1594
1951
|
static constexpr int N = PartitionN;
|
|
1595
|
-
static constexpr int Stride = Tile::
|
|
1952
|
+
static constexpr int Stride = Tile::Layout::Shape::dim(1);
|
|
1596
1953
|
|
|
1597
1954
|
using T = typename Tile::Type;
|
|
1598
1955
|
|
|
@@ -1601,8 +1958,8 @@ struct partition_t
|
|
|
1601
1958
|
data = A.data.ptr;
|
|
1602
1959
|
|
|
1603
1960
|
// todo: do ceil div for non-multiples of M,N
|
|
1604
|
-
shape[0] = Tile::
|
|
1605
|
-
shape[1] = Tile::
|
|
1961
|
+
shape[0] = Tile::Layout::Shape::dim(0)/PartitionM;
|
|
1962
|
+
shape[1] = Tile::Layout::Shape::dim(1)/PartitionN;
|
|
1606
1963
|
}
|
|
1607
1964
|
|
|
1608
1965
|
// underlying data
|
|
@@ -1640,7 +1997,7 @@ inline auto partition_load(const Partition& tile, int i, int j)
|
|
|
1640
1997
|
WP_PRAGMA_UNROLL
|
|
1641
1998
|
for (int j=0; j < Partition::N; ++j)
|
|
1642
1999
|
{
|
|
1643
|
-
out.data[i][j] = index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
|
|
2000
|
+
out.data[i][j] = partitioned_gemm::index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
|
|
1644
2001
|
}
|
|
1645
2002
|
}
|
|
1646
2003
|
|
|
@@ -1664,6 +2021,7 @@ inline void partition_store(const Partition& tile, int i, int j, const Value& va
|
|
|
1664
2021
|
}
|
|
1665
2022
|
}
|
|
1666
2023
|
|
|
2024
|
+
|
|
1667
2025
|
template <typename TileA, typename TileB, typename TileC>
|
|
1668
2026
|
inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
1669
2027
|
{
|
|
@@ -1700,15 +2058,26 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
|
1700
2058
|
|
|
1701
2059
|
} // namespace partition_gemm
|
|
1702
2060
|
|
|
2061
|
+
#endif // WP_USE_REGISTER_GEMM
|
|
2062
|
+
|
|
2063
|
+
|
|
1703
2064
|
template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
1704
2065
|
TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
|
|
1705
2066
|
{
|
|
1706
|
-
using
|
|
2067
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2068
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2069
|
+
using ShapeC = typename TileC::Layout::Shape;
|
|
1707
2070
|
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
2071
|
+
static_assert(ShapeA::N == 2);
|
|
2072
|
+
static_assert(ShapeB::N == 2);
|
|
2073
|
+
static_assert(ShapeC::N == 2);
|
|
2074
|
+
|
|
2075
|
+
static_assert(ShapeA::dim(1) == ShapeB::dim(0));
|
|
2076
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2077
|
+
static_assert(ShapeC::dim(1) == ShapeB::dim(1));
|
|
2078
|
+
|
|
2079
|
+
|
|
2080
|
+
using T = typename TileA::Type;
|
|
1712
2081
|
|
|
1713
2082
|
#if WP_USE_REGISTER_GEMM
|
|
1714
2083
|
partitioned_gemm::matmul(A, B, C);
|
|
@@ -1746,11 +2115,11 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
1746
2115
|
}
|
|
1747
2116
|
|
|
1748
2117
|
// TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
|
|
1749
|
-
//
|
|
2118
|
+
// and remove the need for __align__(16) dtypes data[...]
|
|
1750
2119
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
1751
2120
|
do { \
|
|
1752
2121
|
void function_name(dtype*, dtype*); \
|
|
1753
|
-
|
|
2122
|
+
char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
|
|
1754
2123
|
__align__(16) dtype data[ept]; \
|
|
1755
2124
|
for(int b = 0; b < (int)batch_size; b++) { \
|
|
1756
2125
|
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
@@ -1759,6 +2128,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
1759
2128
|
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
1760
2129
|
WP_TILE_SYNC(); \
|
|
1761
2130
|
} \
|
|
2131
|
+
wp::tile_alloc_shared(-shared_memory_size); \
|
|
1762
2132
|
} while (0)
|
|
1763
2133
|
|
|
1764
2134
|
#define tile_ifft tile_fft
|
|
@@ -1779,12 +2149,78 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
1779
2149
|
tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
1780
2150
|
} while (0)
|
|
1781
2151
|
|
|
2152
|
+
template <typename Fwd, typename TileA, typename TileL>
|
|
2153
|
+
TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
2154
|
+
{
|
|
2155
|
+
// Copy to L
|
|
2156
|
+
L = A;
|
|
2157
|
+
|
|
2158
|
+
// Call cholesky on L
|
|
2159
|
+
WP_TILE_SYNC();
|
|
2160
|
+
|
|
2161
|
+
fun_forward(L.data.ptr, TileL::Layout::Shape::dim(0));
|
|
2162
|
+
|
|
2163
|
+
WP_TILE_SYNC();
|
|
2164
|
+
|
|
2165
|
+
// Zero-out the upper triangular part of L
|
|
2166
|
+
|
|
2167
|
+
WP_PRAGMA_UNROLL
|
|
2168
|
+
for (int i=threadIdx.x; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
2169
|
+
{
|
|
2170
|
+
auto c = TileL::Layout::coord_from_linear(i);
|
|
2171
|
+
|
|
2172
|
+
if(c[0] < c[1])
|
|
2173
|
+
L.data(c) = 0.0;
|
|
2174
|
+
}
|
|
2175
|
+
|
|
2176
|
+
WP_TILE_SYNC();
|
|
2177
|
+
|
|
2178
|
+
return L;
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
#define adj_tile_cholesky(function_name, A, L, \
|
|
2182
|
+
adj_function_name, adj_A, adj_L, adj_ret) \
|
|
2183
|
+
do { \
|
|
2184
|
+
assert(false); \
|
|
2185
|
+
} while (0)
|
|
2186
|
+
|
|
2187
|
+
template <typename Fwd, typename TileL, typename TileX, typename TileY>
|
|
2188
|
+
TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
2189
|
+
{
|
|
2190
|
+
// Copy x to y
|
|
2191
|
+
|
|
2192
|
+
Y = X;
|
|
2193
|
+
|
|
2194
|
+
// Call cholesky solve on L & y
|
|
2195
|
+
|
|
2196
|
+
WP_TILE_SYNC();
|
|
2197
|
+
|
|
2198
|
+
fun_forward(L.data.ptr, Y.data.ptr); \
|
|
2199
|
+
|
|
2200
|
+
WP_TILE_SYNC();
|
|
2201
|
+
|
|
2202
|
+
return Y;
|
|
2203
|
+
}
|
|
2204
|
+
|
|
2205
|
+
#define adj_tile_cholesky_solve(function_name, L, X, Y, \
|
|
2206
|
+
adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \
|
|
2207
|
+
do { \
|
|
2208
|
+
assert(false); \
|
|
2209
|
+
} while (0)
|
|
1782
2210
|
|
|
1783
2211
|
template <typename Tile>
|
|
1784
2212
|
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
1785
2213
|
{
|
|
2214
|
+
static_assert(Tile::Layout::Shape::N == 2);
|
|
2215
|
+
|
|
1786
2216
|
// alias incoming tile
|
|
1787
|
-
|
|
2217
|
+
constexpr int M = Tile::Layout::Shape::dim(0);
|
|
2218
|
+
constexpr int N = Tile::Layout::Shape::dim(1);
|
|
2219
|
+
|
|
2220
|
+
constexpr int StrideM = Tile::Layout::Stride::dim(0);
|
|
2221
|
+
constexpr int StrideN = Tile::Layout::Stride::dim(1);
|
|
2222
|
+
|
|
2223
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N,M>, tile_stride_t<StrideN, StrideM>>, false>(t.data.ptr, t.grad.ptr);
|
|
1788
2224
|
}
|
|
1789
2225
|
|
|
1790
2226
|
template <typename Tile, typename AdjTile>
|
|
@@ -1800,55 +2236,144 @@ template <int M, int N, int StrideM, int StrideN, typename Tile>
|
|
|
1800
2236
|
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
1801
2237
|
{
|
|
1802
2238
|
// alias incoming tile with new strides
|
|
1803
|
-
return tile_shared_t<typename Tile::Type, M, N
|
|
2239
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
1804
2240
|
}
|
|
1805
2241
|
|
|
1806
2242
|
template <typename Tile, typename AdjTile>
|
|
1807
2243
|
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
1808
2244
|
{
|
|
1809
2245
|
// nop, since memory is aliased grads already accumulated
|
|
2246
|
+
}
|
|
2247
|
+
|
|
2248
|
+
template <typename ReturnType, typename Tile, typename... Indices>
|
|
2249
|
+
inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
2250
|
+
{
|
|
2251
|
+
auto c = tile_coord(indices...);
|
|
2252
|
+
|
|
2253
|
+
// return new tile with same strides
|
|
2254
|
+
typename Tile::Type* data_ptr = &t.data(c);
|
|
2255
|
+
typename Tile::Type* grad_ptr = NULL;
|
|
2256
|
+
|
|
2257
|
+
if (t.grad.ptr)
|
|
2258
|
+
grad_ptr = &t.grad(c);
|
|
1810
2259
|
|
|
2260
|
+
return ReturnType(data_ptr, grad_ptr);
|
|
1811
2261
|
}
|
|
1812
2262
|
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
2263
|
+
|
|
2264
|
+
template <typename TileA, typename Scalar>
|
|
2265
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
|
|
2266
|
+
{
|
|
2267
|
+
dest.data(tile_coord(i)) = src;
|
|
2268
|
+
WP_TILE_SYNC();
|
|
1818
2269
|
}
|
|
1819
2270
|
|
|
1820
|
-
template <typename
|
|
1821
|
-
inline CUDA_CALLABLE void
|
|
2271
|
+
template <typename TileA, typename Scalar>
|
|
2272
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
|
|
1822
2273
|
{
|
|
1823
|
-
|
|
2274
|
+
dest.data(tile_coord(i, j)) = src;
|
|
2275
|
+
WP_TILE_SYNC();
|
|
2276
|
+
}
|
|
1824
2277
|
|
|
2278
|
+
template <typename TileA, typename Scalar>
|
|
2279
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
|
|
2280
|
+
{
|
|
2281
|
+
dest.data(tile_coord(i, j, k)) = src;
|
|
2282
|
+
WP_TILE_SYNC();
|
|
1825
2283
|
}
|
|
1826
2284
|
|
|
1827
|
-
template <typename TileA, typename
|
|
1828
|
-
inline CUDA_CALLABLE void
|
|
2285
|
+
template <typename TileA, typename Scalar>
|
|
2286
|
+
inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
|
|
2287
|
+
{
|
|
2288
|
+
dest.data(tile_coord(i, j, k, l)) = src;
|
|
2289
|
+
WP_TILE_SYNC();
|
|
2290
|
+
}
|
|
2291
|
+
|
|
2292
|
+
|
|
2293
|
+
|
|
2294
|
+
|
|
2295
|
+
template <typename TileA, typename TileB, typename Coord>
|
|
2296
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offset)
|
|
1829
2297
|
{
|
|
1830
|
-
|
|
2298
|
+
using Layout = typename TileB::Layout;
|
|
2299
|
+
|
|
2300
|
+
for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
1831
2301
|
{
|
|
1832
|
-
|
|
1833
|
-
dest.data(
|
|
2302
|
+
auto c = Layout::coord_from_linear(t);
|
|
2303
|
+
dest.data(c + offset) = src.data(c);
|
|
1834
2304
|
}
|
|
1835
2305
|
|
|
1836
2306
|
WP_TILE_SYNC();
|
|
1837
2307
|
}
|
|
1838
2308
|
|
|
1839
|
-
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
1840
|
-
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest,
|
|
1841
|
-
AdjTileA& adj_dest,
|
|
2309
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename Coord, typename AdjCoord>
|
|
2310
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
|
|
2311
|
+
AdjTileA& adj_dest, AdjTileB& adj_src, AdjCoord adj_offset)
|
|
1842
2312
|
{
|
|
1843
|
-
|
|
2313
|
+
using Layout = typename TileB::Layout;
|
|
2314
|
+
|
|
2315
|
+
for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
1844
2316
|
{
|
|
1845
|
-
|
|
1846
|
-
src.grad(c
|
|
2317
|
+
auto c = Layout::coord_from_linear(t);
|
|
2318
|
+
src.grad(c) += dest.grad(c + offset);
|
|
1847
2319
|
}
|
|
1848
2320
|
|
|
1849
2321
|
WP_TILE_SYNC();
|
|
1850
2322
|
}
|
|
1851
2323
|
|
|
1852
2324
|
|
|
2325
|
+
// codegen entry points, which emit calls like `tile_assign(dest, src, i, j, k)`
|
|
2326
|
+
// a better approach here would be for codegen to just directly generate `tile_assign(dest, src, tile_coord(i, j, k))`
|
|
2327
|
+
// i.e.: call the above implementation methods directly, then we could remove these overloads
|
|
2328
|
+
template <typename TileA, typename TileB>
|
|
2329
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i) { tile_assign(dest, src, tile_coord(i)); }
|
|
2330
|
+
template <typename TileA, typename TileB>
|
|
2331
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j) { tile_assign(dest, src, tile_coord(i, j)); }
|
|
2332
|
+
template <typename TileA, typename TileB>
|
|
2333
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k) { tile_assign(dest, src, tile_coord(i, j, k)); }
|
|
2334
|
+
template <typename TileA, typename TileB>
|
|
2335
|
+
inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l) { tile_assign(dest, src, tile_coord(i, j, k, l)); }
|
|
2336
|
+
|
|
2337
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2338
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, AdjTileA& adj_dest, AdjTileB& adj_src, int) { adj_tile_assign(dest, src, tile_coord(i), adj_dest, adj_src, tile_coord(0)); }
|
|
2339
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2340
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, AdjTileA& adj_dest, AdjTileB& adj_src, int, int) { adj_tile_assign(dest, src, tile_coord(i,j), adj_dest, adj_src, tile_coord(0)); }
|
|
2341
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2342
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k), adj_dest, adj_src, tile_coord(0)); }
|
|
2343
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2344
|
+
inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k,l), adj_dest, adj_src, tile_coord(0)); }
|
|
2345
|
+
|
|
2346
|
+
|
|
2347
|
+
template <typename TileA, typename TileB, typename TileC>
|
|
2348
|
+
inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
|
|
2349
|
+
{
|
|
2350
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2351
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2352
|
+
using ShapeC = typename TileC::Layout::Shape;
|
|
2353
|
+
|
|
2354
|
+
static_assert(ShapeA::dim(0) == ShapeA::dim(1));
|
|
2355
|
+
static_assert(ShapeB::dim(0) == ShapeA::dim(0));
|
|
2356
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2357
|
+
static_assert(ShapeC::dim(0) == ShapeC::dim(1));
|
|
2358
|
+
|
|
2359
|
+
c = a;
|
|
2360
|
+
|
|
2361
|
+
for (int t=threadIdx.x; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
|
|
2362
|
+
{
|
|
2363
|
+
c.data(tile_coord(t, t)) += b.data(tile_coord(t));
|
|
2364
|
+
}
|
|
2365
|
+
|
|
2366
|
+
WP_TILE_SYNC();
|
|
2367
|
+
|
|
2368
|
+
return c;
|
|
2369
|
+
}
|
|
2370
|
+
|
|
2371
|
+
template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
|
|
2372
|
+
inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
|
|
2373
|
+
{
|
|
2374
|
+
assert(false);
|
|
2375
|
+
}
|
|
2376
|
+
|
|
1853
2377
|
|
|
1854
2378
|
} // namespace wp
|
|
2379
|
+
|