warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -35,10 +35,6 @@
35
35
  #endif
36
36
 
37
37
  #define WP_USE_ASYNC_PIPELINE 0
38
- #if WP_USE_ASYNC_PIPELINE
39
- #include "cuda_pipeline_primitives.h"
40
- #endif // WP_USE_ASYNC_PIPELINE
41
-
42
38
  #define WP_USE_REGISTER_GEMM 0
43
39
 
44
40
  /* Tile Expressions
@@ -171,50 +167,300 @@ struct is_same<T, T> {
171
167
  };
172
168
 
173
169
 
174
- template <typename Tile>
175
- constexpr int tile_size(Tile& t) { return Tile::M*Tile::N; }
170
+ template <int N>
171
+ struct tile_coord_t
172
+ {
173
+ int indices[N];
174
+
175
+ CUDA_CALLABLE inline int operator[](int i) const { assert(0 <= 1 && i < N); return indices[i]; }
176
+ CUDA_CALLABLE inline int& operator[](int i) { assert(0 <= 1 && i < N); return indices[i]; }
177
+
178
+ CUDA_CALLABLE inline tile_coord_t<N> operator + (const tile_coord_t<N>& c) const
179
+ {
180
+ tile_coord_t<N> out;
181
+ for (int i=0; i < N; ++i)
182
+ {
183
+ out.indices[i] = indices[i] + c.indices[i];
184
+ }
185
+ return out;
186
+ }
187
+ };
188
+
189
+ // This function deduces N = sizeof...(Ints)
190
+ template <typename... Ints>
191
+ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
192
+ {
193
+ constexpr int N = sizeof...(Ints);
194
+
195
+ // Create the result
196
+ tile_coord_t<N> result{};
197
+
198
+ // Capture all arguments in a local array
199
+ int arr[] = { static_cast<int>(idxs)... };
200
+
201
+ // C++14 or later: 'for' is allowed in a constexpr context
202
+ for (int i = 0; i < N; ++i)
203
+ {
204
+ result.indices[i] = arr[i];
205
+ }
176
206
 
177
- constexpr int tile_regcount(int m, int n) {
178
- return (m*n + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
207
+ return result;
208
+ }
209
+
210
+ // helpers to construct a coord from a set of indices
211
+ auto tile_coord(int i)
212
+ {
213
+ auto c = tile_coord_t<1>();
214
+ c.indices[0] = i;
215
+ return c;
216
+ }
217
+
218
+ auto tile_coord(int i, int j)
219
+ {
220
+ auto c = tile_coord_t<2>();
221
+ c.indices[0] = i;
222
+ c.indices[1] = j;
223
+ return c;
179
224
  }
180
225
 
181
- struct coord_t
226
+ auto tile_coord(int i, int j, int k)
227
+ {
228
+ auto c = tile_coord_t<3>();
229
+ c.indices[0] = i;
230
+ c.indices[1] = j;
231
+ c.indices[2] = k;
232
+ return c;
233
+ }
234
+
235
+ auto tile_coord(int i, int j, int k, int l)
182
236
  {
183
- int i;
184
- int j;
237
+ auto c = tile_coord_t<4>();
238
+ c.indices[0] = i;
239
+ c.indices[1] = j;
240
+ c.indices[2] = k;
241
+ c.indices[3] = l;
242
+ return c;
243
+ }
244
+
245
+ // represents a compile time int tuple for strides/shapes/coords
246
+ template <int... V>
247
+ struct tile_tuple_t
248
+ {
249
+ static constexpr int N = sizeof...(V);
250
+ static_assert(N > 0);
251
+
252
+ static constexpr int data[N] = { V... };
253
+
254
+ static constexpr int dim(int i) { assert(i < N); return data[i]; }
255
+ static constexpr int size()
256
+ {
257
+ int res = data[0];
258
+ for (int i=1; i < N; ++i)
259
+ res *= data[i];
260
+
261
+ return res;
262
+ }
185
263
  };
186
264
 
265
+ // simple helper to compute strides from a shape up to 4d
266
+ template <typename Shape>
267
+ struct compute_strides;
268
+
269
+ // 1D
270
+ template <int D0>
271
+ struct compute_strides< tile_tuple_t<D0> > { using Stride = tile_tuple_t<1>; };
272
+ // 2D
273
+ template <int D0, int D1>
274
+ struct compute_strides< tile_tuple_t<D0, D1> > { using Stride = tile_tuple_t<D1, 1>; };
275
+ // 3D
276
+ template <int D0, int D1, int D2>
277
+ struct compute_strides< tile_tuple_t<D0, D1, D2> > { using Stride = tile_tuple_t<(D1 * D2), D2, 1>; };
278
+ // 4D
279
+ template <int D0, int D1, int D2, int D3>
280
+ struct compute_strides< tile_tuple_t<D0, D1, D2, D3> > { using Stride = tile_tuple_t<(D1 * D2 * D3), (D2 * D3), D3, 1>; };
281
+
282
+
283
+ // alias of tuple to represent shapes
284
+ template <int... V>
285
+ using tile_shape_t = tile_tuple_t<V...>;
286
+
287
+ // alias of tuple to represent stride
288
+ template <int... V>
289
+ using tile_stride_t = tile_tuple_t<V...>;
290
+
187
291
 
188
292
  // represents a tile stored in global memory with dynamic strides
189
- // only used to represent the source for tile loads to register/shared
190
- template <typename T>
191
- struct tile_global_t
293
+ // used to represent the source and offset for tile loads to register/shared
294
+ template <typename T, typename Shape_>
295
+ struct tile_global_t
192
296
  {
193
297
  using Type = T;
298
+ using Shape = Shape_;
299
+ using Coord = tile_coord_t<Shape::N>;
194
300
 
195
301
  array_t<T> data;
196
- int x;
197
- int y;
302
+ Coord offset;
303
+
304
+ tile_global_t(array_t<T>& a, const Coord& c) : data(a), offset(c)
305
+ {
306
+ }
307
+
308
+ inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const
309
+ {
310
+ // element index
311
+ int index = 0;
312
+
313
+ WP_PRAGMA_UNROLL
314
+ for (int i=0; i < Shape::N; ++i)
315
+ {
316
+ // global = offset + coord
317
+ int c = offset[i] + coord[i];
318
+ index += data.strides[i]*c;
319
+ }
320
+
321
+ return index/sizeof(T);
322
+ }
323
+
324
+ inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
325
+ {
326
+ // element index
327
+ int index = 0;
328
+
329
+ WP_PRAGMA_UNROLL
330
+ for (int i=0; i < Shape::N; ++i)
331
+ {
332
+ // global = offset + coord
333
+ int c = offset[i] + coord[i];
334
+
335
+ // handle out of bounds case
336
+ if (c >= data.shape[i])
337
+ return false;
338
+ else
339
+ index += data.strides[i]*c;
340
+ }
341
+
342
+ // array strides are in bytes so we convert to elements
343
+ out = index / sizeof(T);
344
+ return true;
345
+ }
346
+
347
+ inline CUDA_CALLABLE T load(const Coord& coord) const
348
+ {
349
+ int i;
350
+ if (index(coord, i))
351
+ return data.data[i];
352
+ else
353
+ return T(0);
354
+ }
198
355
 
199
- tile_global_t(array_t<T>& a, int x, int y) : data(a), x(x), y(y)
356
+ inline CUDA_CALLABLE T load_grad(const Coord& coord) const
200
357
  {
358
+ int i;
359
+ if (index(coord, i))
360
+ return data.grad[i];
361
+ else
362
+ return T(0);
363
+ }
364
+
365
+ inline CUDA_CALLABLE void store(const Coord& coord, const T& x) const
366
+ {
367
+ int i;
368
+ if (index(coord, i))
369
+ data.data[i] = x;
370
+ }
371
+
372
+ inline CUDA_CALLABLE T atomic_add(const Coord& coord, const T& value) const
373
+ {
374
+ int i;
375
+ if (index(coord, i))
376
+ return wp::atomic_add(&data.data[i], value);
377
+ else
378
+ return T(0);
379
+ }
380
+
381
+ inline CUDA_CALLABLE T atomic_add_grad(const Coord& coord, const T& grad) const
382
+ {
383
+ int i;
384
+ if (index(coord, i))
385
+ return wp::atomic_add(&data.grad[i], grad);
386
+ else
387
+ return T(0);
201
388
  }
202
389
  };
203
390
 
391
+ template <typename Shape_>
392
+ struct tile_layout_register_t
393
+ {
394
+ using Shape = Shape_;
395
+ using Coord = tile_coord_t<Shape::N>;
396
+
397
+ static constexpr int Size = Shape::size();
398
+ static constexpr int NumRegs = (Size + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
399
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
400
+
401
+ static inline CUDA_CALLABLE int linear_from_register(int reg)
402
+ {
403
+ return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
404
+ }
405
+
406
+ static inline CUDA_CALLABLE int linear_from_coord(Coord c)
407
+ {
408
+ int linear = 0;
409
+ int stride = 1;
410
+
411
+ WP_PRAGMA_UNROLL
412
+ for (int i=Shape::N-1; i >= 0; --i)
413
+ {
414
+ linear += c[i] * stride;
415
+ stride *= Shape::dim(i);
416
+ }
417
+ return linear;
418
+ }
419
+
420
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
421
+ {
422
+ Coord c;
423
+
424
+ WP_PRAGMA_UNROLL
425
+ for (int i=Shape::N-1; i >= 0; --i)
426
+ {
427
+ c[i] = linear%Shape::dim(i);
428
+ linear /= Shape::dim(i);
429
+ }
430
+
431
+ return c;
432
+ }
433
+
434
+ static inline CUDA_CALLABLE int thread_from_linear(int linear)
435
+ {
436
+ const int thread = linear%WP_TILE_BLOCK_DIM;
437
+ return thread;
438
+ }
439
+
440
+ static inline CUDA_CALLABLE int register_from_linear(int linear)
441
+ {
442
+ const int reg = linear/WP_TILE_BLOCK_DIM;
443
+ return reg;
444
+ }
445
+
446
+ static inline CUDA_CALLABLE bool valid(int linear)
447
+ {
448
+ if (Aligned || linear < Size)
449
+ return true;
450
+ else
451
+ return false;
452
+ }
453
+
454
+ };
455
+
204
456
  // represents a tile stored in registers across a block
205
- template <typename T, int M_, int N_>
457
+ template <typename T, typename L>
206
458
  struct tile_register_t
207
459
  {
208
460
  using Type = T;
209
- static constexpr int M = M_;
210
- static constexpr int N = N_;
211
- static constexpr int Size = M*N;
461
+ using Layout = L;
212
462
 
213
- static constexpr int NumRegs = tile_regcount(M, N);
214
-
215
- static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
216
-
217
- T data[NumRegs];
463
+ T data[Layout::NumRegs];
218
464
 
219
465
  inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
220
466
  {
@@ -224,52 +470,34 @@ struct tile_register_t
224
470
  // in backwards pass and letting default constructor
225
471
  // avoid initialization
226
472
 
227
- for (int i=0; i < NumRegs; ++i)
473
+ for (int i=0; i < Layout::NumRegs; ++i)
228
474
  data[i] = value;
229
475
  }
230
476
 
231
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
477
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
232
478
  {
233
- if (t.data.ndim == 1)
234
- copy_from_global(t.data, t.x); // 1d load
235
- else
236
- copy_from_global(t.data, t.x, t.y); // 2d load
237
-
479
+ copy_from_global(t);
238
480
  return *this;
239
-
240
481
  }
241
482
 
242
483
  // define the += operator which is used during backward pass codegen
243
484
  // when returning a register tile from a user defined function
244
- inline CUDA_CALLABLE auto& operator += (tile_register_t<T, M, N>& rhs)
485
+ inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
245
486
  {
246
- this->grad_add(rhs);
487
+ grad_add(rhs);
247
488
  return *this;
248
489
  }
249
490
 
250
- inline CUDA_CALLABLE T& operator()(int index)
491
+ inline CUDA_CALLABLE T& operator()(int reg)
251
492
  {
252
- assert(index < NumRegs);
253
- return data[index];
493
+ assert(reg < Layout::NumRegs);
494
+ return data[reg];
254
495
  }
255
496
 
256
- inline CUDA_CALLABLE const T& operator()(int index) const
497
+ inline CUDA_CALLABLE const T& operator()(int reg) const
257
498
  {
258
- assert(index < NumRegs);
259
- return data[index];
260
- }
261
-
262
-
263
- // compute linear tile index from a local register index
264
- inline CUDA_CALLABLE int index(int reg) const
265
- {
266
- return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
267
- }
268
-
269
- // compute tile coordinate from linear index
270
- inline CUDA_CALLABLE coord_t coord(int index) const
271
- {
272
- return {index/N, index%N};
499
+ assert(reg < Layout::NumRegs);
500
+ return data[reg];
273
501
  }
274
502
 
275
503
  // Returns the number of valid registers for this tile
@@ -278,29 +506,29 @@ struct tile_register_t
278
506
  // some of the trailing registers may lie outside the valid range
279
507
  inline CUDA_CALLABLE int valid() const
280
508
  {
281
- return (Size - threadIdx.x)/WP_TILE_BLOCK_DIM;
509
+ return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
282
510
  }
283
511
 
284
- inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
512
+ inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
285
513
  {
286
- for (int i=0; i < NumRegs; ++i)
514
+ for (int i=0; i < Layout::NumRegs; ++i)
287
515
  data[i] = tile.data[i];
288
516
  }
289
517
 
290
518
  inline CUDA_CALLABLE void zero()
291
519
  {
292
- for (int i=0; i < NumRegs; ++i)
293
- data[i] = T(0);
520
+ for (int i=0; i < Layout::NumRegs; ++i)
521
+ data[i] = T(0);
294
522
  }
295
523
 
296
524
  // extract a single tile element to a native type
297
- inline CUDA_CALLABLE Type extract(int i, int j)
525
+ template <typename Coord>
526
+ inline CUDA_CALLABLE Type extract(const Coord& c)
298
527
  {
299
528
  // map from logical coords (i, j) -> (thread, reg)
300
- const int linear = i*N + j;
301
-
302
- const int thread = linear/NumRegs;
303
- const int reg = linear%NumRegs;
529
+ const int linear = Layout::linear_from_coord(c);
530
+ const int thread = Layout::thread_from_linear(linear);
531
+ const int reg = Layout::register_from_linear(linear);
304
532
 
305
533
  WP_TILE_SHARED Type scratch;
306
534
 
@@ -320,13 +548,13 @@ struct tile_register_t
320
548
 
321
549
 
322
550
  // backward version of scalar extract
323
- inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
551
+ template <typename Coord>
552
+ inline CUDA_CALLABLE void adj_extract(const Coord& c, Type adj_ret)
324
553
  {
325
554
  // map from logical coords (i, j) -> (thread, reg)
326
- const int linear = i*N + j;
327
-
328
- const int thread = linear/NumRegs;
329
- const int reg = linear%NumRegs;
555
+ const int linear = Layout::linear_from_coord(c);
556
+ const int thread = Layout::thread_from_linear(linear);
557
+ const int reg = Layout::register_from_linear(linear);
330
558
 
331
559
  if (threadIdx.x == thread)
332
560
  {
@@ -348,6 +576,24 @@ struct tile_register_t
348
576
  return *this;
349
577
  }
350
578
 
579
+ // apply a lambda to all valid entries in the tile
580
+ // Op should be a functor that takes a register index and tile_coord_t as input
581
+ template <typename Op>
582
+ void apply(Op op)
583
+ {
584
+ WP_PRAGMA_UNROLL
585
+ for (int i=0; i < Layout::NumRegs; ++i)
586
+ {
587
+ int linear = Layout::linear_from_register(i);
588
+ if (!Layout::valid(linear))
589
+ break;
590
+
591
+ auto c = Layout::coord_from_linear(linear);
592
+ op(i, c);
593
+ }
594
+ }
595
+
596
+
351
597
  // in-place gradient zero
352
598
  inline CUDA_CALLABLE void grad_zero()
353
599
  {
@@ -355,118 +601,77 @@ struct tile_register_t
355
601
  }
356
602
 
357
603
  // accumulate gradients onto this tile
358
- inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
604
+ inline CUDA_CALLABLE void grad_add(const tile_register_t<T, Layout>& tile)
359
605
  {
360
- for (int i=0; i < NumRegs; ++i)
606
+ for (int i=0; i < Layout::NumRegs; ++i)
361
607
  data[i] += tile.data[i];
362
608
  }
363
609
 
364
- // copy shared tile to register
610
+ CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
611
+ {
612
+ apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
613
+
614
+ }
615
+
365
616
  inline CUDA_CALLABLE auto& grad_to_register()
366
617
  {
618
+ // nop for register tiles
367
619
  return *this;
368
620
  }
369
621
 
370
- void copy_to_global(array_t<T> dest, int x)
622
+ template <typename Global>
623
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
371
624
  {
372
- assert(dest.ndim == 1);
373
-
374
- const int tile_i = x*N;
375
-
376
- WP_PRAGMA_UNROLL
377
- for (int i=0; i < NumRegs; ++i)
378
- {
379
- // handle case where tile size is not
380
- // aligned to block dimensions
381
- int linear = index(i);
382
- if (!Aligned && linear >= Size)
383
- break;
384
-
385
- wp::index(dest, tile_i + linear) = data[i];
386
- }
625
+ apply([&](int reg, auto c) { dest.store(c, data[reg]); });
387
626
  }
388
627
 
389
- void copy_to_global(array_t<T> dest, int x, int y)
628
+ template <typename Global>
629
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
390
630
  {
391
- assert(dest.ndim == 2);
392
-
393
- const int tile_i = x*M;
394
- const int tile_j = y*N;
395
-
396
- // wp.array() indexing generates poor code due to char* casting
397
- // here we unroll some of the ops, note this assumes byte strides are
398
- // aligned to the element size
399
- T* ptr = &wp::index(dest, tile_i, tile_j);
400
- const int stride_i = dest.strides[0]/sizeof(T);
401
- const int stride_j = dest.strides[1]/sizeof(T);
402
-
403
- WP_PRAGMA_UNROLL
404
- for (int i=0; i < NumRegs; ++i)
405
- {
406
- // handle case where tile size is not
407
- // aligned to block dimensions
408
- int linear = index(i);
409
- if (!Aligned && linear >= Size)
410
- break;
411
-
412
- coord_t c = coord(linear);
413
- ptr[c.i*stride_i + c.j*stride_j] = data[i];
414
- }
631
+ apply([&](int reg, auto c) { data[reg] = src.load(c); });
415
632
  }
416
633
 
417
- inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
634
+ // add a register tile to a global array
635
+ template <typename Global>
636
+ inline CUDA_CALLABLE auto atomic_add(const Global& dest)
418
637
  {
419
- // todo: use async pipelines or TMA here
420
- const int tile_i = x*N;
421
-
422
- WP_PRAGMA_UNROLL
423
- for (int i=0; i < NumRegs; ++i)
424
- {
425
- int linear = index(i);
426
- if (!Aligned && linear >= Size)
427
- break;
638
+ // allocate a tile to hold previous dest value
639
+ auto previous = *this;
428
640
 
429
- data[i] = wp::index(src, tile_i + linear);
430
- }
641
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add(c, data[reg]); });
642
+ return previous;
431
643
  }
432
644
 
433
- inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
645
+ // add a register tile to the gradient of a global array
646
+ template <typename Global>
647
+ inline CUDA_CALLABLE auto atomic_add_grad(const Global& dest)
434
648
  {
435
- // todo: use async pipelines or TMA here
436
- const int tile_i = x*M;
437
- const int tile_j = y*N;
438
-
439
- // wp.array() indexing generates poor code due to char* casting
440
- // here we unroll some of the ops, note this assumes array byte strides are
441
- // aligned to the element size
442
- const T* ptr = &wp::index(src, tile_i, tile_j);
443
-
444
- assert(src.strides[0]%sizeof(T) == 0);
445
- assert(src.strides[1]%sizeof(T) == 0);
446
-
447
- const int stride_i = src.strides[0]/sizeof(T);
448
- const int stride_j = src.strides[1]/sizeof(T);
649
+ // allocate a tile to hold previous dest value
650
+ auto previous = *this;
449
651
 
450
- WP_PRAGMA_UNROLL
451
- for (int i=0; i < NumRegs; ++i)
452
- {
453
- int linear = index(i);
454
- if (!Aligned && linear >= Size)
455
- break;
456
-
457
- coord_t c = coord(linear);
458
- data[i] = ptr[c.i*stride_i + c.j*stride_j];
459
- }
460
- }
652
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add_grad(c, data[reg]); });
653
+ return previous;
654
+ }
461
655
  };
462
656
 
657
+
463
658
  // helper to allocate a register tile like another tile
659
+ // users can either specify a template explicitly or
660
+ // pass in another concrete instance
464
661
  template<typename Tile>
465
- auto tile_register_like()
662
+ auto tile_register_like(Tile* t=NULL)
466
663
  {
467
664
  using T = typename Tile::Type;
665
+ using L = typename Tile::Layout;
666
+
667
+ return tile_register_t<T, tile_layout_register_t<typename L::Shape>>(T(0.0));
668
+ }
468
669
 
469
- return tile_register_t<T, Tile::M, Tile::N>(T(0.0));
670
+ // helper to construct a register tile from a type and a list of dims
671
+ template <typename T, int... Dims>
672
+ auto tile_register()
673
+ {
674
+ return tile_register_t<T, tile_layout_register_t<tile_shape_t<Dims...>>>();
470
675
  }
471
676
 
472
677
  inline CUDA_CALLABLE int tile_align(int num_bytes)
@@ -474,7 +679,10 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
474
679
  // note this much match value in Python types.py
475
680
  const int alignment = 16;
476
681
 
477
- return ((num_bytes + alignment - 1) / alignment) * alignment;
682
+ const int num_bytes_abs = num_bytes < 0 ? - num_bytes : num_bytes;
683
+ const int sign = num_bytes < 0 ? - 1 : 1;
684
+
685
+ return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
478
686
  }
479
687
 
480
688
  inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
@@ -502,20 +710,78 @@ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
502
710
  }
503
711
 
504
712
 
505
-
506
- template <typename T, int M_, int N_, int StrideM_=N_, int StrideN_=1, bool Owner_=true>
507
- struct tile_shared_t
713
+ template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
714
+ struct tile_layout_strided_t
508
715
  {
509
- using Type = T;
510
- static constexpr int M = M_;
511
- static constexpr int N = N_;
512
- static constexpr int Size = M*N;
716
+ using Shape = Shape_;
717
+ using Stride = Stride_;
718
+ using Coord = tile_coord_t<Shape::N>;
513
719
 
514
- static constexpr int StrideM = StrideM_;
515
- static constexpr int StrideN = StrideN_;
516
-
720
+ static constexpr int Size = Shape::size();
517
721
  static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
518
- static constexpr bool Unique = (StrideM >= N) && (StrideN >= 1);
722
+
723
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
724
+ {
725
+ assert(linear < Size);
726
+
727
+ Coord c;
728
+
729
+ WP_PRAGMA_UNROLL
730
+ for (int d=Shape::N-1; d >= 0; --d)
731
+ {
732
+ c[d] = linear%Shape::dim(d);
733
+ linear /= Shape::dim(d);
734
+ }
735
+
736
+ return c;
737
+ }
738
+
739
+ static inline CUDA_CALLABLE int index_from_coord(Coord c)
740
+ {
741
+ int index = 0;
742
+
743
+ WP_PRAGMA_UNROLL
744
+ for (int d=0; d < Shape::N; ++d)
745
+ {
746
+ assert(c[d] < Shape::dim(d));
747
+
748
+ index += c[d]*Stride::dim(d);
749
+ }
750
+
751
+ return index;
752
+ }
753
+
754
+ // checks whether a strided layout is unique, i.e.: if memory locations are only
755
+ // every referred to by one element in the tile, this is a basic test that only
756
+ // checks for broadcast dimensions, it would be possible to do the full check
757
+ // using sorted shape/strides in Python and add it as a template parameter to the type
758
+ static constexpr bool is_unique()
759
+ {
760
+ constexpr int N = Shape::N;
761
+
762
+ // check for any broadcast dimensions
763
+ for (int i=0; i < N; ++i)
764
+ if (Stride::dim(i) == 0)
765
+ return false;
766
+
767
+ return true;
768
+ }
769
+
770
+ static constexpr bool Unique = is_unique();
771
+
772
+ static inline CUDA_CALLABLE bool valid(int linear)
773
+ {
774
+ return linear < Size;
775
+ }
776
+
777
+ };
778
+
779
+
780
+ template <typename T, typename L, bool Owner_=true>
781
+ struct tile_shared_t
782
+ {
783
+ using Type = T;
784
+ using Layout = L;
519
785
  static constexpr bool Owner = Owner_;
520
786
 
521
787
  struct Storage
@@ -524,55 +790,60 @@ struct tile_shared_t
524
790
 
525
791
  Storage(T* p) : ptr(p) {}
526
792
 
527
- inline CUDA_CALLABLE T& operator()(int i, int j)
793
+ inline CUDA_CALLABLE T& operator()(typename Layout::Coord c)
528
794
  {
529
- assert(i < M);
530
- assert(j < N);
795
+ assert(ptr);
531
796
 
532
- return ptr[i*StrideM + j*StrideN];
797
+ int index = Layout::index_from_coord(c);
798
+ return ptr[index];
533
799
  }
534
800
 
535
- inline CUDA_CALLABLE const T& operator()(int i, int j) const
536
- {
537
- assert(i < M);
538
- assert(j < N);
801
+ inline CUDA_CALLABLE const T& operator()(typename Layout::Coord c) const
802
+ {
803
+ assert(ptr);
539
804
 
540
- return ptr[i*StrideM + j*StrideN];
805
+ int index = Layout::index_from_coord(c);
806
+ return ptr[index];
541
807
  }
542
808
 
543
- inline CUDA_CALLABLE T& operator()(int index)
809
+ inline CUDA_CALLABLE T& operator()(int linear)
544
810
  {
545
- assert(index < M*N);
546
-
547
- // unravel
548
- int i = index/N;
549
- int j = index%N;
811
+ assert(ptr);
812
+ assert(Layout::valid(linear));
550
813
 
551
- return (*this)(i,j);
814
+ auto c = Layout::coord_from_linear(linear);
815
+ return (*this)(c);
552
816
  }
553
817
 
554
- inline CUDA_CALLABLE const T& operator()(int index) const
818
+ inline CUDA_CALLABLE const T& operator()(int linear) const
555
819
  {
556
- assert(index < M*N);
557
-
558
- // unravel
559
- int i = index/N;
560
- int j = index%N;
820
+ assert(ptr);
821
+ assert(Layout::valid(linear));
561
822
 
562
- return (*this)(i,j);
823
+ auto c = Layout::coord_from_linear(linear);
824
+ return (*this)(c);
563
825
  }
564
826
  };
565
827
 
566
828
  Storage data;
567
829
  Storage grad;
568
830
 
831
+ // we need to track whether or not this tile's data has been initialized.
832
+ // once true, any re-initialization of data that follows needs a WP_TILE_SYNC()
833
+ // call to precede it, to allow threads that are still reading from this tile
834
+ // to complete their work. e.g, in a dynamic loop:
835
+ // for i in range(x):
836
+ // tile = wp.tile_load(arr, i, TILE_SIZE, storage="shared")
837
+ // # read from tile...
838
+ bool initialized;
839
+
569
840
  // default initialization (non-initialized)
570
- inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL)
841
+ inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL), initialized(false)
571
842
  {
572
843
  }
573
844
 
574
845
  // initialize from an existing tile's memory
575
- inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL) : data(data), grad(grad)
846
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL, bool initialized=true) : data(data), grad(grad), initialized(initialized)
576
847
  {
577
848
  }
578
849
 
@@ -582,10 +853,10 @@ struct tile_shared_t
582
853
  {
583
854
  // update our per-thread shared memory allocator
584
855
  if (data.ptr)
585
- tile_alloc_shared(-M*N*int(sizeof(T)));
856
+ tile_alloc_shared(-Layout::Size*int(sizeof(T)));
586
857
 
587
858
  if (grad.ptr)
588
- tile_alloc_shared(-M*N*int(sizeof(T)));
859
+ tile_alloc_shared(-Layout::Size*int(sizeof(T)));
589
860
  }
590
861
  }
591
862
 
@@ -597,12 +868,13 @@ struct tile_shared_t
597
868
  return *this;
598
869
  }
599
870
 
871
+
600
872
  // construct from another shared tile, this constructor
601
873
  // is invoked for reshape operations like `wp.tile_transpose()`
602
- template <typename OtherT, int OtherM, int OtherN, int OtherStrideM, int OtherStrideN>
603
- inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>& rhs)
874
+ template <typename OtherT, typename OtherLayout>
875
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
604
876
  {
605
- using OtherTile = tile_shared_t<OtherT, OtherM, OtherN, OtherStrideM, OtherStrideN>;
877
+ using OtherTile = tile_shared_t<OtherT, OtherLayout>;
606
878
 
607
879
  // check dimensions are compatible
608
880
  static_assert(Size == OtherTile::Size);
@@ -610,89 +882,89 @@ struct tile_shared_t
610
882
  // alias tile directly
611
883
  data = rhs.data;
612
884
  grad = rhs.grad;
885
+ initialized = rhs.initialized;
613
886
 
614
887
  return *this;
615
888
  }
616
889
 
617
890
  // assign from a global tile (load)
618
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T>& t)
891
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
619
892
  {
620
- if (t.data.ndim == 1)
621
- copy_from_global(t.data, t.x); // 1d load
622
- else
623
- copy_from_global(t.data, t.x, t.y); // 2d load
624
-
625
- // synchronization happens in copy functions above
626
-
893
+ copy_from_global(t);
627
894
  return *this;
628
895
  }
629
896
 
630
897
  // assign from a constant value
631
898
  inline CUDA_CALLABLE auto& operator=(const T& x)
632
899
  {
633
- for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
900
+ // sync if we are re-initializing data so that any threads that are still
901
+ // reading from this tile can complete their work, e.g.: if re-assigning
902
+ // to a tile during a dynamic loop
903
+ if (initialized)
904
+ WP_TILE_SYNC();
905
+
906
+ for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
634
907
  data(i) = x;
635
908
 
909
+ initialized = true;
636
910
  WP_TILE_SYNC();
637
911
  return *this;
638
912
  }
639
913
 
640
-
641
- // compute tile coordinate from linear index
642
- inline CUDA_CALLABLE coord_t coord(int index) const
643
- {
644
- return {index/N, index%N};
645
- }
646
-
647
914
  // in-place zero
648
915
  inline CUDA_CALLABLE void zero()
649
916
  {
650
- for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
917
+ for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
651
918
  data(i) = T(0);
652
919
 
653
920
  WP_TILE_SYNC();
654
921
  }
655
922
 
656
923
  // extract a single tile element to a native type
657
- inline CUDA_CALLABLE Type extract(int i, int j)
924
+ inline CUDA_CALLABLE Type extract(const typename Layout::Coord& c)
658
925
  {
659
- return data(i, j);
926
+ return data(c);
660
927
  }
661
928
 
662
929
  // backward of scalar extraction
663
- inline CUDA_CALLABLE void adj_extract(int i, int j, Type adj_ret)
930
+ inline CUDA_CALLABLE void adj_extract(const typename Layout::Coord& c, Type adj_ret)
664
931
  {
665
- if (threadIdx.x == 0)
666
- data(i, j) += adj_ret;
667
-
668
- WP_TILE_SYNC();
932
+ // since multiple threads may extract the same element
933
+ // we need to accumulate using atomic operations
934
+ wp::atomic_add(&grad(c), adj_ret);
935
+
936
+ WP_TILE_SYNC();
669
937
  }
670
938
 
671
939
 
672
940
  // copy register tile to shared
673
- inline CUDA_CALLABLE void assign(const tile_register_t<T, M, N>& tile)
941
+ template <typename Tile>
942
+ inline CUDA_CALLABLE void assign(const Tile& tile)
674
943
  {
944
+ if (initialized)
945
+ WP_TILE_SYNC();
946
+
675
947
  WP_PRAGMA_UNROLL
676
- for (int i=0; i < tile.NumRegs; ++i)
948
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
677
949
  {
678
- const int linear = tile.index(i);
950
+ const int linear = Tile::Layout::linear_from_register(i);
679
951
 
680
952
  // handle case where tile size is not
681
953
  // aligned to block dimensions
682
- if (!Aligned && linear >= Size)
683
- break;
954
+ if (!Tile::Layout::valid(linear))
955
+ break;
684
956
 
685
957
  data(linear) = tile.data[i];
686
958
  }
687
959
 
960
+ initialized = true;
688
961
  WP_TILE_SYNC();
689
962
  }
690
963
 
691
964
  // in-place gradient zero
692
965
  inline CUDA_CALLABLE void grad_zero()
693
966
  {
694
- // todo: make this subtile (stride aware)
695
- for (int i=threadIdx.x; i < M*N; i+= WP_TILE_BLOCK_DIM)
967
+ for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
696
968
  grad(i) = T(0);
697
969
 
698
970
  WP_TILE_SYNC();
@@ -700,44 +972,73 @@ struct tile_shared_t
700
972
 
701
973
 
702
974
  // accumulate gradients onto this tile
703
- inline CUDA_CALLABLE void grad_add(const tile_register_t<T, M, N>& tile)
975
+ template <typename Tile>
976
+ inline CUDA_CALLABLE void grad_add(const Tile& tile)
704
977
  {
705
978
  WP_PRAGMA_UNROLL
706
- for (int i=0; i < tile.NumRegs; ++i)
979
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
707
980
  {
708
- const int linear = tile.index(i);
981
+ const int linear = Tile::Layout::linear_from_register(i);
709
982
 
710
983
  // handle case where tile size is not
711
984
  // aligned to block dimensions
712
- if (!Aligned && linear >= Size)
985
+ if (!Tile::Layout::valid(linear))
713
986
  break;
714
987
 
715
- if (Unique)
988
+ // if the destination layout is unique (no broadcast dimensions)
989
+ // then we can use regular non-atomic accmulation
990
+ if (Layout::Unique)
716
991
  grad(linear) += tile.data[i];
717
992
  else
718
993
  // use shared memory atomics to accumulate gradients
719
994
  // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
720
995
  // may map to a single location in shared memory
721
- atomic_add(&grad(linear), tile.data[i]);
996
+ wp::atomic_add(&grad(linear), tile.data[i]);
722
997
 
723
998
  }
724
999
 
725
1000
  WP_TILE_SYNC();
726
1001
  }
727
1002
 
1003
+ // accumulate gradient onto this tile from a global array
1004
+ CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1005
+ {
1006
+ WP_PRAGMA_UNROLL
1007
+ for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1008
+ {
1009
+ auto c = Layout::coord_from_linear(i);
1010
+ T g = global.load_grad(c);
1011
+
1012
+ if (Layout::Unique)
1013
+ {
1014
+ // if the destination layout is unique (no broadcast dimensions)
1015
+ // then we can use regular non-atomic accumulation
1016
+ grad(c) += g;
1017
+ }
1018
+ else
1019
+ {
1020
+ // use shared memory atomics to accumulate gradients
1021
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
1022
+ // may map to a single location in shared memory
1023
+ wp::atomic_add(&grad(c), g);
1024
+ }
1025
+ }
1026
+
1027
+ WP_TILE_SYNC();
1028
+ }
1029
+
728
1030
  // copy shared tile to register
729
- inline CUDA_CALLABLE tile_register_t<T, M, N> grad_to_register()
1031
+ inline CUDA_CALLABLE auto grad_to_register()
730
1032
  {
731
- tile_register_t<T, M, N> out;
1033
+ using Tile = tile_register_t<T, tile_layout_register_t<typename Layout::Shape>>;
1034
+ Tile out;
732
1035
 
733
1036
  WP_PRAGMA_UNROLL
734
- for (int i=0; i < out.NumRegs; ++i)
1037
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
735
1038
  {
736
- const int linear = out.index(i);
1039
+ const int linear = Tile::Layout::linear_from_register(i);
737
1040
 
738
- // handle case where tile size is not
739
- // aligned to block dimensions
740
- if (!Aligned && linear >= Size)
1041
+ if (!Tile::Layout::valid(linear))
741
1042
  break;
742
1043
 
743
1044
  out(i) = grad(linear);
@@ -746,40 +1047,20 @@ struct tile_shared_t
746
1047
  return out;
747
1048
  }
748
1049
 
749
- inline CUDA_CALLABLE void print() const
750
- {
751
- if (threadIdx.x == 0)
752
- {
753
- printf("tile(m=%d, n=%d, storage=shared) = [", M, N);
754
- for (int i=0; i < M; ++i)
755
- {
756
- printf("%*s[", i>0, "");
757
- for (int j=0; j < N; ++j)
758
- {
759
- printf("%g ", double(data(i, j)));
760
- }
761
-
762
- if (i == M-1)
763
- printf("]]\n");
764
- else
765
- printf("]\n");
766
- }
767
- }
768
- }
769
-
770
1050
  // copy shared tile to register
771
- inline CUDA_CALLABLE tile_register_t<T, M, N> copy_to_register() const
1051
+ inline CUDA_CALLABLE auto copy_to_register() const
772
1052
  {
773
- tile_register_t<T, M, N> out;
1053
+
1054
+ auto out = tile_register_like(this);
1055
+
1056
+ using Layout = typename decltype(out)::Layout;
774
1057
 
775
1058
  WP_PRAGMA_UNROLL
776
- for (int i=0; i < out.NumRegs; ++i)
1059
+ for (int i=0; i < Layout::NumRegs; ++i)
777
1060
  {
778
- const int linear = out.index(i);
1061
+ const int linear = Layout::linear_from_register(i);
779
1062
 
780
- // handle case where tile size is not
781
- // aligned to block dimensions
782
- if (!Aligned && linear >= Size)
1063
+ if (!Layout::valid(linear))
783
1064
  break;
784
1065
 
785
1066
  out(i) = data(linear);
@@ -788,220 +1069,358 @@ struct tile_shared_t
788
1069
  return out;
789
1070
  }
790
1071
 
791
- inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x) const
792
- {
793
- assert(dest.ndim == 1);
1072
+ template <typename Global>
1073
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
1074
+ {
1075
+ // vectorized loads for specific input/output shapes
1076
+ if constexpr (Layout::Shape::N == 2)
1077
+ {
1078
+ constexpr int lastdim = Layout::Shape::N-1;
1079
+ constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1080
+ const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1081
+ const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
1082
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1083
+
1084
+ float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1085
+ const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1086
+
1087
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
1088
+ {
1089
+ constexpr int M = Layout::Shape::dim(0);
1090
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1091
+
1092
+ // alias of shared tile with 128bit type
1093
+ using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1094
+ tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
1095
+
1096
+ assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1097
+ assert(((uint64_t)(dest128))%sizeof(float4) == 0);
1098
+
1099
+ const int stride_i = dest.data.strides[0]/sizeof(float4);
1100
+ const int stride_j = 1;
1101
+
1102
+ WP_PRAGMA_UNROLL
1103
+ for (int i=threadIdx.x; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1104
+ {
1105
+ auto c = SrcLayout::coord_from_linear(i);
1106
+
1107
+ dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i);
1108
+ }
794
1109
 
795
- // todo: use TMA here
796
- const int tile_i = x*N;
1110
+ return;
1111
+ }
1112
+ }
797
1113
 
1114
+ // scalar bounds checked path
798
1115
  WP_PRAGMA_UNROLL
799
- for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
1116
+ for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
800
1117
  {
801
- wp::index(dest, tile_i + i) = data(i);
1118
+ auto c = Layout::coord_from_linear(i);
1119
+ dest.store(c, data(i));
802
1120
  }
803
1121
  }
804
1122
 
805
- inline CUDA_CALLABLE void copy_to_global(array_t<T> dest, int x, int y)
1123
+ __device__ __forceinline__
1124
+ void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
806
1125
  {
807
- // todo: use TMA here
808
- const int tile_i = x*M;
809
- const int tile_j = y*N;
810
-
811
- // check each row is contiguous and 128bit aligned
812
- if (StrideN == 1 && dest.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
813
- {
814
- constexpr int num_rows = M;
815
- constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
816
-
817
- tile_shared_t<float4, num_rows, num_cols> src128((float4*)data.ptr);
818
-
819
- // alias of shared tile with 128bit type
820
- float4* ptr = (float4*)&wp::index(dest, tile_i, tile_j);
1126
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1127
+
1128
+ unsigned long long saddr = 0ULL;
1129
+ unsigned long long gaddr = 0ULL;
1130
+
1131
+ asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest));
1132
+ asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src));
1133
+
1134
+ // Use cp.async on newer architectures
1135
+ asm volatile(
1136
+ "cp.async.ca.shared.global [%0], [%1], 16;\n"
1137
+ :
1138
+ : "l"(saddr), "l"(gaddr)
1139
+ );
1140
+ #else
1141
+ // use regular load/store through register on older arches
1142
+ *shared_dest = *global_src;
1143
+ #endif
1144
+ }
821
1145
 
822
- assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
823
- assert(((uint64_t)(ptr))%sizeof(float4) == 0);
1146
+ __device__ __forceinline__
1147
+ void cp_async_commit_and_wait_all_128()
1148
+ {
1149
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1150
+ asm volatile(
1151
+ "cp.async.commit_group;\n"
1152
+ "cp.async.wait_group 0;\n" ::);
1153
+ #endif
1154
+ }
1155
+
1156
+ template <typename Global>
1157
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
1158
+ {
1159
+ if (initialized)
1160
+ WP_TILE_SYNC();
1161
+
1162
+ // vectorized loads for specific input/output shapes
1163
+ if constexpr (Layout::Shape::N == 2)
1164
+ {
1165
+ constexpr int lastdim = Layout::Shape::N-1;
1166
+ constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1167
+ const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1168
+ const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
1169
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1170
+
1171
+ float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1172
+ const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1173
+
1174
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
1175
+ {
1176
+ constexpr int M = Layout::Shape::dim(0);
1177
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1178
+
1179
+ // alias of shared tile with 128bit type
1180
+ using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1181
+ tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1182
+
1183
+ assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1184
+ assert(((uint64_t)(src128))%sizeof(float4) == 0);
1185
+
1186
+ const int stride_i = src.data.strides[0]/sizeof(float4);
1187
+ const int stride_j = 1;
1188
+
1189
+ WP_PRAGMA_UNROLL
1190
+ for (int i=threadIdx.x; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1191
+ {
1192
+ auto c = DestLayout::coord_from_linear(i);
1193
+
1194
+ #if WP_USE_ASYNC_PIPELINE
1195
+ cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]);
1196
+ #else
1197
+ dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]];
1198
+ #endif // WP_USE_ASYNC_PIPELINE
1199
+ }
824
1200
 
825
- const int stride_i = dest.strides[0]/sizeof(float4);
826
- const int stride_j = 1;
1201
+ #if WP_USE_ASYNC_PIPELINE
1202
+ cp_async_commit_and_wait_all_128();
1203
+ #endif // WP_USE_ASYNC_PIPELINE
827
1204
 
828
- WP_PRAGMA_UNROLL
829
- for (int i=threadIdx.x; i < src128.Size; i += WP_TILE_BLOCK_DIM)
830
- {
831
- coord_t c = src128.coord(i);
832
- ptr[c.i*stride_i + c.j*stride_j] = src128.data(i);
1205
+ initialized = true;
1206
+ WP_TILE_SYNC();
1207
+ return;
833
1208
  }
834
1209
  }
835
- else
836
- {
837
- // wp.array() indexing generates poor code due to char* casting
838
- // here we unroll some of the ops, note this assumes byte strides are
839
- // aligned to the element size
840
- T* ptr = &wp::index(dest, tile_i, tile_j);
841
- const int stride_i = dest.strides[0]/sizeof(T);
842
- const int stride_j = dest.strides[1]/sizeof(T);
843
-
844
- WP_PRAGMA_UNROLL
845
- for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
846
- {
847
- coord_t c = coord(i);
848
- ptr[c.i*stride_i + c.j*stride_j] = data(c.i, c.j);
849
- }
850
- }
851
- }
852
-
853
- inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x)
854
- {
855
- // todo: use async pipelines or TMA here
856
- const int tile_i = x*N;
857
1210
 
1211
+ // scalar bounds checked path
858
1212
  WP_PRAGMA_UNROLL
859
- for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
1213
+ for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
860
1214
  {
861
- data(i) = wp::index(src, tile_i + i);
1215
+ auto c = Layout::coord_from_linear(i);
1216
+ data(i) = src.load(c);
862
1217
  }
863
1218
 
1219
+ initialized = true;
864
1220
  WP_TILE_SYNC();
865
1221
  }
866
1222
 
867
- inline CUDA_CALLABLE void copy_from_global(const array_t<T>& src, int x, int y)
1223
+ template <typename Global>
1224
+ inline CUDA_CALLABLE auto atomic_add(Global& dest)
868
1225
  {
869
- // todo: use async pipelines or TMA here
870
- const int tile_i = x*M;
871
- const int tile_j = y*N;
1226
+ copy_to_register().atomic_add(dest);
1227
+ }
872
1228
 
873
- // check each row is contiguous and 128bit aligned
874
- if (StrideN == 1 && src.strides[1] == sizeof(T) && (N*sizeof(T))%sizeof(float4) == 0)
875
- {
876
- constexpr int num_rows = M;
877
- constexpr int num_cols = (N*sizeof(T))/sizeof(float4);
1229
+ template <typename Global>
1230
+ inline CUDA_CALLABLE auto atomic_add_grad(Global& dest)
1231
+ {
1232
+ grad_to_register().atomic_add_grad(dest);
1233
+ }
878
1234
 
879
- // alias of shared tile with 128bit type
880
- tile_shared_t<float4, num_rows, num_cols> dest128((float4*)data.ptr);
1235
+ // overload for integral types
1236
+ inline CUDA_CALLABLE void print_value(int x) const
1237
+ {
1238
+ printf("%d", x);
1239
+ }
881
1240
 
882
- const float4* ptr = (const float4*)&wp::index(src, tile_i, tile_j);
1241
+ // overload for floating point types
1242
+ template <typename ValueType>
1243
+ inline CUDA_CALLABLE void print_value(ValueType x) const
1244
+ {
1245
+ printf("%g", x);
1246
+ }
883
1247
 
884
- assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
885
- assert(((uint64_t)(ptr))%sizeof(float4) == 0);
1248
+ template <int Level = 0>
1249
+ inline CUDA_CALLABLE void print_values(const Storage& storage, int index=0) const
1250
+ {
1251
+ using Shape = typename Layout::Shape;
886
1252
 
887
- const int stride_i = src.strides[0]/sizeof(float4);
888
- //const int stride_j = 1;
1253
+ if constexpr (Level < Shape::N)
1254
+ {
1255
+ if constexpr (Level == Shape::N - 1)
1256
+ {
1257
+ // Special handling for 1D case
1258
+ printf("[");
1259
+ for (int i = 0; i < Shape::dim(Level); ++i)
1260
+ {
1261
+ print_value(storage(index + i));
889
1262
 
890
- WP_PRAGMA_UNROLL
891
- for (int i=threadIdx.x; i < dest128.Size; i += WP_TILE_BLOCK_DIM)
892
- {
893
- coord_t c = dest128.coord(i);
894
-
895
- #if WP_USE_ASYNC_PIPELINE
896
- __pipeline_memcpy_async(&dest128.data(i),
897
- &ptr[c.i*stride_i + c.j],
898
- sizeof(float4));
899
- #else
900
- dest128.data(i) = ptr[c.i*stride_i + c.j];
901
- #endif // WP_USE_ASYNC_PIPELINE
1263
+ if (i < Shape::dim(Level) - 1)
1264
+ {
1265
+ printf(" ");
1266
+ }
1267
+ }
1268
+ printf("]");
1269
+ }
1270
+ else if constexpr (Level == Shape::N - 2)
1271
+ {
1272
+ // Special handling for 2D case
1273
+ printf("[");
1274
+ for (int i = 0; i < Shape::dim(Level); ++i)
1275
+ {
1276
+ printf("[");
1277
+ for (int j=0; j < Shape::dim(Level+1); ++j)
1278
+ {
1279
+ print_value(storage(index));
1280
+
1281
+ if (j < Shape::dim(Level+1) - 1)
1282
+ {
1283
+ printf(" ");
1284
+ }
1285
+
1286
+ ++index;
1287
+ }
1288
+
1289
+ printf("]");
1290
+
1291
+ // next row
1292
+ if (i < Shape::dim(Level)-1)
1293
+ {
1294
+ printf("\n");
1295
+
1296
+ // indent next row
1297
+ for (int i=0; i <= Shape::N-2; ++i)
1298
+ printf(" ");
1299
+
1300
+ }
1301
+ }
1302
+ printf("]");
1303
+ }
1304
+ else
1305
+ {
1306
+ printf("[");
1307
+ for (int i = 0; i < Shape::dim(Level); ++i)
1308
+ {
1309
+ print_values<Level + 1>(storage, index + i * Shape::dim(Level));
1310
+ if (i < Shape::dim(Level) - 1)
1311
+ {
1312
+ printf("\n\n");
1313
+
1314
+ // indent next row
1315
+ for (int i=0; i <= Level; ++i)
1316
+ printf(" ");
1317
+ }
1318
+ }
1319
+ printf("]");
902
1320
  }
1321
+ }
1322
+ }
903
1323
 
904
- #if WP_USE_ASYNC_PIPELINE
905
- __pipeline_commit();
906
- #endif // WP_USE_ASYNC_PIPELINE
1324
+ inline CUDA_CALLABLE void print(bool reverse=false) const
1325
+ {
1326
+ if (threadIdx.x != 0)
1327
+ return;
907
1328
 
908
- }
1329
+ if (reverse)
1330
+ print_values(grad);
909
1331
  else
1332
+ print_values(data);
1333
+
1334
+ printf(" = tile(shape=(");
1335
+ for (int i=0; i < Layout::Shape::N; ++i)
910
1336
  {
911
- // wp.array() indexing generates poor code due to char* casting
912
- // here we unroll some of the ops, note this assumes array byte strides are
913
- // aligned to the element size
914
- const T* ptr = &wp::index(src, tile_i, tile_j);
915
-
916
- assert(src.strides[0]%sizeof(T) == 0);
917
- assert(src.strides[1]%sizeof(T) == 0);
918
-
919
- const int stride_i = src.strides[0]/sizeof(T);
920
- const int stride_j = src.strides[1]/sizeof(T);
921
-
922
- WP_PRAGMA_UNROLL
923
- for (int i=threadIdx.x; i < Size; i += WP_TILE_BLOCK_DIM)
924
- {
925
- coord_t c = coord(i);
926
- data(c.i, c.j) = ptr[c.i*stride_i + c.j*stride_j];
927
- }
1337
+ printf("%d", Layout::Shape::dim(i));
1338
+ if (i != Layout::Shape::N-1)
1339
+ printf(",");
928
1340
  }
929
1341
 
930
- #if !WP_USE_ASYNC_PIPELINE
931
- WP_TILE_SYNC();
932
- #endif
933
-
934
- }
1342
+ printf("), storage=shared)\n");
1343
+ }
935
1344
  };
936
1345
 
937
- template <typename T, int M, int N>
938
- void tile_register_t<T, M, N>::print() const
1346
+
1347
+ template <typename T, typename L>
1348
+ void tile_register_t<T, L>::print() const
939
1349
  {
940
1350
  // create a temporary shared tile so that
941
1351
  // we can print it deterministically
942
- WP_TILE_SHARED T smem[M*N];
943
-
944
- tile_shared_t<T, M, N> scratch(smem, NULL);
1352
+ WP_TILE_SHARED T smem[L::Size];
1353
+ tile_shared_t<T, tile_layout_strided_t<typename L::Shape>> scratch(smem, NULL);
1354
+
945
1355
  scratch.assign(*this);
946
1356
 
947
1357
  WP_TILE_SYNC();
948
1358
 
949
1359
  if (threadIdx.x == 0)
950
1360
  {
951
- printf("tile(m=%d, n=%d, storage=register) = [", M, N);
952
- for (int i=0; i < M; ++i)
953
- {
954
- printf("%*s[", i>0, "");
955
- for (int j=0; j < N; ++j)
956
- {
957
- printf("%g ", double(scratch.data(i, j)));
958
- }
1361
+ scratch.print_values(scratch.data, 0);
959
1362
 
960
- if (i == M-1)
961
- printf("]]\n");
962
- else
963
- printf("]\n");
1363
+ printf(" = tile(shape=(");
1364
+ for (int i=0; i < L::Shape::N; ++i)
1365
+ {
1366
+ printf("%d", L::Shape::dim(i));
1367
+ if (i != L::Shape::N-1)
1368
+ printf(",");
964
1369
  }
1370
+
1371
+ printf("), storage=register)\n");
965
1372
  }
966
1373
 
967
1374
  WP_TILE_SYNC();
968
1375
  }
969
1376
 
970
- template <typename T, int M, int N>
971
- inline CUDA_CALLABLE void print(const tile_register_t<T, M, N>& t)
1377
+ // print entry points
1378
+ template <typename T, typename L>
1379
+ inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1380
+ template <typename T, typename L, bool Owner>
1381
+ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1382
+
1383
+ template <typename T, typename L, bool O>
1384
+ inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
972
1385
  {
973
- t.print();
1386
+ return Tile::Layout::Shape::dim(0);
974
1387
  }
975
1388
 
976
- template <typename T, int M, int N>
977
- inline CUDA_CALLABLE void adj_print(const tile_register_t<T, M, N>& t, const tile_register_t<T, M, N>& a)
1389
+ template <typename T, typename L, bool O, typename AdjTile>
1390
+ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile& a, int& adj_ret)
978
1391
  {
979
- a.print();
980
1392
  }
981
1393
 
982
- template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
983
- inline CUDA_CALLABLE void print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t)
1394
+ template <typename T, typename L>
1395
+ inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
984
1396
  {
985
- t.print();
1397
+ return Tile::Layout::Shape::dim(0);
986
1398
  }
987
1399
 
988
- template <typename T, int M, int N, int StrideM, int StrideN, bool Owner>
989
- inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& t, const tile_shared_t<T, M, N, StrideM, StrideN, Owner>& a)
1400
+ template <typename T, typename L, typename AdjTile>
1401
+ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile& a, int& adj_ret)
990
1402
  {
991
- a.print();
992
1403
  }
993
1404
 
1405
+
1406
+ template <typename T, typename L>
1407
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1408
+ template <typename T, typename L, bool Owner>
1409
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1410
+
1411
+
1412
+
994
1413
  // helpers to allocate shared tiles
995
- template <typename T, int M, int N, bool RequiresGrad>
1414
+ template <typename T, typename Shape, bool RequiresGrad>
996
1415
  inline CUDA_CALLABLE auto tile_alloc_empty()
997
1416
 
998
- { constexpr int Len = M*N;
999
- T* data = (T*)tile_alloc_shared(Len*sizeof(T));
1417
+ { constexpr int size = Shape::size();
1418
+ T* data = (T*)tile_alloc_shared(size*sizeof(T));
1000
1419
  T* grad = NULL;
1001
1420
 
1002
1421
  #if FP_CHECK
1003
1422
 
1004
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1423
+ for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1005
1424
  data[i] = T(nanf(""));
1006
1425
 
1007
1426
  WP_TILE_SYNC();
@@ -1011,15 +1430,15 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1011
1430
 
1012
1431
  if (RequiresGrad)
1013
1432
  {
1014
- grad = (T*)tile_alloc_shared(Len*sizeof(T));
1433
+ grad = (T*)tile_alloc_shared(size*sizeof(T));
1015
1434
 
1016
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1435
+ for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1017
1436
  grad[i] = T(0);
1018
1437
 
1019
1438
  WP_TILE_SYNC();
1020
1439
  }
1021
1440
 
1022
- return tile_shared_t<T, M, N>(data, grad);
1441
+ return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
1023
1442
  }
1024
1443
 
1025
1444
  template <typename T, int M, int N, bool RequiresGrad>
@@ -1043,7 +1462,7 @@ inline CUDA_CALLABLE auto tile_alloc_zeros()
1043
1462
 
1044
1463
  WP_TILE_SYNC();
1045
1464
 
1046
- return tile_shared_t<T, M, N, StrideM, StrideN>(data, grad);
1465
+ return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
1047
1466
  }
1048
1467
 
1049
1468
 
@@ -1054,9 +1473,10 @@ inline CUDA_CALLABLE auto tile_alloc_zeros()
1054
1473
  template <typename T>
1055
1474
  inline CUDA_CALLABLE auto tile(const T& x)
1056
1475
  {
1057
- tile_register_t<T, 1, WP_TILE_BLOCK_DIM> result;
1476
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
1058
1477
 
1059
- static_assert(result.NumRegs == 1);
1478
+ using Layout = typename decltype(result)::Layout;
1479
+ static_assert(Layout::NumRegs == 1);
1060
1480
 
1061
1481
  result.data[0] = x;
1062
1482
  return result;
@@ -1066,9 +1486,10 @@ inline CUDA_CALLABLE auto tile(const T& x)
1066
1486
  template <typename T, unsigned Length>
1067
1487
  inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1068
1488
  {
1069
- tile_register_t<T, Length, WP_TILE_BLOCK_DIM> result;
1489
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
1070
1490
 
1071
- static_assert(result.NumRegs == Length);
1491
+ using Layout = typename decltype(result)::Layout;
1492
+ static_assert(Layout::NumRegs == Length);
1072
1493
 
1073
1494
  for (int i=0; i < Length; ++i)
1074
1495
  result.data[i] = x[i];
@@ -1080,8 +1501,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1080
1501
  template <typename T, typename AdjTile>
1081
1502
  inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1082
1503
  {
1083
- static_assert(AdjTile::M == 1);
1084
- static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
1504
+ static_assert(AdjTile::Layout::Shape::N == 1);
1505
+ static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
1085
1506
 
1086
1507
  auto adj_reg = adj_ret.copy_to_register();
1087
1508
 
@@ -1091,8 +1512,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1091
1512
  template <typename T, unsigned Length, typename AdjTile>
1092
1513
  inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1093
1514
  {
1094
- static_assert(AdjTile::M == Length);
1095
- static_assert(AdjTile::N == WP_TILE_BLOCK_DIM);
1515
+ static_assert(AdjTile::Layout::Shape::N == 2);
1516
+ static_assert(AdjTile::Layout::Shape::dim(0) == Length);
1517
+ static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
1096
1518
 
1097
1519
  auto adj_reg = adj_ret.copy_to_register();
1098
1520
 
@@ -1108,17 +1530,20 @@ inline CUDA_CALLABLE auto untile(Tile& tile)
1108
1530
  // there is exactly one value per-thread
1109
1531
  auto reg = tile.copy_to_register();
1110
1532
 
1533
+ constexpr int N = Tile::Layout::Shape::N;
1534
+
1111
1535
  // scalar case
1112
- if constexpr(Tile::M == 1)
1536
+ if constexpr(N == 1)
1113
1537
  {
1114
1538
  return reg.data[0];
1115
1539
  }
1116
1540
 
1117
1541
  // vector case
1118
- if constexpr(Tile::M > 1)
1542
+ if constexpr(N == 2)
1119
1543
  {
1120
- wp::vec_t<Tile::M, typename Tile::Type> v;
1121
- for (int i=0; i < Tile::M; ++i)
1544
+ constexpr int Length = Tile::Layout::Shape::dim(0);
1545
+ wp::vec_t<Length, typename Tile::Type> v;
1546
+ for (int i=0; i < Length; ++i)
1122
1547
  v[i] = reg.data[i];
1123
1548
 
1124
1549
  return v;
@@ -1130,24 +1555,27 @@ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
1130
1555
  {
1131
1556
  auto adj = adj_tile.copy_to_register();
1132
1557
 
1558
+ constexpr int N = Tile::Layout::Shape::N;
1559
+
1133
1560
  // scalar case
1134
- if constexpr(Tile::M == 1)
1561
+ if constexpr(N == 1)
1135
1562
  {
1136
1563
  adj.data[0] += adj_ret;
1137
1564
  }
1138
1565
 
1139
1566
  // vector case
1140
- if constexpr(Tile::M > 1)
1567
+ if constexpr(N == 2)
1141
1568
  {
1142
- for (int i=0; i < Tile::M; ++i)
1143
- adj.data[i] = adj_ret[i];
1569
+ constexpr int Length = Tile::Layout::Shape::dim(0);
1570
+ for (int i=0; i < Length; ++i)
1571
+ adj.data[i] += adj_ret[i];
1144
1572
  }
1145
1573
 
1146
1574
  adj_tile.assign(adj);
1147
1575
  }
1148
1576
 
1149
1577
  // zero initialized tile
1150
- template <typename T, int M, int N>
1578
+ template <typename T, unsigned... Shape>
1151
1579
  inline CUDA_CALLABLE auto tile_zeros()
1152
1580
  {
1153
1581
  // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
@@ -1155,7 +1583,7 @@ inline CUDA_CALLABLE auto tile_zeros()
1155
1583
  }
1156
1584
 
1157
1585
  // one-initialized tile
1158
- template <typename T, int M, int N>
1586
+ template <typename T, unsigned... Shape>
1159
1587
  inline CUDA_CALLABLE auto tile_ones()
1160
1588
  {
1161
1589
  // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
@@ -1163,19 +1591,21 @@ inline CUDA_CALLABLE auto tile_ones()
1163
1591
  }
1164
1592
 
1165
1593
  // tile with evenly spaced values
1166
- template <typename T, int M, int N>
1594
+ template <typename T, int Len>
1167
1595
  inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
1168
1596
  {
1169
- tile_register_t<T, M, N> out;
1597
+ auto out = tile_register<T, Len>();
1598
+
1599
+ using Layout = typename decltype(out)::Layout;
1170
1600
 
1171
1601
  WP_PRAGMA_UNROLL
1172
- for (int i=0; i < out.NumRegs; ++i)
1602
+ for (int i=0; i < Layout::NumRegs; ++i)
1173
1603
  {
1174
- const int linear = out.index(i);
1604
+ const int linear = Layout::linear_from_register(i);
1175
1605
 
1176
1606
  // handle case where tile size is not
1177
1607
  // aligned to block dimensions
1178
- if (!out.Aligned && linear >= out.Size)
1608
+ if (!Layout::valid(linear))
1179
1609
  break;
1180
1610
 
1181
1611
  out.data[i] = start + linear*step;
@@ -1188,191 +1618,106 @@ template <typename T, typename AdjTile>
1188
1618
  inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
1189
1619
  T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
1190
1620
 
1191
- // entry point for 1d load
1192
- template <typename T, int N>
1193
- inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x)
1621
+ // entry point for load operations, these just return a reference to a global memory array + coordinate
1622
+ template <unsigned... Shape, typename... Indices, typename T>
1623
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1194
1624
  {
1195
- return tile_global_t<T>(src, x, 0);
1625
+ return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
1196
1626
  }
1197
1627
 
1198
- // entry point for 2d load
1199
- template <typename T, int M, int N>
1200
- inline CUDA_CALLABLE auto tile_load(array_t<T>& src, int x, int y)
1201
- {
1202
- return tile_global_t<T>(src, x, y);
1203
- }
1628
+ // // entry point for tile store operations
1629
+ // template <typename... Indices, typename T, typename Tile>
1630
+ // inline CUDA_CALLABLE void tile_store(array_t<T>& dest, Tile& src, Indices... x)
1631
+ // {
1632
+ // src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x)));
1633
+ // }
1204
1634
 
1205
- // entry point for 1d store
1635
+ // entry point for tile store operations
1206
1636
  template <typename T, typename Tile>
1207
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src)
1208
- {
1209
- // dispatch to tile type
1210
- src.copy_to_global(dest, x);
1211
- }
1212
-
1213
- // entry point for 2d store
1637
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1214
1638
  template <typename T, typename Tile>
1215
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src)
1216
- {
1217
- // dispatch to tile type
1218
- src.copy_to_global(dest, x, y);
1219
- }
1220
-
1639
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
1221
1640
  template <typename T, typename Tile>
1222
- inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src)
1223
- {
1224
- auto src_reg = src.copy_to_register();
1225
-
1226
- const int tile_i = x*src_reg.M;
1227
- const int tile_j = y*src_reg.N;
1228
-
1229
- tile_register_t<T, src_reg.M, src_reg.N> previous;
1230
-
1231
- WP_PRAGMA_UNROLL
1232
- for (int i=0; i < src_reg.NumRegs; ++i)
1233
- {
1234
- // handle case where tile size is not
1235
- // aligned to block dimensions
1236
- int linear = src_reg.index(i);
1237
- if (!src_reg.Aligned && linear >= src_reg.Size)
1238
- break;
1641
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
1642
+ template <typename T, typename Tile>
1643
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
1239
1644
 
1240
- coord_t c = src_reg.coord(linear);
1241
- previous.data[i] = atomic_add(dest, tile_i + c.i, tile_j + c.j, src_reg.data[i]);
1242
- }
1243
1645
 
1244
- return previous;
1245
- }
1246
1646
 
1647
+ template <typename T, typename Tile>
1648
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1649
+ template <typename T, typename Tile>
1650
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y)));}
1651
+ template <typename T, typename Tile>
1652
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z)));}
1653
+ template <typename T, typename Tile>
1654
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w)));}
1247
1655
 
1248
1656
 
1249
1657
  //-------------------------------------
1250
1658
  // Adjoints
1251
1659
 
1252
- template <typename T, typename AdjTile>
1253
- inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x,
1254
- array_t<T>& adj_src, int adj_x,
1660
+ template <typename T, typename AdjTile, typename Coord>
1661
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
1662
+ array_t<T>& adj_src, Coord adj_c,
1255
1663
  AdjTile& adj_ret)
1256
1664
  {
1257
- // early out
1258
- // if (!src.grad)
1259
- // return;
1260
-
1261
- auto adj_reg = adj_ret.grad_to_register();
1262
-
1263
- const int tile_i = x*adj_reg.N;
1264
-
1265
- // add gradients to src array
1266
- WP_PRAGMA_UNROLL
1267
- for (int i=0; i < adj_reg.NumRegs; ++i)
1268
- {
1269
- int linear = adj_reg.index(i);
1270
- if (!adj_reg.Aligned && linear >= adj_reg.Size)
1271
- break;
1272
-
1273
- auto grad = adj_reg.data[i];
1665
+ tile_global_t<T, typename AdjTile::Layout::Shape> dest(src, c);
1666
+
1667
+ // we allow users to override grad of src
1668
+ if (adj_src.data)
1669
+ dest.data.grad = adj_src.data;
1274
1670
 
1275
- if (adj_src.data)
1276
- adj_atomic_add(&index(adj_src, tile_i + linear), grad);
1277
- else if (src.grad)
1278
- adj_atomic_add(&index_grad(src, tile_i + linear), grad);
1279
- }
1671
+ adj_ret.atomic_add_grad(dest);
1280
1672
  }
1281
1673
 
1282
- template <typename T, typename AdjTile>
1283
- inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y,
1284
- array_t<T>& adj_src, int adj_x, int adj_y,
1285
- AdjTile& adj_ret)
1286
- {
1287
- // early out
1288
- // if (!src.grad)
1289
- // return;
1290
-
1291
- auto adj_reg = adj_ret.grad_to_register();
1292
-
1293
- const int tile_i = x*adj_reg.M;
1294
- const int tile_j = y*adj_reg.N;
1295
-
1296
- // add gradients to src array
1297
- WP_PRAGMA_UNROLL
1298
- for (int i=0; i < adj_reg.NumRegs; ++i)
1299
- {
1300
- int linear = adj_reg.index(i);
1301
- if (!adj_reg.Aligned && linear >= adj_reg.Size)
1302
- break;
1303
-
1304
- coord_t coord = adj_reg.coord(linear);
1305
1674
 
1306
- auto grad = adj_reg.data[i];
1675
+ template <typename T, typename AdjTile>
1676
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
1677
+ template <typename T, typename AdjTile>
1678
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, array_t<T>& adj_src, int adj_x, int adj_y, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y), adj_src, tile_coord(0,0), adj_ret); }
1679
+ template <typename T, typename AdjTile>
1680
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z), adj_src, tile_coord(0,0,0), adj_ret); }
1681
+ template <typename T, typename AdjTile>
1682
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
1307
1683
 
1308
- if (adj_src.data)
1309
- adj_atomic_add(&index(adj_src, tile_i + coord.i, tile_j + coord.j), grad);
1310
- else if (src.grad)
1311
- adj_atomic_add(&index_grad(src, tile_i + coord.i, tile_j + coord.j), grad);
1312
- }
1313
- }
1314
1684
 
1315
1685
 
1316
- template <typename T, typename Tile, typename AdjTile>
1317
- inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t)
1686
+ template <typename T, typename Tile, typename AdjTile, typename Coord>
1687
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
1318
1688
  {
1319
- // convert to register if necessary
1320
- tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
1321
-
1322
- const int tile_i = x*adj_reg.N;
1323
-
1324
- // load gradients from output
1325
- WP_PRAGMA_UNROLL
1326
- for (int i=0; i < adj_reg.NumRegs; ++i)
1327
- {
1328
- int linear = adj_reg.index(i);
1329
- if (!adj_reg.Aligned && linear >= adj_reg.Size)
1330
- break;
1689
+ tile_global_t<T, typename AdjTile::Layout::Shape> src(dest, c);
1690
+
1691
+ // we allow users to override grad of src
1692
+ if (adj_dest.data)
1693
+ src.data.grad = adj_dest.data;
1331
1694
 
1332
- if (adj_dest.data)
1333
- adj_reg.data[i] = index(adj_dest, tile_i + linear);
1334
- else if (dest.grad)
1335
- adj_reg.data[i] = index_grad(dest, tile_i + linear);
1336
- }
1695
+ if (src.data.grad == NULL)
1696
+ return;
1337
1697
 
1338
- // store adjoint back to tile
1339
- adj_t.grad_add(adj_reg);
1698
+ adj_t.grad_add(src);
1340
1699
  }
1341
1700
 
1342
1701
  template <typename T, typename Tile, typename AdjTile>
1343
- inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t)
1344
- {
1345
- // allocate register tile to load grads into
1346
- tile_register_t<T, AdjTile::M, AdjTile::N> adj_reg;
1347
-
1348
- const int tile_i = x*adj_reg.M;
1349
- const int tile_j = y*adj_reg.N;
1350
-
1351
- // load gradients from output
1352
- WP_PRAGMA_UNROLL
1353
- for (int i=0; i < adj_reg.NumRegs; ++i)
1354
- {
1355
- int linear = adj_reg.index(i);
1356
- if (!adj_reg.Aligned && linear >= adj_reg.Size)
1357
- break;
1358
-
1359
- coord_t coord = adj_reg.coord(linear);
1702
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(0), adj_t); }
1703
+ template <typename T, typename Tile, typename AdjTile>
1704
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(0,0), adj_t); }
1705
+ template <typename T, typename Tile, typename AdjTile>
1706
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(0,0,0), adj_t); }
1707
+ template <typename T, typename Tile, typename AdjTile>
1708
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
1360
1709
 
1361
- if (adj_dest.data)
1362
- adj_reg.data[i] = index(adj_dest, tile_i + coord.i, tile_j + coord.j);
1363
- else if (dest.grad)
1364
- adj_reg.data[i] = index_grad(dest, tile_i + coord.i, tile_j + coord.j);
1365
- }
1366
1710
 
1367
- // store adjoint back to tile
1368
- adj_t.grad_add(adj_reg);
1369
- }
1370
1711
 
1712
+ // adj_tile_atomic_add is an alias for adj_tile_store
1371
1713
  template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1372
- inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret)
1373
- {
1374
- adj_tile_store(dest, x, y, t, adj_dest, adj_x, adj_y, adj_t);
1375
- }
1714
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(adj_x), adj_t); }
1715
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1716
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(adj_x, adj_y), adj_t); }
1717
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1718
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(adj_x, adj_y, adj_z), adj_t); }
1719
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1720
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
1376
1721
 
1377
1722
 
1378
1723
  // unary map
@@ -1380,11 +1725,13 @@ template <typename Tile, typename Fwd>
1380
1725
  inline CUDA_CALLABLE auto tile_map(Fwd op,
1381
1726
  Tile &a)
1382
1727
  {
1383
- auto out = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1728
+ auto out = tile_register_like<Tile>();
1384
1729
  auto a_reg = a.copy_to_register();
1730
+
1731
+ using Layout = typename decltype(out)::Layout;
1385
1732
 
1386
1733
  WP_PRAGMA_UNROLL
1387
- for (int i=0; i < out.NumRegs; ++i)
1734
+ for (int i=0; i < Layout::NumRegs; ++i)
1388
1735
  {
1389
1736
  out.data[i] = op(a_reg.data[i]);
1390
1737
  }
@@ -1404,8 +1751,10 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1404
1751
  auto adj_a_reg = tile_register_like<Tile>();
1405
1752
  auto adj_ret_reg = adj_ret.grad_to_register();
1406
1753
 
1754
+ using Layout = typename decltype(a_reg)::Layout;
1755
+
1407
1756
  WP_PRAGMA_UNROLL
1408
- for (int i=0; i < a_reg.NumRegs; ++i)
1757
+ for (int i=0; i < Layout::NumRegs; ++i)
1409
1758
  {
1410
1759
  adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
1411
1760
  }
@@ -1420,14 +1769,18 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
1420
1769
  TileA& a,
1421
1770
  TileB& b)
1422
1771
  {
1423
- auto out = tile_register_t<typename TileA::Type, TileA::M, TileA::N>();
1772
+ auto out = tile_register_like<TileA>();
1424
1773
 
1425
1774
  auto a_reg = a.copy_to_register();
1426
1775
  auto b_reg = b.copy_to_register();
1427
1776
 
1777
+ using Layout = typename decltype(out)::Layout;
1778
+
1428
1779
  WP_PRAGMA_UNROLL
1429
- for (int i=0; i < out.NumRegs; ++i)
1780
+ for (int i=0; i < Layout::NumRegs; ++i)
1781
+ {
1430
1782
  out.data[i] = op(a_reg.data[i], b_reg.data[i]);
1783
+ }
1431
1784
 
1432
1785
  return out;
1433
1786
  }
@@ -1451,8 +1804,10 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1451
1804
 
1452
1805
  auto adj_ret_reg = adj_ret.grad_to_register();
1453
1806
 
1807
+ using Layout = typename decltype(a_reg)::Layout;
1808
+
1454
1809
  WP_PRAGMA_UNROLL
1455
- for (int i=0; i < a_reg.NumRegs; ++i)
1810
+ for (int i=0; i < Layout::NumRegs; ++i)
1456
1811
  {
1457
1812
  adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
1458
1813
  }
@@ -1485,26 +1840,32 @@ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1485
1840
  return tile_binary_map(add, a, b);
1486
1841
  }
1487
1842
 
1488
- // // tile + tile, we implement this
1489
- // template <typename TileA, typename TileB>
1490
- // inline CUDA_CALLABLE auto add(TileA& a, TileB& b)
1491
- // {
1492
- // return tile_binary_map(add, a, b);
1493
- // }
1494
-
1495
-
1496
1843
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1497
1844
  inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1498
1845
  {
1499
1846
  adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
1500
1847
  }
1501
1848
 
1849
+ // tile - tile
1850
+ template <typename TileA, typename TileB>
1851
+ inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
1852
+ {
1853
+ return tile_binary_map(sub, a, b);
1854
+ }
1855
+
1856
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1857
+ inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1858
+ {
1859
+ adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
1860
+ }
1861
+
1862
+
1502
1863
  // tile*scalar
1503
1864
  template <typename Tile>
1504
1865
  inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
1505
1866
  {
1506
1867
  // promote scalar to a constant tile
1507
- auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1868
+ auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
1508
1869
 
1509
1870
  return tile_binary_map(mul, a, s_tile);
1510
1871
  }
@@ -1514,12 +1875,17 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
1514
1875
  Tile& adj_a, typename Tile::Type& adj_s,
1515
1876
  AdjTile& adj_c)
1516
1877
  {
1517
- auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1518
- auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1878
+ auto s_tile = tile_register_like<Tile>();
1879
+ auto adj_s_tile = tile_register_like<Tile>();
1880
+
1881
+ using Layout = typename decltype(adj_s_tile)::Layout;
1882
+
1883
+ // initialize to constant
1884
+ s_tile = s;
1519
1885
 
1520
1886
  adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
1521
1887
 
1522
- for (int i=0; i < adj_s_tile.NumRegs; ++i)
1888
+ for (int i=0; i < Layout::NumRegs; ++i)
1523
1889
  {
1524
1890
  adj_s += adj_s_tile.data[i];
1525
1891
  }
@@ -1530,10 +1896,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
1530
1896
  template <typename Tile>
1531
1897
  inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
1532
1898
  {
1533
- // promote scalar to a constant tile
1534
- auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1535
-
1536
- return tile_binary_map(mul, s_tile, a);
1899
+ return tile_mul(a, s);
1537
1900
  }
1538
1901
 
1539
1902
  template <typename Tile, typename AdjTile>
@@ -1541,36 +1904,30 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
1541
1904
  typename Tile::Type& adj_s, Tile& adj_a,
1542
1905
  AdjTile& adj_c)
1543
1906
  {
1544
- auto s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>(s);
1545
- auto adj_s_tile = tile_register_t<typename Tile::Type, Tile::M, Tile::N>();
1546
-
1547
- adj_tile_binary_map(mul, s_tile, a, adj_mul, adj_s_tile, adj_a, adj_c);
1548
-
1549
- for (int i=0; i < adj_s_tile.NumRegs; ++i)
1550
- {
1551
- adj_s += adj_s_tile.data[i];
1552
- }
1907
+ adj_tile_mul(a, s, adj_a, adj_s, adj_c);
1553
1908
  }
1554
1909
 
1555
1910
 
1556
-
1557
1911
  template<typename Tile>
1558
- typename Tile::Type tile_extract(Tile& t, int i, int j)
1559
- {
1560
- assert(i < Tile::M);
1561
- assert(j < Tile::N);
1912
+ typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
1913
+ template<typename Tile>
1914
+ typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
1915
+ template<typename Tile>
1916
+ typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
1917
+ template<typename Tile>
1918
+ typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
1562
1919
 
1563
- return t.extract(i, j);
1564
- }
1565
1920
 
1566
1921
  template<typename Tile, typename AdjTile>
1567
- void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret)
1568
- {
1569
- assert(i < Tile::M);
1570
- assert(j < Tile::N);
1922
+ void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
1923
+ template<typename Tile, typename AdjTile>
1924
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
1925
+ template<typename Tile, typename AdjTile>
1926
+ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
1927
+ template<typename Tile, typename AdjTile>
1928
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
1571
1929
 
1572
- adj_t.adj_extract(i, j, adj_ret);
1573
- }
1930
+ #if WP_USE_REGISTER_GEMM
1574
1931
 
1575
1932
  namespace partitioned_gemm
1576
1933
  {
@@ -1592,7 +1949,7 @@ struct partition_t
1592
1949
  {
1593
1950
  static constexpr int M = PartitionM;
1594
1951
  static constexpr int N = PartitionN;
1595
- static constexpr int Stride = Tile::N;
1952
+ static constexpr int Stride = Tile::Layout::Shape::dim(1);
1596
1953
 
1597
1954
  using T = typename Tile::Type;
1598
1955
 
@@ -1601,8 +1958,8 @@ struct partition_t
1601
1958
  data = A.data.ptr;
1602
1959
 
1603
1960
  // todo: do ceil div for non-multiples of M,N
1604
- shape[0] = Tile::M/PartitionM;
1605
- shape[1] = Tile::N/PartitionN;
1961
+ shape[0] = Tile::Layout::Shape::dim(0)/PartitionM;
1962
+ shape[1] = Tile::Layout::Shape::dim(1)/PartitionN;
1606
1963
  }
1607
1964
 
1608
1965
  // underlying data
@@ -1640,7 +1997,7 @@ inline auto partition_load(const Partition& tile, int i, int j)
1640
1997
  WP_PRAGMA_UNROLL
1641
1998
  for (int j=0; j < Partition::N; ++j)
1642
1999
  {
1643
- out.data[i][j] = index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
2000
+ out.data[i][j] = partitioned_gemm::index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
1644
2001
  }
1645
2002
  }
1646
2003
 
@@ -1664,6 +2021,7 @@ inline void partition_store(const Partition& tile, int i, int j, const Value& va
1664
2021
  }
1665
2022
  }
1666
2023
 
2024
+
1667
2025
  template <typename TileA, typename TileB, typename TileC>
1668
2026
  inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
1669
2027
  {
@@ -1700,15 +2058,26 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
1700
2058
 
1701
2059
  } // namespace partition_gemm
1702
2060
 
2061
+ #endif // WP_USE_REGISTER_GEMM
2062
+
2063
+
1703
2064
  template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
1704
2065
  TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
1705
2066
  {
1706
- using T = typename TileA::Type;
2067
+ using ShapeA = typename TileA::Layout::Shape;
2068
+ using ShapeB = typename TileB::Layout::Shape;
2069
+ using ShapeC = typename TileC::Layout::Shape;
1707
2070
 
1708
- #if WP_USE_ASYNC_PIPELINE
1709
- __pipeline_wait_prior(0);
1710
- WP_TILE_SYNC();
1711
- #endif
2071
+ static_assert(ShapeA::N == 2);
2072
+ static_assert(ShapeB::N == 2);
2073
+ static_assert(ShapeC::N == 2);
2074
+
2075
+ static_assert(ShapeA::dim(1) == ShapeB::dim(0));
2076
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2077
+ static_assert(ShapeC::dim(1) == ShapeB::dim(1));
2078
+
2079
+
2080
+ using T = typename TileA::Type;
1712
2081
 
1713
2082
  #if WP_USE_REGISTER_GEMM
1714
2083
  partitioned_gemm::matmul(A, B, C);
@@ -1746,11 +2115,11 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
1746
2115
  }
1747
2116
 
1748
2117
  // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
1749
- // TODO(lcambier): use dynamic smem
2118
+ // and remove the need for __align__(16) dtypes data[...]
1750
2119
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
1751
2120
  do { \
1752
2121
  void function_name(dtype*, dtype*); \
1753
- WP_TILE_SHARED __align__(16) char buffer[shared_memory_size]; \
2122
+ char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
1754
2123
  __align__(16) dtype data[ept]; \
1755
2124
  for(int b = 0; b < (int)batch_size; b++) { \
1756
2125
  dtype* inout = Xinout.data + (int)b * (int)ept; \
@@ -1759,6 +2128,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
1759
2128
  memcpy(inout, data, sizeof(dtype) * ept); \
1760
2129
  WP_TILE_SYNC(); \
1761
2130
  } \
2131
+ wp::tile_alloc_shared(-shared_memory_size); \
1762
2132
  } while (0)
1763
2133
 
1764
2134
  #define tile_ifft tile_fft
@@ -1779,12 +2149,78 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
1779
2149
  tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
1780
2150
  } while (0)
1781
2151
 
2152
+ template <typename Fwd, typename TileA, typename TileL>
2153
+ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2154
+ {
2155
+ // Copy to L
2156
+ L = A;
2157
+
2158
+ // Call cholesky on L
2159
+ WP_TILE_SYNC();
2160
+
2161
+ fun_forward(L.data.ptr, TileL::Layout::Shape::dim(0));
2162
+
2163
+ WP_TILE_SYNC();
2164
+
2165
+ // Zero-out the upper triangular part of L
2166
+
2167
+ WP_PRAGMA_UNROLL
2168
+ for (int i=threadIdx.x; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2169
+ {
2170
+ auto c = TileL::Layout::coord_from_linear(i);
2171
+
2172
+ if(c[0] < c[1])
2173
+ L.data(c) = 0.0;
2174
+ }
2175
+
2176
+ WP_TILE_SYNC();
2177
+
2178
+ return L;
2179
+ }
2180
+
2181
+ #define adj_tile_cholesky(function_name, A, L, \
2182
+ adj_function_name, adj_A, adj_L, adj_ret) \
2183
+ do { \
2184
+ assert(false); \
2185
+ } while (0)
2186
+
2187
+ template <typename Fwd, typename TileL, typename TileX, typename TileY>
2188
+ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2189
+ {
2190
+ // Copy x to y
2191
+
2192
+ Y = X;
2193
+
2194
+ // Call cholesky solve on L & y
2195
+
2196
+ WP_TILE_SYNC();
2197
+
2198
+ fun_forward(L.data.ptr, Y.data.ptr); \
2199
+
2200
+ WP_TILE_SYNC();
2201
+
2202
+ return Y;
2203
+ }
2204
+
2205
+ #define adj_tile_cholesky_solve(function_name, L, X, Y, \
2206
+ adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \
2207
+ do { \
2208
+ assert(false); \
2209
+ } while (0)
1782
2210
 
1783
2211
  template <typename Tile>
1784
2212
  inline CUDA_CALLABLE auto tile_transpose(Tile& t)
1785
2213
  {
2214
+ static_assert(Tile::Layout::Shape::N == 2);
2215
+
1786
2216
  // alias incoming tile
1787
- return tile_shared_t<typename Tile::Type, Tile::N, Tile::M, Tile::StrideN, Tile::StrideM, false>(t.data.ptr, t.grad.ptr);
2217
+ constexpr int M = Tile::Layout::Shape::dim(0);
2218
+ constexpr int N = Tile::Layout::Shape::dim(1);
2219
+
2220
+ constexpr int StrideM = Tile::Layout::Stride::dim(0);
2221
+ constexpr int StrideN = Tile::Layout::Stride::dim(1);
2222
+
2223
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N,M>, tile_stride_t<StrideN, StrideM>>, false>(t.data.ptr, t.grad.ptr);
1788
2224
  }
1789
2225
 
1790
2226
  template <typename Tile, typename AdjTile>
@@ -1800,55 +2236,144 @@ template <int M, int N, int StrideM, int StrideN, typename Tile>
1800
2236
  inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
1801
2237
  {
1802
2238
  // alias incoming tile with new strides
1803
- return tile_shared_t<typename Tile::Type, M, N, StrideM, StrideN, false>(t.data.ptr, t.grad.ptr);
2239
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
1804
2240
  }
1805
2241
 
1806
2242
  template <typename Tile, typename AdjTile>
1807
2243
  inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
1808
2244
  {
1809
2245
  // nop, since memory is aliased grads already accumulated
2246
+ }
2247
+
2248
+ template <typename ReturnType, typename Tile, typename... Indices>
2249
+ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2250
+ {
2251
+ auto c = tile_coord(indices...);
2252
+
2253
+ // return new tile with same strides
2254
+ typename Tile::Type* data_ptr = &t.data(c);
2255
+ typename Tile::Type* grad_ptr = NULL;
2256
+
2257
+ if (t.grad.ptr)
2258
+ grad_ptr = &t.grad(c);
1810
2259
 
2260
+ return ReturnType(data_ptr, grad_ptr);
1811
2261
  }
1812
2262
 
1813
- template <int M, int N, typename Tile>
1814
- inline CUDA_CALLABLE auto tile_view(Tile& t, int i, int j)
1815
- {
1816
- // alias incoming tile with new strides
1817
- return tile_shared_t<typename Tile::Type, M, N, Tile::StrideM, Tile::StrideN, false>(&t.data(i, j), &t.grad(i, j));
2263
+
2264
+ template <typename TileA, typename Scalar>
2265
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
2266
+ {
2267
+ dest.data(tile_coord(i)) = src;
2268
+ WP_TILE_SYNC();
1818
2269
  }
1819
2270
 
1820
- template <typename Tile, typename AdjTile>
1821
- inline CUDA_CALLABLE void adj_tile_view(Tile& t, int i, int j, Tile& adj_t, int adj_i, int adj_j, AdjTile& adj_ret)
2271
+ template <typename TileA, typename Scalar>
2272
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
1822
2273
  {
1823
- // nop, since memory is aliased grads already accumulated
2274
+ dest.data(tile_coord(i, j)) = src;
2275
+ WP_TILE_SYNC();
2276
+ }
1824
2277
 
2278
+ template <typename TileA, typename Scalar>
2279
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
2280
+ {
2281
+ dest.data(tile_coord(i, j, k)) = src;
2282
+ WP_TILE_SYNC();
1825
2283
  }
1826
2284
 
1827
- template <typename TileA, typename TileB>
1828
- inline CUDA_CALLABLE void tile_assign(TileA& dest, int i, int j, TileB& src)
2285
+ template <typename TileA, typename Scalar>
2286
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
2287
+ {
2288
+ dest.data(tile_coord(i, j, k, l)) = src;
2289
+ WP_TILE_SYNC();
2290
+ }
2291
+
2292
+
2293
+
2294
+
2295
+ template <typename TileA, typename TileB, typename Coord>
2296
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offset)
1829
2297
  {
1830
- for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
2298
+ using Layout = typename TileB::Layout;
2299
+
2300
+ for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
1831
2301
  {
1832
- coord_t c = src.coord(t);
1833
- dest.data(i + c.i, j + c.j) = src.data(c.i, c.j);
2302
+ auto c = Layout::coord_from_linear(t);
2303
+ dest.data(c + offset) = src.data(c);
1834
2304
  }
1835
2305
 
1836
2306
  WP_TILE_SYNC();
1837
2307
  }
1838
2308
 
1839
- template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
1840
- inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, int i, int j, TileB& src,
1841
- AdjTileA& adj_dest, int adj_i, int adj_j, AdjTileB& adj_src)
2309
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename Coord, typename AdjCoord>
2310
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
2311
+ AdjTileA& adj_dest, AdjTileB& adj_src, AdjCoord adj_offset)
1842
2312
  {
1843
- for (int t=threadIdx.x; t < src.Size; t += WP_TILE_BLOCK_DIM)
2313
+ using Layout = typename TileB::Layout;
2314
+
2315
+ for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
1844
2316
  {
1845
- coord_t c = src.coord(t);
1846
- src.grad(c.i, c.j) += dest.grad(i + c.i, j + c.j);
2317
+ auto c = Layout::coord_from_linear(t);
2318
+ src.grad(c) += dest.grad(c + offset);
1847
2319
  }
1848
2320
 
1849
2321
  WP_TILE_SYNC();
1850
2322
  }
1851
2323
 
1852
2324
 
2325
+ // codegen entry points, which emit calls like `tile_assign(dest, src, i, j, k)`
2326
+ // a better approach here would be for codegen to just directly generate `tile_assign(dest, src, tile_coord(i, j, k))`
2327
+ // i.e.: call the above implementation methods directly, then we could remove these overloads
2328
+ template <typename TileA, typename TileB>
2329
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i) { tile_assign(dest, src, tile_coord(i)); }
2330
+ template <typename TileA, typename TileB>
2331
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j) { tile_assign(dest, src, tile_coord(i, j)); }
2332
+ template <typename TileA, typename TileB>
2333
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k) { tile_assign(dest, src, tile_coord(i, j, k)); }
2334
+ template <typename TileA, typename TileB>
2335
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l) { tile_assign(dest, src, tile_coord(i, j, k, l)); }
2336
+
2337
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2338
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, AdjTileA& adj_dest, AdjTileB& adj_src, int) { adj_tile_assign(dest, src, tile_coord(i), adj_dest, adj_src, tile_coord(0)); }
2339
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2340
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, AdjTileA& adj_dest, AdjTileB& adj_src, int, int) { adj_tile_assign(dest, src, tile_coord(i,j), adj_dest, adj_src, tile_coord(0)); }
2341
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2342
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k), adj_dest, adj_src, tile_coord(0)); }
2343
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2344
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k,l), adj_dest, adj_src, tile_coord(0)); }
2345
+
2346
+
2347
+ template <typename TileA, typename TileB, typename TileC>
2348
+ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
2349
+ {
2350
+ using ShapeA = typename TileA::Layout::Shape;
2351
+ using ShapeB = typename TileB::Layout::Shape;
2352
+ using ShapeC = typename TileC::Layout::Shape;
2353
+
2354
+ static_assert(ShapeA::dim(0) == ShapeA::dim(1));
2355
+ static_assert(ShapeB::dim(0) == ShapeA::dim(0));
2356
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2357
+ static_assert(ShapeC::dim(0) == ShapeC::dim(1));
2358
+
2359
+ c = a;
2360
+
2361
+ for (int t=threadIdx.x; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2362
+ {
2363
+ c.data(tile_coord(t, t)) += b.data(tile_coord(t));
2364
+ }
2365
+
2366
+ WP_TILE_SYNC();
2367
+
2368
+ return c;
2369
+ }
2370
+
2371
+ template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
2372
+ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
2373
+ {
2374
+ assert(false);
2375
+ }
2376
+
1853
2377
 
1854
2378
  } // namespace wp
2379
+